diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..e7c53aecf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.9 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/README.md b/README.md index 73f211d86..eb5bb7d98 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,19 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用 欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :) -对于新功能的添加,请先通过 Issue 讨论。 +### 如何贡献 + +你可以通过查看问题或帮助审核 PR(拉取请求)来贡献。任何问题或 PR 都欢迎参与,以促进社区贡献。当然,这些只是建议,你可以以任何方式进行贡献。对于新功能的添加,请先通过 Issue 讨论。 + +### 开发环境 + +AstrBot 使用 `ruff` 进行代码格式化和检查。 + +```bash +git clone https://github.com/Soulter/AstrBot +pip install pre-commit +pre-commit install +``` ## 🌟 支持 diff --git a/astrbot/__init__.py b/astrbot/__init__.py index 5dff017b0..73d64f303 100644 --- a/astrbot/__init__.py +++ b/astrbot/__init__.py @@ -1,2 +1,3 @@ from .core.log import LogManager -logger = LogManager.GetLogger(log_name='astrbot') \ No newline at end of file + +logger = LogManager.GetLogger(log_name="astrbot") diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index 090a26386..e8a9d23a9 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -4,10 +4,4 @@ from astrbot.core import html_renderer from astrbot.core import sp from astrbot.core.star.register import register_llm_tool as llm_tool -__all__ = [ - "AstrBotConfig", - "logger", - "html_renderer", - "llm_tool", - "sp" -] \ No newline at end of file +__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"] diff --git a/astrbot/api/all.py b/astrbot/api/all.py index 5c7046d35..2463dbc2b 100644 --- a/astrbot/api/all.py +++ b/astrbot/api/all.py @@ -1,4 +1,3 @@ - from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger from astrbot.core import html_renderer @@ -6,8 +5,11 @@ from astrbot.core.star.register import register_llm_tool as llm_tool # event from astrbot.core.message.message_event_result import ( - MessageEventResult, MessageChain, CommandResult, EventResultType -) + MessageEventResult, + MessageChain, + CommandResult, + EventResultType, +) from astrbot.core.platform import AstrMessageEvent # star register @@ -18,10 +20,16 @@ from astrbot.core.star.register import ( register_regex as regex, register_platform_adapter_type as platform_adapter_type, ) -from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType -from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType +from astrbot.core.star.filter.event_message_type import ( + EventMessageTypeFilter, + EventMessageType, +) +from astrbot.core.star.filter.platform_adapter_type import ( + PlatformAdapterTypeFilter, + PlatformAdapterType, +) from astrbot.core.star.register import ( - register_star as register # 注册插件(Star) + register_star as register, # 注册插件(Star) ) from astrbot.core.star import Context, Star from astrbot.core.star.config import * @@ -32,9 +40,14 @@ from astrbot.core.provider import Provider, Personality, ProviderMetaData # platform from astrbot.core.platform import ( - AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata + AstrMessageEvent, + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, ) from astrbot.core.platform.register import register_platform_adapter -from .message_components import * +from .message_components import * \ No newline at end of file diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index 646800f55..dd737e3ff 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -11,33 +11,39 @@ from astrbot.core.star.register import ( register_on_llm_response as on_llm_response, register_llm_tool as llm_tool, register_on_decorating_result as on_decorating_result, - register_after_message_sent as after_message_sent + register_after_message_sent as after_message_sent, ) -from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType -from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType +from astrbot.core.star.filter.event_message_type import ( + EventMessageTypeFilter, + EventMessageType, +) +from astrbot.core.star.filter.platform_adapter_type import ( + PlatformAdapterTypeFilter, + PlatformAdapterType, +) from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType from astrbot.core.star.filter.custom_filter import CustomFilter __all__ = [ - 'command', - 'command_group', - 'event_message_type', - 'regex', - 'platform_adapter_type', - 'permission_type', - 'EventMessageTypeFilter', - 'EventMessageType', - 'PlatformAdapterTypeFilter', - 'PlatformAdapterType', - 'PermissionTypeFilter', - 'CustomFilter', - 'custom_filter', - 'PermissionType', - 'on_astrbot_loaded', - 'on_llm_request', - 'llm_tool', - 'on_decorating_result', - 'after_message_sent', - 'on_llm_response' -] \ No newline at end of file + "command", + "command_group", + "event_message_type", + "regex", + "platform_adapter_type", + "permission_type", + "EventMessageTypeFilter", + "EventMessageType", + "PlatformAdapterTypeFilter", + "PlatformAdapterType", + "PermissionTypeFilter", + "CustomFilter", + "custom_filter", + "PermissionType", + "on_astrbot_loaded", + "on_llm_request", + "llm_tool", + "on_decorating_result", + "after_message_sent", + "on_llm_response", +] diff --git a/astrbot/api/message_components.py b/astrbot/api/message_components.py index 97d456b39..ff9add858 100644 --- a/astrbot/api/message_components.py +++ b/astrbot/api/message_components.py @@ -1 +1 @@ -from astrbot.core.message.components import * \ No newline at end of file +from astrbot.core.message.components import * diff --git a/astrbot/api/platform/__init__.py b/astrbot/api/platform/__init__.py index 6a6cc77ec..dcc02bb49 100644 --- a/astrbot/api/platform/__init__.py +++ b/astrbot/api/platform/__init__.py @@ -1,6 +1,21 @@ from astrbot.core.platform import ( - AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata + AstrMessageEvent, + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, ) from astrbot.core.platform.register import register_platform_adapter -from astrbot.core.message.components import * \ No newline at end of file +from astrbot.core.message.components import * + +__all__ = [ + "AstrMessageEvent", + "Platform", + "AstrBotMessage", + "MessageMember", + "MessageType", + "PlatformMetadata", + "register_platform_adapter", +] diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 0158d2a89..557273acd 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,2 +1,17 @@ from astrbot.core.provider import Provider, STTProvider, Personality -from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse \ No newline at end of file +from astrbot.core.provider.entites import ( + ProviderRequest, + ProviderType, + ProviderMetaData, + LLMResponse, +) + +__all__ = [ + "Provider", + "STTProvider", + "Personality", + "ProviderRequest", + "ProviderType", + "ProviderMetaData", + "LLMResponse", +] diff --git a/astrbot/api/star/__init__.py b/astrbot/api/star/__init__.py index dfe1eb77d..236cf70e6 100644 --- a/astrbot/api/star/__init__.py +++ b/astrbot/api/star/__init__.py @@ -1,6 +1,16 @@ from astrbot.core.star.register import ( - register_star as register # 注册插件(Star) + register_star as register, # 注册插件(Star) ) from astrbot.core.star import Context, Star from astrbot.core.star.config import * + +__all__ = [ + "register", + "Context", + "Star", + "DEFAULT_CONFIG", + "VERSION", + "DB_PATH", + "AstrBotConfig", +] diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 6e22ba94d..9749dee24 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,6 +1,6 @@ import os import asyncio -from .log import LogManager, LogBroker +from .log import LogManager, LogBroker # noqa from astrbot.core.utils.t2i.renderer import HtmlRenderer from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.utils.pip_installer import PipInstaller @@ -11,16 +11,16 @@ from astrbot.core.config import AstrBotConfig os.makedirs("data", exist_ok=True) astrbot_config = AstrBotConfig() -t2i_base_url = astrbot_config.get('t2i_endpoint', 'https://t2i.soulter.top/text2img') +t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") html_renderer = HtmlRenderer(t2i_base_url) -logger = LogManager.GetLogger(log_name='astrbot') +logger = LogManager.GetLogger(log_name="astrbot") + +if os.environ.get("TESTING", ""): + logger.setLevel("DEBUG") -if os.environ.get('TESTING', ""): - logger.setLevel('DEBUG') - db_helper = SQLiteDatabase(DB_PATH) -sp = SharedPreferences() # 简单的偏好设置存储 -pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', '')) +sp = SharedPreferences() # 简单的偏好设置存储 +pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", "")) web_chat_queue = asyncio.Queue(maxsize=32) web_chat_back_queue = asyncio.Queue(maxsize=32) WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" diff --git a/astrbot/core/config/__init__.py b/astrbot/core/config/__init__.py index 095d8c773..e49ac88a5 100644 --- a/astrbot/core/config/__init__.py +++ b/astrbot/core/config/__init__.py @@ -1,2 +1,9 @@ from .default import DEFAULT_CONFIG, VERSION, DB_PATH -from .astrbot_config import * \ No newline at end of file +from .astrbot_config import * + +__all__ = [ + "DEFAULT_CONFIG", + "VERSION", + "DB_PATH", + "AstrBotConfig", +] diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 8341d12c2..09e66ce17 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -8,79 +8,82 @@ from typing import Dict ASTRBOT_CONFIG_PATH = "data/cmd_config.json" logger = logging.getLogger("astrbot") + class RateLimitStrategy(enum.Enum): STALL = "stall" DISCARD = "discard" + class AstrBotConfig(dict): - '''从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。 - + """从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。 + - 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。 - 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。 - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 - ''' - + """ + def __init__( - self, - config_path: str = ASTRBOT_CONFIG_PATH, + self, + config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, - schema: dict = None + schema: dict = None, ): super().__init__() - + # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 - object.__setattr__(self, 'config_path', config_path) - object.__setattr__(self, 'default_config', default_config) - object.__setattr__(self, 'schema', schema) - + object.__setattr__(self, "config_path", config_path) + object.__setattr__(self, "default_config", default_config) + object.__setattr__(self, "schema", schema) + if schema: default_config = self._config_schema_to_default_config(schema) - + if not self.check_exist(): - '''不存在时载入默认配置''' + """不存在时载入默认配置""" with open(config_path, "w", encoding="utf-8-sig") as f: json.dump(default_config, f, indent=4, ensure_ascii=False) with open(config_path, "r", encoding="utf-8-sig") as f: conf_str = f.read() - if conf_str.startswith(u'/ufeff'): # remove BOM - conf_str = conf_str.encode('utf8')[3:].decode('utf8') + if conf_str.startswith("/ufeff"): # remove BOM + conf_str = conf_str.encode("utf8")[3:].decode("utf8") conf = json.loads(conf_str) - + # 检查配置完整性,并插入 has_new = self.check_config_integrity(default_config, conf) self.update(conf) if has_new: self.save_config() - + self.update(conf) def _config_schema_to_default_config(self, schema: dict) -> dict: - '''将 Schema 转换成 Config''' + """将 Schema 转换成 Config""" conf = {} - + def _parse_schema(schema: dict, conf: dict): for k, v in schema.items(): - if v['type'] not in DEFAULT_VALUE_MAP: - raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}") - if 'default' in v: - default = v['default'] + if v["type"] not in DEFAULT_VALUE_MAP: + raise TypeError( + f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}" + ) + if "default" in v: + default = v["default"] else: - default = DEFAULT_VALUE_MAP[v['type']] - - if v['type'] == 'object': + default = DEFAULT_VALUE_MAP[v["type"]] + + if v["type"] == "object": conf[k] = {} - _parse_schema(v['items'], conf[k]) + _parse_schema(v["items"], conf[k]) else: conf[k] = default - + _parse_schema(schema, conf) return conf - def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): - '''检查配置完整性,如果有新的配置项则返回 True''' + """检查配置完整性,如果有新的配置项则返回 True""" has_new = False for key, value in refer_conf.items(): if key not in conf: @@ -94,25 +97,27 @@ class AstrBotConfig(dict): conf[key] = value has_new = True elif isinstance(value, dict): - has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key) + has_new |= self.check_config_integrity( + value, conf[key], path + "." + key if path else key + ) return has_new - + def save_config(self, replace_config: Dict = None): - '''将配置写入文件 - + """将配置写入文件 + 如果传入 replace_config,则将配置替换为 replace_config - ''' + """ if replace_config: self.update(replace_config) with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) - + def __getattr__(self, item): try: return self[item] except KeyError: return None - + def __delattr__(self, key): try: del self[key] @@ -124,4 +129,4 @@ class AstrBotConfig(dict): self[key] = value def check_exist(self) -> bool: - return os.path.exists(self.config_path) \ No newline at end of file + return os.path.exists(self.config_path) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index d905219a4..6cba41142 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -6,14 +6,16 @@ from typing import Dict, List from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation -class ConversationManager(): - '''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。''' + +class ConversationManager: + """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" + def __init__(self, db_helper: BaseDatabase): self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 self._start_periodic_save() - + def _start_periodic_save(self): asyncio.create_task(self._periodic_save()) @@ -26,83 +28,83 @@ class ConversationManager(): sp.put("session_conversation", self.session_conversations) async def new_conversation(self, unified_msg_origin: str) -> str: - '''新建对话,并将当前会话的对话转移到新对话''' + """新建对话,并将当前会话的对话转移到新对话""" conversation_id = str(uuid.uuid4()) - self.db.new_conversation( - user_id=unified_msg_origin, - cid=conversation_id - ) + self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id) self.session_conversations[unified_msg_origin] = conversation_id sp.put("session_conversation", self.session_conversations) return conversation_id - + async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): - '''切换会话的对话''' + """切换会话的对话""" self.session_conversations[unified_msg_origin] = conversation_id sp.put("session_conversation", self.session_conversations) - - async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None): - '''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话''' + + async def delete_conversation( + self, unified_msg_origin: str, conversation_id: str = None + ): + """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话""" conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: - self.db.delete_conversation( - user_id=unified_msg_origin, - cid=conversation_id - ) + self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id) del self.session_conversations[unified_msg_origin] sp.put("session_conversation", self.session_conversations) - + async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: - '''获取会话当前的对话 ID''' + """获取会话当前的对话 ID""" return self.session_conversations.get(unified_msg_origin, None) - - async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation: - '''获取会话的对话''' + + async def get_conversation( + self, unified_msg_origin: str, conversation_id: str + ) -> Conversation: + """获取会话的对话""" return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) - + async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: - '''获取会话的所有对话''' + """获取会话的所有对话""" return self.db.get_conversations(unified_msg_origin) - - async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]): - '''更新会话的对话''' + + async def update_conversation( + self, unified_msg_origin: str, conversation_id: str, history: List[Dict] + ): + """更新会话的对话""" if conversation_id: self.db.update_conversation( user_id=unified_msg_origin, cid=conversation_id, - history=json.dumps(history) + history=json.dumps(history), ) - + async def update_conversation_title(self, unified_msg_origin: str, title: str): - '''更新会话的对话标题''' + """更新会话的对话标题""" conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: self.db.update_conversation_title( - user_id=unified_msg_origin, - cid=conversation_id, - title=title + user_id=unified_msg_origin, cid=conversation_id, title=title ) - - async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str): - '''更新会话的对话 Persona ID''' + + async def update_conversation_persona_id( + self, unified_msg_origin: str, persona_id: str + ): + """更新会话的对话 Persona ID""" conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: self.db.update_conversation_persona_id( - user_id=unified_msg_origin, - cid=conversation_id, - persona_id=persona_id + user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id ) - - async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10): + + async def get_human_readable_context( + self, unified_msg_origin, conversation_id, page=1, page_size=10 + ): conversation = await self.get_conversation(unified_msg_origin, conversation_id) history = json.loads(conversation.history) contexts = [] temp_contexts = [] for record in history: - if record['role'] == "user": + if record["role"] == "user": temp_contexts.append(f"User: {record['content']}") - elif record['role'] == "assistant": + elif record["role"] == "assistant": temp_contexts.append(f"Assistant: {record['content']}") contexts.insert(0, temp_contexts) temp_contexts = [] @@ -111,9 +113,9 @@ class ConversationManager(): contexts = [item for sublist in contexts for item in sublist] # 计算分页 - paged_contexts = contexts[(page-1)*page_size:page*page_size] + paged_contexts = contexts[(page - 1) * page_size : page * page_size] total_pages = len(contexts) // page_size if len(contexts) % page_size != 0: total_pages += 1 - - return paged_contexts, total_pages \ No newline at end of file + + return paged_contexts, total_pages diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 924dd4137..e5485edb4 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,109 +22,119 @@ from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_map + class AstrBotCoreLifecycle: def __init__(self, log_broker: LogBroker, db: BaseDatabase): self.log_broker = log_broker self.astrbot_config = astrbot_config self.db = db - - os.environ['https_proxy'] = self.astrbot_config['http_proxy'] - os.environ['http_proxy'] = self.astrbot_config['http_proxy'] - os.environ['no_proxy'] = 'localhost' - + + os.environ["https_proxy"] = self.astrbot_config["http_proxy"] + os.environ["http_proxy"] = self.astrbot_config["http_proxy"] + os.environ["no_proxy"] = "localhost" + async def initialize(self): - logger.info("AstrBot v"+ VERSION) + logger.info("AstrBot v" + VERSION) if os.environ.get("TESTING", ""): logger.setLevel("DEBUG") else: - logger.setLevel(self.astrbot_config['log_level']) + logger.setLevel(self.astrbot_config["log_level"]) self.event_queue = Queue() self.event_queue.closed = False - + self.provider_manager = ProviderManager(self.astrbot_config, self.db) - + self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) - + self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config) - + self.conversation_manager = ConversationManager(self.db) - + self.star_context = Context( - self.event_queue, - self.astrbot_config, + self.event_queue, + self.astrbot_config, self.db, self.provider_manager, self.platform_manager, self.conversation_manager, - self.knowledge_db_manager + self.knowledge_db_manager, ) self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) - + await self.plugin_manager.reload() - '''扫描、注册插件、实例化插件类''' - + """扫描、注册插件、实例化插件类""" + await self.provider_manager.initialize() - '''根据配置实例化各个 Provider''' - - self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager)) + """根据配置实例化各个 Provider""" + + self.pipeline_scheduler = PipelineScheduler( + PipelineContext(self.astrbot_config, self.plugin_manager) + ) await self.pipeline_scheduler.initialize() - '''初始化消息事件流水线调度器''' - - self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror']) + """初始化消息事件流水线调度器""" + + self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"]) self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) self.start_time = int(time.time()) self.curr_tasks: List[asyncio.Task] = [] - + await self.platform_manager.initialize() - '''根据配置实例化各个平台适配器''' - + """根据配置实例化各个平台适配器""" + def _load(self): - event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus") - + event_bus_task = asyncio.create_task( + self.event_bus.dispatch(), name="event_bus" + ) + extra_tasks = [] for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) - + tasks_ = [event_bus_task, *extra_tasks] for task in tasks_: - self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name())) - + self.curr_tasks.append( + asyncio.create_task(self._task_wrapper(task), name=task.get_name()) + ) + self.start_time = int(time.time()) - + async def _task_wrapper(self, task: asyncio.Task): try: await task except asyncio.CancelledError: pass except Exception as e: - logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") for line in traceback.format_exc().split("\n"): logger.error(f"| {line}") logger.error("-------") - + async def start(self): self._load() logger.info("AstrBot 启动完成。") - + # 执行启动完成事件钩子 - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAstrBotLoadedEvent) + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnAstrBotLoadedEvent + ) for handler in handlers: try: - logger.info(f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") + logger.info( + f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) await handler.handler() except BaseException: logger.error(traceback.format_exc()) - + await asyncio.gather(*self.curr_tasks, return_exceptions=True) - + async def stop(self): self.event_queue.closed = True for task in self.curr_tasks: task.cancel() - + await self.provider_manager.terminate() - + for task in self.curr_tasks: try: await task @@ -132,14 +142,18 @@ class AstrBotCoreLifecycle: pass except Exception as e: logger.error(f"任务 {task.get_name()} 发生错误: {e}") - + def restart(self): self.event_queue.closed = True - threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start() - + threading.Thread( + target=self.astrbot_updator._reboot, name="restart", daemon=True + ).start() + def load_platform(self) -> List[asyncio.Task]: tasks = [] platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: - tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)) - return tasks \ No newline at end of file + tasks.append( + asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name) + ) + return tasks diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 03474ecbf..494927014 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -3,111 +3,117 @@ from dataclasses import dataclass from typing import List from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation + @dataclass class BaseDatabase(abc.ABC): - ''' + """ 数据库基类 - ''' + """ + def __init__(self) -> None: pass - + def insert_base_metrics(self, metrics: dict): - '''插入基础指标数据''' - self.insert_platform_metrics(metrics['platform_stats']) - self.insert_plugin_metrics(metrics['plugin_stats']) - self.insert_command_metrics(metrics['command_stats']) - self.insert_llm_metrics(metrics['llm_stats']) - + """插入基础指标数据""" + self.insert_platform_metrics(metrics["platform_stats"]) + self.insert_plugin_metrics(metrics["plugin_stats"]) + self.insert_command_metrics(metrics["command_stats"]) + self.insert_llm_metrics(metrics["llm_stats"]) + @abc.abstractmethod def insert_platform_metrics(self, metrics: dict): - '''插入平台指标数据''' + """插入平台指标数据""" raise NotImplementedError - + @abc.abstractmethod def insert_plugin_metrics(self, metrics: dict): - '''插入插件指标数据''' + """插入插件指标数据""" raise NotImplementedError - + @abc.abstractmethod def insert_command_metrics(self, metrics: dict): - '''插入指令指标数据''' + """插入指令指标数据""" raise NotImplementedError - + @abc.abstractmethod def insert_llm_metrics(self, metrics: dict): - '''插入 LLM 指标数据''' + """插入 LLM 指标数据""" raise NotImplementedError - + @abc.abstractmethod def update_llm_history(self, session_id: str, content: str, provider_type: str): - '''更新 LLM 历史记录。当不存在 session_id 时插入''' + """更新 LLM 历史记录。当不存在 session_id 时插入""" raise NotImplementedError - + @abc.abstractmethod - def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]: - '''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有''' + def get_llm_history( + self, session_id: str = None, provider_type: str = None + ) -> List[LLMHistory]: + """获取 LLM 历史记录, 如果 session_id 为 None, 返回所有""" raise NotImplementedError - + @abc.abstractmethod def get_base_stats(self, offset_sec: int = 86400) -> Stats: - '''获取基础统计数据''' + """获取基础统计数据""" raise NotImplementedError - + @abc.abstractmethod def get_total_message_count(self) -> int: - '''获取总消息数''' + """获取总消息数""" raise NotImplementedError - + @abc.abstractmethod def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: - '''获取基础统计数据(合并)''' + """获取基础统计数据(合并)""" raise NotImplementedError @abc.abstractmethod def insert_atri_vision_data(self, vision_data: ATRIVision): - '''插入 ATRI 视觉数据''' + """插入 ATRI 视觉数据""" raise NotImplementedError - + @abc.abstractmethod def get_atri_vision_data(self) -> List[ATRIVision]: - '''获取 ATRI 视觉数据''' + """获取 ATRI 视觉数据""" raise NotImplementedError - + @abc.abstractmethod - def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision: - '''通过 url 或 path 获取 ATRI 视觉数据''' + def get_atri_vision_data_by_path_or_id( + self, url_or_path: str, id: str + ) -> ATRIVision: + """通过 url 或 path 获取 ATRI 视觉数据""" raise NotImplementedError - + @abc.abstractmethod def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: - '''通过 user_id 和 cid 获取 Conversation''' + """通过 user_id 和 cid 获取 Conversation""" raise NotImplementedError - + @abc.abstractmethod def new_conversation(self, user_id: str, cid: str): - '''新建 Conversation''' + """新建 Conversation""" raise NotImplementedError - + @abc.abstractmethod def get_conversations(self, user_id: str) -> List[Conversation]: raise NotImplementedError @abc.abstractmethod def update_conversation(self, user_id: str, cid: str, history: str): - '''更新 Conversation''' + """更新 Conversation""" raise NotImplementedError - + @abc.abstractmethod def delete_conversation(self, user_id: str, cid: str): - '''删除 Conversation''' + """删除 Conversation""" raise NotImplementedError - + @abc.abstractmethod def update_conversation_title(self, user_id: str, cid: str, title: str): - '''更新 Conversation 标题''' + """更新 Conversation 标题""" raise NotImplementedError - + @abc.abstractmethod def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): - '''更新 Conversation Persona ID''' - raise NotImplementedError \ No newline at end of file + """更新 Conversation Persona ID""" + raise NotImplementedError diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index c905a50ba..59041d6dd 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,48 +1,57 @@ -'''指标数据''' +"""指标数据""" from dataclasses import dataclass, field from typing import List + @dataclass -class Platform(): - name: str - count: int - timestamp: int - -@dataclass -class Provider(): - name: str - count: int - timestamp: int - -@dataclass -class Plugin(): - name: str - count: int - timestamp: int - -@dataclass -class Command(): +class Platform: name: str count: int timestamp: int + @dataclass -class Stats(): +class Provider: + name: str + count: int + timestamp: int + + +@dataclass +class Plugin: + name: str + count: int + timestamp: int + + +@dataclass +class Command: + name: str + count: int + timestamp: int + + +@dataclass +class Stats: platform: List[Platform] = field(default_factory=list) command: List[Command] = field(default_factory=list) llm: List[Provider] = field(default_factory=list) - + + @dataclass -class LLMHistory(): - '''LLM 聊天时持久化的信息''' +class LLMHistory: + """LLM 聊天时持久化的信息""" + provider_type: str session_id: str content: str - + + @dataclass -class ATRIVision(): - '''Deprecated''' +class ATRIVision: + """Deprecated""" + id: str url_or_path: str caption: str @@ -52,19 +61,21 @@ class ATRIVision(): session_id: str sender_nickname: str timestamp: int = -1 - + + @dataclass -class Conversation(): - '''LLM 对话存储 - +class Conversation: + """LLM 对话存储 + 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 - ''' + """ + user_id: str cid: str history: str = "" - '''字符串格式的列表。''' + """字符串格式的列表。""" created_at: int = 0 updated_at: int = 0 title: str = "" - persona_id: str = "" \ No newline at end of file + persona_id: str = "" diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index fecd251df..e41fc3b38 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,13 +1,7 @@ import sqlite3 import os import time -from astrbot.core.db.po import ( - Platform, - Stats, - LLMHistory, - ATRIVision, - Conversation -) +from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation from . import BaseDatabase from typing import Tuple @@ -16,21 +10,21 @@ class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: super().__init__() self.db_path = db_path - + with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f: sql = f.read() - + # 初始化数据库 self.conn = self._get_conn(self.db_path) c = self.conn.cursor() c.executescript(sql) self.conn.commit() - + # 检查 webchat_conversation 的 title 字段是否存在 c.execute( - ''' + """ PRAGMA table_info(webchat_conversation) - ''' + """ ) res = c.fetchall() has_title = False @@ -42,26 +36,26 @@ class SQLiteDatabase(BaseDatabase): has_persona_id = True if not has_title: c.execute( - ''' + """ ALTER TABLE webchat_conversation ADD COLUMN title TEXT; - ''' + """ ) self.conn.commit() if not has_persona_id: c.execute( - ''' + """ ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; - ''' + """ ) self.conn.commit() - + c.close() - + def _get_conn(self, db_path: str) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) conn.text_factory = str return conn - + def _exec_sql(self, sql: str, params: Tuple = None): conn = self.conn try: @@ -69,22 +63,23 @@ class SQLiteDatabase(BaseDatabase): except sqlite3.ProgrammingError: conn = self._get_conn(self.db_path) c = conn.cursor() - + if params: c.execute(sql, params) c.close() else: c.execute(sql) c.close() - + conn.commit() - + def insert_platform_metrics(self, metrics: dict): for k, v in metrics.items(): self._exec_sql( - ''' + """ INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) - ''', (k, v, int(time.time())) + """, + (k, v, int(time.time())), ) def insert_plugin_metrics(self, metrics: dict): @@ -93,40 +88,46 @@ class SQLiteDatabase(BaseDatabase): def insert_command_metrics(self, metrics: dict): for k, v in metrics.items(): self._exec_sql( - ''' + """ INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?) - ''', (k, v, int(time.time())) + """, + (k, v, int(time.time())), ) def insert_llm_metrics(self, metrics: dict): for k, v in metrics.items(): self._exec_sql( - ''' + """ INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) - ''', (k, v, int(time.time())) + """, + (k, v, int(time.time())), ) def update_llm_history(self, session_id: str, content: str, provider_type: str): res = self.get_llm_history(session_id, provider_type) if res: self._exec_sql( - ''' + """ UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ? - ''', (content, session_id, provider_type) + """, + (content, session_id, provider_type), ) else: self._exec_sql( - ''' + """ INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?) - ''', (provider_type, session_id, content) + """, + (provider_type, session_id, content), ) - def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple: + def get_llm_history( + self, session_id: str = None, provider_type: str = None + ) -> Tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + where_clause = "" if session_id or provider_type: where_clause += " WHERE " @@ -138,11 +139,12 @@ class SQLiteDatabase(BaseDatabase): if has: where_clause += " AND " where_clause += f"provider_type = '{provider_type}'" - + c.execute( - ''' + """ SELECT * FROM llm_history - ''' + where_clause + """ + + where_clause ) res = c.fetchall() histories = [] @@ -152,129 +154,134 @@ class SQLiteDatabase(BaseDatabase): return histories def get_base_stats(self, offset_sec: int = 86400) -> Stats: - '''获取 offset_sec 秒前到现在的基础统计数据''' + """获取 offset_sec 秒前到现在的基础统计数据""" where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" - + try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT * FROM platform - ''' + where_clause + """ + + where_clause ) - + platform = [] for row in c.fetchall(): platform.append(Platform(*row)) - + # c.execute( # ''' # SELECT * FROM command # ''' + where_clause # ) - + # command = [] # for row in c.fetchall(): # command.append(Command(*row)) - + # c.execute( # ''' # SELECT * FROM llm # ''' + where_clause # ) - + # llm = [] # for row in c.fetchall(): # llm.append(Provider(*row)) - + c.close() - + return Stats(platform, [], []) - + def get_total_message_count(self) -> int: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT SUM(count) FROM platform - ''' + """ ) res = c.fetchone() c.close() return res[0] - + def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: - '''获取 offset_sec 秒前到现在的基础统计数据(合并)''' + """获取 offset_sec 秒前到现在的基础统计数据(合并)""" where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" - + try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT name, SUM(count), timestamp FROM platform - ''' + where_clause + " GROUP BY name" + """ + + where_clause + + " GROUP BY name" ) - + platform = [] for row in c.fetchall(): platform.append(Platform(*row)) - + c.close() - + return Stats(platform, [], []) - - + def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? - ''', (user_id, cid) + """, + (user_id, cid), ) - + res = c.fetchone() c.close() - + if not res: return - + return Conversation(*res) - + def new_conversation(self, user_id: str, cid: str): history = "[]" updated_at = int(time.time()) created_at = updated_at self._exec_sql( - ''' + """ INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) - ''', (user_id, cid, history, updated_at, created_at) + """, + (user_id, cid, history, updated_at, created_at), ) - + def get_conversations(self, user_id: str) -> Tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC - ''', (user_id,) + """, + (user_id,), ) - + res = c.fetchall() c.close() conversations = [] @@ -284,82 +291,101 @@ class SQLiteDatabase(BaseDatabase): updated_at = row[2] title = row[3] persona_id = row[4] - conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id)) + conversations.append( + Conversation("", cid, "[]", created_at, updated_at, title, persona_id) + ) return conversations - + def update_conversation(self, user_id: str, cid: str, history: str): - '''更新对话,并且同时更新时间''' + """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( - ''' + """ UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? - ''', (history, updated_at, user_id, cid) + """, + (history, updated_at, user_id, cid), ) - - + def update_conversation_title(self, user_id: str, cid: str, title: str): self._exec_sql( - ''' + """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? - ''', (title, user_id, cid) + """, + (title, user_id, cid), ) - + def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): self._exec_sql( - ''' + """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? - ''', (persona_id, user_id, cid) + """, + (persona_id, user_id, cid), ) - + def delete_conversation(self, user_id: str, cid: str): self._exec_sql( - ''' + """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? - ''', (user_id, cid) + """, + (user_id, cid), ) def insert_atri_vision_data(self, vision: ATRIVision): ts = int(time.time()) keywords = ",".join(vision.keywords) self._exec_sql( - ''' + """ INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts) + """, + ( + vision.id, + vision.url_or_path, + vision.caption, + vision.is_meme, + keywords, + vision.platform_name, + vision.session_id, + vision.sender_nickname, + ts, + ), ) - + def get_atri_vision_data(self) -> Tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT * FROM atri_vision - ''' + """ ) - + res = c.fetchall() visions = [] for row in res: visions.append(ATRIVision(*row)) c.close() return visions - - def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision: + + def get_atri_vision_data_by_path_or_id( + self, url_or_path: str, id: str + ) -> ATRIVision: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - + c.execute( - ''' + """ SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ? - ''', (url_or_path, id) + """, + (url_or_path, id), ) - + res = c.fetchone() c.close() if res: return ATRIVision(*res) - return None \ No newline at end of file + return None diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index d6039c15f..2688bd400 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -4,6 +4,7 @@ from astrbot.core.pipeline.scheduler import PipelineScheduler from astrbot.core import logger from .platform import AstrMessageEvent + class EventBus: def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler): self.event_queue = event_queue @@ -14,9 +15,13 @@ class EventBus: event: AstrMessageEvent = await self.event_queue.get() self._print_event(event) asyncio.create_task(self.pipeline_scheduler.execute(event)) - - def _print_event(self, event: AstrMessageEvent): + + def _print_event(self, event: AstrMessageEvent): if event.get_sender_name(): - logger.info(f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}") + logger.info( + f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" + ) else: - logger.info(f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}") \ No newline at end of file + logger.info( + f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}" + ) diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 0ab5fe852..e578c2f1f 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -7,31 +7,35 @@ from typing import List CACHED_SIZE = 200 log_color_config = { - 'DEBUG': 'green', 'INFO': 'bold_cyan', - 'WARNING': 'bold_yellow', 'ERROR': 'red', - 'CRITICAL': 'bold_red', 'RESET': 'reset', - 'asctime': 'green' + "DEBUG": "green", + "INFO": "bold_cyan", + "WARNING": "bold_yellow", + "ERROR": "red", + "CRITICAL": "bold_red", + "RESET": "reset", + "asctime": "green", } + class LogBroker: def __init__(self): self.log_cache = deque(maxlen=CACHED_SIZE) self.subscribers: List[Queue] = [] - + def register(self) -> Queue: - '''给每个订阅者返回一个带有日志缓存的队列''' + """给每个订阅者返回一个带有日志缓存的队列""" q = Queue(maxsize=CACHED_SIZE + 10) for log in self.log_cache: q.put_nowait(log) self.subscribers.append(q) return q - + def unregister(self, q: Queue): - '''取消订阅''' + """取消订阅""" self.subscribers.remove(q) - + def publish(self, log_entry: str): - '''发布消息''' + """发布消息""" self.log_cache.append(log_entry) for q in self.subscribers: try: @@ -39,6 +43,7 @@ class LogBroker: except asyncio.QueueFull: pass + class LogQueueHandler(logging.Handler): def __init__(self, log_broker: LogBroker): super().__init__() @@ -48,26 +53,26 @@ class LogQueueHandler(logging.Handler): log_entry = self.format(record) self.log_broker.publish(log_entry) -class LogManager: +class LogManager: @classmethod - def GetLogger(cls, log_name: str = 'default'): + def GetLogger(cls, log_name: str = "default"): logger = logging.getLogger(log_name) if logger.hasHandlers(): return logger console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) console_formatter = colorlog.ColoredFormatter( - fmt='%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s', - datefmt='%H:%M:%S', - log_colors=log_color_config + fmt="%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s", + datefmt="%H:%M:%S", + log_colors=log_color_config, ) console_handler.setFormatter(console_formatter) logger.setLevel(logging.DEBUG) logger.addHandler(console_handler) - + return logger - + @classmethod def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): handler = LogQueueHandler(log_broker) @@ -75,5 +80,9 @@ class LogManager: if logger.handlers: handler.setFormatter(logger.handlers[0].formatter) else: - handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - logger.addHandler(handler) \ No newline at end of file + handler.setFormatter( + logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ) + logger.addHandler(handler) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 477e7eeec..d2f48d388 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -1,4 +1,4 @@ -''' +""" MIT License Copyright (c) 2021 Lxns-Network @@ -20,7 +20,7 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import base64 import json @@ -29,20 +29,21 @@ import typing as T from enum import Enum from pydantic.v1 import BaseModel + class ComponentType(Enum): - Plain = "Plain" # 纯文本消息 - Face = "Face" # QQ表情 - Record = "Record" # 语音 - Video = "Video" # 视频 - At = "At" # At - Node = "Node" # 转发消息的一个节点 - Nodes = "Nodes" # 转发消息的多个节点 - Poke = "Poke" # QQ 戳一戳 - Image = "Image" # 图片 - Reply = "Reply" # 回复 - Forward = "Forward" # 转发消息 - File = "File" # 文件 - + Plain = "Plain" # 纯文本消息 + Face = "Face" # QQ表情 + Record = "Record" # 语音 + Video = "Video" # 视频 + At = "At" # At + Node = "Node" # 转发消息的一个节点 + Nodes = "Nodes" # 转发消息的多个节点 + Poke = "Poke" # QQ 戳一戳 + Image = "Image" # 图片 + Reply = "Reply" # 回复 + Forward = "Forward" # 转发消息 + File = "File" # 文件 + RPS = "RPS" # TODO Dice = "Dice" # TODO Shake = "Shake" # TODO @@ -71,10 +72,14 @@ class BaseMessageComponent(BaseModel): k = "type" if isinstance(v, bool): v = 1 if v else 0 - output += ",%s=%s" % (k, str(v).replace("&", "&") \ - .replace(",", ",") \ - .replace("[", "[") \ - .replace("]", "]")) + output += ",%s=%s" % ( + k, + str(v) + .replace("&", "&") + .replace(",", ",") + .replace("[", "[") + .replace("]", "]"), + ) output += "]" return output @@ -86,10 +91,7 @@ class BaseMessageComponent(BaseModel): if k == "_type": k = "type" data[k] = v - return { - "type": self.type.lower(), - "data": data - } + return {"type": self.type.lower(), "data": data} class Plain(BaseMessageComponent): @@ -103,9 +105,9 @@ class Plain(BaseMessageComponent): def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本 if not self.convert: return self.text - return self.text.replace("&", "&") \ - .replace("[", "[") \ - .replace("]", "]") + return ( + self.text.replace("&", "&").replace("[", "[").replace("]", "]") + ) class Face(BaseMessageComponent): @@ -274,7 +276,7 @@ class Image(BaseMessageComponent): c: T.Optional[int] = 2 # 额外 path: T.Optional[str] = "" - file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 + file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: T.Optional[str], **_): # for k in _.keys(): @@ -343,14 +345,16 @@ class Forward(BaseMessageComponent): def __init__(self, **_): super().__init__(**_) + class Node(BaseMessageComponent): - '''群合并转发消息''' + """群合并转发消息""" + type: ComponentType = "Node" - id: T.Optional[int] = 0 # 忽略 - name: T.Optional[str] = "" # qq昵称 - uin: T.Optional[int] = 0 # qq号 - content: T.Optional[T.Union[str, list]] = "" # 子消息段列表 - seq: T.Optional[T.Union[str, list]] = "" # 忽略 + id: T.Optional[int] = 0 # 忽略 + name: T.Optional[str] = "" # qq昵称 + uin: T.Optional[int] = 0 # qq号 + content: T.Optional[T.Union[str, list]] = "" # 子消息段列表 + seq: T.Optional[T.Union[str, list]] = "" # 忽略 time: T.Optional[int] = 0 def __init__(self, content: T.Union[str, list], **_): @@ -364,25 +368,24 @@ class Node(BaseMessageComponent): def toString(self): # logger.warn("Protocol: node doesn't support stringify") return "" - + + class Nodes(BaseMessageComponent): type: ComponentType = "Nodes" nodes: T.List[Node] - + def __init__(self, nodes: T.List[Node], **_): super().__init__(nodes=nodes, **_) - + def toDict(self): - return { - "messages": [node.toDict() for node in self.nodes] - } + return {"messages": [node.toDict() for node in self.nodes]} class Xml(BaseMessageComponent): type: ComponentType = "Xml" data: str resid: T.Optional[int] = 0 - + def __init__(self, **_): super().__init__(**_) @@ -432,14 +435,16 @@ class Unknown(BaseMessageComponent): def toString(self): return "" + class File(BaseMessageComponent): - ''' + """ 目前此消息段只适配了 Napcat。 - ''' + """ + type: ComponentType = "File" - name: T.Optional[str] = "" # 名字 - file: T.Optional[str] = "" # url(本地路径) - + name: T.Optional[str] = "" # 名字 + file: T.Optional[str] = "" # url(本地路径) + def __init__(self, name: str, file: str): super().__init__(name=name, file=file) @@ -471,5 +476,5 @@ ComponentTypes = { "cardimage": CardImage, "tts": TTS, "unknown": Unknown, - 'file': File, + "file": File, } diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 33eae64bc..89aff17a8 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -5,145 +5,151 @@ from dataclasses import dataclass, field from astrbot.core.message.components import BaseMessageComponent, Plain, Image from typing_extensions import deprecated + @dataclass -class MessageChain(): - '''MessageChain 描述了一整条消息中带有的所有组件。 +class MessageChain: + """MessageChain 描述了一整条消息中带有的所有组件。 现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。 - + Attributes: `chain` (list): 用于顺序存储各个组件。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 - ''' - + """ + chain: List[BaseMessageComponent] = field(default_factory=list) - use_t2i_: Optional[bool] = None # None 为跟随用户设置 - + use_t2i_: Optional[bool] = None # None 为跟随用户设置 + def message(self, message: str): - '''添加一条文本消息到消息链 `chain` 中。 - + """添加一条文本消息到消息链 `chain` 中。 + Example: CommandResult().message("Hello ").message("world!") # 输出 Hello world! - ''' + """ self.chain.append(Plain(message)) return self - + @deprecated("请使用 message 方法代替。") def error(self, message: str): - '''添加一条错误消息到消息链 `chain` 中 - + """添加一条错误消息到消息链 `chain` 中 + Example: - + CommandResult().error("解析失败") - - ''' + + """ self.chain.append(Plain(message)) return self - + def url_image(self, url: str): - '''添加一条图片消息(https 链接)到消息链 `chain` 中。 - + """添加一条图片消息(https 链接)到消息链 `chain` 中。 + Note: 如果需要发送本地图片,请使用 `file_image` 方法。 - + Example: - + CommandResult().image("https://example.com/image.jpg") - - ''' + + """ self.chain.append(Image.fromURL(url)) return self - + def file_image(self, path: str): - '''添加一条图片消息(本地文件路径)到消息链 `chain` 中。 - + """添加一条图片消息(本地文件路径)到消息链 `chain` 中。 + Note: 如果需要发送网络图片,请使用 `url_image` 方法。 - + CommandResult().image("image.jpg") - ''' + """ self.chain.append(Image.fromFileSystem(path)) return self - + def use_t2i(self, use_t2i: bool): - '''设置是否使用文本转图片服务。 - + """设置是否使用文本转图片服务。 + Args: use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 - ''' + """ self.use_t2i_ = use_t2i return self + class EventResultType(enum.Enum): - '''用于描述事件处理的结果类型。 - + """用于描述事件处理的结果类型。 + Attributes: CONTINUE: 事件将会继续传播 STOP: 事件将会终止传播 - ''' + """ + CONTINUE = enum.auto() STOP = enum.auto() - + + class ResultContentType(enum.Enum): - '''用于描述事件结果的内容的类型。 - ''' + """用于描述事件结果的内容的类型。""" + LLM_RESULT = enum.auto() - '''调用 LLM 产生的结果''' + """调用 LLM 产生的结果""" GENERAL_RESULT = enum.auto() - '''普通的消息结果''' + """普通的消息结果""" + + @dataclass class MessageEventResult(MessageChain): - '''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。 + """MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。 现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。 - + Attributes: `chain` (list): 用于顺序存储各个组件。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 `result_type` (EventResultType): 事件处理的结果类型。 - ''' - - result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE) - - result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT) - - def stop_event(self) -> 'MessageEventResult': - '''终止事件传播。 - ''' + """ + + result_type: Optional[EventResultType] = field( + default_factory=lambda: EventResultType.CONTINUE + ) + + result_content_type: Optional[ResultContentType] = field( + default_factory=lambda: ResultContentType.GENERAL_RESULT + ) + + def stop_event(self) -> "MessageEventResult": + """终止事件传播。""" self.result_type = EventResultType.STOP return self - - def continue_event(self) -> 'MessageEventResult': - '''继续事件传播。 - ''' + + def continue_event(self) -> "MessageEventResult": + """继续事件传播。""" self.result_type = EventResultType.CONTINUE return self - + def is_stopped(self) -> bool: - ''' + """ 是否终止事件传播。 - ''' + """ return self.result_type == EventResultType.STOP - - def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult': - '''设置事件处理的结果类型。 - + + def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": + """设置事件处理的结果类型。 + Args: result_type (EventResultType): 事件处理的结果类型。 - ''' + """ self.result_content_type = typ return self - + def is_llm_result(self) -> bool: - '''是否为 LLM 结果。 - ''' + """是否为 LLM 结果。""" return self.result_content_type == ResultContentType.LLM_RESULT - + def get_plain_text(self) -> str: - '''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。 - ''' + """获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - - -CommandResult = MessageEventResult \ No newline at end of file + + +CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index c2dd62054..76844f6fd 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -1,4 +1,7 @@ -from astrbot.core.message.message_event_result import MessageEventResult, EventResultType +from astrbot.core.message.message_event_result import ( + MessageEventResult, + EventResultType, +) from .waking_check.stage import WakingCheckStage from .whitelist_check.stage import WhitelistCheckStage @@ -10,14 +13,14 @@ from .result_decorate.stage import ResultDecorateStage from .respond.stage import RespondStage STAGES_ORDER = [ - "WakingCheckStage", # 检查是否需要唤醒 - "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 - "RateLimitStage", # 检查会话是否超过频率限制 - "ContentSafetyCheckStage", # 检查内容安全 - "PreProcessStage", # 预处理 - "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 - "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 - "RespondStage" # 发送消息 + "WakingCheckStage", # 检查是否需要唤醒 + "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 + "RateLimitStage", # 检查会话是否超过频率限制 + "ContentSafetyCheckStage", # 检查内容安全 + "PreProcessStage", # 预处理 + "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 + "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 + "RespondStage", # 发送消息 ] __all__ = [ @@ -30,5 +33,5 @@ __all__ = [ "ResultDecorateStage", "RespondStage", "MessageEventResult", - "EventResultType" -] \ No newline at end of file + "EventResultType", +] diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index 7f781a1e7..bafef1b05 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -6,25 +6,32 @@ from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core import logger from .strategies.strategy import StrategySelector + @register_stage class ContentSafetyCheckStage(Stage): - '''检查内容安全 - + """检查内容安全 + 当前只会检查文本的。 - ''' + """ async def initialize(self, ctx: PipelineContext): - config = ctx.astrbot_config['content_safety'] + config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) - async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]: - '''检查内容安全''' + async def process( + self, event: AstrMessageEvent, check_text: str = None + ) -> Union[None, AsyncGenerator[None, None]]: + """检查内容安全""" text = check_text if check_text else event.get_message_str() ok, info = self.strategy_selector.check(text) if not ok: if event.is_at_or_wake_command: - event.set_result(MessageEventResult().message("你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。")) - yield + event.set_result( + MessageEventResult().message( + "你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。" + ) + ) + yield event.stop_event() logger.info(f"内容安全检查不通过,原因:{info}") return diff --git a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py index 5962f27d8..5701f0634 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py @@ -1,8 +1,8 @@ import abc from typing import Tuple + class ContentSafetyStrategy(abc.ABC): - @abc.abstractmethod def check(self, content: str) -> Tuple[bool, str]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py index 05d93af6c..73296b90e 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py @@ -1,30 +1,30 @@ -''' +""" 使用此功能应该先 pip install baidu-aip -''' +""" + from . import ContentSafetyStrategy from aip import AipContentCensor + class BaiduAipStrategy(ContentSafetyStrategy): def __init__(self, appid: str, ak: str, sk: str) -> None: self.app_id = appid self.api_key = ak self.secret_key = sk - self.client = AipContentCensor(self.app_id, - self.api_key, - self.secret_key) + self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key) def check(self, content: str): res = self.client.textCensorUserDefined(content) - if 'conclusionType' not in res: + if "conclusionType" not in res: return False, "" - if res['conclusionType'] == 1: + if res["conclusionType"] == 1: return True, "" else: - if 'data' not in res: + if "data" not in res: return False, "" - count = len(res['data']) + count = len(res["data"]) info = f"百度审核服务发现 {count} 处违规:\n" - for i in res['data']: + for i in res["data"]: info += f"{i['msg']};\n" - info += "\n判断结果:"+res['conclusion'] - return False, info \ No newline at end of file + info += "\n判断结果:" + res["conclusion"] + return False, info diff --git a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py index 9f09c1ae8..e117f47a5 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py @@ -4,20 +4,23 @@ import json import base64 from . import ContentSafetyStrategy + class KeywordsStrategy(ContentSafetyStrategy): def __init__(self, extra_keywords: list) -> None: self.keywords = [] if extra_keywords is None: extra_keywords = [] self.keywords.extend(extra_keywords) - keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words') + keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words") # internal keywords if os.path.exists(keywords_path): with open(keywords_path, "r", encoding="utf-8") as f: - self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords']) + self.keywords.extend( + json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"] + ) def check(self, content: str) -> bool: for keyword in self.keywords: if re.search(keyword, content): return False, "内容安全检查不通过,匹配到敏感词。" - return True, "" \ No newline at end of file + return True, "" diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py index 57efd22f9..af960328f 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -2,6 +2,7 @@ from . import ContentSafetyStrategy from typing import List, Tuple from astrbot import logger + class StrategySelector: def __init__(self, config: dict) -> None: self.enabled_strategies: List[ContentSafetyStrategy] = [] diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index a6b41f8bf..1abbca4e1 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.star import PluginManager + @dataclass class PipelineContext: astrbot_config: AstrBotConfig - plugin_manager: PluginManager \ No newline at end of file + plugin_manager: PluginManager diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 15bfac2e2..96d7ff4b7 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -7,42 +7,45 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core import logger from astrbot.core.message.components import Plain, Record, Image + @register_stage class PreProcessStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager - - self.stt_settings: dict = self.config.get('provider_stt_settings', {}) - self.platform_settings: dict = self.config.get('platform_settings', {}) - - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: - '''在处理事件之前的预处理''' + self.stt_settings: dict = self.config.get("provider_stt_settings", {}) + self.platform_settings: dict = self.config.get("platform_settings", {}) + + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: + """在处理事件之前的预处理""" # 路径映射 - if mappings := self.platform_settings.get('path_mapping', []): + if mappings := self.platform_settings.get("path_mapping", []): # 支持 Record,Image 消息段的路径映射。 message_chain = event.get_messages() - + for idx, component in enumerate(message_chain): if isinstance(component, (Record, Image)) and component.url: for mapping in mappings: from_, to_ = mapping.split(":") from_ = from_.removesuffix("/") to_ = to_.removesuffix("/") - + url = component.url.removeprefix("file://") if url.startswith(from_): component.url = url.replace(from_, to_, 1) logger.debug(f"路径映射: {url} -> {component.url}") message_chain[idx] = component - + # STT - if self.stt_settings.get('enable', False): + if self.stt_settings.get("enable", False): # TODO: 独立 - stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst + stt_provider = ( + self.plugin_manager.context.provider_manager.curr_stt_provider_inst + ) if stt_provider: message_chain = event.get_messages() for idx, component in enumerate(message_chain): diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index f17c7d7bb..ede3bb9d8 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -1,13 +1,17 @@ -''' +""" 本地 Agent 模式的 LLM 调用 Stage -''' +""" + import traceback import json from typing import Union, AsyncGenerator from ...context import PipelineContext from ..stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType +from astrbot.core.message.message_event_result import ( + MessageEventResult, + ResultContentType, +) from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric @@ -15,31 +19,39 @@ from astrbot.core.provider.entites import ProviderRequest, LLMResponse from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map + class LLMRequestSubStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx - self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list - self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str - + self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list + self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][ + "wake_prefix" + ] # str + for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): - logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。") - self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):] - + logger.info( + f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。" + ) + self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :] + self.conv_manager = ctx.plugin_manager.context.conversation_manager - - async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]: + + async def process( + self, event: AstrMessageEvent, _nested: bool = False + ) -> Union[None, AsyncGenerator[None, None]]: req: ProviderRequest = None - + provider = self.ctx.plugin_manager.context.get_using_provider() if provider is None: return - + if event.get_extra("provider_request"): req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。" - + assert isinstance(req, ProviderRequest), ( + "provider_request 必须是 ProviderRequest 类型。" + ) + if req.conversation: req.contexts = json.loads(req.conversation.history) else: @@ -47,132 +59,176 @@ class LLMRequestSubStage(Stage): if self.provider_wake_prefix: if not event.message_str.startswith(self.provider_wake_prefix): return - req.prompt = event.message_str[len(self.provider_wake_prefix):] + req.prompt = event.message_str[len(self.provider_wake_prefix) :] req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file req.image_urls.append(image_url) - + # 获取对话上下文 - conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin) + conversation_id = await self.conv_manager.get_curr_conversation_id( + event.unified_msg_origin + ) if not conversation_id: - conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin) + conversation_id = await self.conv_manager.new_conversation( + event.unified_msg_origin + ) req.session_id = event.unified_msg_origin - conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id) + conversation = await self.conv_manager.get_conversation( + event.unified_msg_origin, conversation_id + ) req.conversation = conversation req.contexts = json.loads(conversation.history) event.set_extra("provider_request", req) - + if not req.prompt and not req.image_urls: return - + # 执行请求 LLM 前事件钩子。 # 装饰 system_prompt 等功能 - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent) + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnLLMRequestEvent + ) for handler in handlers: try: - logger.debug(f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") + logger.debug( + f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) await handler.handler(event, req) except BaseException: logger.error(traceback.format_exc()) - + if event.is_stopped(): - logger.info(f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。") + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) return - + if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) - + try: logger.debug(f"提供商请求 Payload: {req}") if _nested: - req.func_tool = None # 暂时不支持递归工具调用 - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM - + req.func_tool = None # 暂时不支持递归工具调用 + llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + # 执行 LLM 响应后的事件钩子。 - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent) + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnLLMResponseEvent + ) for handler in handlers: try: - logger.debug(f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") + logger.debug( + f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) await handler.handler(event, llm_response) except BaseException: logger.error(traceback.format_exc()) - + if event.is_stopped(): - logger.info(f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。") + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) return - - + # 保存到历史记录 await self._save_to_history(event, req, llm_response) - - await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) - if llm_response.role == 'assistant': + await Metric.upload( + llm_tick=1, + model_name=provider.get_model(), + provider_type=provider.meta().type, + ) + + if llm_response.role == "assistant": # text completion - event.set_result(MessageEventResult().message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT)) - elif llm_response.role == 'err': - event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}")) - elif llm_response.role == 'tool': + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT) + ) + elif llm_response.role == "err": + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" + ) + ) + elif llm_response.role == "tool": # function calling function_calling_result = {} - logger.info(f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}") - for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args): + logger.info( + f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + ) + for func_tool_name, func_tool_args in zip( + llm_response.tools_call_name, llm_response.tools_call_args + ): func_tool = req.func_tool.get_func(func_tool_name) - logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}") + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) try: # 尝试调用工具函数 - wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args) + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) async for resp in wrapper: - if resp is not None: # 有 return 返回 + if resp is not None: # 有 return 返回 function_calling_result[func_tool_name] = resp else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 except BaseException as e: logger.warning(traceback.format_exc()) - function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e) + function_calling_result[func_tool_name] = ( + "When calling the function, an error occurred: " + str(e) + ) if function_calling_result: # 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。 # 我们重新执行一遍这个 stage - req.func_tool = None # 暂时不支持递归工具调用 + req.func_tool = None # 暂时不支持递归工具调用 extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n" for tool_name, tool_result in function_calling_result.items(): - extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n" + extra_prompt += ( + f"Tool: {tool_name}\nTool Result: {tool_result}\n" + ) req.prompt += extra_prompt async for _ in self.process(event, _nested=True): yield else: if llm_response.completion_text: - event.set_result(MessageEventResult().message(llm_response.completion_text)) + event.set_result( + MessageEventResult().message(llm_response.completion_text) + ) except BaseException as e: logger.error(traceback.format_exc()) - event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}")) + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + ) + ) return - - async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse): + + async def _save_to_history( + self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse + ): if not req or not req.conversation or not llm_response: return - + if llm_response.role == "assistant": # 文本回复 contexts = req.contexts - new_record = { - "role": "user", - "content": req.prompt - } + new_record = {"role": "user", "content": req.prompt} contexts.append(new_record) - contexts.append({ - "role": "assistant", - "content": llm_response.completion_text - }) - contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts)) + contexts.append( + {"role": "assistant", "content": llm_response.completion_text} + ) + contexts_to_save = list( + filter(lambda item: "_no_save" not in item, contexts) + ) await self.conv_manager.update_conversation( - event.unified_msg_origin, - req.conversation.cid, - history=contexts_to_save - ) \ No newline at end of file + event.unified_msg_origin, req.conversation.cid, history=contexts_to_save + ) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 703f9469e..d369e53ed 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -1,6 +1,7 @@ -''' +""" 本地 Agent 模式的 AstrBot 插件调用 Stage -''' +""" + from ...context import PipelineContext from ..stage import Stage from typing import Dict, Any, List, AsyncGenerator, Union @@ -11,17 +12,23 @@ from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.star.star import star_map import traceback + class StarRequestSubStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: self.curr_provider = ctx.plugin_manager.context.get_using_provider() - self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix'] - self.identifier = ctx.astrbot_config['provider_settings']['identifier'] + self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"] + self.identifier = ctx.astrbot_config["provider_settings"]["identifier"] self.ctx = ctx - - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: - activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") - handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params") + + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: + activated_handlers: List[StarHandlerMetadata] = event.get_extra( + "activated_handlers" + ) + handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra( + "handlers_parsed_params" + ) if not handlers_parsed_params: handlers_parsed_params = {} for handler in activated_handlers: @@ -29,19 +36,21 @@ class StarRequestSubStage(Stage): try: if handler.handler_module_path not in star_map: continue - logger.debug(f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}") + logger.debug( + f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}" + ) wrapper = self._call_handler(self.ctx, event, handler.handler, **params) async for ret in wrapper: yield ret - event.clear_result() # 清除上一个 handler 的结果 + event.clear_result() # 清除上一个 handler 的结果 except Exception as e: logger.error(traceback.format_exc()) logger.error(f"Star {handler.handler_full_name} handle error: {e}") - + if event.is_at_or_wake_command: ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}" event.set_result(MessageEventResult().message(ret)) yield event.clear_result() - - event.stop_event() \ No newline at end of file + + event.stop_event() diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index c22ab4d92..4c52a4a3e 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -8,23 +8,26 @@ from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.provider.entites import ProviderRequest from astrbot.core import logger + @register_stage class ProcessStage(Stage): - async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager self.llm_request_sub_stage = LLMRequestSubStage() await self.llm_request_sub_stage.initialize(ctx) - + self.star_request_sub_stage = StarRequestSubStage() await self.star_request_sub_stage.initialize(ctx) - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: - '''处理事件 - ''' - activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: + """处理事件""" + activated_handlers: List[StarHandlerMetadata] = event.get_extra( + "activated_handlers" + ) # 有插件 Handler 被激活 if activated_handlers: async for resp in self.star_request_sub_stage.process(event): @@ -40,20 +43,26 @@ class ProcessStage(Stage): yield else: yield - + # 调用 LLM 相关请求 - if not self.ctx.astrbot_config['provider_settings'].get('enable', True): + if not self.ctx.astrbot_config["provider_settings"].get("enable", True): return - - if not event._has_send_oper and event.is_at_or_wake_command and not event.call_llm: + + if ( + not event._has_send_oper + and event.is_at_or_wake_command + and not event.call_llm + ): # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 - if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): + if ( + event.get_result() and not event.get_result().is_stopped() + ) or not event.get_result(): # 事件没有终止传播 provider = self.ctx.plugin_manager.context.get_using_provider() - + if not provider: logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。") return - + async for _ in self.llm_request_sub_stage.process(event): - yield \ No newline at end of file + yield diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index 2d0437e34..7550d84e0 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -5,7 +5,6 @@ from typing import DefaultDict, Deque, Union, AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core import logger from astrbot.core.config.astrbot_config import RateLimitStrategy @@ -32,11 +31,19 @@ class RateLimitStage(Stage): """ 初始化限流器,根据配置设置限流参数。 """ - self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count'] - self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time']) - self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard + self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][ + "count" + ] + self.rate_limit_time = timedelta( + seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"] + ) + self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][ + "strategy" + ] # stall or discard - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: """ 检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 @@ -59,21 +66,29 @@ class RateLimitStage(Stage): # 达到限流阈值,计算下一个窗口的时间 next_window_time = timestamps[0] + self.rate_limit_time stall_duration = (next_window_time - now).total_seconds() - + match self.rl_strategy: case RateLimitStrategy.STALL.value: - logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。") + logger.info( + f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。" + ) await asyncio.sleep(stall_duration) case RateLimitStrategy.DISCARD.value: # event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。")) - logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。") + logger.info( + f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。" + ) return event.stop_event() - - self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration)) + + self._remove_expired_timestamps( + timestamps, now + timedelta(seconds=stall_duration) + ) timestamps.append(now) - def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None: + def _remove_expired_timestamps( + self, timestamps: Deque[datetime], now: datetime + ) -> None: """ 移除时间窗口外的时间戳。 @@ -83,4 +98,4 @@ class RateLimitStage(Stage): """ expiry_threshold: datetime = now - self.rate_limit_time while timestamps and timestamps[0] < expiry_threshold: - timestamps.popleft() + timestamps.popleft() diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index cbd8c197d..54d319de0 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -12,40 +12,56 @@ from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map from astrbot.core.message.components import Plain, Reply, At + + @register_stage class RespondStage(Stage): async def initialize(self, ctx: PipelineContext): self.ctx = ctx - - self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention'] - self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote'] - + + self.reply_with_mention = ctx.astrbot_config["platform_settings"][ + "reply_with_mention" + ] + self.reply_with_quote = ctx.astrbot_config["platform_settings"][ + "reply_with_quote" + ] + # 分段回复 - self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable'] - self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result'] - - self.interval_method = ctx.astrbot_config['platform_settings']['segmented_reply']['interval_method'] - self.log_base = float(ctx.astrbot_config['platform_settings']['segmented_reply']['log_base']) - interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval'] + self.enable_seg: bool = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ]["enable"] + self.only_llm_result = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ]["only_llm_result"] + + self.interval_method = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ]["interval_method"] + self.log_base = float( + ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"] + ) + interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][ + "interval" + ] interval_str_ls = interval_str.replace(" ", "").split(",") try: self.interval = [float(t) for t in interval_str_ls] except BaseException as e: - logger.error(f'解析分段回复的间隔时间失败。{e}') + logger.error(f"解析分段回复的间隔时间失败。{e}") self.interval = [1.5, 3.5] logger.info(f"分段回复间隔时间:{self.interval}") - + async def _word_cnt(self, text: str) -> int: - '''分段回复 统计字数''' + """分段回复 统计字数""" if all(ord(c) < 128 for c in text): word_count = len(text.split()) else: word_count = len([c for c in text if c.isalnum()]) return word_count - + async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: - '''分段回复 计算间隔时间''' - if self.interval_method == 'log': + """分段回复 计算间隔时间""" + if self.interval_method == "log": if isinstance(comp, Plain): wc = await self._word_cnt(comp.text) i = math.log(wc + 1, self.log_base) @@ -56,15 +72,20 @@ class RespondStage(Stage): # random return random.uniform(self.interval[0], self.interval[1]) - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: result = event.get_result() if result is None: return if len(result.chain) > 0: await event._pre_send() - - if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result): + + if self.enable_seg and ( + (self.only_llm_result and result.is_llm_result()) + or not self.only_llm_result + ): decorated_comps = [] if self.reply_with_mention: for comp in result.chain: @@ -86,18 +107,26 @@ class RespondStage(Stage): else: await event.send(result) await event._post_send() - logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}") - - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent) + logger.info( + f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" + ) + + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnAfterMessageSentEvent + ) for handler in handlers: try: - logger.debug(f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") + logger.debug( + f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) await handler.handler(event) except BaseException: logger.error(traceback.format_exc()) - + if event.is_stopped(): - logger.info(f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。") + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) return - - event.clear_result() \ No newline at end of file + + event.clear_result() diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 5db4d19b5..d80945d5b 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -12,74 +12,107 @@ from astrbot.core import html_renderer from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map + @register_stage class ResultDecorateStage(Stage): async def initialize(self, ctx: PipelineContext): self.ctx = ctx - self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix'] - self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention'] - self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote'] - self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold'] + self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] + self.reply_with_mention = ctx.astrbot_config["platform_settings"][ + "reply_with_mention" + ] + self.reply_with_quote = ctx.astrbot_config["platform_settings"][ + "reply_with_quote" + ] + self.t2i_word_threshold = ctx.astrbot_config["t2i_word_threshold"] try: self.t2i_word_threshold = int(self.t2i_word_threshold) if self.t2i_word_threshold < 50: self.t2i_word_threshold = 50 except BaseException: self.t2i_word_threshold = 150 - - self.forward_threshold = ctx.astrbot_config['platform_settings']['forward_threshold'] - - # 分段回复 - self.words_count_threshold = int(ctx.astrbot_config['platform_settings']['segmented_reply']['words_count_threshold']) - self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable'] - self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result'] - self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex'] - self.content_cleanup_rule = ctx.astrbot_config['platform_settings']['segmented_reply']['content_cleanup_rule'] - + self.forward_threshold = ctx.astrbot_config["platform_settings"][ + "forward_threshold" + ] + + # 分段回复 + self.words_count_threshold = int( + ctx.astrbot_config["platform_settings"]["segmented_reply"][ + "words_count_threshold" + ] + ) + self.enable_segmented_reply = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ]["enable"] + self.only_llm_result = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ]["only_llm_result"] + self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"] + self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ]["content_cleanup_rule"] + # exception - self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response'] + self.content_safe_check_reply = ctx.astrbot_config["content_safety"][ + "also_use_in_response" + ] self.content_safe_check_stage = None if self.content_safe_check_reply: for stage in registered_stages: if stage.__class__.__name__ == "ContentSafetyCheckStage": self.content_safe_check_stage = stage - - - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: result = event.get_result() if result is None or not result.chain: return - + # 回复时检查内容安全 - if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result(): + if ( + self.content_safe_check_reply + and self.content_safe_check_stage + and result.is_llm_result() + ): text = "" for comp in result.chain: if isinstance(comp, Plain): text += comp.text - async for _ in self.content_safe_check_stage.process(event, check_text=text): + async for _ in self.content_safe_check_stage.process( + event, check_text=text + ): yield - + # 发送消息前事件钩子 - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent) + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnDecoratingResultEvent + ) for handler in handlers: try: - logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") + logger.debug( + f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) await handler.handler(event) if event.get_result() is None or not event.get_result().chain: - logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。") + logger.debug( + f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。" + ) except BaseException: logger.error(traceback.format_exc()) - + if event.is_stopped(): - logger.info(f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。") + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) return - + # 需要再获取一次。插件可能直接对 chain 进行了替换。 result = event.get_result() if result is None: return - + if len(result.chain) > 0: # 回复前缀 if self.reply_prefix: @@ -87,10 +120,12 @@ class ResultDecorateStage(Stage): if isinstance(comp, Plain): comp.text = self.reply_prefix + comp.text break - - # 分段回复 + + # 分段回复 if self.enable_segmented_reply: - if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result: + if ( + self.only_llm_result and result.is_llm_result() + ) or not self.only_llm_result: new_chain = [] for comp in result.chain: if isinstance(comp, Plain): @@ -113,9 +148,12 @@ class ResultDecorateStage(Stage): # 非 Plain 类型的消息段不分段 new_chain.append(comp) result.chain = new_chain - + # TTS - if self.ctx.astrbot_config['provider_tts_settings']['enable'] and result.is_llm_result(): + if ( + self.ctx.astrbot_config["provider_tts_settings"]["enable"] + and result.is_llm_result() + ): tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst new_chain = [] for comp in result.chain: @@ -125,9 +163,13 @@ class ResultDecorateStage(Stage): audio_path = await tts_provider.get_audio(comp.text) logger.info("TTS 结果: " + audio_path) if audio_path: - new_chain.append(Record(file=audio_path, url=audio_path)) + new_chain.append( + Record(file=audio_path, url=audio_path) + ) else: - logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}") + logger.error( + f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}" + ) new_chain.append(comp) except BaseException: logger.error(traceback.format_exc()) @@ -136,9 +178,11 @@ class ResultDecorateStage(Stage): else: new_chain.append(comp) result.chain = new_chain - + # 文本转图片 - elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_: + elif ( + result.use_t2i_ is None and self.ctx.astrbot_config["t2i"] + ) or result.use_t2i_: plain_str = "" for comp in result.chain: if isinstance(comp, Plain): @@ -153,35 +197,38 @@ class ResultDecorateStage(Stage): logger.error("文本转图片失败,使用文本发送。") return if time.time() - render_start > 3: - logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。") + logger.warning( + "文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。" + ) if url: result.chain = [Image.fromURL(url)] - + # 触发转发消息 has_forwarded = False - if event.get_platform_name() == 'aiocqhttp': + if event.get_platform_name() == "aiocqhttp": word_cnt = 0 for comp in result.chain: if isinstance(comp, Plain): word_cnt += len(comp.text) if word_cnt > self.forward_threshold: node = Node( - uin=event.get_self_id(), - name="AstrBot", - content=[ - *result.chain - ] + uin=event.get_self_id(), name="AstrBot", content=[*result.chain] ) result.chain = [node] has_forwarded = True - + if not has_forwarded: # at 回复 - if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE: - result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name())) + if ( + self.reply_with_mention + and event.get_message_type() != MessageType.FRIEND_MESSAGE + ): + result.chain.insert( + 0, At(qq=event.get_sender_id(), name=event.get_sender_name()) + ) if len(result.chain) > 1 and isinstance(result.chain[1], Plain): result.chain[1].text = "\n" + result.chain[1].text - + # 引用回复 if self.reply_with_quote: if not any(isinstance(item, File) for item in result.chain): diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index a59181cf5..66874b80f 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -5,17 +5,18 @@ from typing import AsyncGenerator from astrbot.core.platform import AstrMessageEvent from astrbot.core import logger -class PipelineScheduler(): + +class PipelineScheduler: def __init__(self, context: PipelineContext): - registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__ .__name__)) + registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__)) self.ctx = context - + async def initialize(self): for stage in registered_stages: # logger.debug(f"初始化阶段 {stage.__class__ .__name__}") - + await stage.initialize(self.ctx) - + async def _process_stages(self, event: AstrMessageEvent, from_stage=0): for i in range(from_stage, len(registered_stages)): stage = registered_stages[i] @@ -24,28 +25,32 @@ class PipelineScheduler(): if isinstance(coro, AsyncGenerator): async for _ in coro: if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + logger.debug( + f"阶段 {stage.__class__.__name__} 已终止事件传播。" + ) break await self._process_stages(event, i + 1) if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + logger.debug( + f"阶段 {stage.__class__.__name__} 已终止事件传播。" + ) break else: await coro if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - + async def execute(self, event: AstrMessageEvent): - '''执行 pipeline''' + """执行 pipeline""" await self._process_stages(event) - + if not event._has_send_oper and event.get_platform_name() == "webchat": await event.send(None) - - logger.debug("pipeline 执行完毕。") \ No newline at end of file + + logger.debug("pipeline 执行完毕。") diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 88b6e7943..f1f8587a6 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -8,39 +8,39 @@ from .context import PipelineContext from astrbot.core.message.message_event_result import MessageEventResult, CommandResult registered_stages: List[Stage] = [] -'''维护了所有已注册的 Stage 实现类''' +"""维护了所有已注册的 Stage 实现类""" + def register_stage(cls): - '''一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类 - ''' + """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" registered_stages.append(cls()) return cls - + + class Stage(abc.ABC): - '''描述一个 Pipeline 的某个阶段 - ''' - + """描述一个 Pipeline 的某个阶段""" + @abc.abstractmethod async def initialize(self, ctx: PipelineContext) -> None: - '''初始化阶段 - ''' + """初始化阶段""" raise NotImplementedError - + @abc.abstractmethod - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: - '''处理事件 - ''' + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: + """处理事件""" raise NotImplementedError - + async def _call_handler( - self, + self, ctx: PipelineContext, - event: AstrMessageEvent, + event: AstrMessageEvent, handler: Awaitable, *args, **kwargs, ) -> AsyncGenerator[None, None]: - '''调用 Handler。''' + """调用 Handler。""" # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) ready_to_call = None try: @@ -49,7 +49,7 @@ class Stage(abc.ABC): # 向下兼容 logger.debug(str(e)) ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs) - + if isinstance(ready_to_call, AsyncGenerator): _has_yielded = False async for ret in ready_to_call: @@ -69,4 +69,4 @@ class Stage(abc.ABC): event.set_result(ret) yield else: - yield ret \ No newline at end of file + yield ret diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 5693b5603..d18a0b5c8 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -3,12 +3,12 @@ from ..context import PipelineContext from typing import Union, AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageEventResult, MessageChain -from astrbot.core.message.components import At, Reply +from astrbot.core.message.components import At from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map -from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.filter.permission import PermissionTypeFilter + @register_stage class WakingCheckStage(Stage): """检查是否需要唤醒。唤醒机器人有如下几点条件: @@ -77,10 +77,12 @@ class WakingCheckStage(Stage): # 检查插件的 handler filter activated_handlers = [] handlers_parsed_params = {} # 注册了指令的 handler - - for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent): + + for handler in star_handlers_registry.get_handlers_by_event_type( + EventType.AdapterMessageEvent + ): # filter 需满足 AND 逻辑关系 - passed = True + passed = True permission_not_pass = False if len(handler.event_filters) == 0: continue @@ -106,10 +108,14 @@ class WakingCheckStage(Stage): if passed: if permission_not_pass: if self.no_permission_reply: - await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。")) + await event.send( + MessageChain().message( + f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。" + ) + ) event.stop_event() return - + is_wake = True event.is_wake = True @@ -118,7 +124,7 @@ class WakingCheckStage(Stage): handlers_parsed_params[handler.handler_full_name] = event.get_extra( "parsed_params" ) - + event.clear_extra() event.set_extra("activated_handlers", activated_handlers) diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py index f13c79029..024de0834 100644 --- a/astrbot/core/pipeline/whitelist_check/stage.py +++ b/astrbot/core/pipeline/whitelist_check/stage.py @@ -5,38 +5,55 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType from astrbot.core import logger + @register_stage class WhitelistCheckStage(Stage): - '''检查是否在群聊/私聊白名单 - ''' + """检查是否在群聊/私聊白名单""" + async def initialize(self, ctx: PipelineContext) -> None: - self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list'] - self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist'] - self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group'] - self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend'] - self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log'] - - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + self.enable_whitelist_check = ctx.astrbot_config["platform_settings"][ + "enable_id_white_list" + ] + self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"] + self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][ + "wl_ignore_admin_on_group" + ] + self.wl_ignore_admin_on_friend = ctx.astrbot_config["platform_settings"][ + "wl_ignore_admin_on_friend" + ] + self.wl_log = ctx.astrbot_config["platform_settings"]["id_whitelist_log"] + + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: if not self.enable_whitelist_check: # 白名单检查未启用 return - + if len(self.whitelist) == 0: # 白名单为空,不检查 return - - if event.get_platform_name() == 'webchat': + + if event.get_platform_name() == "webchat": # WebChat 豁免 return - + # 检查是否在白名单 if self.wl_ignore_admin_on_group: - if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE: + if ( + event.role == "admin" + and event.get_message_type() == MessageType.GROUP_MESSAGE + ): return if self.wl_ignore_admin_on_friend: - if event.role == 'admin' and event.get_message_type() == MessageType.FRIEND_MESSAGE: + if ( + event.role == "admin" + and event.get_message_type() == MessageType.FRIEND_MESSAGE + ): return if event.unified_msg_origin not in self.whitelist: if self.wl_log: - logger.info(f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。") - event.stop_event() \ No newline at end of file + logger.info( + f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。" + ) + event.stop_event() diff --git a/astrbot/core/platform/__init__.py b/astrbot/core/platform/__init__.py index d7db721ef..48ea57b8a 100644 --- a/astrbot/core/platform/__init__.py +++ b/astrbot/core/platform/__init__.py @@ -1,4 +1,13 @@ from .platform import Platform from .astr_message_event import AstrMessageEvent from .platform_metadata import PlatformMetadata -from .astrbot_message import AstrBotMessage, MessageMember, MessageType \ No newline at end of file +from .astrbot_message import AstrBotMessage, MessageMember, MessageType + +__all__ = [ + "Platform", + "AstrMessageEvent", + "PlatformMetadata", + "AstrBotMessage", + "MessageMember", + "MessageType", +] diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 9569e281e..96b5dbacd 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -5,73 +5,85 @@ from .platform_metadata import PlatformMetadata from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.platform.message_type import MessageType from typing import List, Union -from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward +from astrbot.core.message.components import ( + Plain, + Image, + BaseMessageComponent, + Face, + At, + AtAll, + Forward, +) from astrbot.core.utils.metrics import Metric from astrbot.core.provider.entites import ProviderRequest from astrbot.core.db.po import Conversation + @dataclass class MessageSesion: platform_name: str message_type: MessageType session_id: str - + def __str__(self): return f"{self.platform_name}:{self.message_type.value}:{self.session_id}" - + @staticmethod def from_str(session_str: str): platform_name, message_type, session_id = session_str.split(":") return MessageSesion(platform_name, MessageType(message_type), session_id) + class AstrMessageEvent(abc.ABC): - def __init__(self, - message_str: str, - message_obj: AstrBotMessage, - platform_meta: PlatformMetadata, - session_id: str,): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + ): self.message_str = message_str - '''纯文本的消息''' + """纯文本的消息""" self.message_obj = message_obj - '''消息对象, AstrBotMessage。带有完整的消息结构。''' + """消息对象, AstrBotMessage。带有完整的消息结构。""" self.platform_meta = platform_meta - '''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp''' + """消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp""" self.session_id = session_id - '''用户的会话 ID。可以直接使用下面的 unified_msg_origin''' + """用户的会话 ID。可以直接使用下面的 unified_msg_origin""" self.role = "member" - '''用户是否是管理员。如果是管理员,这里是 admin''' + """用户是否是管理员。如果是管理员,这里是 admin""" self.is_wake = False - '''是否唤醒(是否通过 WakingStage)''' + """是否唤醒(是否通过 WakingStage)""" self.is_at_or_wake_command = False - '''是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)''' + """是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)""" self._extras = {} self.session = MessageSesion( platform_name=platform_meta.name, message_type=message_obj.type, - session_id=session_id + session_id=session_id, ) self.unified_msg_origin = str(self.session) - '''统一的消息来源字符串。格式为 platform_name:message_type:session_id''' + """统一的消息来源字符串。格式为 platform_name:message_type:session_id""" self._result: MessageEventResult = None - '''消息事件的结果''' - - self._has_send_oper = False - '''在此次事件中是否有过至少一次发送消息的操作''' + """消息事件的结果""" + + self._has_send_oper = False + """在此次事件中是否有过至少一次发送消息的操作""" self.call_llm = False - '''是否在此消息事件中禁止默认的 LLM 请求''' - + """是否在此消息事件中禁止默认的 LLM 请求""" + # back_compability self.platform = platform_meta - + def get_platform_name(self): return self.platform_meta.name - + def get_message_str(self) -> str: - ''' + """ 获取消息字符串。 - ''' + """ return self.message_str - + def _outline_chain(self, chain: List[BaseMessageComponent]) -> str: outline = "" for i in chain: @@ -91,188 +103,185 @@ class AstrMessageEvent(abc.ABC): else: outline += f"[{i.type}]" return outline - + def get_message_outline(self) -> str: - ''' + """ 获取消息概要。 - + 除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。 - ''' + """ return self._outline_chain(self.message_obj.message) - + def get_messages(self) -> List[BaseMessageComponent]: - ''' + """ 获取消息链。 - ''' + """ return self.message_obj.message - + def get_message_type(self) -> MessageType: - ''' + """ 获取消息类型。 - ''' + """ return self.message_obj.type - + def get_session_id(self) -> str: - ''' + """ 获取会话id。 - ''' + """ return self.session_id - + def get_group_id(self) -> str: - ''' + """ 获取群组id。如果不是群组消息,返回空字符串。 - ''' + """ return self.message_obj.group_id - + def get_self_id(self) -> str: - ''' + """ 获取机器人自身的id。 - ''' + """ return self.message_obj.self_id - + def get_sender_id(self) -> str: - ''' + """ 获取消息发送者的id。 - ''' + """ return self.message_obj.sender.user_id - + def get_sender_name(self) -> str: - ''' + """ 获取消息发送者的名称。(可能会返回空字符串) - ''' + """ return self.message_obj.sender.nickname - + def set_extra(self, key, value): - ''' + """ 设置额外的信息。 - ''' + """ self._extras[key] = value - - def get_extra(self, key = None): - ''' + + def get_extra(self, key=None): + """ 获取额外的信息。 - ''' + """ if key is None: return self._extras return self._extras.get(key, None) - + def clear_extra(self): - ''' + """ 清除额外的信息。 - ''' + """ self._extras.clear() - + def is_private_chat(self) -> bool: - ''' + """ 是否是私聊。 - ''' + """ return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value - + def is_wake_up(self) -> bool: - ''' + """ 是否是唤醒机器人的事件。 - ''' + """ return self.is_wake - + def is_admin(self) -> bool: - ''' + """ 是否是管理员。 - ''' + """ return self.role == "admin" - + async def send(self, message: MessageChain): - ''' + """ 发送消息到消息平台。 - ''' - await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name) + """ + await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) self._has_send_oper = True - + async def _pre_send(self): - '''调度器会在执行 send() 前调用该方法''' + """调度器会在执行 send() 前调用该方法""" pass - + async def _post_send(self): - '''调度器会在执行 send() 后调用该方法''' + """调度器会在执行 send() 后调用该方法""" pass - - + def set_result(self, result: Union[MessageEventResult, str]): - '''设置消息事件的结果。 - + """设置消息事件的结果。 + Note: 事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。 - + 如果没有设置 `MessageEventResult` 中的 result_type,默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。 - + Example: ``` async def ban_handler(self, event: AstrMessageEvent): if event.get_sender_id() in self.blacklist: event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP) return - + async def check_count(self, event: AstrMessageEvent): self.count += 1 event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE)) return ``` - ''' + """ if isinstance(result, str): result = MessageEventResult().message(result) self._result = result - + def stop_event(self): - '''终止事件传播。 - ''' + """终止事件传播。""" if self._result is None: self.set_result(MessageEventResult().stop_event()) else: self._result.stop_event() - + def continue_event(self): - '''继续事件传播。 - ''' + """继续事件传播。""" if self._result is None: self.set_result(MessageEventResult().continue_event()) else: self._result.continue_event() - + def is_stopped(self) -> bool: - ''' + """ 是否终止事件传播。 - ''' + """ if self._result is None: - return False # 默认是继续传播 - return self._result.is_stopped() - + return False # 默认是继续传播 + return self._result.is_stopped() + def should_call_llm(self, call_llm: bool): - ''' + """ 是否在此消息事件中禁止默认的 LLM 请求。 - + 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 - ''' + """ self.call_llm = call_llm - + def get_result(self) -> MessageEventResult: - ''' + """ 获取消息事件的结果。 - ''' + """ return self._result - + def clear_result(self): - ''' + """ 清除消息事件的结果。 - ''' + """ self._result = None - - '''消息链相关''' - + + """消息链相关""" + def make_result(self) -> MessageEventResult: - ''' + """ 创建一个空的消息事件结果。 - + Example: - + ```python # 纯文本回复 yield event.make_result().message("Hi") @@ -280,76 +289,76 @@ class AstrMessageEvent(abc.ABC): yield event.make_result().url_image("https://example.com/image.jpg") yield event.make_result().file_image("image.jpg") ``` - ''' + """ return MessageEventResult() - + def plain_result(self, text: str) -> MessageEventResult: - ''' + """ 创建一个空的消息事件结果,只包含一条文本消息。 - ''' + """ return MessageEventResult().message(text) - + def image_result(self, url_or_path: str) -> MessageEventResult: - ''' + """ 创建一个空的消息事件结果,只包含一条图片消息。 - + 根据开头是否包含 http 来判断是网络图片还是本地图片。 - ''' + """ if url_or_path.startswith("http"): return MessageEventResult().url_image(url_or_path) return MessageEventResult().file_image(url_or_path) - + def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult: - ''' + """ 创建一个空的消息事件结果,包含指定的消息链。 - ''' + """ mer = MessageEventResult() mer.chain = chain return mer - - '''LLM 请求相关''' - + + """LLM 请求相关""" + def request_llm( self, prompt: str, - func_tool_manager = None, + func_tool_manager=None, session_id: str = None, image_urls: List[str] = [], contexts: List = [], system_prompt: str = "", - conversation: Conversation = None + conversation: Conversation = None, ) -> ProviderRequest: - ''' + """ 创建一个 LLM 请求。 - + Examples: ```py yield event.request_llm(prompt="hi") ``` prompt: 提示词 - + system_prompt: 系统提示词 - + session_id: 已经过时,留空即可 - + image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。 - + contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。如果同时传入了 conversation,将会忽略 conversation。 - + func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。 - + conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。 - ''' - + """ + if len(contexts) > 0 and conversation: conversation = None - + return ProviderRequest( - prompt = prompt, - session_id = session_id, - image_urls = image_urls, - func_tool = func_tool_manager, - contexts = contexts, - system_prompt = system_prompt, - conversation=conversation - ) \ No newline at end of file + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool_manager, + contexts=contexts, + system_prompt=system_prompt, + conversation=conversation, + ) diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 1ca4f3109..ea55eaf4b 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -4,26 +4,29 @@ from dataclasses import dataclass from astrbot.core.message.components import BaseMessageComponent from .message_type import MessageType + @dataclass -class MessageMember(): +class MessageMember: user_id: str # 发送者id nickname: str = None + class AstrBotMessage: - ''' + """ AstrBot 的消息对象 - ''' + """ + type: MessageType # 消息类型 self_id: str # 机器人的识别id session_id: str # 会话id。取决于 unique_session 的设置。 message_id: str # 消息id - group_id: str = "" # 群组id,如果为私聊,则为空 + group_id: str = "" # 群组id,如果为私聊,则为空 sender: MessageMember # 发送者 message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 raw_message: object timestamp: int # 消息时间戳 - + def __init__(self) -> None: self.timestamp = int(time.time()) diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index a21b0672e..e3e63ef21 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -8,100 +8,121 @@ from .register import platform_cls_map from astrbot.core import logger from .sources.webchat.webchat_adapter import WebChatAdapter -class PlatformManager(): + +class PlatformManager: def __init__(self, config: AstrBotConfig, event_queue: Queue): self.platform_insts: List[Platform] = [] - '''加载的 Platform 的实例''' - + """加载的 Platform 的实例""" + self._inst_map = {} - - self.platforms_config = config['platform'] - self.settings = config['platform_settings'] + + self.platforms_config = config["platform"] + self.settings = config["platform_settings"] self.event_queue = event_queue async def initialize(self): - '''初始化所有平台适配器''' + """初始化所有平台适配器""" for platform in self.platforms_config: await self.load_platform(platform) - + # 网页聊天 webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) self.platform_insts.append(webchat_inst) - asyncio.create_task(self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat"))) - + asyncio.create_task( + self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")) + ) + async def load_platform(self, platform_config: dict): - '''实例化一个平台''' - if not platform_config['enable']: + """实例化一个平台""" + if not platform_config["enable"]: return - - logger.info(f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...") - + + logger.info( + f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ..." + ) + # 动态导入 try: - match platform_config['type']: + match platform_config["type"]: case "aiocqhttp": - from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401 + from .sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, # noqa: F401 + ) case "qq_official": - from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401 + from .sources.qqofficial.qqofficial_platform_adapter import ( + QQOfficialPlatformAdapter, # noqa: F401 + ) case "qq_official_webhook": - from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401 + from .sources.qqofficial_webhook.qo_webhook_adapter import ( + QQOfficialWebhookPlatformAdapter, # noqa: F401 + ) case "gewechat": - from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401 + from .sources.gewechat.gewechat_platform_adapter import ( + GewechatPlatformAdapter, # noqa: F401 + ) case "lark": - from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 - case "telegram": - from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401 + from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 + case "telegram": + from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401 except (ImportError, ModuleNotFoundError) as e: - logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。") + logger.error( + f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。" + ) except Exception as e: logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") - - - if platform_config['type'] not in platform_cls_map: - logger.error(f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误") + + if platform_config["type"] not in platform_cls_map: + logger.error( + f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误" + ) return - cls_type = platform_cls_map[platform_config['type']] + cls_type = platform_cls_map[platform_config["type"]] inst = cls_type(platform_config, self.settings, self.event_queue) - self._inst_map[platform_config['id']] = inst + self._inst_map[platform_config["id"]] = inst self.platform_insts.append(inst) - - asyncio.create_task(self._task_wrapper(asyncio.create_task(inst.run(), name=platform_config['id'] + "_platform"))) - + + asyncio.create_task( + self._task_wrapper( + asyncio.create_task( + inst.run(), name=platform_config["id"] + "_platform" + ) + ) + ) + async def _task_wrapper(self, task: asyncio.Task): try: await task except asyncio.CancelledError: pass except Exception as e: - logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") for line in traceback.format_exc().split("\n"): logger.error(f"| {line}") logger.error("-------") - + async def reload(self, platform_config: dict): # 还未实现完成,不要调用此方法 - - if platform_config['id'] in self._inst_map: + + if platform_config["id"] in self._inst_map: # 正在运行 - if getattr(self._inst_map[platform_config['id']], 'terminate', None): + if getattr(self._inst_map[platform_config["id"]], "terminate", None): logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...") - await self._inst_map[platform_config['id']].terminate() + await self._inst_map[platform_config["id"]].terminate() logger.info(f"{platform_config['id']} 平台适配器已终止。") - del self._inst_map[platform_config['id']] - self.platform_insts.remove(self._inst_map[platform_config['id']]) + del self._inst_map[platform_config["id"]] + self.platform_insts.remove(self._inst_map[platform_config["id"]]) else: logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。") - + # 再启动新的实例 await self.load_platform(platform_config) - + else: # 先将 _inst_map 中在 platform_config 中不存在的实例删除 - config_ids = [platform['id'] for platform in self.platforms_config] + config_ids = [platform["id"] for platform in self.platforms_config] for key in list(self._inst_map.keys()): if key not in config_ids: - if getattr(self._inst_map[key], 'terminate', None): + if getattr(self._inst_map[key], "terminate", None): logger.info(f"正在尝试终止 {key} 平台适配器 ...") await self._inst_map[key].terminate() logger.info(f"{key} 平台适配器已终止。") @@ -109,9 +130,9 @@ class PlatformManager(): self.platform_insts.remove(self._inst_map[key]) else: logger.warning(f"可能无法正常终止 {key} 平台适配器。") - + # 再启动新的实例 await self.load_platform(platform_config) - + def get_insts(self): - return self.platform_insts \ No newline at end of file + return self.platform_insts diff --git a/astrbot/core/platform/message_type.py b/astrbot/core/platform/message_type.py index 6149277d9..25b7cdc48 100644 --- a/astrbot/core/platform/message_type.py +++ b/astrbot/core/platform/message_type.py @@ -1,6 +1,7 @@ from enum import Enum + class MessageType(Enum): - GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 - FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 - OTHER_MESSAGE = 'OtherMessage' # 其他类型的消息,如系统消息等 \ No newline at end of file + GROUP_MESSAGE = "GroupMessage" # 群组形式的消息 + FRIEND_MESSAGE = "FriendMessage" # 私聊、好友等单聊消息 + OTHER_MESSAGE = "OtherMessage" # 其他类型的消息,如系统消息等 diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 1ae8afb4d..8ed0be039 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -7,48 +7,51 @@ from astrbot.core.message.message_event_result import MessageChain from .astr_message_event import MessageSesion from astrbot.core.utils.metrics import Metric + class Platform(abc.ABC): def __init__(self, event_queue: Queue): super().__init__() # 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。 self._event_queue = event_queue - + @abc.abstractmethod def run(self) -> Awaitable[Any]: - ''' + """ 得到一个平台的运行实例,需要返回一个协程对象。 - ''' + """ raise NotImplementedError - + async def terminate(self): - ''' + """ 终止一个平台的运行实例。 - ''' + """ pass - + @abc.abstractmethod def meta(self) -> PlatformMetadata: - ''' + """ 得到一个平台的元数据。 - ''' + """ raise NotImplementedError - - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain) -> Awaitable[Any]: - ''' + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ) -> Awaitable[Any]: + """ 通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 - + 异步方法。 - ''' - await Metric.upload(msg_event_tick = 1, adapter_name = self.meta().name) - + """ + await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name) + def commit_event(self, event: AstrMessageEvent): - ''' + """ 提交一个事件到事件队列。 - ''' + """ self._event_queue.put_nowait(event) - + def get_client(self): - ''' + """ 获取平台的客户端对象。 - ''' - pass \ No newline at end of file + """ + pass diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 721192355..48fe23af7 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -1,12 +1,14 @@ from dataclasses import dataclass + + @dataclass -class PlatformMetadata(): +class PlatformMetadata: name: str - '''平台的名称''' + """平台的名称""" description: str - '''平台的描述''' - + """平台的描述""" + default_config_tmpl: dict = None - '''平台的默认配置模板''' + """平台的默认配置模板""" adapter_display_name: str = None - '''显示在 WebUI 配置页中的平台名称,如空则是 name''' \ No newline at end of file + """显示在 WebUI 配置页中的平台名称,如空则是 name""" diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index b746cdb95..fa65392a8 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -3,42 +3,46 @@ from .platform_metadata import PlatformMetadata from astrbot.core import logger platform_registry: List[PlatformMetadata] = [] -'''维护了通过装饰器注册的平台适配器''' +"""维护了通过装饰器注册的平台适配器""" platform_cls_map: Dict[str, Type] = {} -'''维护了平台适配器名称和适配器类的映射''' +"""维护了平台适配器名称和适配器类的映射""" + def register_platform_adapter( - adapter_name: str, - desc: str, + adapter_name: str, + desc: str, default_config_tmpl: dict = None, - adapter_display_name: str = None + adapter_display_name: str = None, ): - '''用于注册平台适配器的带参装饰器。 - - default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。 - ''' + """用于注册平台适配器的带参装饰器。 + + default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。 + """ + def decorator(cls): if adapter_name in platform_cls_map: - raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。") - + raise ValueError( + f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。" + ) + # 添加必备选项 if default_config_tmpl: - if 'type' not in default_config_tmpl: - default_config_tmpl['type'] = adapter_name - if 'enable' not in default_config_tmpl: - default_config_tmpl['enable'] = False - if 'id' not in default_config_tmpl: - default_config_tmpl['id'] = adapter_name + if "type" not in default_config_tmpl: + default_config_tmpl["type"] = adapter_name + if "enable" not in default_config_tmpl: + default_config_tmpl["enable"] = False + if "id" not in default_config_tmpl: + default_config_tmpl["id"] = adapter_name pm = PlatformMetadata( name=adapter_name, description=desc, default_config_tmpl=default_config_tmpl, - adapter_display_name=adapter_display_name + adapter_display_name=adapter_display_name, ) platform_registry.append(pm) platform_cls_map[adapter_name] = cls logger.debug(f"平台适配器 {adapter_name} 已注册") return cls - + return decorator diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 5f54f69a2..ce38296e6 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -5,19 +5,22 @@ from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from aiocqhttp import CQHttp from astrbot.core.utils.io import file_to_base64, download_image_by_url + class AiocqhttpMessageEvent(AstrMessageEvent): - def __init__(self, message_str, message_obj, platform_meta, session_id, bot: CQHttp): + def __init__( + self, message_str, message_obj, platform_meta, session_id, bot: CQHttp + ): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot - + @staticmethod async def _parse_onebot_json(message_chain: MessageChain): - '''解析成 OneBot json 格式''' + """解析成 OneBot json 格式""" ret = [] for segment in message_chain.chain: d = segment.toDict() if isinstance(segment, Plain): - d['type'] = 'text' + d["type"] = "text" elif isinstance(segment, (Image, Record)): # convert to base64 if segment.file and segment.file.startswith("file:///"): @@ -30,41 +33,48 @@ class AiocqhttpMessageEvent(AstrMessageEvent): bs64_data = segment.file else: bs64_data = file_to_base64(segment.file) - d['data'] = { - 'file': bs64_data, + d["data"] = { + "file": bs64_data, } elif isinstance(segment, At): - d['data'] = { - 'qq': str(segment.qq) # 转换为字符串 + d["data"] = { + "qq": str(segment.qq) # 转换为字符串 } ret.append(d) return ret async def send(self, message: MessageChain): ret = await AiocqhttpMessageEvent._parse_onebot_json(message) - + send_one_by_one = False for seg in message.chain: if isinstance(seg, (Node, Nodes)): # 转发消息不能和普通消息混在一起发送 send_one_by_one = True break - + if send_one_by_one: for seg in message.chain: if isinstance(seg, Nodes): # 带有多个节点的合并转发消息 payload = seg.toDict() if self.get_group_id(): - payload['group_id'] = self.get_group_id() - await self.bot.call_action('send_group_forward_msg', **payload) + payload["group_id"] = self.get_group_id() + await self.bot.call_action("send_group_forward_msg", **payload) else: - payload['user_id'] = self.get_sender_id() - await self.bot.call_action('send_private_forward_msg', **payload) + payload["user_id"] = self.get_sender_id() + await self.bot.call_action( + "send_private_forward_msg", **payload + ) else: - await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg]))) + await self.bot.send( + self.message_obj.raw_message, + await AiocqhttpMessageEvent._parse_onebot_json( + MessageChain([seg]) + ), + ) await asyncio.sleep(0.5) else: await self.bot.send(self.message_obj.raw_message, ret) - - await super().send(message) \ No newline at end of file + + await super().send(message) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 83689d6b6..9287044d9 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -5,7 +5,13 @@ import logging import uuid from typing import Awaitable, Any from aiocqhttp import CQHttp, Event -from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, +) from astrbot.api.event import MessageChain from .aiocqhttp_message_event import * # noqa: F403 from astrbot.api.message_components import * # noqa: F403 @@ -16,55 +22,64 @@ from ...register import register_platform_adapter from aiocqhttp.exceptions import ActionFailed from astrbot.core.utils.io import download_file -@register_platform_adapter("aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。") + +@register_platform_adapter( + "aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。" +) class AiocqhttpAdapter(Platform): - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) - + self.config = platform_config self.settings = platform_settings - self.unique_session = platform_settings['unique_session'] - self.host = platform_config['ws_reverse_host'] - self.port = platform_config['ws_reverse_port'] - + self.unique_session = platform_settings["unique_session"] + self.host = platform_config["ws_reverse_host"] + self.port = platform_config["ws_reverse_port"] + self.metadata = PlatformMetadata( "aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", ) - + self.stop = False - - self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) - + + self.bot = CQHttp( + use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180 + ) + @self.bot.on_request() async def request(event: Event): abm = await self.convert_message(event) if abm: await self.handle_msg(abm) - + @self.bot.on_notice() async def notice(event: Event): abm = await self.convert_message(event) if abm: await self.handle_msg(abm) - - @self.bot.on_message('group') + + @self.bot.on_message("group") async def group(event: Event): abm = await self.convert_message(event) if abm: await self.handle_msg(abm) - - @self.bot.on_message('private') + + @self.bot.on_message("private") async def private(event: Event): abm = await self.convert_message(event) if abm: await self.handle_msg(abm) - + @self.bot.on_websocket_connection def on_websocket_connection(_): logger.info("aiocqhttp(OneBot v11) 适配器已连接。") - - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) match session.message_type.value: case MessageType.GROUP_MESSAGE.value: @@ -73,94 +88,104 @@ class AiocqhttpAdapter(Platform): _, group_id = session.session_id.split("_") await self.bot.send_group_msg(group_id=group_id, message=ret) else: - await self.bot.send_group_msg(group_id=session.session_id, message=ret) + await self.bot.send_group_msg( + group_id=session.session_id, message=ret + ) case MessageType.FRIEND_MESSAGE.value: await self.bot.send_private_msg(user_id=session.session_id, message=ret) await super().send_by_session(session, message_chain) - + async def convert_message(self, event: Event) -> AstrBotMessage: logger.debug(f"[aiocqhttp] RawMessage {event}") - - if event['post_type'] == 'message': + + if event["post_type"] == "message": abm = await self._convert_handle_message_event(event) - elif event['post_type'] == 'notice': + elif event["post_type"] == "notice": abm = await self._convert_handle_notice_event(event) - elif event['post_type'] == 'request': + elif event["post_type"] == "request": abm = await self._convert_handle_request_event(event) - + return abm - + async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage: - '''OneBot V11 请求类事件''' + """OneBot V11 请求类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) - abm.sender = MessageMember( - user_id=event.user_id, - nickname=event.user_id - ) + abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id) abm.type = MessageType.OTHER_MESSAGE - if 'group_id' in event and event['group_id']: + if "group_id" in event and event["group_id"]: abm.type = MessageType.GROUP_MESSAGE abm.group_id = str(event.group_id) else: abm.type = MessageType.FRIEND_MESSAGE if self.unique_session and abm.type == MessageType.GROUP_MESSAGE: abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id) - abm.message_str = '' + abm.message_str = "" abm.message = [] abm.timestamp = int(time.time()) abm.message_id = uuid.uuid4().hex abm.raw_message = event return abm - + async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage: - '''OneBot V11 通知类事件''' + """OneBot V11 通知类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) - abm.sender = MessageMember( - user_id=event.user_id, - nickname=event.user_id - ) + abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id) abm.type = MessageType.OTHER_MESSAGE - if 'group_id' in event and event['group_id']: + if "group_id" in event and event["group_id"]: abm.group_id = str(event.group_id) abm.type = MessageType.GROUP_MESSAGE else: - abm.type = MessageType.FRIEND_MESSAGE + abm.type = MessageType.FRIEND_MESSAGE if self.unique_session and abm.type == MessageType.GROUP_MESSAGE: - abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id + abm.session_id = ( + abm.sender.user_id + "_" + str(event.group_id) + ) # 也保留群组 id else: - abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id + abm.session_id = ( + str(event.group_id) + if abm.type == MessageType.GROUP_MESSAGE + else abm.sender.user_id + ) abm.message_str = "" abm.message = [] abm.raw_message = event abm.timestamp = int(time.time()) abm.message_id = uuid.uuid4().hex - - if 'sub_type' in event: - if event['sub_type'] == 'poke' and 'target_id' in event: - abm.message.append(Poke(qq=str(event['target_id']), type='poke')) # noqa: F405 - + + if "sub_type" in event: + if event["sub_type"] == "poke" and "target_id" in event: + abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405 + return abm - + async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage: - '''OneBot V11 消息类事件''' + """OneBot V11 消息类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) - abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname']) - if event['message_type'] == 'group': + abm.sender = MessageMember( + str(event.sender["user_id"]), event.sender["nickname"] + ) + if event["message_type"] == "group": abm.type = MessageType.GROUP_MESSAGE abm.group_id = str(event.group_id) - elif event['message_type'] == 'private': + elif event["message_type"] == "private": abm.type = MessageType.FRIEND_MESSAGE if self.unique_session and abm.type == MessageType.GROUP_MESSAGE: - abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id + abm.session_id = ( + abm.sender.user_id + "_" + str(event.group_id) + ) # 也保留群组 id else: - abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id - + abm.session_id = ( + str(event.group_id) + if abm.type == MessageType.GROUP_MESSAGE + else abm.sender.user_id + ) + abm.message_id = str(event.message_id) abm.message = [] - + message_str = "" if not isinstance(event.message, list): err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。" @@ -170,99 +195,103 @@ class AiocqhttpAdapter(Platform): except BaseException as e: logger.error(f"回复消息失败: {e}") return - + # 按消息段类型类型适配 for m in event.message: - t = m['type'] + t = m["type"] a = None - if t == 'text': - message_str += m['data']['text'].strip() - a = ComponentTypes[t](**m['data']) # noqa: F405 + if t == "text": + message_str += m["data"]["text"].strip() + a = ComponentTypes[t](**m["data"]) # noqa: F405 abm.message.append(a) - elif t == 'file': - if m['data'].get('url') and m['data'].get('url').startswith("http"): + elif t == "file": + if m["data"].get("url") and m["data"].get("url").startswith("http"): # Lagrange logger.info("guessing lagrange") - - file_name = m['data'].get('file_name', "file") + + file_name = m["data"].get("file_name", "file") path = os.path.join("data/temp", file_name) - await download_file(m['data']['url'], path) - - m['data'] = { - "file": path, - "name": file_name - } - a = ComponentTypes[t](**m['data']) # noqa: F405 + await download_file(m["data"]["url"], path) + + m["data"] = {"file": path, "name": file_name} + a = ComponentTypes[t](**m["data"]) # noqa: F405 abm.message.append(a) - + else: try: # Napcat, LLBot - ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id']) - if not ret.get('file', None): + ret = await self.bot.call_action( + action="get_file", + file_id=event.message[0]["data"]["file_id"], + ) + if not ret.get("file", None): raise ValueError(f"无法解析文件响应: {ret}") - if not os.path.exists(ret['file']): - raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot") - - m['data'] = { - "file": ret['file'], - "name": ret['file_name'] - } - a = ComponentTypes[t](**m['data']) # noqa: F405 + if not os.path.exists(ret["file"]): + raise FileNotFoundError( + f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot" + ) + + m["data"] = {"file": ret["file"], "name": ret["file_name"]} + a = ComponentTypes[t](**m["data"]) # noqa: F405 abm.message.append(a) except ActionFailed as e: logger.error(f"获取文件失败: {e},此消息段将被忽略。") except BaseException as e: logger.error(f"获取文件失败: {e},此消息段将被忽略。") - + else: - a = ComponentTypes[t](**m['data']) # noqa: F405 + a = ComponentTypes[t](**m["data"]) # noqa: F405 abm.message.append(a) abm.timestamp = int(time.time()) abm.message_str = message_str abm.raw_message = event - + return abm - + def run(self) -> Awaitable[Any]: if not self.host or not self.port: - logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199") + logger.warning( + "aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199" + ) self.host = "0.0.0.0" self.port = 6199 - - coro = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) - + + coro = self.bot.run_task( + host=self.host, + port=int(self.port), + shutdown_trigger=self.shutdown_trigger_placeholder, + ) + for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) - logging.getLogger('aiocqhttp').setLevel(logging.ERROR) - + logging.getLogger("aiocqhttp").setLevel(logging.ERROR) + return coro - + async def terminate(self): self.stop = True await asyncio.sleep(1) - + def meta(self) -> PlatformMetadata: return self.metadata - + async def shutdown_trigger_placeholder(self): while not self._event_queue.closed and not self.stop: await asyncio.sleep(1) logger.info("aiocqhttp 适配器已关闭。") async def handle_msg(self, message: AstrBotMessage): - message_event = AiocqhttpMessageEvent( message_str=message.message_str, message_obj=message, platform_meta=self.meta(), session_id=message.session_id, - bot=self.bot + bot=self.bot, ) - + self.commit_event(message_event) async def get_client(self): - return self.bot \ No newline at end of file + return self.bot diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index f26831cd9..8b41abc97 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -13,140 +13,158 @@ from .downloader import GeweDownloader from astrbot.core.utils.io import download_image_by_url -class SimpleGewechatClient(): - '''针对 Gewechat 的简单实现。 - +class SimpleGewechatClient: + """针对 Gewechat 的简单实现。 + @author: Soulter @website: https://github.com/Soulter - ''' - def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue): + """ + + def __init__( + self, + base_url: str, + nickname: str, + host: str, + port: int, + event_queue: asyncio.Queue, + ): self.base_url = base_url - if self.base_url.endswith('/'): + if self.base_url.endswith("/"): self.base_url = self.base_url[:-1] - - self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口 - self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/" - + + self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口 + self.download_base_url = ":".join(self.download_base_url) + ":2532/download/" + self.base_url += "/v2/api" - + logger.info(f"Gewechat API: {self.base_url}") logger.info(f"Gewechat 下载 API: {self.download_base_url}") - + if isinstance(port, str): port = int(port) - + self.token = None self.headers = {} self.nickname = nickname self.appid = sp.get(f"gewechat-appid-{nickname}", "") - + self.server = quart.Quart(__name__) - self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST']) - self.server.add_url_rule('/astrbot-gewechat/file/', view_func=self.handle_file, methods=['GET']) - + self.server.add_url_rule( + "/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"] + ) + self.server.add_url_rule( + "/astrbot-gewechat/file/", + view_func=self.handle_file, + methods=["GET"], + ) + self.host = host - self.port = port + self.port = port self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback" self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file" - + self.event_queue = event_queue - + self.multimedia_downloader = None - + self.userrealnames = {} - + self.stop = False - + async def get_token_id(self): async with aiohttp.ClientSession() as session: async with session.post(f"{self.base_url}/tools/getTokenId") as resp: json_blob = await resp.json() - self.token = json_blob['data'] + self.token = json_blob["data"] logger.info(f"获取到 Gewechat Token: {self.token}") - self.headers = { - "X-GEWE-TOKEN": self.token - } - + self.headers = {"X-GEWE-TOKEN": self.token} + async def _convert(self, data: dict) -> AstrBotMessage: - type_name = data['TypeName'] + type_name = data["TypeName"] if type_name == "Offline": logger.critical("收到 gewechat 下线通知。") return - - if 'Data' in data and 'CreateTime' in data['Data']: + + if "Data" in data and "CreateTime" in data["Data"]: # 得到系统 UTF+8 的 ts tz_offset = datetime.timedelta(hours=8) tz = datetime.timezone(tz_offset) ts = datetime.datetime.now(tz).timestamp() - create_time = data['Data']['CreateTime'] + create_time = data["Data"]["CreateTime"] if create_time < ts - 30: logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}") return - abm = AstrBotMessage() - d = data['Data'] - - from_user_name = d['FromUserName']['string'] # 消息来源 - d['to_wxid'] = from_user_name # 用于发信息 - - abm.message_id = str(d.get('MsgId')) + d = data["Data"] + + from_user_name = d["FromUserName"]["string"] # 消息来源 + d["to_wxid"] = from_user_name # 用于发信息 + + abm.message_id = str(d.get("MsgId")) abm.session_id = from_user_name - abm.self_id = data['Wxid'] # 机器人的 wxid - - user_id = "" # 发送人 wxid - content = d['Content']['string'] # 消息内容 - + abm.self_id = data["Wxid"] # 机器人的 wxid + + user_id = "" # 发送人 wxid + content = d["Content"]["string"] # 消息内容 + at_me = False if "@chatroom" in from_user_name: abm.type = MessageType.GROUP_MESSAGE - _t = content.split(':\n') + _t = content.split(":\n") user_id = _t[0] content = _t[1] - if '\u2005' in content: + if "\u2005" in content: # at # content = content.split('\u2005')[1] - content = re.sub(r'@[^\u2005]*\u2005', '', content) + content = re.sub(r"@[^\u2005]*\u2005", "", content) abm.group_id = from_user_name # at - msg_source = d['MsgSource'] - if f'' in msg_source \ - or f'' in msg_source: + msg_source = d["MsgSource"] + if ( + f"" in msg_source + or f"" in msg_source + ): at_me = True - if '在群聊中@了你' in d.get('PushContent', ''): + if "在群聊中@了你" in d.get("PushContent", ""): at_me = True else: abm.type = MessageType.FRIEND_MESSAGE user_id = from_user_name - + abm.message = [] if at_me: abm.message.insert(0, At(qq=abm.self_id)) - + # 解析用户真实名字 user_real_name = "unknown" if abm.group_id: - if abm.group_id not in self.userrealnames or user_id not in self.userrealnames[abm.group_id]: + if ( + abm.group_id not in self.userrealnames + or user_id not in self.userrealnames[abm.group_id] + ): # 获取群成员列表,并且缓存 if abm.group_id not in self.userrealnames: self.userrealnames[abm.group_id] = {} member_list = await self.get_chatroom_member_list(abm.group_id) logger.debug(f"获取到 {abm.group_id} 的群成员列表。") - if member_list and 'memberList' in member_list: - for member in member_list['memberList']: - self.userrealnames[abm.group_id][member['wxid']] = member['nickName'] + if member_list and "memberList" in member_list: + for member in member_list["memberList"]: + self.userrealnames[abm.group_id][member["wxid"]] = member[ + "nickName" + ] if user_id in self.userrealnames[abm.group_id]: user_real_name = self.userrealnames[abm.group_id][user_id] else: user_real_name = self.userrealnames[abm.group_id][user_id] else: - user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] + user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0] abm.sender = MessageMember(user_id, user_real_name) abm.raw_message = d abm.message_str = "" # 不同消息类型 - match d['MsgType']: + match d["MsgType"]: case 1: # 文本消息 abm.message.append(Plain(content)) @@ -154,23 +172,22 @@ class SimpleGewechatClient(): case 3: # 图片消息 file_url = await self.multimedia_downloader.download_image( - self.appid, - content + self.appid, content ) logger.debug(f"下载图片: {file_url}") file_path = await download_image_by_url(file_url) abm.message.append(Image(file=file_path, url=file_path)) - + case 34: # 语音消息 # data = await self.multimedia_downloader.download_voice( - # self.appid, - # content, + # self.appid, + # content, # abm.message_id # ) # print(data) - if 'ImgBuf' in d and 'buffer' in d['ImgBuf']: - voice_data = base64.b64decode(d['ImgBuf']['buffer']) + if "ImgBuf" in d and "buffer" in d["ImgBuf"]: + voice_data = base64.b64decode(d["ImgBuf"]["buffer"]) file_path = f"data/temp/gewe_voice_{abm.message_id}.silk" with open(file_path, "wb") as f: f.write(voice_data) @@ -178,34 +195,36 @@ class SimpleGewechatClient(): case _: logger.info(f"未实现的消息类型: {d['MsgType']}") abm.raw_message = d - + logger.debug(f"abm: {abm}") return abm async def callback(self): data = await quart.request.json logger.debug(f"收到 gewechat 回调: {data}") - - if data.get('testMsg', None): + + if data.get("testMsg", None): return quart.jsonify({"r": "AstrBot ACK"}) - + abm = None try: abm = await self._convert(data) except BaseException as e: - logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。") - + logger.warning( + f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。" + ) + if abm: coro = getattr(self, "on_event_received") if coro: await coro(abm) - + return quart.jsonify({"r": "AstrBot ACK"}) - + async def handle_file(self, file_id): file_path = f"data/temp/{file_id}" return await quart.send_file(file_path) - + async def _set_callback_url(self): logger.info("设置回调,请等待...") await asyncio.sleep(3) @@ -213,43 +232,40 @@ class SimpleGewechatClient(): async with session.post( f"{self.base_url}/tools/setCallback", headers=self.headers, - json={ - "token": self.token, - "callbackUrl": self.callback_url - } + json={"token": self.token, "callbackUrl": self.callback_url}, ) as resp: json_blob = await resp.json() logger.info(f"设置回调结果: {json_blob}") - if json_blob['ret'] != 200: + if json_blob["ret"] != 200: raise Exception(f"设置回调失败: {json_blob}") - logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。") - + logger.info( + f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。" + ) + async def start_polling(self): threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start() await self.server.run_task( - host='0.0.0.0', - port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder + host="0.0.0.0", + port=self.port, + shutdown_trigger=self.shutdown_trigger_placeholder, ) - + async def shutdown_trigger_placeholder(self): while not self.event_queue.closed and not self.stop: await asyncio.sleep(1) logger.info("gewechat 适配器已关闭。") - + async def check_online(self, appid: str): # /login/checkOnline async with aiohttp.ClientSession() as session: async with session.post( f"{self.base_url}/login/checkOnline", headers=self.headers, - json={ - "appId": appid - } + json={"appId": appid}, ) as resp: json_blob = await resp.json() - return json_blob['data'] - + return json_blob["data"] + async def logout(self): if self.appid: online = await self.check_online(self.appid) @@ -258,53 +274,50 @@ class SimpleGewechatClient(): async with session.post( f"{self.base_url}/login/logout", headers=self.headers, - json={ - "appId": self.appid - } + json={"appId": self.appid}, ) as resp: json_blob = await resp.json() logger.info(f"登出结果: {json_blob}") - + async def login(self): if self.token is None: await self.get_token_id() - - self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token) - + + self.multimedia_downloader = GeweDownloader( + self.base_url, self.download_base_url, self.token + ) + if self.appid: online = await self.check_online(self.appid) if online: logger.info(f"APPID: {self.appid} 已在线") return - - payload = { - "appId": self.appid - } - + + payload = {"appId": self.appid} + if self.appid: logger.info(f"使用 APPID: {self.appid}, {self.nickname}") - + async with aiohttp.ClientSession() as session: async with session.post( - f"{self.base_url}/login/getLoginQrCode", + f"{self.base_url}/login/getLoginQrCode", headers=self.headers, - json=payload + json=payload, ) as resp: json_blob = await resp.json() - if json_blob['ret'] != 200: + if json_blob["ret"] != 200: raise Exception(f"获取二维码失败: {json_blob}") - qr_data = json_blob['data']['qrData'] - qr_uuid = json_blob['data']['uuid'] - appid = json_blob['data']['appId'] + qr_data = json_blob["data"]["qrData"] + qr_uuid = json_blob["data"]["uuid"] + appid = json_blob["data"]["appId"] logger.info(f"APPID: {appid}") - logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}") - + logger.warning( + f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}" + ) + # 执行登录 retry_cnt = 64 - payload.update({ - "uuid": qr_uuid, - "appId": appid - }) + payload.update({"uuid": qr_uuid, "appId": appid}) while retry_cnt > 0: retry_cnt -= 1 @@ -313,10 +326,12 @@ class SimpleGewechatClient(): with open("data/temp/gewe_code", "r") as f: code = f.read().strip() if not code: - logger.warning("未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456") + logger.warning( + "未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456" + ) await asyncio.sleep(5) continue - payload['captchCode'] = code + payload["captchCode"] = code logger.info(f"使用验证码: {code}") try: os.remove("data/temp/gewe_code") @@ -327,21 +342,23 @@ class SimpleGewechatClient(): async with session.post( f"{self.base_url}/login/checkLogin", headers=self.headers, - json=payload + json=payload, ) as resp: json_blob = await resp.json() logger.info(f"检查登录状态: {json_blob}") - ret = json_blob['ret'] - msg = '' - if json_blob['data'] and 'msg' in json_blob['data']: - msg = json_blob['data']['msg'] - if ret == 500 and '安全验证码' in msg: - logger.warning("此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456") + ret = json_blob["ret"] + msg = "" + if json_blob["data"] and "msg" in json_blob["data"]: + msg = json_blob["data"]["msg"] + if ret == 500 and "安全验证码" in msg: + logger.warning( + "此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456" + ) else: - status = json_blob['data']['status'] - nickname = json_blob['data'].get('nickName', '') - if status == 1: + status = json_blob["data"]["status"] + nickname = json_blob["data"].get("nickName", "") + if status == 1: logger.info(f"等待确认...{nickname}") elif status == 2: logger.info(f"绿泡泡平台登录成功: {nickname}") @@ -351,29 +368,26 @@ class SimpleGewechatClient(): else: logger.warning(f"未知状态: {status}") await asyncio.sleep(5) - + if appid: sp.put(f"gewechat-appid-{self.nickname}", appid) self.appid = appid logger.info(f"已保存 APPID: {appid}") - - '''API''' - + + """API""" + async def get_chatroom_member_list(self, chatroom_wxid: str): - payload = { - "appId": self.appid, - "chatroomId": chatroom_wxid - } - + payload = {"appId": self.appid, "chatroomId": chatroom_wxid} + async with aiohttp.ClientSession() as session: async with session.post( f"{self.base_url}/group/getChatroomMemberList", headers=self.headers, - json=payload + json=payload, ) as resp: json_blob = await resp.json() - return json_blob['data'] - + return json_blob["data"] + async def post_text(self, to_wxid, content: str, ats: str = ""): payload = { "appId": self.appid, @@ -381,65 +395,57 @@ class SimpleGewechatClient(): "content": content, } if ats: - payload['ats'] = ats - + payload["ats"] = ats + async with aiohttp.ClientSession() as session: async with session.post( - f"{self.base_url}/message/postText", - headers=self.headers, - json=payload + f"{self.base_url}/message/postText", headers=self.headers, json=payload ) as resp: json_blob = await resp.json() logger.debug(f"发送消息结果: {json_blob}") - + async def post_image(self, to_wxid, image_url: str): payload = { "appId": self.appid, "toWxid": to_wxid, "imgUrl": image_url, } - + async with aiohttp.ClientSession() as session: async with session.post( - f"{self.base_url}/message/postImage", - headers=self.headers, - json=payload + f"{self.base_url}/message/postImage", headers=self.headers, json=payload ) as resp: json_blob = await resp.json() logger.debug(f"发送图片结果: {json_blob}") - + async def post_voice(self, to_wxid, voice_url: str, voice_duration: int): payload = { "appId": self.appid, "toWxid": to_wxid, "voiceUrl": voice_url, - "voiceDuration": voice_duration + "voiceDuration": voice_duration, } - + logger.debug(f"发送语音: {payload}") - + async with aiohttp.ClientSession() as session: async with session.post( - f"{self.base_url}/message/postVoice", - headers=self.headers, - json=payload + f"{self.base_url}/message/postVoice", headers=self.headers, json=payload ) as resp: json_blob = await resp.json() logger.debug(f"发送语音结果: {json_blob}") - + async def post_file(self, to_wxid, file_url: str, file_name: str): payload = { "appId": self.appid, "toWxid": to_wxid, "fileUrl": file_url, - "fileName": file_name + "fileName": file_name, } - + async with aiohttp.ClientSession() as session: async with session.post( - f"{self.base_url}/message/postFile", - headers=self.headers, - json=payload + f"{self.base_url}/message/postFile", headers=self.headers, json=payload ) as resp: json_blob = await resp.json() - logger.debug(f"发送文件结果: {json_blob}") \ No newline at end of file + logger.debug(f"发送文件结果: {json_blob}") diff --git a/astrbot/core/platform/sources/gewechat/downloader.py b/astrbot/core/platform/sources/gewechat/downloader.py index d0f2efce7..d2227e75f 100644 --- a/astrbot/core/platform/sources/gewechat/downloader.py +++ b/astrbot/core/platform/sources/gewechat/downloader.py @@ -2,50 +2,40 @@ from astrbot import logger import aiohttp import json -class GeweDownloader(): + +class GeweDownloader: def __init__(self, base_url: str, download_base_url: str, token: str): self.base_url = base_url self.download_base_url = download_base_url - self.headers = { - "Content-Type": "application/json", - "X-GEWE-TOKEN": token - } - + self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token} + async def _post_json(self, baseurl: str, route: str, payload: dict): async with aiohttp.ClientSession() as session: async with session.post( - f"{baseurl}{route}", - headers=self.headers, - json=payload + f"{baseurl}{route}", headers=self.headers, json=payload ) as resp: return await resp.read() - + async def download_voice(self, appid: str, xml: str, msg_id: str): - payload = { - "appId": appid, - "xml": xml, - "msgId": msg_id - } + payload = {"appId": appid, "xml": xml, "msgId": msg_id} return await self._post_json(self.base_url, "/message/downloadVoice", payload) - + async def download_image(self, appid: str, xml: str) -> str: - '''返回一个可下载的 URL''' - choices = [2, 3] # 2:常规图片 3:缩略图 - + """返回一个可下载的 URL""" + choices = [2, 3] # 2:常规图片 3:缩略图 + for choice in choices: try: - payload = { - "appId": appid, - "xml": xml, - "type": choice - } - data = await self._post_json(self.base_url, "/message/downloadImage", payload) + payload = {"appId": appid, "xml": xml, "type": choice} + data = await self._post_json( + self.base_url, "/message/downloadImage", payload + ) json_blob = json.loads(data) - if 'fileUrl' in json_blob['data']: - return self.download_base_url + json_blob['data']['fileUrl'] + if "fileUrl" in json_blob["data"]: + return self.download_base_url + json_blob["data"]["fileUrl"] except BaseException as e: logger.error(f"gewe download image: {e}") continue - - raise Exception("无法下载图片") \ No newline at end of file + + raise Exception("无法下载图片") diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index fa0f73457..4a90065c2 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -10,8 +10,9 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image, Record, At, File from .client import SimpleGewechatClient + def get_wav_duration(file_path): - with wave.open(file_path, 'rb') as wav_file: + with wave.open(file_path, "rb") as wav_file: file_size = os.path.getsize(file_path) n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4] if n_frames == 2147483647: @@ -22,30 +23,30 @@ def get_wav_duration(file_path): duration = n_frames / float(framerate) return duration + class GewechatPlatformEvent(AstrMessageEvent): def __init__( - self, - message_str: str, - message_obj: AstrBotMessage, - platform_meta: PlatformMetadata, - session_id: str, - client: SimpleGewechatClient - ): + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + client: SimpleGewechatClient, + ): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client - + @staticmethod async def send_with_client(message: MessageChain, user_name: str): pass - - + async def send(self, message: MessageChain): - to_wxid = self.message_obj.raw_message.get('to_wxid', None) - + to_wxid = self.message_obj.raw_message.get("to_wxid", None) + if not to_wxid: logger.error("无法获取到 to_wxid。") return - + # 检查@ ats = [] ats_names = [] @@ -54,7 +55,7 @@ class GewechatPlatformEvent(AstrMessageEvent): ats.append(comp.qq) ats_names.append(comp.name) has_at = False - + for comp in message.chain: if isinstance(comp, Plain): text = comp.text @@ -70,7 +71,7 @@ class GewechatPlatformEvent(AstrMessageEvent): payload["ats"] = ats has_at = True await self.client.post_text(**payload) - + elif isinstance(comp, Image): img_url = comp.file img_path = "" @@ -80,9 +81,9 @@ class GewechatPlatformEvent(AstrMessageEvent): img_path = await download_image_by_url(comp.file) else: img_path = img_url - + # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 - temp_directory = os.path.abspath('data/temp') + temp_directory = os.path.abspath("data/temp") img_path = os.path.abspath(img_path) if os.path.commonpath([temp_directory, img_path]) != temp_directory: with open(img_path, "rb") as f: @@ -96,27 +97,29 @@ class GewechatPlatformEvent(AstrMessageEvent): # 默认已经存在 data/temp 中 record_url = comp.file record_path = "" - + if record_url.startswith("file:///"): record_path = record_url[8:] elif record_url.startswith("http"): await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") else: record_path = record_url - + silk_path = f"data/temp/{uuid.uuid4()}.silk" try: duration = await wav_to_tencent_silk(record_path, silk_path) except Exception as e: logger.error(traceback.format_exc()) - await self.send(MessageChain().message(f"语音文件转换失败。{str(e)}")) + await self.send( + MessageChain().message(f"语音文件转换失败。{str(e)}") + ) logger.info("Silk 语音文件格式转换至: " + record_path) if duration == 0: duration = get_wav_duration(record_path) file_id = os.path.basename(silk_path) record_url = f"{self.client.file_server_url}/{file_id}" logger.debug(f"gewe callback record url: {record_url}") - await self.client.post_voice(to_wxid, record_url, duration*1000) + await self.client.post_voice(to_wxid, record_url, duration * 1000) elif isinstance(comp, File): file_path = comp.file file_name = comp.name @@ -126,14 +129,14 @@ class GewechatPlatformEvent(AstrMessageEvent): await download_file(file_path, f"data/temp/{file_name}") else: file_path = file_path - + file_id = os.path.basename(file_path) file_url = f"{self.client.file_server_url}/{file_id}" logger.debug(f"gewe callback file url: {file_url}") - await self.client.post_file(to_wxid, file_url, file_id) + await self.client.post_file(to_wxid, file_url, file_id) elif isinstance(comp, At): pass else: - logger.debug(f"gewechat 忽略: {comp.type}") - - await super().send(message) \ No newline at end of file + logger.debug(f"gewechat 忽略: {comp.type}") + + await super().send(message) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py index ffbfb6c94..fa7be6bb2 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py @@ -15,44 +15,47 @@ if sys.version_info >= (3, 12): from typing import override else: from typing_extensions import override - + @register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器") class GewechatPlatformAdapter(Platform): - - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) self.config = platform_config self.settingss = platform_settings - self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' + self.test_mode = os.environ.get("TEST_MODE", "off") == "on" self.client = None - + self.client = SimpleGewechatClient( - self.config['base_url'], - self.config['nickname'], - self.config['host'], - self.config['port'], + self.config["base_url"], + self.config["nickname"], + self.config["host"], + self.config["port"], self._event_queue, ) - + async def on_event_received(abm: AstrBotMessage): await self.handle_msg(abm) - + self.client.on_event_received = on_event_received - + @override - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): to_wxid = session.session_id if not to_wxid: logger.error("无法获取到 to_wxid。") return - + for comp in message_chain.chain: if isinstance(comp, Plain): await self.client.post_text(to_wxid, comp.text) await super().send_by_session(session, message_chain) - + @override def meta(self) -> PlatformMetadata: return PlatformMetadata( @@ -63,32 +66,32 @@ class GewechatPlatformAdapter(Platform): async def terminate(self): self.client.stop = True await asyncio.sleep(1) - + async def logout(self): await self.client.logout() - + @override def run(self): return self._run() - + async def _run(self): await self.client.login() await self.client.start_polling() - + async def handle_msg(self, message: AstrBotMessage): if message.type == MessageType.GROUP_MESSAGE: - if self.settingss['unique_session']: + if self.settingss["unique_session"]: message.session_id = message.sender.user_id + "_" + message.group_id - + message_event = GewechatPlatformEvent( message_str=message.message_str, message_obj=message, platform_meta=self.meta(), session_id=message.session_id, - client=self.client + client=self.client, ) - + self.commit_event(message_event) - + def get_client(self): - return self.client \ No newline at end of file + return self.client diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index d0ca5a5b8..282b22149 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -3,7 +3,13 @@ import asyncio import json import re -from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, +) from astrbot.api.event import MessageChain from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion @@ -13,80 +19,87 @@ from astrbot import logger import lark_oapi as lark from lark_oapi.api.im.v1 import * + @register_platform_adapter("lark", "飞书机器人官方 API 适配器") class LarkPlatformAdapter(Platform): - - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) - - self.config = platform_config - - self.unique_session = platform_settings['unique_session'] - self.appid = platform_config['app_id'] - self.appsecret = platform_config['app_secret'] - self.domain = platform_config.get('domain', lark.FEISHU_DOMAIN) - self.bot_name = platform_config.get('lark_bot_name', "astrbot") - + self.config = platform_config + + self.unique_session = platform_settings["unique_session"] + + self.appid = platform_config["app_id"] + self.appsecret = platform_config["app_secret"] + self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN) + self.bot_name = platform_config.get("lark_bot_name", "astrbot") + if not self.bot_name: logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") - + async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): await self.convert_msg(event) - + def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1): asyncio.create_task(on_msg_event_recv(event)) - - self.event_handler = lark.EventDispatcherHandler.builder("", "") \ - .register_p2_im_message_receive_v1(do_v2_msg_event) \ + + self.event_handler = ( + lark.EventDispatcherHandler.builder("", "") + .register_p2_im_message_receive_v1(do_v2_msg_event) .build() - + ) + self.client = lark.ws.Client( app_id=self.appid, app_secret=self.appsecret, log_level=lark.LogLevel.ERROR, domain=self.domain, - event_handler=self.event_handler + event_handler=self.event_handler, ) - + self.lark_api = ( - lark.Client.builder() - .app_id(self.appid) - .app_secret(self.appsecret) - .build() + lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build() ) - - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") - + def meta(self) -> PlatformMetadata: return PlatformMetadata( "lark", "飞书机器人官方 API 适配器", ) - + async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): message = event.event.message abm = AstrBotMessage() abm.timestamp = int(message.create_time) / 1000 abm.message = [] - abm.type = MessageType.GROUP_MESSAGE if message.chat_type == 'group' else MessageType.FRIEND_MESSAGE - if message.chat_type == 'group': + abm.type = ( + MessageType.GROUP_MESSAGE + if message.chat_type == "group" + else MessageType.FRIEND_MESSAGE + ) + if message.chat_type == "group": abm.group_id = message.chat_id abm.self_id = self.bot_name abm.message_str = "" - + at_list = {} if message.mentions: for m in message.mentions: at_list[m.key] = At(qq=m.id.open_id, name=m.name) if m.name == self.bot_name: abm.self_id = m.id.open_id - + content_json_b = json.loads(message.content) - - if message.message_type == 'text': - message_str_raw = content_json_b['text'] # 带有 @ 的消息 + + if message.message_type == "text": + message_str_raw = content_json_b["text"] # 带有 @ 的消息 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 # at_users = re.findall(at_pattern, message_str_raw) # 拆分文本,去掉AT符号部分 @@ -99,41 +112,43 @@ class LarkPlatformAdapter(Platform): abm.message.append(at_list[s]) else: abm.message.append(Plain(parts[i].strip())) - elif message.message_type == 'post': + elif message.message_type == "post": _ls = [] - - content_ls = content_json_b.get('content', []) + + content_ls = content_json_b.get("content", []) for comp in content_ls: if isinstance(comp, list): _ls.extend(comp) elif isinstance(comp, dict): _ls.append(comp) content_json_b = _ls - elif message.message_type == 'image': + elif message.message_type == "image": content_json_b = [ {"tag": "img", "image_key": content_json_b["image_key"], "style": []} ] - - if message.message_type in ('post', 'image'): + + if message.message_type in ("post", "image"): for comp in content_json_b: - if comp['tag'] == 'at': - abm.message.append(at_list[comp['user_id']]) - elif comp['tag'] == 'text' and comp['text'].strip(): - abm.message.append(Plain(comp['text'].strip())) - elif comp['tag'] == 'img': - image_key = comp['image_key'] - request = GetMessageResourceRequest.builder() \ - .message_id(message.message_id) \ - .file_key(image_key) \ - .type("image") \ + if comp["tag"] == "at": + abm.message.append(at_list[comp["user_id"]]) + elif comp["tag"] == "text" and comp["text"].strip(): + abm.message.append(Plain(comp["text"].strip())) + elif comp["tag"] == "img": + image_key = comp["image_key"] + request = ( + GetMessageResourceRequest.builder() + .message_id(message.message_id) + .file_key(image_key) + .type("image") .build() + ) response = await self.lark_api.im.v1.message_resource.aget(request) if not response.success(): logger.error(f"无法下载飞书图片: {image_key}") image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() abm.message.append(Image.fromBase64(image_base64)) - + for comp in abm.message: if isinstance(comp, Plain): abm.message_str += comp.text @@ -141,7 +156,7 @@ class LarkPlatformAdapter(Platform): abm.raw_message = message abm.sender = MessageMember( user_id=event.event.sender.sender_id.open_id, - nickname=event.event.sender.sender_id.open_id[:8] + nickname=event.event.sender.sender_id.open_id[:8], ) # 独立会话 if not self.unique_session: @@ -151,24 +166,24 @@ class LarkPlatformAdapter(Platform): abm.session_id = abm.sender.user_id else: abm.session_id = abm.sender.user_id - + logger.debug(abm) await self.handle_msg(abm) - + async def handle_msg(self, abm: AstrBotMessage): event = LarkMessageEvent( message_str=abm.message_str, message_obj=abm, platform_meta=self.meta(), session_id=abm.session_id, - bot=self.lark_api + bot=self.lark_api, ) - + self._event_queue.put_nowait(event) - + async def run(self): # self.client.start() await self.client._connect() - + async def get_client(self): - return self.client \ No newline at end of file + return self.client diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index f3b1529ec..e170b76a0 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -3,36 +3,32 @@ import uuid import lark_oapi as lark from typing import List from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Plain, Image as AstrBotImage, Record, At, Node, Music, Video -from astrbot.core.utils.io import file_to_base64, download_image_by_url +from astrbot.api.message_components import Plain, Image as AstrBotImage, At +from astrbot.core.utils.io import download_image_by_url from lark_oapi.api.im.v1 import * from astrbot import logger + class LarkMessageEvent(AstrMessageEvent): - def __init__(self, message_str, message_obj, platform_meta, session_id, bot: lark.Client): + def __init__( + self, message_str, message_obj, platform_meta, session_id, bot: lark.Client + ): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot - + @staticmethod async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List: ret = [] _stage = [] for comp in message.chain: if isinstance(comp, Plain): - _stage.append({ - "tag": "md", - "text": comp.text - }) + _stage.append({"tag": "md", "text": comp.text}) elif isinstance(comp, At): - _stage.append({ - "tag": "at", - "user_id": comp.qq, - "style": [] - }) + _stage.append({"tag": "at", "user_id": comp.qq, "style": []}) elif isinstance(comp, AstrBotImage): file_path = "" if comp.file and comp.file.startswith("file:///"): - file_path = comp.file.replace('file:///', '') + file_path = comp.file.replace("file:///", "") elif comp.file and comp.file.startswith("http"): image_file_path = await download_image_by_url(comp.file) file_path = image_file_path @@ -40,31 +36,30 @@ class LarkMessageEvent(AstrMessageEvent): pass else: file_path = comp.file - - request = CreateImageRequest.builder() \ - .request_body( \ - CreateImageRequestBody.builder() \ - .image_type("message") \ - .image(open(file_path, 'rb')) \ - .build() \ - ) \ + + request = ( + CreateImageRequest.builder() + .request_body( + CreateImageRequestBody.builder() + .image_type("message") + .image(open(file_path, "rb")) .build() + ) + .build() + ) response = await lark_client.im.v1.image.acreate(request) if not response.success(): logger.error(f"无法上传飞书图片({response.code}): {response.msg}") image_key = response.data.image_key print(image_key) ret.append(_stage) - ret.append([{ - "tag": "img", - "image_key": image_key - }]) + ret.append([{"tag": "img", "image_key": image_key}]) _stage.clear() else: logger.warning(f"飞书 暂时不支持消息段: {comp.type}") if _stage: - ret.append(_stage) + ret.append(_stage) return ret async def send(self, message: MessageChain): @@ -76,21 +71,23 @@ class LarkMessageEvent(AstrMessageEvent): } } - request = ReplyMessageRequest.builder() \ - .message_id(self.message_obj.message_id) \ - .request_body( \ - ReplyMessageRequestBody.builder() \ - .content(json.dumps(wrapped)) \ - .msg_type("post") \ - .uuid(str(uuid.uuid4())) \ - .reply_in_thread(False) \ - .build() \ - ) \ + request = ( + ReplyMessageRequest.builder() + .message_id(self.message_obj.message_id) + .request_body( + ReplyMessageRequestBody.builder() + .content(json.dumps(wrapped)) + .msg_type("post") + .uuid(str(uuid.uuid4())) + .reply_in_thread(False) + .build() + ) .build() + ) response = await self.bot.im.v1.message.areply(request) - + if not response.success(): logger.error(f"回复飞书消息失败({response.code}): {response.msg}") - - await super().send(message) \ No newline at end of file + + await super().send(message) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2d8cd71f6..ec7eaa8fc 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -12,11 +12,18 @@ from astrbot.api import logger class QQOfficialMessageEvent(AstrMessageEvent): - def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + bot: Client, + ): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot self.send_buffer = None - + async def send(self, message: MessageChain): if not self.send_buffer: self.send_buffer = message @@ -24,61 +31,87 @@ class QQOfficialMessageEvent(AstrMessageEvent): self.send_buffer.chain.extend(message.chain) async def _post_send(self): - '''QQ 官方 API 仅支持回复一次''' + """QQ 官方 API 仅支持回复一次""" source = self.message_obj.raw_message - assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) - - plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) - + assert isinstance( + source, + ( + botpy.message.Message, + botpy.message.GroupMessage, + botpy.message.DirectMessage, + botpy.message.C2CMessage, + ), + ) + + ( + plain_text, + image_base64, + image_path, + ) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) + if not plain_text and not image_base64 and not image_path: return - + payload = { - 'content': plain_text, - 'msg_id': self.message_obj.message_id, + "content": plain_text, + "msg_id": self.message_obj.message_id, } - + match type(source): case botpy.message.GroupMessage: if image_base64: - media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid) - payload['media'] = media - payload['msg_type'] = 7 - await self.bot.api.post_group_message(group_openid=source.group_openid, **payload) + media = await self.upload_group_and_c2c_image( + image_base64, 1, group_openid=source.group_openid + ) + payload["media"] = media + payload["msg_type"] = 7 + await self.bot.api.post_group_message( + group_openid=source.group_openid, **payload + ) case botpy.message.C2CMessage: if image_base64: - media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid) - payload['media'] = media - payload['msg_type'] = 7 - await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload) + media = await self.upload_group_and_c2c_image( + image_base64, 1, openid=source.author.user_openid + ) + payload["media"] = media + payload["msg_type"] = 7 + await self.bot.api.post_c2c_message( + openid=source.author.user_openid, **payload + ) case botpy.message.Message: if image_path: - payload['file_image'] = image_path + payload["file_image"] = image_path await self.bot.api.post_message(channel_id=source.channel_id, **payload) case botpy.message.DirectMessage: if image_path: - payload['file_image'] = image_path + payload["file_image"] = image_path await self.bot.api.post_dms(guild_id=source.guild_id, **payload) await super().send(self.send_buffer) - + self.send_buffer = None - - async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media: + + async def upload_group_and_c2c_image( + self, image_base64: str, file_type: int, **kwargs + ) -> botpy.types.message.Media: payload = { - 'file_data': image_base64, - 'file_type': file_type, - "srv_send_msg": False + "file_data": image_base64, + "file_type": file_type, + "srv_send_msg": False, } - if 'openid' in kwargs: - payload['openid'] = kwargs['openid'] - route = Route("POST", "/v2/users/{openid}/files", openid=kwargs['openid']) + if "openid" in kwargs: + payload["openid"] = kwargs["openid"] + route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) return await self.bot.api._http.request(route, json=payload) - elif 'group_openid' in kwargs: - payload['group_openid'] = kwargs['group_openid'] - route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs['group_openid']) + elif "group_openid" in kwargs: + payload["group_openid"] = kwargs["group_openid"] + route = Route( + "POST", + "/v2/groups/{group_openid}/files", + group_openid=kwargs["group_openid"], + ) return await self.bot.api._http.request(route, json=payload) - + @staticmethod async def _parse_to_qqofficial(message: MessageChain): plain_text = "" @@ -93,10 +126,12 @@ class QQOfficialMessageEvent(AstrMessageEvent): image_file_path = i.file[8:] elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) - image_base64 = file_to_base64(image_file_path).replace("base64://", "") + image_base64 = file_to_base64(image_file_path).replace( + "base64://", "" + ) else: image_base64 = file_to_base64(i.file).replace("base64://", "") image_file_path = i.file else: logger.debug(f"qq_official 忽略 {i.type}") - return plain_text, image_base64, image_file_path \ No newline at end of file + return plain_text, image_base64, image_file_path diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index d0a9a4dcb..f932b4ec8 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -10,7 +10,13 @@ import botpy.types.message import os from botpy import Client -from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, +) from astrbot.api.event import MessageChain from typing import Union, List from astrbot.api.message_components import Image, Plain, At @@ -23,67 +29,84 @@ from astrbot.core.message.components import BaseMessageComponent for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) + # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: 'QQOfficialPlatformAdapter'): + def set_platform(self, platform: "QQOfficialPlatformAdapter"): self.platform = platform - + # 收到群消息 async def on_group_at_message_create(self, message: botpy.message.GroupMessage): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) - abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.GROUP_MESSAGE + ) + abm.session_id = ( + abm.sender.user_id if self.platform.unique_session else message.group_openid + ) self._commit(abm) # 收到频道消息 async def on_at_message_create(self, message: botpy.message.Message): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) - abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.GROUP_MESSAGE + ) + abm.session_id = ( + abm.sender.user_id if self.platform.unique_session else message.channel_id + ) self._commit(abm) - + # 收到私聊消息 async def on_direct_message_create(self, message: botpy.message.DirectMessage): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.FRIEND_MESSAGE + ) abm.session_id = abm.sender.user_id self._commit(abm) - + # 收到 C2C 消息 async def on_c2c_message_create(self, message: botpy.message.C2CMessage): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.FRIEND_MESSAGE + ) abm.session_id = abm.sender.user_id self._commit(abm) - + def _commit(self, abm: AstrBotMessage): - self.platform.commit_event(QQOfficialMessageEvent( - abm.message_str, - abm, - self.platform.meta(), - abm.session_id, - self.platform.client - )) + self.platform.commit_event( + QQOfficialMessageEvent( + abm.message_str, + abm, + self.platform.meta(), + abm.session_id, + self.platform.client, + ) + ) + + @register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器") class QQOfficialPlatformAdapter(Platform): - - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) - + self.config = platform_config - - self.appid = platform_config['appid'] - self.secret = platform_config['secret'] - self.unique_session = platform_settings['unique_session'] - qq_group = platform_config['enable_group_c2c'] - guild_dm = platform_config['enable_guild_direct_message'] + + self.appid = platform_config["appid"] + self.secret = platform_config["secret"] + self.unique_session = platform_settings["unique_session"] + qq_group = platform_config["enable_group_c2c"] + guild_dm = platform_config["enable_guild_direct_message"] if qq_group: self.intents = botpy.Intents( public_messages=True, public_guild_messages=True, - direct_message=guild_dm + direct_message=guild_dm, ) else: self.intents = botpy.Intents( - public_guild_messages=True, - direct_message=guild_dm + public_guild_messages=True, direct_message=guild_dm ) self.client = botClient( intents=self.intents, @@ -92,12 +115,14 @@ class QQOfficialPlatformAdapter(Platform): ) self.client.set_platform(self) - - self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' - - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + + self.test_mode = os.environ.get("TEST_MODE", "off") == "on" + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") - + def meta(self) -> PlatformMetadata: return PlatformMetadata( "qq_official", @@ -105,8 +130,10 @@ class QQOfficialPlatformAdapter(Platform): ) @staticmethod - def _parse_from_qqofficial(message: Union[botpy.message.Message, botpy.message.GroupMessage], - message_type: MessageType): + def _parse_from_qqofficial( + message: Union[botpy.message.Message, botpy.message.GroupMessage], + message_type: MessageType, + ): abm = AstrBotMessage() abm.type = message_type abm.timestamp = int(time.time()) @@ -114,20 +141,15 @@ class QQOfficialPlatformAdapter(Platform): abm.message_id = message.id abm.tag = "qq_official" msg: List[BaseMessageComponent] = [] - - if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage): + if isinstance(message, botpy.message.GroupMessage) or isinstance( + message, botpy.message.C2CMessage + ): if isinstance(message, botpy.message.GroupMessage): - abm.sender = MessageMember( - message.author.member_openid, - "" - ) + abm.sender = MessageMember(message.author.member_openid, "") abm.group_id = message.group_openid else: - abm.sender = MessageMember( - message.author.user_openid, - "" - ) + abm.sender = MessageMember(message.author.user_openid, "") abm.message_str = message.content.strip() abm.self_id = "unknown_selfid" msg.append(At(qq="qq_official")) @@ -137,37 +159,39 @@ class QQOfficialPlatformAdapter(Platform): if i.content_type.startswith("image"): url = i.url if not url.startswith("http"): - url = "https://"+url + url = "https://" + url img = Image.fromURL(url) msg.append(img) abm.message = msg - elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage): + elif isinstance(message, botpy.message.Message) or isinstance( + message, botpy.message.DirectMessage + ): try: abm.self_id = str(message.mentions[0].id) except BaseException as _: abm.self_id = "" plain_content = message.content.replace( - "<@!"+str(abm.self_id)+">", "").strip() + "<@!" + str(abm.self_id) + ">", "" + ).strip() if message.attachments: for i in message.attachments: if i.content_type.startswith("image"): url = i.url if not url.startswith("http"): - url = "https://"+url + url = "https://" + url img = Image.fromURL(url) msg.append(img) abm.message = msg abm.message_str = plain_content abm.sender = MessageMember( - str(message.author.id), - str(message.author.username) + str(message.author.id), str(message.author.username) ) msg.append(At(qq="qq_official")) msg.append(Plain(plain_content)) - + if isinstance(message, botpy.message.Message): abm.group_id = message.channel_id else: @@ -176,10 +200,7 @@ class QQOfficialPlatformAdapter(Platform): return abm def run(self): - return self.client.start( - appid=self.appid, - secret=self.secret - ) - + return self.client.start(appid=self.appid, secret=self.secret) + def get_client(self): - return self.client \ No newline at end of file + return self.client diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 1f131cb68..44686e1dc 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -17,72 +17,85 @@ from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # remove logger handler for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) - + + # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: 'QQOfficialWebhookPlatformAdapter'): + def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"): self.platform = platform - + # 收到群消息 async def on_group_at_message_create(self, message: botpy.message.GroupMessage): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) - abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.GROUP_MESSAGE + ) + abm.session_id = ( + abm.sender.user_id if self.platform.unique_session else message.group_openid + ) self._commit(abm) # 收到频道消息 async def on_at_message_create(self, message: botpy.message.Message): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) - abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.GROUP_MESSAGE + ) + abm.session_id = ( + abm.sender.user_id if self.platform.unique_session else message.channel_id + ) self._commit(abm) - + # 收到私聊消息 async def on_direct_message_create(self, message: botpy.message.DirectMessage): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.FRIEND_MESSAGE + ) abm.session_id = abm.sender.user_id self._commit(abm) - + # 收到 C2C 消息 async def on_c2c_message_create(self, message: botpy.message.C2CMessage): - abm = QQOfficialPlatformAdapter._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) + abm = QQOfficialPlatformAdapter._parse_from_qqofficial( + message, MessageType.FRIEND_MESSAGE + ) abm.session_id = abm.sender.user_id self._commit(abm) - + def _commit(self, abm: AstrBotMessage): - self.platform.commit_event(QQOfficialWebhookMessageEvent( - abm.message_str, - abm, - self.platform.meta(), - abm.session_id, - self - )) - + self.platform.commit_event( + QQOfficialWebhookMessageEvent( + abm.message_str, abm, self.platform.meta(), abm.session_id, self + ) + ) + + @register_platform_adapter("qq_official_webhook", "QQ 机器人官方 API 适配器(Webhook)") class QQOfficialWebhookPlatformAdapter(Platform): - - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) - + self.config = platform_config - - self.appid = platform_config['appid'] - self.secret = platform_config['secret'] - self.unique_session = platform_settings['unique_session'] + + self.appid = platform_config["appid"] + self.secret = platform_config["secret"] + self.unique_session = platform_settings["unique_session"] intents = botpy.Intents( - public_messages=True, - public_guild_messages=True, - direct_message=True + public_messages=True, public_guild_messages=True, direct_message=True ) self.client = botClient( - intents=intents, # 已经无用 + intents=intents, # 已经无用 bot_log=False, timeout=20, ) self.client.set_platform(self) - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") - + def meta(self) -> PlatformMetadata: return PlatformMetadata( "qq_official_webhook", @@ -91,13 +104,10 @@ class QQOfficialWebhookPlatformAdapter(Platform): async def run(self): self.webhook_helper = QQOfficialWebhook( - self.config, - self._event_queue, - self.client + self.config, self._event_queue, self.client ) await self.webhook_helper.initialize() await self.webhook_helper.start_polling() - - + async def get_client(self): - return self.client \ No newline at end of file + return self.client diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py index 2056ab56e..4c0bf8329 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py @@ -1,18 +1,15 @@ -import botpy -import botpy.message -import botpy.types -import botpy.types.message -from astrbot.core.utils.io import file_to_base64, download_image_by_url -from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image, Reply from botpy import Client -from botpy.http import Route -from astrbot.api import logger from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent class QQOfficialWebhookMessageEvent(QQOfficialMessageEvent): - def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + bot: Client, + ): super().__init__(message_str, message_obj, platform_meta, session_id, bot) - \ No newline at end of file diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 81ce8469e..f863c87be 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -1,47 +1,46 @@ -import aiohttp import quart -import json import logging import asyncio -import typing from botpy import BotAPI, BotHttp, Client, Token, BotWebSocket, ConnectionSession from astrbot.api import logger -import traceback from cryptography.hazmat.primitives.asymmetric import ed25519 # remove logger handler for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) -class QQOfficialWebhook(): + +class QQOfficialWebhook: def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client): - self.appid = config['appid'] - self.secret = config['secret'] + self.appid = config["appid"] + self.secret = config["secret"] self.port = config.get("port", 6196) - + if isinstance(self.port, str): self.port = int(self.port) - + self.http: BotHttp = BotHttp(timeout=300) self.api: BotAPI = BotAPI(http=self.http) self.token = Token(self.appid, self.secret) - + self.server = quart.Quart(__name__) - self.server.add_url_rule('/astrbot-qo-webhook/callback', view_func=self.callback, methods=['POST']) + self.server.add_url_rule( + "/astrbot-qo-webhook/callback", view_func=self.callback, methods=["POST"] + ) self.client = botpy_client self.event_queue = event_queue - + async def initialize(self): - logger.info(f"正在登录到 QQ 官方机器人...") + logger.info("正在登录到 QQ 官方机器人...") self.user = await self.http.login(self.token) logger.info(f"已登录 QQ 官方机器人账号: {self.user}") # 直接注入到 botpy 的 Client,移花接木! self.client.api = self.api self.client.http = self.http - + async def bot_connect(): pass - + self._connection = ConnectionSession( max_async=1, connect=bot_connect, @@ -54,21 +53,22 @@ class QQOfficialWebhook(): seed = bot_secret while len(seed) < target_size: seed *= 2 - return seed[:target_size].encode('utf-8') + return seed[:target_size].encode("utf-8") - async def webhook_validation(self, validation_payload: dict): seed = await self.repeat_seed(self.secret) private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed) - msg = validation_payload.get("event_ts", "") + validation_payload.get("plain_token", "") + msg = validation_payload.get("event_ts", "") + validation_payload.get( + "plain_token", "" + ) # sign signature = private_key.sign(msg.encode()).hex() response = { "plain_token": validation_payload.get("plain_token"), - "signature": signature + "signature": signature, } return response - + async def callback(self): msg: dict = await quart.request.json logger.debug(f"收到 qq_official_webhook 回调: {msg}") @@ -76,7 +76,7 @@ class QQOfficialWebhook(): event = msg.get("t") opcode = msg.get("op") data = msg.get("d") - + if opcode == 13: # validation signed = await self.webhook_validation(data) @@ -91,18 +91,17 @@ class QQOfficialWebhook(): logger.error("_parser unknown event %s.", event) else: func(msg) - + return {"opcode": 12} - + async def start_polling(self): await self.server.run_task( - host='0.0.0.0', - port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder + host="0.0.0.0", + port=self.port, + shutdown_trigger=self.shutdown_trigger_placeholder, ) - + async def shutdown_trigger_placeholder(self): while not self.event_queue.closed: await asyncio.sleep(1) logger.info("qq_official_webhook 适配器已关闭。") - \ No newline at end of file diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 446d9ec38..6c7cdf1fc 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -2,9 +2,22 @@ import sys import uuid import asyncio -from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, PlatformMetadata, MessageType +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + PlatformMetadata, + MessageType, +) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image, Record, File as AstrBotFile, Video, At +from astrbot.api.message_components import ( + Plain, + Image, + Record, + File as AstrBotFile, + Video, + At, +) from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.api.platform import register_platform_adapter @@ -20,32 +33,45 @@ if sys.version_info >= (3, 12): else: from typing_extensions import override + @register_platform_adapter("telegram", "telegram 适配器") class TelegramPlatformAdapter(Platform): - - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) self.config = platform_config self.settings = platform_settings self.client_self_id = uuid.uuid4().hex[:8] - - base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot") + + base_url = self.config.get( + "telegram_api_base_url", "https://api.telegram.org/bot" + ) if not base_url: base_url = "https://api.telegram.org/bot" - self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build() + self.application = ( + ApplicationBuilder() + .token(self.config["telegram_token"]) + .base_url(base_url) + .build() + ) message_handler = TelegramMessageHandler( filters=filters.ALL, # receive all messages - callback=self.convert_message + callback=self.convert_message, ) self.application.add_handler(message_handler) self.client = self.application.bot - + @override - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): from_username = session.session_id - await TelegramPlatformEvent.send_with_client(self.client, message_chain, from_username) + await TelegramPlatformEvent.send_with_client( + self.client, message_chain, from_username + ) await super().send_by_session(session, message_chain) - + @override def meta(self) -> PlatformMetadata: return PlatformMetadata( @@ -62,9 +88,13 @@ class TelegramPlatformAdapter(Platform): await queue async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): - await context.bot.send_message(chat_id=update.effective_chat.id, text=self.config["start_message"]) + await context.bot.send_message( + chat_id=update.effective_chat.id, text=self.config["start_message"] + ) - async def convert_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> AstrBotMessage: + async def convert_message( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> AstrBotMessage: message = AstrBotMessage() # 获得是群聊还是私聊 if update.effective_chat.type == ChatType.PRIVATE: @@ -74,57 +104,68 @@ class TelegramPlatformAdapter(Platform): message.group_id = update.effective_chat.id message.message_id = str(update.message.message_id) message.session_id = str(update.effective_chat.id) - message.sender = MessageMember(str(update.effective_user.id), update.effective_user.username) + message.sender = MessageMember( + str(update.effective_user.id), update.effective_user.username + ) message.self_id = str(context.bot.id) message.raw_message = update message.message_str = "" message.message = [] - + logger.debug(f"Telegram message: {update.message}") - + if update.message.text: plain_text = update.message.text if update.message.entities: for entity in update.message.entities: if entity.type == "mention": - name = plain_text[entity.offset:entity.offset+entity.length] + name = plain_text[entity.offset : entity.offset + entity.length] message.message.append(At(qq=message.self_id, name=name)) - plain_text = plain_text[:entity.offset] + plain_text[entity.offset+entity.length:] - + plain_text = ( + plain_text[: entity.offset] + + plain_text[entity.offset + entity.length :] + ) + message.message.append(Plain(plain_text)) message.message_str = plain_text - - + elif update.message.voice: file = await update.message.voice.get_file() - message.message = [Record(file=file.file_path, url=file.file_path),] - + message.message = [ + Record(file=file.file_path, url=file.file_path), + ] + elif update.message.photo: - photo = update.message.photo[-1] # get the largest photo + photo = update.message.photo[-1] # get the largest photo file = await photo.get_file() message.message.append(Image(file=file.file_path, url=file.file_path)) - + elif update.message.document: file = await update.message.document.get_file() - message.message = [AstrBotFile(file=file.file_path, name=update.message.document.file_name),] - + message.message = [ + AstrBotFile( + file=file.file_path, name=update.message.document.file_name + ), + ] + elif update.message.video: file = await update.message.video.get_file() - message.message = [Video(file=file.file_path, path=file.file_path),] - - + message.message = [ + Video(file=file.file_path, path=file.file_path), + ] + await self.handle_msg(message) - + async def handle_msg(self, message: AstrBotMessage): message_event = TelegramPlatformEvent( message_str=message.message_str, message_obj=message, platform_meta=self.meta(), session_id=message.session_id, - client=self.client + client=self.client, ) self.commit_event(message_event) - + async def get_client(self): - return self.client \ No newline at end of file + return self.client diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 625e94c48..de0f9a58e 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -3,15 +3,23 @@ from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType from astrbot.api.message_components import Plain, Image, Reply, At, File, Record from telegram.ext import ExtBot + class TelegramPlatformEvent(AstrMessageEvent): - def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: ExtBot): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + client: ExtBot, + ): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client - + @staticmethod async def send_with_client(client: ExtBot, message: MessageChain, user_name: str): image_path = None - + has_reply = False reply_message_id = None at_user_id = None @@ -21,7 +29,7 @@ class TelegramPlatformEvent(AstrMessageEvent): reply_message_id = i.id if isinstance(i, At): at_user_id = i.name - + at_flag = False for i in message.chain: payload = { @@ -29,7 +37,7 @@ class TelegramPlatformEvent(AstrMessageEvent): } if has_reply: payload["reply_to_message_id"] = reply_message_id - + if isinstance(i, Plain): if at_user_id and not at_flag: i.text = f"@{at_user_id} " + i.text @@ -43,6 +51,7 @@ class TelegramPlatformEvent(AstrMessageEvent): if image_path.startswith("base64://"): import base64 + base64_data = image_path[9:] image_bytes = base64.b64decode(base64_data) await client.send_photo(photo=image_bytes, **payload) @@ -52,11 +61,10 @@ class TelegramPlatformEvent(AstrMessageEvent): await client.send_document(document=i.file, filename=i.name, **payload) elif isinstance(i, Record): await client.send_voice(voice=i.file, **payload) - async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: await self.send_with_client(self.client, message, self.message_obj.group_id) else: await self.send_with_client(self.client, message, self.get_sender_id()) - await super().send(message) \ No newline at end of file + await super().send(message) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 63da1d6f4..f6a1cd8f7 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -3,7 +3,13 @@ import asyncio import uuid import os from typing import Awaitable, Any -from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, +) from astrbot.api.event import MessageChain from astrbot.api.message_components import Plain, Image, Record # noqa: F403 from astrbot.api import logger @@ -12,32 +18,38 @@ from .webchat_event import WebChatMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter + class QueueListener: def __init__(self, queue: asyncio.Queue, callback: callable) -> None: self.queue = queue self.callback = callback - + async def run(self): while True: data = await self.queue.get() await self.callback(data) + @register_platform_adapter("webchat", "webchat") class WebChatAdapter(Platform): - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) - + self.config = platform_config self.settings = platform_settings - self.unique_session = platform_settings['unique_session'] + self.unique_session = platform_settings["unique_session"] self.imgs_dir = "data/webchat/imgs" self.metadata = PlatformMetadata( "webchat", "webchat", ) - - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): # abm.session_id = f"webchat!{username}!{cid}" plain = "" cid = session.session_id.split("!")[-1] @@ -45,68 +57,72 @@ class WebChatAdapter(Platform): if isinstance(comp, Plain): plain += comp.text web_chat_back_queue.put_nowait((plain, cid)) - + await super().send_by_session(session, message_chain) - + async def convert_message(self, data: tuple) -> AstrBotMessage: username, cid, payload = data - - + abm = AstrBotMessage() abm.self_id = "webchat" abm.tag = "webchat" - abm.sender = MessageMember(username, username) + abm.sender = MessageMember(username, username) abm.type = MessageType.FRIEND_MESSAGE - + abm.session_id = f"webchat!{username}!{cid}" - + abm.message_id = str(uuid.uuid4()) abm.message = [] - - if payload['message']: - abm.message.append(Plain(payload['message'])) - if payload['image_url']: - if isinstance(payload['image_url'], list): - for img in payload['image_url']: - abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img))) + + if payload["message"]: + abm.message.append(Plain(payload["message"])) + if payload["image_url"]: + if isinstance(payload["image_url"], list): + for img in payload["image_url"]: + abm.message.append( + Image.fromFileSystem(os.path.join(self.imgs_dir, img)) + ) else: - abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url']))) - if payload['audio_url']: - if isinstance(payload['audio_url'], list): - for audio in payload['audio_url']: + abm.message.append( + Image.fromFileSystem( + os.path.join(self.imgs_dir, payload["image_url"]) + ) + ) + if payload["audio_url"]: + if isinstance(payload["audio_url"], list): + for audio in payload["audio_url"]: path = os.path.join(self.imgs_dir, audio) abm.message.append(Record(file=path, path=path)) else: - path = os.path.join(self.imgs_dir, payload['audio_url']) + path = os.path.join(self.imgs_dir, payload["audio_url"]) abm.message.append(Record(file=path, path=path)) - + logger.debug(f"WebChatAdapter: {abm.message}") - - message_str = payload['message'] + + message_str = payload["message"] abm.timestamp = int(time.time()) abm.message_str = message_str abm.raw_message = data return abm - + def run(self) -> Awaitable[Any]: async def callback(data: tuple): abm = await self.convert_message(data) await self.handle_msg(abm) - + bot = QueueListener(web_chat_queue, callback) return bot.run() - + def meta(self) -> PlatformMetadata: return self.metadata async def handle_msg(self, message: AstrBotMessage): - message_event = WebChatMessageEvent( message_str=message.message_str, message_obj=message, platform_meta=self.meta(), - session_id=message.session_id + session_id=message.session_id, ) - - self.commit_event(message_event) \ No newline at end of file + + self.commit_event(message_event) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index afcda5e1c..01bc4b352 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -7,19 +7,20 @@ from astrbot.api.message_components import Plain, Image from astrbot.core.utils.io import download_image_by_url from astrbot.core import web_chat_back_queue + class WebChatMessageEvent(AstrMessageEvent): def __init__(self, message_str, message_obj, platform_meta, session_id): super().__init__(message_str, message_obj, platform_meta, session_id) self.imgs_dir = "data/webchat/imgs" - os.makedirs(self.imgs_dir, exist_ok=True) + os.makedirs(self.imgs_dir, exist_ok=True) async def send(self, message: MessageChain): if not message: web_chat_back_queue.put_nowait(None) return - + cid = self.session_id.split("!")[-1] - + for comp in message.chain: if isinstance(comp, Plain): web_chat_back_queue.put_nowait((comp.text, cid)) @@ -47,4 +48,4 @@ class WebChatMessageEvent(AstrMessageEvent): else: logger.debug(f"webchat 忽略: {comp.type}") web_chat_back_queue.put_nowait(None) - await super().send(message) \ No newline at end of file + await super().send(message) diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 6bfea54e7..2e3f5fb19 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -3,7 +3,13 @@ import uuid import asyncio import quart -from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, PlatformMetadata, MessageType +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + PlatformMetadata, + MessageType, +) from astrbot.api.event import MessageChain from astrbot.api.message_components import Plain, Image, Record from astrbot.core.platform.astr_message_event import MessageSesion @@ -22,120 +28,118 @@ if sys.version_info >= (3, 12): from typing import override else: from typing_extensions import override - -class WecomServer(): - def __init__( - self, - event_queue: asyncio.Queue, - config: dict - ): + + +class WecomServer: + def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) self.port = int(config.get("port")) - self.server.add_url_rule('/callback/command', view_func=self.verify, methods=['GET']) - self.server.add_url_rule('/callback/command', view_func=self.callback_command, methods=['POST']) + self.server.add_url_rule( + "/callback/command", view_func=self.verify, methods=["GET"] + ) + self.server.add_url_rule( + "/callback/command", view_func=self.callback_command, methods=["POST"] + ) self.event_queue = event_queue - + self.crypto = WeChatCrypto( - config['token'].strip(), - config['encoding_aes_key'].strip(), - config['corpid'].strip() + config["token"].strip(), + config["encoding_aes_key"].strip(), + config["corpid"].strip(), ) self.callback = None - + async def verify(self): logger.info(f"验证请求有效性: {quart.request.args}") args = quart.request.args try: echo_str = self.crypto.check_signature( - args.get('msg_signature'), - args.get('timestamp'), - args.get('nonce'), - args.get('echostr') + args.get("msg_signature"), + args.get("timestamp"), + args.get("nonce"), + args.get("echostr"), ) logger.info("验证请求有效性成功。") return echo_str except InvalidSignatureException: logger.error("验证请求有效性失败,签名异常,请检查配置。") - raise - + raise + async def callback_command(self): data = await quart.request.get_data() - msg_signature = quart.request.args.get('msg_signature') - timestamp = quart.request.args.get('timestamp') - nonce = quart.request.args.get('nonce') + msg_signature = quart.request.args.get("msg_signature") + timestamp = quart.request.args.get("timestamp") + nonce = quart.request.args.get("nonce") try: - xml = self.crypto.decrypt_message( - data, - msg_signature, - timestamp, - nonce - ) + xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce) except InvalidSignatureException: logger.error("解密失败,签名异常,请检查配置。") raise else: msg = parse_message(xml) logger.info(f"解析成功: {msg}") - + if self.callback: await self.callback(msg) - + return "success" - + async def start_polling(self): logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。") await self.server.run_task( - host='0.0.0.0', - port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder + host="0.0.0.0", + port=self.port, + shutdown_trigger=self.shutdown_trigger_placeholder, ) - + async def shutdown_trigger_placeholder(self): while not self.event_queue.closed: await asyncio.sleep(1) logger.info("企业微信 适配器已关闭。") - + @register_platform_adapter("wecom", "wecom 适配器") class WecomPlatformAdapter(Platform): - def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: super().__init__(event_queue) self.config = platform_config self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] - self.api_base_url = platform_config.get("api_base_url", "https://qyapi.weixin.qq.com/cgi-bin/") - + self.api_base_url = platform_config.get( + "api_base_url", "https://qyapi.weixin.qq.com/cgi-bin/" + ) + if not self.api_base_url: self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/" - + if self.api_base_url.endswith("/"): self.api_base_url = self.api_base_url[:-1] if not self.api_base_url.endswith("/cgi-bin"): self.api_base_url += "/cgi-bin" - + if not self.api_base_url.endswith("/"): self.api_base_url += "/" - - self.server = WecomServer( - self._event_queue, - self.config - ) - + + self.server = WecomServer(self._event_queue, self.config) + self.client = WeChatClient( - self.config['corpid'].strip(), - self.config['secret'].strip(), + self.config["corpid"].strip(), + self.config["secret"].strip(), ) self.client.API_BASE_URL = self.api_base_url - + async def callback(msg): await self.convert_message(msg) - + self.server.callback = callback - @override - async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): await super().send_by_session(session, message_chain) @override @@ -148,10 +152,10 @@ class WecomPlatformAdapter(Platform): @override async def run(self): await self.server.start_polling() - + async def convert_message(self, msg): abm = AstrBotMessage() - if msg.type == 'text': + if msg.type == "text": assert isinstance(msg, TextMessage) abm.message_str = msg.content abm.self_id = str(msg.agent) @@ -165,7 +169,7 @@ class WecomPlatformAdapter(Platform): abm.timestamp = msg.time abm.session_id = abm.sender.user_id abm.raw_message = msg - elif msg.type == 'image': + elif msg.type == "image": assert isinstance(msg, ImageMessage) abm.message_str = "[图片]" abm.self_id = str(msg.agent) @@ -179,18 +183,16 @@ class WecomPlatformAdapter(Platform): abm.timestamp = msg.time abm.session_id = abm.sender.user_id abm.raw_message = msg - elif msg.type == 'voice': + elif msg.type == "voice": assert isinstance(msg, VoiceMessage) - + resp: Response = await asyncio.get_event_loop().run_in_executor( - None, - self.client.media.download, - msg.media_id + None, self.client.media.download, msg.media_id ) path = f"data/temp/wecom_{msg.media_id}.amr" - with open(path, 'wb') as f: + with open(path, "wb") as f: f.write(resp.content) - + try: from pydub import AudioSegment @@ -201,7 +203,7 @@ class WecomPlatformAdapter(Platform): logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。") path_wav = path return - + abm.message_str = "" abm.self_id = str(msg.agent) abm.message = [Record(file=path_wav, url=path_wav)] @@ -214,9 +216,7 @@ class WecomPlatformAdapter(Platform): abm.timestamp = msg.time abm.session_id = abm.sender.user_id abm.raw_message = msg - - - + logger.info(f"abm: {abm}") await self.handle_msg(abm) @@ -226,9 +226,9 @@ class WecomPlatformAdapter(Platform): message_obj=message, platform_meta=self.meta(), session_id=message.session_id, - client=self.client + client=self.client, ) self.commit_event(message_event) - + def get_client(self): - return self.client \ No newline at end of file + return self.client diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index 246a74e57..f30d1ac32 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -2,9 +2,4 @@ from .provider import Provider, Personality, STTProvider from .entites import ProviderMetaData -__all__ = [ - "Provider", - "Personality", - "ProviderMetaData", - "STTProvider" -] \ No newline at end of file +__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"] diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 87b9d006b..a98bbfa7c 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -10,55 +10,58 @@ class ProviderType(enum.Enum): CHAT_COMPLETION = "chat_completion" SPEECH_TO_TEXT = "speech_to_text" TEXT_TO_SPEECH = "text_to_speech" - -@dataclass -class ProviderMetaData(): - type: str - '''提供商适配器名称,如 openai, ollama''' - desc: str = "" - '''提供商适配器描述.''' - provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: Type = None - - default_config_tmpl: dict = None - '''平台的默认配置模板''' - provider_display_name: str = None - '''显示在 WebUI 配置页中的提供商名称,如空则是 type''' + @dataclass -class ProviderRequest(): +class ProviderMetaData: + type: str + """提供商适配器名称,如 openai, ollama""" + desc: str = "" + """提供商适配器描述.""" + provider_type: ProviderType = ProviderType.CHAT_COMPLETION + cls_type: Type = None + + default_config_tmpl: dict = None + """平台的默认配置模板""" + provider_display_name: str = None + """显示在 WebUI 配置页中的提供商名称,如空则是 type""" + + +@dataclass +class ProviderRequest: prompt: str - '''提示词''' + """提示词""" session_id: str = "" - '''会话 ID''' + """会话 ID""" image_urls: List[str] = None - '''图片 URL 列表''' + """图片 URL 列表""" func_tool: FuncCall = None - '''工具''' + """工具""" contexts: List = None - '''上下文。格式与 openai 的上下文格式一致: + """上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages - ''' + """ system_prompt: str = "" - '''系统提示词''' + """系统提示词""" conversation: Conversation = None - + def __repr__(self): return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})" - + def __str__(self): return self.__repr__() - + + @dataclass class LLMResponse: role: str - '''角色, assistant, tool, err''' + """角色, assistant, tool, err""" completion_text: str = "" - '''LLM 返回的文本''' + """LLM 返回的文本""" tools_call_args: List[Dict[str, any]] = field(default_factory=list) - '''工具调用参数''' + """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) - '''工具调用名称''' - + """工具调用名称""" + raw_completion: ChatCompletion = None - _new_record: Dict[str, any] = None \ No newline at end of file + _new_record: Dict[str, any] = None diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 8ddc21b75..1b2bd6410 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -14,10 +14,11 @@ class FuncTool: parameters: Dict description: str handler: Awaitable - handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools active: bool = True - '''是否激活''' + """是否激活""" + SUPPORTED_TYPES = [ "string", @@ -101,7 +102,7 @@ class FuncCall: } ) return _l - + def get_func_desc_anthropic_style(self) -> list: """ 获得 Anthropic API 风格的**已经激活**的工具描述 @@ -119,12 +120,12 @@ class FuncCall: "type": "object", "properties": f.parameters.get("properties", {}), # Keep the required field from the original parameters if it exists - "required": f.parameters.get("required", []) - } + "required": f.parameters.get("required", []), + }, } tools.append(tool) return tools - + def get_func_desc_google_genai_style(self) -> Dict: declarations = {} tools = [] @@ -132,10 +133,7 @@ class FuncCall: if not f.active: continue - func_declaration = { - "name": f.name, - "description": f.description - } + func_declaration = {"name": f.name, "description": f.description} # 检查并添加非空的properties参数 params = f.parameters if isinstance(f.parameters, dict) else {} @@ -147,7 +145,6 @@ class FuncCall: if tools: declarations["function_declarations"] = tools return declarations - async def func_call(self, question: str, session_id: str, provider) -> tuple: _l = [] @@ -220,9 +217,8 @@ class FuncCall: tool_call_result.append(str(ret)) return tool_call_result, True - def __str__(self): return str(self.func_list) - + def __repr__(self): - return str(self.func_list) \ No newline at end of file + return str(self.func_list) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 546198094..3da2ea00c 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -7,25 +7,28 @@ from astrbot.core.db import BaseDatabase from .register import provider_cls_map, llm_tools from astrbot.core import logger, sp -class ProviderManager(): + +class ProviderManager: def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase): - self.providers_config: List = config['provider'] - self.provider_settings: dict = config['provider_settings'] - self.provider_stt_settings: dict = config.get('provider_stt_settings', {}) - self.provider_tts_settings: dict = config.get('provider_tts_settings', {}) - self.persona_configs: list = config.get('persona', []) + self.providers_config: List = config["provider"] + self.provider_settings: dict = config["provider_settings"] + self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) + self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) + self.persona_configs: list = config.get("persona", []) self.astrbot_config = config - + self.selected_provider_id = sp.get("curr_provider") self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") self.selected_tts_provider_id = self.provider_settings.get("provider_id") self.provider_enabled = self.provider_settings.get("enable", False) self.stt_enabled = self.provider_stt_settings.get("enable", False) self.tts_enabled = self.provider_tts_settings.get("enable", False) - + # 人格情景管理 # 目前没有拆成独立的模块 - self.default_persona_name = self.provider_settings.get('default_personality', 'default') + self.default_persona_name = self.provider_settings.get( + "default_personality", "default" + ) self.personas: List[Personality] = [] self.selected_default_persona = None for persona in self.persona_configs: @@ -35,76 +38,81 @@ class ProviderManager(): mid_processed = "" if begin_dialogs: if len(begin_dialogs) % 2 != 0: - logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。") + logger.error( + f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。" + ) begin_dialogs = [] user_turn = True for dialog in begin_dialogs: - bd_processed.append({ - "role": "user" if user_turn else "assistant", - "content": dialog, - "_no_save": None # 不持久化到 db - }) + bd_processed.append( + { + "role": "user" if user_turn else "assistant", + "content": dialog, + "_no_save": None, # 不持久化到 db + } + ) user_turn = not user_turn if mood_imitation_dialogs: if len(mood_imitation_dialogs) % 2 != 0: - logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。") + logger.error( + f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。" + ) mood_imitation_dialogs = [] user_turn = True for dialog in mood_imitation_dialogs: role = "A" if user_turn else "B" mid_processed += f"{role}: {dialog}\n" if not user_turn: - mid_processed += '\n' + mid_processed += "\n" user_turn = not user_turn - + try: persona = Personality( - **persona, + **persona, _begin_dialogs_processed=bd_processed, - _mood_imitation_dialogs_processed=mid_processed + _mood_imitation_dialogs_processed=mid_processed, ) - if persona['name'] == self.default_persona_name: + if persona["name"] == self.default_persona_name: self.selected_default_persona = persona self.personas.append(persona) except Exception as e: logger.error(f"解析 Persona 配置失败:{e}") - + if not self.selected_default_persona and len(self.personas) > 0: # 默认选择第一个 self.selected_default_persona = self.personas[0] - + if not self.selected_default_persona: self.selected_default_persona = Personality( prompt="You are a helpful and friendly assistant.", name="default", _begin_dialogs_processed=[], - _mood_imitation_dialogs_processed="" + _mood_imitation_dialogs_processed="", ) self.personas.append(self.selected_default_persona) self.provider_insts: List[Provider] = [] - '''加载的 Provider 的实例''' + """加载的 Provider 的实例""" self.stt_provider_insts: List[STTProvider] = [] - '''加载的 Speech To Text Provider 的实例''' + """加载的 Speech To Text Provider 的实例""" self.tts_provider_insts: List[TTSProvider] = [] - '''加载的 Text To Speech Provider 的实例''' + """加载的 Text To Speech Provider 的实例""" self.inst_map = {} - '''Provider 实例映射. key: provider_id, value: Provider 实例''' + """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools self.curr_provider_inst: Provider = None - '''当前使用的 Provider 实例''' + """当前使用的 Provider 实例""" self.curr_stt_provider_inst: STTProvider = None - '''当前使用的 Speech To Text Provider 实例''' + """当前使用的 Speech To Text Provider 实例""" self.curr_tts_provider_inst: TTSProvider = None - '''当前使用的 Text To Speech Provider 实例''' + """当前使用的 Text To Speech Provider 实例""" self.db_helper = db_helper - + # kdb(experimental) self.curr_kdb_name = "" kdb_cfg = config.get("knowledge_db", {}) if kdb_cfg and len(kdb_cfg): self.curr_kdb_name = list(kdb_cfg.keys())[0] - async def initialize(self): for provider_config in self.providers_config: @@ -112,131 +120,184 @@ class ProviderManager(): if not self.curr_provider_inst: logger.warning("未启用任何用于 文本生成 的提供商适配器。") - + if self.stt_enabled and not self.curr_stt_provider_inst: - logger.warning("未启用任何用于 语音转文本 的提供商适配器。") - + logger.warning("未启用任何用于 语音转文本 的提供商适配器。") + if self.tts_enabled and not self.curr_tts_provider_inst: - logger.warning("未启用任何用于 文本转语音 的提供商适配器。") - + logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + async def load_provider(self, provider_config: dict): - if not provider_config['enable']: + if not provider_config["enable"]: return - - logger.info(f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ...") + + logger.info( + f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ..." + ) logger.debug(f"Provider Config: {provider_config}") - + # 动态导入 try: - match provider_config['type']: + match provider_config["type"]: case "openai_chat_completion": - from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial + from .sources.openai_source import ( + ProviderOpenAIOfficial as ProviderOpenAIOfficial, + ) case "zhipu_chat_completion": from .sources.zhipu_source import ProviderZhipu as ProviderZhipu case "anthropic_chat_completion": - from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic + from .sources.anthropic_source import ( + ProviderAnthropic as ProviderAnthropic, + ) case "llm_tuner": logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader + from .sources.llmtuner_source import ( + LLMTunerModelLoader as LLMTunerModelLoader, + ) case "dify": from .sources.dify_source import ProviderDify as ProviderDify case "dashscope": - from .sources.dashscope_source import ProviderDashscope as ProviderDashscope + from .sources.dashscope_source import ( + ProviderDashscope as ProviderDashscope, + ) case "googlegenai_chat_completion": - from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI + from .sources.gemini_source import ( + ProviderGoogleGenAI as ProviderGoogleGenAI, + ) case "sensevoice_stt_selfhost": - from .sources.sensevoice_selfhosted_source import ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost + from .sources.sensevoice_selfhosted_source import ( + ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost, + ) case "openai_whisper_api": - from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI + from .sources.whisper_api_source import ( + ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI, + ) case "openai_whisper_selfhost": - from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost + from .sources.whisper_selfhosted_source import ( + ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost, + ) case "openai_tts_api": - from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI + from .sources.openai_tts_api_source import ( + ProviderOpenAITTSAPI as ProviderOpenAITTSAPI, + ) case "edge_tts": - from .sources.edge_tts_source import ProviderEdgeTTS as ProviderEdgeTTS + from .sources.edge_tts_source import ( + ProviderEdgeTTS as ProviderEdgeTTS, + ) case "gsvi_tts_api": - from .sources.gsvi_tts_source import ProviderGSVITTS as ProviderGSVITTS + from .sources.gsvi_tts_source import ( + ProviderGSVITTS as ProviderGSVITTS, + ) case "fishaudio_tts_api": - from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI + from .sources.fishaudio_tts_api_source import ( + ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI, + ) except (ImportError, ModuleNotFoundError) as e: - logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") + logger.critical( + f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" + ) return except Exception as e: - logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因") + logger.critical( + f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因" + ) return - if provider_config['type'] not in provider_cls_map: - logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") + if provider_config["type"] not in provider_cls_map: + logger.error( + f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。" + ) return - provider_metadata = provider_cls_map[provider_config['type']] + provider_metadata = provider_cls_map[provider_config["type"]] try: # 按任务实例化提供商 - + if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: # STT 任务 - inst = provider_metadata.cls_type(provider_config, self.provider_settings) - + inst = provider_metadata.cls_type( + provider_config, self.provider_settings + ) + if getattr(inst, "initialize", None): await inst.initialize() - + self.stt_provider_insts.append(inst) - if self.selected_stt_provider_id == provider_config['id'] and self.stt_enabled: + if ( + self.selected_stt_provider_id == provider_config["id"] + and self.stt_enabled + ): self.curr_stt_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。" + ) if not self.curr_stt_provider_inst and self.stt_enabled: self.curr_stt_provider_inst = inst - + elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: # TTS 任务 - inst = provider_metadata.cls_type(provider_config, self.provider_settings) - + inst = provider_metadata.cls_type( + provider_config, self.provider_settings + ) + if getattr(inst, "initialize", None): await inst.initialize() - + self.tts_provider_insts.append(inst) - if self.selected_tts_provider_id == provider_config['id'] and self.tts_enabled: + if ( + self.selected_tts_provider_id == provider_config["id"] + and self.tts_enabled + ): self.curr_tts_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。") + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。" + ) if not self.curr_tts_provider_inst and self.tts_enabled: self.curr_tts_provider_inst = inst - + elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: # 文本生成任务 inst = provider_metadata.cls_type( - provider_config, - self.provider_settings, + provider_config, + self.provider_settings, self.db_helper, - self.provider_settings.get('persistant_history', True), - self.selected_default_persona + self.provider_settings.get("persistant_history", True), + self.selected_default_persona, ) - + if getattr(inst, "initialize", None): await inst.initialize() - + self.provider_insts.append(inst) - if self.selected_provider_id == provider_config['id'] and self.provider_enabled: + if ( + self.selected_provider_id == provider_config["id"] + and self.provider_enabled + ): self.curr_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。" + ) if not self.curr_provider_inst and self.provider_enabled: self.curr_provider_inst = inst - - self.inst_map[provider_config['id']] = inst + + self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error(traceback.format_exc()) - logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") + logger.error( + f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}" + ) async def reload(self, provider_config: dict): - await self.terminate_provider(provider_config['id']) - if provider_config['enable']: + await self.terminate_provider(provider_config["id"]) + if provider_config["enable"]: await self.load_provider(provider_config) - + # 和配置文件保持同步 - config_ids = [provider['id'] for provider in self.providers_config] + config_ids = [provider["id"] for provider in self.providers_config] for key in list(self.inst_map.keys()): if key not in config_ids: await self.terminate_provider(key) - + if len(self.provider_insts) == 0: self.curr_provider_inst = None if len(self.stt_provider_insts) == 0: @@ -246,12 +307,13 @@ class ProviderManager(): def get_insts(self): return self.provider_insts - + async def terminate_provider(self, provider_id: str): if provider_id in self.inst_map: - - logger.info(f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...") - + logger.info( + f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..." + ) + if self.inst_map[provider_id] in self.provider_insts: self.provider_insts.remove(self.inst_map[provider_id]) if self.inst_map[provider_id] in self.stt_provider_insts: @@ -265,14 +327,16 @@ class ProviderManager(): self.curr_stt_provider_inst = None if self.inst_map[provider_id] == self.curr_tts_provider_inst: self.curr_tts_provider_inst = None - - if getattr(self.inst_map[provider_id], 'terminate', None): + + if getattr(self.inst_map[provider_id], "terminate", None): await self.inst_map[provider_id].terminate() - - logger.info(f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})") + + logger.info( + f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" + ) del self.inst_map[provider_id] - + async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): - await provider_inst.terminate() \ No newline at end of file + await provider_inst.terminate() diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 7fd93b41c..57dc57f90 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,9 +1,6 @@ import abc -import json -from collections import defaultdict from typing import List from astrbot.core.db import BaseDatabase -from astrbot.core import logger from typing import TypedDict from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.entites import LLMResponse @@ -15,19 +12,19 @@ class Personality(TypedDict): name: str = "" begin_dialogs: List[str] = [] mood_imitation_dialogs: List[str] = [] - + # cache _begin_dialogs_processed: List[dict] = [] _mood_imitation_dialogs_processed: str = "" - - + + @dataclass -class ProviderMeta(): +class ProviderMeta: id: str model: str type: str - - + + class AbstractProvider(abc.ABC): def __init__(self, provider_config: dict) -> None: super().__init__() @@ -35,66 +32,68 @@ class AbstractProvider(abc.ABC): self.provider_config = provider_config def set_model(self, model_name: str): - '''设置当前使用的模型名称''' + """设置当前使用的模型名称""" self.model_name = model_name - + def get_model(self) -> str: - '''获得当前使用的模型名称''' + """获得当前使用的模型名称""" return self.model_name - + def meta(self) -> ProviderMeta: - '''获取 Provider 的元数据''' + """获取 Provider 的元数据""" return ProviderMeta( - id=self.provider_config['id'], + id=self.provider_config["id"], model=self.get_model(), - type=self.provider_config['type'] + type=self.provider_config["type"], ) class Provider(AbstractProvider): def __init__( - self, + self, provider_config: dict, - provider_settings: dict, + provider_settings: dict, persistant_history: bool = True, db_helper: BaseDatabase = None, - default_persona: Personality = None + default_persona: Personality = None, ) -> None: super().__init__(provider_config) - + self.provider_settings = provider_settings - + self.curr_personality: Personality = default_persona - '''维护了当前的使用的 persona,即人格。可能为 None''' + """维护了当前的使用的 persona,即人格。可能为 None""" @abc.abstractmethod def get_current_key(self) -> str: raise NotImplementedError() - + def get_keys(self) -> List[str]: - '''获得提供商 Key''' + """获得提供商 Key""" return self.provider_config.get("key", []) - + @abc.abstractmethod def set_key(self, key: str): raise NotImplementedError() - + @abc.abstractmethod def get_models(self) -> List[str]: - '''获得支持的模型列表''' + """获得支持的模型列表""" raise NotImplementedError() - + @abc.abstractmethod - async def text_chat(self, - prompt: str, - session_id: str=None, - image_urls: List[str]=None, - func_tool: FuncCall=None, - contexts: List=None, - system_prompt: str=None, - **kwargs) -> LLMResponse: - '''获得 LLM 的文本对话结果。会使用当前的模型进行对话。 - + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + **kwargs, + ) -> LLMResponse: + """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 + Args: prompt: 提示词 session_id: 会话 ID(此属性已经被废弃) @@ -102,17 +101,17 @@ class Provider(AbstractProvider): tools: Function-calling 工具 contexts: 上下文 kwargs: 其他参数 - + Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 - ''' + """ raise NotImplementedError() async def pop_record(self, context: List): - ''' + """ 弹出 context 第一条非系统提示词对话记录 - ''' + """ poped = 0 indexs_to_pop = [] for idx, record in enumerate(context): @@ -123,20 +122,20 @@ class Provider(AbstractProvider): poped += 1 if poped == 2: break - + for idx in reversed(indexs_to_pop): context.pop(idx) - + class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings - + @abc.abstractmethod async def get_text(self, audio_url: str) -> str: - '''获取音频的文本''' + """获取音频的文本""" raise NotImplementedError() @@ -145,8 +144,8 @@ class TTSProvider(AbstractProvider): super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings - + @abc.abstractmethod async def get_audio(self, text: str) -> str: - '''获取文本的音频,返回音频文件路径''' - raise NotImplementedError() \ No newline at end of file + """获取文本的音频,返回音频文件路径""" + raise NotImplementedError() diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index c7f85f493..41a7a29d5 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,35 +1,39 @@ -from typing import List, Dict, Type +from typing import List, Dict from .entites import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall provider_registry: List[ProviderMetaData] = [] -'''维护了通过装饰器注册的 Provider''' +"""维护了通过装饰器注册的 Provider""" provider_cls_map: Dict[str, ProviderMetaData] = {} -'''维护了 Provider 类型名称和 ProviderMetadata 的映射''' +"""维护了 Provider 类型名称和 ProviderMetadata 的映射""" llm_tools = FuncCall() + def register_provider_adapter( - provider_type_name: str, - desc: str, + provider_type_name: str, + desc: str, provider_type: ProviderType = ProviderType.CHAT_COMPLETION, default_config_tmpl: dict = None, - provider_display_name: str = None + provider_display_name: str = None, ): - '''用于注册平台适配器的带参装饰器''' + """用于注册平台适配器的带参装饰器""" + def decorator(cls): if provider_type_name in provider_cls_map: - raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。") - + raise ValueError( + f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。" + ) + # 添加必备选项 if default_config_tmpl: - if 'type' not in default_config_tmpl: - default_config_tmpl['type'] = provider_type_name - if 'enable' not in default_config_tmpl: - default_config_tmpl['enable'] = False - if 'id' not in default_config_tmpl: - default_config_tmpl['id'] = provider_type_name + if "type" not in default_config_tmpl: + default_config_tmpl["type"] = provider_type_name + if "enable" not in default_config_tmpl: + default_config_tmpl["enable"] = False + if "id" not in default_config_tmpl: + default_config_tmpl["id"] = provider_type_name pm = ProviderMetaData( type=provider_type_name, @@ -37,11 +41,11 @@ def register_provider_adapter( provider_type=provider_type, cls_type=cls, default_config_tmpl=default_config_tmpl, - provider_display_name=provider_display_name + provider_display_name=provider_display_name, ) provider_registry.append(pm) provider_cls_map[provider_type_name] = pm logger.debug(f"服务提供商 Provider {provider_type_name} 已注册") return cls - + return decorator diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 35200bf23..13f7482d4 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -13,18 +13,28 @@ from ..register import register_provider_adapter from astrbot.core.provider.entites import LLMResponse from .openai_source import ProviderOpenAIOfficial -@register_provider_adapter("anthropic_chat_completion", "Anthropic Claude API 提供商适配器") + +@register_provider_adapter( + "anthropic_chat_completion", "Anthropic Claude API 提供商适配器" +) class ProviderAnthropic(ProviderOpenAIOfficial): def __init__( self, provider_config: dict, provider_settings: dict, db_helper: BaseDatabase, - persistant_history = True, - default_persona: Personality = None + persistant_history=True, + default_persona: Personality = None, ) -> None: # Skip OpenAI's __init__ and call Provider's __init__ directly - Provider.__init__(self, provider_config, provider_settings, persistant_history, db_helper, default_persona) + Provider.__init__( + self, + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, + ) self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) @@ -35,23 +45,18 @@ class ProviderAnthropic(ProviderOpenAIOfficial): self.timeout = int(self.timeout) self.client = AsyncAnthropic( - api_key=self.chosen_api_key, - timeout=self.timeout, - base_url=self.base_url + api_key=self.chosen_api_key, timeout=self.timeout, base_url=self.base_url ) - - self.set_model(provider_config['model_config']['model']) + + self.set_model(provider_config["model_config"]["model"]) async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: tool_list = tools.get_func_desc_anthropic_style() if tool_list: - payloads['tools'] = tool_list + payloads["tools"] = tool_list - completion = await self.client.messages.create( - **payloads, - stream=False - ) + completion = await self.client.messages.create(**payloads, stream=False) assert isinstance(completion, Message) logger.debug(f"completion: {completion}") @@ -61,14 +66,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial): # TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容 # 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求 content = completion.content[-1] - + llm_response = LLMResponse("assistant") - + if content.type == "text": # text completion completion_text = str(content.text).strip() llm_response.completion_text = completion_text - + # Anthropic每次只返回一个函数调用 if completion.stop_reason == "tool_use": # tools call (function calling) @@ -83,11 +88,11 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if not llm_response.completion_text and not llm_response.tools_call_args: logger.error(f"API 返回的 completion 无法解析:{completion}。") raise Exception(f"API 返回的 completion 无法解析:{completion}。") - + llm_response.raw_completion = completion - + return llm_response - + async def text_chat( self, prompt: str, @@ -96,28 +101,24 @@ class ProviderAnthropic(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=[], system_prompt=None, - **kwargs + **kwargs, ) -> LLMResponse: - if not prompt: prompt = "" - + new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] for part in context_query: - if '_no_save' in part: - del part['_no_save'] + if "_no_save" in part: + del part["_no_save"] model_config = self.provider_config.get("model_config", {}) - payloads = { - "messages": context_query, - **model_config - } + payloads = {"messages": context_query, **model_config} # Anthropic has a different way of handling system prompts if system_prompt: - payloads['system'] = system_prompt + payloads["system"] = system_prompt llm_response = None try: @@ -127,12 +128,13 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if "maximum context length" in str(e): retry_cnt = 20 while retry_cnt > 0: - logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}") + logger.warning( + f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + ) try: await self.pop_record(context_query) response = await self.client.messages.create( - messages=context_query, - **model_config + messages=context_query, **model_config ) llm_response = LLMResponse("assistant") llm_response.completion_text = response.content[0].text @@ -147,17 +149,17 @@ class ProviderAnthropic(ProviderOpenAIOfficial): else: logger.error(f"发生了错误。Provider 配置如下: {model_config}") raise e - + return llm_response async def assemble_context(self, text: str, image_urls: List[str] = None): - '''组装上下文,支持文本和图片''' + """组装上下文,支持文本和图片""" if not image_urls: return {"role": "user", "content": text} - + content = [] content.append({"type": "text", "text": text}) - + for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) @@ -167,23 +169,27 @@ class ProviderAnthropic(ProviderOpenAIOfficial): image_data = await self.encode_image_bs64(image_path) else: image_data = await self.encode_image_bs64(image_url) - + if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - + # Get mime type for the image mime_type, _ = guess_type(image_url) if not mime_type: mime_type = "image/jpeg" # Default to JPEG if can't determine - - content.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": image_data.split("base64,")[1] if "base64," in image_data else image_data + + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": image_data.split("base64,")[1] + if "base64," in image_data + else image_data, + }, } - }) - - return {"role": "user", "content": content} \ No newline at end of file + ) + + return {"role": "user", "content": content} diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 9fd26c90b..9647b41c0 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -125,4 +125,4 @@ class ProviderDashscope(ProviderOpenAIOfficial): raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。") async def terminate(self): - pass \ No newline at end of file + pass diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 9e8e344f7..37f575f21 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -8,6 +8,7 @@ from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.io import download_image_by_url from astrbot.core import logger, sp + @register_provider_adapter("dify", "Dify APP 适配器。") class ProviderDify(Provider): def __init__( @@ -16,10 +17,14 @@ class ProviderDify(Provider): provider_settings: dict, db_helper: BaseDatabase, persistant_history=False, - default_persona: Personality=None + default_persona: Personality = None, ) -> None: super().__init__( - provider_config, provider_settings, persistant_history, db_helper, default_persona + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, ) self.api_key = provider_config.get("dify_api_key", "") if not self.api_key: @@ -30,8 +35,12 @@ class ProviderDify(Provider): if not self.api_type: raise Exception("Dify API 类型不能为空。") self.model_name = "dify" - self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output") - self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query") + self.workflow_output_key = provider_config.get( + "dify_workflow_output_key", "astrbot_wf_output" + ) + self.dify_query_input_key = provider_config.get( + "dify_query_input_key", "astrbot_text_query" + ) self.variables: dict = provider_config.get("variables", {}) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" @@ -39,8 +48,7 @@ class ProviderDify(Provider): if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.conversation_ids = {} - '''记录当前 session id 的对话 ID''' - + """记录当前 session id 的对话 ID""" async def text_chat( self, @@ -54,31 +62,37 @@ class ProviderDify(Provider): ) -> LLMResponse: result = "" conversation_id = self.conversation_ids.get(session_id, "") - + files_payload = [] for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) - file_response = await self.api_client.file_upload(image_path, user=session_id) - if 'id' not in file_response: - logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。") + file_response = await self.api_client.file_upload( + image_path, user=session_id + ) + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" + ) continue - files_payload.append({ - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response['id'], - }) + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) else: # TODO: 处理更多情况 logger.warning(f"未知的图片链接:{image_url},图片将忽略。") - + # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 session_vars = sp.get("session_variables", {}) session_var = session_vars.get(session_id, {}) payload_vars.update(session_var) - + try: match self.api_type: case "chat" | "agent": @@ -90,22 +104,28 @@ class ProviderDify(Provider): user=session_id, conversation_id=conversation_id, files=files_payload, - timeout=self.timeout + timeout=self.timeout, ): logger.debug(f"dify resp chunk: {chunk}") - if chunk['event'] == "message" or \ - chunk['event'] == "agent_message": - result += chunk['answer'] + if ( + chunk["event"] == "message" + or chunk["event"] == "agent_message" + ): + result += chunk["answer"] if not conversation_id: - self.conversation_ids[session_id] = chunk['conversation_id'] - conversation_id = chunk['conversation_id'] - elif chunk['event'] == 'message_end': + self.conversation_ids[session_id] = chunk[ + "conversation_id" + ] + conversation_id = chunk["conversation_id"] + elif chunk["event"] == "message_end": logger.debug("Dify message end") break - elif chunk['event'] == 'error': + elif chunk["event"] == "error": logger.error(f"Dify 出现错误:{chunk}") - raise Exception(f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}") - + raise Exception( + f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}" + ) + case "workflow": async for chunk in self.api_client.workflow_run( inputs={ @@ -115,30 +135,47 @@ class ProviderDify(Provider): }, user=session_id, files=files_payload, - timeout=self.timeout + timeout=self.timeout, ): - match chunk['event']: + match chunk["event"]: case "workflow_started": - logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。") + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。" + ) case "node_finished": - logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。") + logger.debug( + f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。" + ) case "workflow_finished": - logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。") - if chunk['data']['error']: - logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}") - raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}") - if self.workflow_output_key not in chunk['data']['outputs']: - raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}") - result = chunk['data']['outputs'][self.workflow_output_key] + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。" + ) + if chunk["data"]["error"]: + logger.error( + f"Dify 工作流出现错误:{chunk['data']['error']}" + ) + raise Exception( + f"Dify 工作流出现错误:{chunk['data']['error']}" + ) + if ( + self.workflow_output_key + not in chunk["data"]["outputs"] + ): + raise Exception( + f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}" + ) + result = chunk["data"]["outputs"][ + self.workflow_output_key + ] case _: raise Exception(f"未知的 Dify API 类型:{self.api_type}") except Exception as e: logger.error(f"Dify 请求失败:{str(e)}") return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}") - + if not result: logger.warning("Dify 请求结果为空,请查看 Debug 日志。") - + return LLMResponse(role="assistant", completion_text=result) async def forget(self, session_id): @@ -158,4 +195,4 @@ class ProviderDify(Provider): raise Exception("暂不支持获得 Dify 的历史消息记录。") async def terminate(self): - await self.api_client.close() \ No newline at end of file + await self.api_client.close() diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index c64fb07c4..db28244c4 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -15,29 +15,32 @@ pip install edge_tts Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot """ -@register_provider_adapter("edge_tts", "Microsoft Edge TTS", provider_type=ProviderType.TEXT_TO_SPEECH) + +@register_provider_adapter( + "edge_tts", "Microsoft Edge TTS", provider_type=ProviderType.TEXT_TO_SPEECH +) class ProviderEdgeTTS(TTSProvider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - + # 设置默认语音,如果没有指定则使用中文小萱 self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") self.rate = provider_config.get("rate", None) self.volume = provider_config.get("volume", None) self.pitch = provider_config.get("pitch", None) self.timeout = provider_config.get("timeout", 30) - + self.set_model("edge_tts") - + async def get_audio(self, text: str) -> str: os.makedirs("data/temp", exist_ok=True) - mp3_path = f'data/temp/edge_tts_temp_{uuid.uuid4()}.mp3' - wav_path = f'data/temp/edge_tts_{uuid.uuid4()}.wav' - + mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3" + wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav" + # 构建Edge TTS参数 kwargs = {"text": text, "voice": self.voice} if self.rate: @@ -46,22 +49,29 @@ class ProviderEdgeTTS(TTSProvider): kwargs["volume"] = self.volume if self.pitch: kwargs["pitch"] = self.pitch - + try: communicate = edge_tts.Communicate(**kwargs) await communicate.save(mp3_path) - + # 使用ffmpeg将MP3转换为标准WAV格式 - _ = subprocess.run([ - "ffmpeg", - "-y", # 覆盖输出文件 - "-i", mp3_path, # 输入文件 - "-acodec", "pcm_s16le", # 16位PCM编码 - "-ar", "24000", # 采样率24kHz (适合微信语音) - "-ac", "1", # 单声道 - wav_path # 输出文件 - ], capture_output=True, check=True) - + _ = subprocess.run( + [ + "ffmpeg", + "-y", # 覆盖输出文件 + "-i", + mp3_path, # 输入文件 + "-acodec", + "pcm_s16le", # 16位PCM编码 + "-ar", + "24000", # 采样率24kHz (适合微信语音) + "-ac", + "1", # 单声道 + wav_path, # 输出文件 + ], + capture_output=True, + check=True, + ) os.remove(mp3_path) if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: @@ -69,22 +79,21 @@ class ProviderEdgeTTS(TTSProvider): else: logger.error("生成的WAV文件不存在或为空") raise RuntimeError("生成的WAV文件不存在或为空") - + except subprocess.CalledProcessError as e: logger.error(f"FFmpeg转换失败: {e.stderr.decode() if e.stderr else str(e)}") try: if os.path.exists(mp3_path): os.remove(mp3_path) - except: + except Exception: pass raise RuntimeError(f"FFmpeg转换失败: {str(e)}") - + except Exception as e: logger.error(f"音频生成失败: {str(e)}") try: if os.path.exists(mp3_path): os.remove(mp3_path) - except: + except Exception: pass raise RuntimeError(f"音频生成失败: {str(e)}") - \ No newline at end of file diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index c6357bf7a..233cc8932 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -10,8 +10,9 @@ from typing import List from ..register import register_provider_adapter from astrbot.core.provider.entites import LLMResponse -class SimpleGoogleGenAIClient(): - def __init__(self, api_key: str, api_base: str, timeout: int=120) -> None: + +class SimpleGoogleGenAIClient: + def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None: self.api_key = api_key if api_base.endswith("/"): self.api_base = api_base[:-1] @@ -19,36 +20,38 @@ class SimpleGoogleGenAIClient(): self.api_base = api_base self.client = aiohttp.ClientSession(trust_env=True) self.timeout = timeout - + async def models_list(self) -> List[str]: request_url = f"{self.api_base}/v1beta/models?key={self.api_key}" async with self.client.get(request_url, timeout=self.timeout) as resp: response = await resp.json() - + models = [] for model in response["models"]: - if 'generateContent' in model["supportedGenerationMethods"]: + if "generateContent" in model["supportedGenerationMethods"]: models.append(model["name"].replace("models/", "")) return models async def generate_content( - self, - contents: List[dict], - model: str="gemini-1.5-flash", - system_instruction: str="", - tools: dict=None + self, + contents: List[dict], + model: str = "gemini-1.5-flash", + system_instruction: str = "", + tools: dict = None, ): payload = {} if system_instruction: - payload["system_instruction"] = { - "parts": {"text": system_instruction} - } + payload["system_instruction"] = {"parts": {"text": system_instruction}} if tools: payload["tools"] = [tools] payload["contents"] = contents logger.debug(f"payload: {payload}") - request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" - async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp: + request_url = ( + f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}" + ) + async with self.client.post( + request_url, json=payload, timeout=self.timeout + ) as resp: if "application/json" in resp.headers.get("Content-Type"): try: response = await resp.json() @@ -63,17 +66,25 @@ class SimpleGoogleGenAIClient(): raise Exception("Gemini 返回了非 json 数据: ") -@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器") +@register_provider_adapter( + "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" +) class ProviderGoogleGenAI(Provider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, - db_helper: BaseDatabase, - persistant_history = True, - default_persona: Personality=None + db_helper: BaseDatabase, + persistant_history=True, + default_persona: Personality = None, ) -> None: - super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona) + super().__init__( + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, + ) self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None @@ -83,37 +94,36 @@ class ProviderGoogleGenAI(Provider): self.client = SimpleGoogleGenAIClient( api_key=self.chosen_api_key, api_base=provider_config.get("api_base", None), - timeout=self.timeout + timeout=self.timeout, ) - self.set_model(provider_config['model_config']['model']) + self.set_model(provider_config["model_config"]["model"]) async def get_models(self): return await self.client.models_list() - + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: tool = None if tools: tool = tools.get_func_desc_google_genai_style() if not tool: tool = None - + system_instruction = "" for message in payloads["messages"]: if message["role"] == "system": system_instruction = message["content"] break - + google_genai_conversation = [] for message in payloads["messages"]: if message["role"] == "user": if isinstance(message["content"], str): - if not message['content']: - message['content'] = "" - - google_genai_conversation.append({ - "role": "user", - "parts": [{"text": message["content"]}] - }) + if not message["content"]: + message["content"] = "" + + google_genai_conversation.append( + {"role": "user", "parts": [{"text": message["content"]}]} + ) elif isinstance(message["content"], list): # images parts = [] @@ -123,46 +133,48 @@ class ProviderGoogleGenAI(Provider): part["text"] = "" parts.append({"text": part["text"]}) elif part["type"] == "image_url": - parts.append({"inline_data": { - "mime_type": "image/jpeg", - "data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64 - }}) - google_genai_conversation.append({ - "role": "user", - "parts": parts - }) - + parts.append( + { + "inline_data": { + "mime_type": "image/jpeg", + "data": part["image_url"]["url"].replace( + "data:image/jpeg;base64,", "" + ), # base64 + } + } + ) + google_genai_conversation.append({"role": "user", "parts": parts}) + elif message["role"] == "assistant": if not message["content"]: message["content"] = "" - google_genai_conversation.append({ - "role": "model", - "parts": [{"text": message["content"]}] - }) + google_genai_conversation.append( + {"role": "model", "parts": [{"text": message["content"]}]} + ) logger.debug(f"google_genai_conversation: {google_genai_conversation}") - + result = await self.client.generate_content( contents=google_genai_conversation, model=self.get_model(), system_instruction=system_instruction, - tools=tool + tools=tool, ) logger.debug(f"result: {result}") - + if "candidates" not in result: raise Exception("Gemini 返回异常结果: " + str(result)) - - candidates = result["candidates"][0]['content']['parts'] + + candidates = result["candidates"][0]["content"]["parts"] llm_response = LLMResponse("assistant") for candidate in candidates: - if 'text' in candidate: - llm_response.completion_text += candidate['text'] - elif 'functionCall' in candidate: + if "text" in candidate: + llm_response.completion_text += candidate["text"] + elif "functionCall" in candidate: llm_response.role = "tool" - llm_response.tools_call_args.append(candidate['functionCall']['args']) - llm_response.tools_call_name.append(candidate['functionCall']['name']) - + llm_response.tools_call_args.append(candidate["functionCall"]["args"]) + llm_response.tools_call_name.append(candidate["functionCall"]["name"]) + llm_response.completion_text = llm_response.completion_text.strip() return llm_response @@ -170,45 +182,44 @@ class ProviderGoogleGenAI(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str]=None, - func_tool: FuncCall=None, + image_urls: List[str] = None, + func_tool: FuncCall = None, contexts=[], system_prompt=None, - **kwargs - ) -> LLMResponse: + **kwargs, + ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) context_query = [] context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) - - for part in context_query: - if '_no_save' in part: - del part['_no_save'] - - model_config = self.provider_config.get("model_config", {}) - model_config['model'] = self.get_model() - payloads = { - "messages": context_query, - **model_config - } + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": context_query, **model_config} llm_response = None - + retry = 10 keys = self.api_keys.copy() chosen_key = random.choice(keys) - + for i in range(retry): try: - self.client.api_key = chosen_key + self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break except Exception as e: if "maximum context length" in str(e): retry_cnt = 20 while retry_cnt > 0: - logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}") + logger.warning( + f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + ) try: await self.pop_record(context_query) llm_response = await self._query(payloads, func_tool) @@ -219,24 +230,34 @@ class ProviderGoogleGenAI(Provider): else: raise e if retry_cnt == 0: - llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话") + llm_response = LLMResponse( + "err", "err: 请尝试 /reset 重置会话" + ) elif "Function calling is not enabled" in str(e): - logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。") - if 'tools' in payloads: - del payloads['tools'] + logger.info( + f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" + ) + if "tools" in payloads: + del payloads["tools"] llm_response = await self._query(payloads, None) break elif "429" in str(e) or "API key not valid" in str(e): keys.remove(chosen_key) if len(keys) > 0: chosen_key = random.choice(keys) - logger.info(f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}...") + logger.info( + f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..." + ) continue else: - logger.error(f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...") + logger.error( + f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..." + ) raise Exception("API 资源已耗尽,且没有可用的 Key 重试...") else: - logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}") + logger.error( + f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" + ) raise e return llm_response @@ -246,16 +267,16 @@ class ProviderGoogleGenAI(Provider): def get_keys(self) -> List[str]: return self.api_keys - + def set_key(self, key): self.client.api_key = key - + async def assemble_context(self, text: str, image_urls: List[str] = None): - ''' + """ 组装上下文。 - ''' + """ if image_urls: - user_content = {"role": "user","content": [{"type": "text", "text": text}]} + user_content = {"role": "user", "content": [{"type": "text", "text": text}]} for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) @@ -268,22 +289,24 @@ class ProviderGoogleGenAI(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) + user_content["content"].append( + {"type": "image_url", "image_url": {"url": image_data}} + ) return user_content else: - return {"role": "user","content": text} + return {"role": "user", "content": text} async def encode_image_bs64(self, image_url: str) -> str: - ''' + """ 将图片转换为 base64 - ''' + """ if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode('utf-8') + image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - return '' - + return "" + async def terminate(self): await self.client.client.close() - logger.info("Google GenAI 适配器已终止。") \ No newline at end of file + logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index 0c8b7ab33..b57932edf 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,5 +1,4 @@ import uuid -import os import aiohttp import urllib.parse from ..provider import TTSProvider @@ -7,11 +6,13 @@ from ..entites import ProviderType from ..register import register_provider_adapter -@register_provider_adapter("gsvi_tts_api", "GSVI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH) +@register_provider_adapter( + "gsvi_tts_api", "GSVI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH +) class ProviderGSVITTS(TTSProvider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) @@ -22,9 +23,9 @@ class ProviderGSVITTS(TTSProvider): self.emotion = provider_config.get("emotion") async def get_audio(self, text: str) -> str: - path = f'data/temp/gsvi_tts_{uuid.uuid4()}.wav' + path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav" params = {"text": text} - + if self.character: params["character"] = self.character if self.emotion: @@ -34,16 +35,18 @@ class ProviderGSVITTS(TTSProvider): for key, value in params.items(): encoded_value = urllib.parse.quote(str(value)) query_parts.append(f"{key}={encoded_value}") - + url = f"{self.api_base}/tts?{'&'.join(query_parts)}" - + async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: - with open(path, 'wb') as f: + with open(path, "wb") as f: f.write(await response.read()) else: error_text = await response.text() - raise Exception(f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}") - - return path \ No newline at end of file + raise Exception( + f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}" + ) + + return path diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index 3f8f4b9f6..bfd9e03a5 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -1,8 +1,7 @@ -import json import os from llmtuner.chat import ChatModel from typing import List -from .. import Provider, Personality +from .. import Provider from ..entites import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase @@ -22,7 +21,11 @@ class LLMTunerModelLoader(Provider): default_persona=None, ) -> None: super().__init__( - provider_config, provider_settings, persistant_history, db_helper, default_persona + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, ) if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists( provider_config["adapter_model_path"] @@ -70,10 +73,10 @@ class LLMTunerModelLoader(Provider): for idx, context in enumerate(query_context): if context["role"] == "system": system_idxs.append(idx) - - if '_no_save' in context: - del context['_no_save'] - + + if "_no_save" in context: + del context["_no_save"] + for idx in reversed(system_idxs): system_prompt += " " + query_context.pop(idx)["content"] @@ -84,12 +87,12 @@ class LLMTunerModelLoader(Provider): if func_tool: tool_list = func_tool.get_func_desc_openai_style() if tool_list: - conf['tools'] = tool_list + conf["tools"] = tool_list responses = await self.model.achat(**conf) llm_response = LLMResponse("assistant", responses[-1].response_text) - + return llm_response async def get_current_key(self): @@ -99,4 +102,4 @@ class LLMTunerModelLoader(Provider): pass async def get_models(self): - return [self.get_model()] \ No newline at end of file + return [self.get_model()] diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index d43f00e99..4b66c50ef 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -16,17 +16,26 @@ from typing import List from ..register import register_provider_adapter from astrbot.core.provider.entites import LLMResponse -@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器") + +@register_provider_adapter( + "openai_chat_completion", "OpenAI API Chat Completion 提供商适配器" +) class ProviderOpenAIOfficial(Provider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, - db_helper: BaseDatabase, - persistant_history = True, - default_persona: Personality = None + db_helper: BaseDatabase, + persistant_history=True, + default_persona: Personality = None, ) -> None: - super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona) + super().__init__( + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, + ) self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None @@ -40,18 +49,20 @@ class ProviderOpenAIOfficial(Provider): api_key=self.chosen_api_key, api_version=provider_config.get("api_version", None), base_url=provider_config.get("api_base", None), - timeout=self.timeout + timeout=self.timeout, ) else: # 使用 openai api self.client = AsyncOpenAI( api_key=self.chosen_api_key, base_url=provider_config.get("api_base", None), - timeout=self.timeout + timeout=self.timeout, ) - - self.default_params = inspect.signature(self.client.chat.completions.create).parameters.keys() - + + self.default_params = inspect.signature( + self.client.chat.completions.create + ).parameters.keys() + model_config = provider_config.get("model_config", {}) model = model_config.get("model", "unknown") self.set_model(model) @@ -66,12 +77,12 @@ class ProviderOpenAIOfficial(Provider): return models_str except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") - + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: tool_list = tools.get_func_desc_openai_style() if tool_list: - payloads['tools'] = tool_list + payloads["tools"] = tool_list # 不在默认参数中的参数放在 extra_body 中 extra_body = {} @@ -84,27 +95,27 @@ class ProviderOpenAIOfficial(Provider): del payloads[key] completion = await self.client.chat.completions.create( - **payloads, - stream=False, - extra_body=extra_body + **payloads, stream=False, extra_body=extra_body ) if not isinstance(completion, ChatCompletion): - raise Exception(f"API 返回的 completion 类型错误:{type(completion)}: {completion}。") + raise Exception( + f"API 返回的 completion 类型错误:{type(completion)}: {completion}。" + ) logger.debug(f"completion: {completion}") if len(completion.choices) == 0: raise Exception("API 返回的 completion 为空。") choice = completion.choices[0] - + llm_response = LLMResponse("assistant") - + if choice.message.content: # text completion completion_text = str(choice.message.content).strip() llm_response.completion_text = completion_text - + if choice.message.tool_calls: # tools call (function calling) args_ls = [] @@ -118,45 +129,43 @@ class ProviderOpenAIOfficial(Provider): llm_response.role = "tool" llm_response.tools_call_args = args_ls llm_response.tools_call_name = func_name_ls - - if choice.finish_reason == 'content_filter': - raise Exception("API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。") + + if choice.finish_reason == "content_filter": + raise Exception( + "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。" + ) if not llm_response.completion_text and not llm_response.tools_call_args: logger.error(f"API 返回的 completion 无法解析:{completion}。") raise Exception(f"API 返回的 completion 无法解析:{completion}。") - + llm_response.raw_completion = completion - + return llm_response async def text_chat( self, prompt: str, - session_id: str=None, - image_urls: List[str]=[], - func_tool: FuncCall=None, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, contexts=[], system_prompt=None, - **kwargs - ) -> LLMResponse: + **kwargs, + ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) for part in context_query: - if '_no_save' in part: - del part['_no_save'] - - - model_config = self.provider_config.get("model_config", {}) - model_config['model'] = self.get_model() + if "_no_save" in part: + del part["_no_save"] - payloads = { - "messages": context_query, - **model_config - } + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": context_query, **model_config} llm_response = None try: llm_response = await self._query(payloads, func_tool) @@ -164,7 +173,7 @@ class ProviderOpenAIOfficial(Provider): logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") # 尝试删除所有 image new_contexts = await self._remove_image_from_context(context_query) - payloads['messages'] = new_contexts + payloads["messages"] = new_contexts context_query = new_contexts llm_response = await self._query(payloads, func_tool) except Exception as e: @@ -172,7 +181,9 @@ class ProviderOpenAIOfficial(Provider): # 重试 10 次 retry_cnt = 20 while retry_cnt > 0: - logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}") + logger.warning( + f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + ) try: await self.pop_record(context_query) llm_response = await self._query(payloads, func_tool) @@ -183,83 +194,93 @@ class ProviderOpenAIOfficial(Provider): else: raise e if retry_cnt == 0: - llm_response = LLMResponse("err", "err: 请尝试 /reset 清除会话记录。") - elif "The model is not a VLM" in str(e): # siliconcloud + llm_response = LLMResponse( + "err", "err: 请尝试 /reset 清除会话记录。" + ) + elif "The model is not a VLM" in str(e): # siliconcloud # 尝试删除所有 image new_contexts = await self._remove_image_from_context(context_query) - payloads['messages'] = new_contexts + payloads["messages"] = new_contexts llm_response = await self._query(payloads, func_tool) # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 - elif 'does not support Function Calling' in str(e) \ - or 'does not support tools' in str(e) \ - or 'Function call is not supported' in str(e) \ - or 'Function calling is not enabled' in str(e) \ - or 'Tool calling is not supported' in str(e) \ - or 'No endpoints found that support tool use' in str(e) \ - or 'model does not support function calling' in str(e) \ - or ('tool' in str(e) and 'support' in str(e).lower()) \ - or ('function' in str(e) and 'support' in str(e).lower()): - logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。") - if 'tools' in payloads: - del payloads['tools'] - llm_response = await self._query(payloads, None) + elif ( + "does not support Function Calling" in str(e) + or "does not support tools" in str(e) + or "Function call is not supported" in str(e) + or "Function calling is not enabled" in str(e) + or "Tool calling is not supported" in str(e) + or "No endpoints found that support tool use" in str(e) + or "model does not support function calling" in str(e) + or ("tool" in str(e) and "support" in str(e).lower()) + or ("function" in str(e) and "support" in str(e).lower()) + ): + logger.info( + f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" + ) + if "tools" in payloads: + del payloads["tools"] + llm_response = await self._query(payloads, None) else: logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") - - if 'tool' in str(e).lower() and 'support' in str(e).lower(): - logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all") - - if 'Connection error.' in str(e): + + if "tool" in str(e).lower() and "support" in str(e).lower(): + logger.error( + "疑似该模型不支持函数调用工具调用。请输入 /tool off_all" + ) + + if "Connection error." in str(e): proxy = os.environ.get("http_proxy", None) if proxy: - logger.error(f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}") - + logger.error( + f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" + ) + raise e - + return llm_response - + async def _remove_image_from_context(self, contexts: List): - ''' + """ 从上下文中删除所有带有 image 的记录 - ''' + """ new_contexts = [] - + flag = False for context in contexts: if flag: - flag = False # 删除 image 后,下一条(LLM 响应)也要删除 + flag = False # 删除 image 后,下一条(LLM 响应)也要删除 continue - if isinstance(context['content'], list): + if isinstance(context["content"], list): flag = True # continue new_content = [] - for item in context['content']: - if isinstance(item, dict) and 'image_url' in item: + for item in context["content"]: + if isinstance(item, dict) and "image_url" in item: continue new_content.append(item) if not new_content: # 用户只发了图片 new_content = [{"type": "text", "text": "[图片]"}] - context['content'] = new_content + context["content"] = new_content new_contexts.append(context) return new_contexts - + def get_current_key(self) -> str: return self.client.api_key def get_keys(self) -> List[str]: return self.api_keys - + def set_key(self, key): self.client.api_key = key - + async def assemble_context(self, text: str, image_urls: List[str] = None): - ''' + """ 组装上下文。 - ''' + """ if image_urls: - user_content = {"role": "user","content": [{"type": "text", "text": text}]} + user_content = {"role": "user", "content": [{"type": "text", "text": text}]} for image_url in image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) @@ -272,18 +293,20 @@ class ProviderOpenAIOfficial(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) + user_content["content"].append( + {"type": "image_url", "image_url": {"url": image_data}} + ) return user_content else: - return {"role": "user","content": text} + return {"role": "user", "content": text} async def encode_image_bs64(self, image_url: str) -> str: - ''' + """ 将图片转换为 base64 - ''' + """ if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode('utf-8') + image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - return '' \ No newline at end of file + return "" diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 9d59dec70..b59f2c283 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,16 +1,17 @@ import uuid -import os from openai import AsyncOpenAI, NOT_GIVEN from ..provider import TTSProvider from ..entites import ProviderType from ..register import register_provider_adapter -@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH) +@register_provider_adapter( + "openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH +) class ProviderOpenAITTSAPI(TTSProvider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) @@ -22,19 +23,15 @@ class ProviderOpenAITTSAPI(TTSProvider): base_url=provider_config.get("api_base", None), timeout=provider_config.get("timeout", NOT_GIVEN), ) - + self.set_model(provider_config.get("model", None)) - async def get_audio(self, text: str) -> str: - path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav' + path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav" async with self.client.audio.speech.with_streaming_response.create( - model=self.model_name, - voice=self.voice, - response_format='wav', - input=text + model=self.model_name, voice=self.voice, response_format="wav", input=text ) as response: - with open(path, 'wb') as f: + with open(path, "wb") as f: async for chunk in response.iter_bytes(chunk_size=1024): f.write(chunk) - return path \ No newline at end of file + return path diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index e08c1bd0a..4842b0e04 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -1,8 +1,9 @@ -''' +""" Author: diudiu62 Date: 2025-02-24 18:04:18 LastEditTime: 2025-02-25 14:06:30 -''' +""" + import asyncio from datetime import datetime import os @@ -16,41 +17,45 @@ from ..register import register_provider_adapter from astrbot.core import logger from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav -@register_provider_adapter("sensevoice_stt_selfhost", "SenseVoice 自托管语音识别 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT) + +@register_provider_adapter( + "sensevoice_stt_selfhost", + "SenseVoice 自托管语音识别 模型部署", + provider_type=ProviderType.SPEECH_TO_TEXT, +) class ProviderSenseVoiceSTTSelfHost(STTProvider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) self.set_model(provider_config.get("stt_model", None)) self.model = None self.is_emotion = provider_config.get("is_emotion", False) - + async def initialize(self): logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") - # 将模型加载放到线程池中执行 self.model = await asyncio.get_event_loop().run_in_executor( - None, - lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16) + None, lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16) ) logger.info("SenseVoice 模型加载完成。") - + async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") return os.path.join("data", "temp", f"{timestamp}") - + async def _convert_audio(self, path: str) -> str: from pyffmpeg import FFmpeg - filename = await self.get_timestamped_path() + '.mp3' + + filename = await self.get_timestamped_path() + ".mp3" ff = FFmpeg() output_path = ff.convert(path, os.path.join('data","temp', filename)) return output_path - + async def _is_silk_file(self, file_path): silk_header = b"SILK" with open(file_path, "rb") as f: @@ -63,13 +68,15 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): async def get_text(self, audio_url: str) -> str: try: - is_tencent = audio_url.startswith("http") and "multimedia.nt.qq.com.cn" in audio_url - + is_tencent = ( + audio_url.startswith("http") and "multimedia.nt.qq.com.cn" in audio_url + ) + if is_tencent: path = await self.get_timestamped_path() await download_file(audio_url, path) audio_url = path - + if not os.path.isfile(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") @@ -77,7 +84,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - output_path = await self.get_timestamped_path()+'.wav' + output_path = await self.get_timestamped_path() + ".wav" await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path @@ -85,7 +92,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): loop = asyncio.get_event_loop() res = await loop.run_in_executor( None, # 使用默认的线程池 - lambda: self.model(audio_url, language="auto", use_itn=True) + lambda: self.model(audio_url, language="auto", use_itn=True), ) # res = self.model(audio_url, language="auto", use_itn=True) @@ -93,7 +100,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): text = rich_transcription_postprocess(res[0]) if self.is_emotion: # 提取第二个匹配的值 - matches = re.findall(r'<\|([^|]+)\|>', res[0]) + matches = re.findall(r"<\|([^|]+)\|>", res[0]) if len(matches) >= 2: emotion = matches[1] text = f"(当前的情绪:{emotion}) {text}" @@ -102,4 +109,4 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): return text except Exception as e: logger.error(f"处理音频文件时出错: {e}") - raise \ No newline at end of file + raise diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 3190c042e..ce474f4ef 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -8,11 +8,16 @@ from ..register import register_provider_adapter from astrbot.core import logger from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav -@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT) + +@register_provider_adapter( + "openai_whisper_api", + "OpenAI Whisper API", + provider_type=ProviderType.SPEECH_TO_TEXT, +) class ProviderOpenAIWhisperAPI(STTProvider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) @@ -23,16 +28,17 @@ class ProviderOpenAIWhisperAPI(STTProvider): base_url=provider_config.get("api_base", None), timeout=provider_config.get("timeout", NOT_GIVEN), ) - + self.set_model(provider_config.get("model", None)) - + async def _convert_audio(self, path: str) -> str: from pyffmpeg import FFmpeg - filename = str(uuid.uuid4()) + '.mp3' + + filename = str(uuid.uuid4()) + ".mp3" ff = FFmpeg() - output_path = ff.convert(path, os.path.join('data/temp', filename)) + output_path = ff.convert(path, os.path.join("data/temp", filename)) return output_path - + async def _is_silk_file(self, file_path): silk_header = b"SILK" with open(file_path, "rb") as f: @@ -44,31 +50,31 @@ class ProviderOpenAIWhisperAPI(STTProvider): return False async def get_text(self, audio_url: str) -> str: - '''only supports mp3, mp4, mpeg, m4a, wav, webm''' + """only supports mp3, mp4, mpeg, m4a, wav, webm""" is_tencent = False - + if audio_url.startswith("http"): if "multimedia.nt.qq.com.cn" in audio_url: is_tencent = True - + name = str(uuid.uuid4()) path = os.path.join("data/temp", name) await download_file(audio_url, path) audio_url = path - + if not os.path.exists(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") - + if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav') + output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav") await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path - + result = await self.client.audio.transcriptions.create( model=self.model_name, file=open(audio_url, "rb"), ) - return result.text \ No newline at end of file + return result.text diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 6b95a57b8..1bbc2a1dc 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -9,30 +9,38 @@ from ..register import register_provider_adapter from astrbot.core import logger from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav -@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT) + +@register_provider_adapter( + "openai_whisper_selfhost", + "OpenAI Whisper 模型部署", + provider_type=ProviderType.SPEECH_TO_TEXT, +) class ProviderOpenAIWhisperSelfHost(STTProvider): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) self.set_model(provider_config.get("model", None)) self.model = None - + async def initialize(self): loop = asyncio.get_event_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") - self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name) + self.model = await loop.run_in_executor( + None, whisper.load_model, self.model_name + ) logger.info("Whisper 模型加载完成。") - + async def _convert_audio(self, path: str) -> str: from pyffmpeg import FFmpeg - filename = str(uuid.uuid4()) + '.mp3' + + filename = str(uuid.uuid4()) + ".mp3" ff = FFmpeg() - output_path = ff.convert(path, os.path.join('data/temp', filename)) + output_path = ff.convert(path, os.path.join("data/temp", filename)) return output_path - + async def _is_silk_file(self, file_path): silk_header = b"SILK" with open(file_path, "rb") as f: @@ -45,28 +53,28 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): async def get_text(self, audio_url: str) -> str: loop = asyncio.get_event_loop() - + is_tencent = False - + if audio_url.startswith("http"): if "multimedia.nt.qq.com.cn" in audio_url: is_tencent = True - + name = str(uuid.uuid4()) path = os.path.join("data/temp", name) await download_file(audio_url, path) audio_url = path - + if not os.path.exists(audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") - + if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav') + output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav") await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path - + result = await loop.run_in_executor(None, self.model.transcribe, audio_url) - return result['text'] \ No newline at end of file + return result["text"] diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 1e1f67802..3e819d633 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -1,4 +1,3 @@ -import traceback from astrbot.core.db import BaseDatabase from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall @@ -7,37 +6,44 @@ from ..register import register_provider_adapter from astrbot.core.provider.entites import LLMResponse from .openai_source import ProviderOpenAIOfficial + @register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器") class ProviderZhipu(ProviderOpenAIOfficial): def __init__( - self, - provider_config: dict, + self, + provider_config: dict, provider_settings: dict, - db_helper: BaseDatabase, - persistant_history = True, - default_persona = None + db_helper: BaseDatabase, + persistant_history=True, + default_persona=None, ) -> None: - super().__init__(provider_config, provider_settings, db_helper, persistant_history, default_persona) - + super().__init__( + provider_config, + provider_settings, + db_helper, + persistant_history, + default_persona, + ) + async def text_chat( self, prompt: str, session_id: str = None, - image_urls: List[str]=None, - func_tool: FuncCall=None, + image_urls: List[str] = None, + func_tool: FuncCall = None, contexts=[], system_prompt=None, - **kwargs - ) -> LLMResponse: + **kwargs, + ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) context_query = [] - + context_query = [*contexts, new_record] - + model_cfgs: dict = self.provider_config.get("model_config", {}) model = self.get_model() # glm-4v-flash 只支持一张图片 - if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1: + if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1: logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片") logger.debug(context_query) new_context_query_ = [] @@ -45,18 +51,15 @@ class ProviderZhipu(ProviderOpenAIOfficial): if isinstance(context_query[i].get("content", ""), list): continue new_context_query_.append(context_query[i]) - new_context_query_.append(context_query[i+1]) - new_context_query_.append(context_query[-1]) # 保留最后一条记录 + new_context_query_.append(context_query[i + 1]) + new_context_query_.append(context_query[-1]) # 保留最后一条记录 context_query = new_context_query_ logger.debug(context_query) - + if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) - - payloads = { - "messages": context_query, - **model_cfgs - } + + payloads = {"messages": context_query, **model_cfgs} try: llm_response = await self._query(payloads, func_tool) return llm_response @@ -64,7 +67,9 @@ class ProviderZhipu(ProviderOpenAIOfficial): if "maximum context length" in str(e): retry_cnt = 10 while retry_cnt > 0: - logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") + logger.warning( + f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。" + ) try: self.pop_record(session_id) llm_response = await self._query(payloads, func_tool) @@ -75,4 +80,4 @@ class ProviderZhipu(ProviderOpenAIOfficial): else: raise e else: - raise e \ No newline at end of file + raise e diff --git a/astrbot/core/rag/embedding/openai_source.py b/astrbot/core/rag/embedding/openai_source.py index 648de0fda..dc09d84dc 100644 --- a/astrbot/core/rag/embedding/openai_source.py +++ b/astrbot/core/rag/embedding/openai_source.py @@ -1,25 +1,20 @@ from typing import List from openai import AsyncOpenAI -class SimpleOpenAIEmbedding(): + +class SimpleOpenAIEmbedding: def __init__( - self, + self, model, api_key, api_base=None, ) -> None: - self.client = AsyncOpenAI( - api_key=api_key, - base_url=api_base - ) + self.client = AsyncOpenAI(api_key=api_key, base_url=api_base) self.model = model - + async def get_embedding(self, text) -> List[float]: - ''' + """ 获取文本的嵌入 - ''' - embedding = await self.client.embeddings.create( - input=text, - model=self.model - ) + """ + embedding = await self.client.embeddings.create(input=text, model=self.model) return embedding.data[0].embedding diff --git a/astrbot/core/rag/knowledge_db_mgr.py b/astrbot/core/rag/knowledge_db_mgr.py index 2ee8199b7..2aed0e448 100644 --- a/astrbot/core/rag/knowledge_db_mgr.py +++ b/astrbot/core/rag/knowledge_db_mgr.py @@ -4,7 +4,8 @@ from astrbot.core import logger from .store import Store from astrbot.core.config import AstrBotConfig -class KnowledgeDBManager(): + +class KnowledgeDBManager: def __init__(self, astrbot_config: AstrBotConfig) -> None: self.db_path = "data/knowledge_db/" self.config = astrbot_config.get("knowledge_db", {}) @@ -20,23 +21,27 @@ class KnowledgeDBManager(): except ImportError as ie: logger.error(f"{ie} 可能未安装 chromadb 库。") continue - self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"]) + self.store_insts[name] = ChromaVectorStore( + name, cfg["embedding_config"] + ) else: logger.error(f"不支持的策略:{cfg['strategy']}") - async def list_knowledge_db(self) -> List[str]: - return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))] - - + return [ + f + for f in os.listdir(self.db_path) + if os.path.isfile(os.path.join(self.db_path, f)) + ] + async def create_knowledge_db(self, name: str, config: Dict): - ''' + """ config 格式: ``` { "strategy": "embedding", # 目前只支持 embedding "chunk_method": { - "strategy": "fixed", + "strategy": "fixed", "chunk_size": 100, "overlap_size": 10 }, @@ -48,40 +53,37 @@ class KnowledgeDBManager(): } } ``` - ''' + """ if name in self.config: raise ValueError(f"知识库已存在:{name}") - + self.config[name] = config self.astrbot_config["knowledge_db"] = self.config self.astrbot_config.save_config() - - + async def insert_record(self, name: str, text: str): if name not in self.store_insts: raise ValueError(f"未找到知识库:{name}") - + ret = [] - match self.config[name]["chunk_method"]['strategy']: + match self.config[name]["chunk_method"]["strategy"]: case "fixed": chunk_size = self.config[name]["chunk_method"]["chunk_size"] chunk_overlap = self.config[name]["chunk_method"]["overlap_size"] ret = self._fixed_chunk(text, chunk_size, chunk_overlap) case _: pass - + for chunk in ret: await self.store_insts[name].save(chunk) - async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]: if name not in self.store_insts: raise ValueError(f"未找到知识库:{name}") - + inst = self.store_insts[name] return await inst.query(query, top_n) - - + def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]: chunks = [] start = 0 @@ -89,4 +91,4 @@ class KnowledgeDBManager(): end = start + chunk_size chunks.append(text[start:end]) start += chunk_size - chunk_overlap - return chunks \ No newline at end of file + return chunks diff --git a/astrbot/core/rag/store/__init__.py b/astrbot/core/rag/store/__init__.py index cd4a3060a..0e74c5a07 100644 --- a/astrbot/core/rag/store/__init__.py +++ b/astrbot/core/rag/store/__init__.py @@ -1,8 +1,9 @@ from typing import List -class Store(): + +class Store: async def save(self, text: str): pass - + async def query(self, query: str, top_n: int = 3) -> List[str]: pass diff --git a/astrbot/core/rag/store/chroma_db.py b/astrbot/core/rag/store/chroma_db.py index 58ee9d9fb..30befb978 100644 --- a/astrbot/core/rag/store/chroma_db.py +++ b/astrbot/core/rag/store/chroma_db.py @@ -5,35 +5,38 @@ from astrbot.api import logger from ..embedding.openai_source import SimpleOpenAIEmbedding from . import Store + class ChromaVectorStore(Store): def __init__(self, name: str, embedding_cfg: Dict) -> None: - self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db') + self.chroma_client = chromadb.PersistentClient( + path="data/long_term_memory_chroma.db" + ) self.collection = self.chroma_client.get_or_create_collection(name=name) self.embedding = None if embedding_cfg["strategy"] == "openai": self.embedding = SimpleOpenAIEmbedding( model=embedding_cfg["model"], api_key=embedding_cfg["api_key"], - api_base=embedding_cfg.get("base_url", None) + api_base=embedding_cfg.get("base_url", None), ) - + async def save(self, text: str, metadata: Dict = None): logger.debug(f"Saving text: {text}") embedding = await self.embedding.get_embedding(text) - + self.collection.upsert( - documents=text, + documents=text, metadatas=metadata, ids=str(uuid.uuid4()), - embeddings=embedding + embeddings=embedding, ) - - async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]: + + async def query( + self, query: str, top_n=3, metadata_filter: Dict = None + ) -> List[str]: embedding = await self.embedding.get_embedding(query) - + results = self.collection.query( - query_embeddings=embedding, - n_results=top_n, - where=metadata_filter + query_embeddings=embedding, n_results=top_n, where=metadata_filter ) - return results['documents'][0] + return results["documents"][0] diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 28cb9a909..b1bd5de81 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -5,27 +5,26 @@ from astrbot.core.provider import Provider from astrbot.core.utils.command_parser import CommandParserMixin from astrbot.core import html_renderer + class Star(CommandParserMixin): - '''所有插件(Star)的父类,所有插件都应该继承于这个类''' + """所有插件(Star)的父类,所有插件都应该继承于这个类""" + def __init__(self, context: Context): self.context = context - - async def text_to_image(self, text: str, return_url = True) -> str: - '''将文本转换为图片''' + + async def text_to_image(self, text: str, return_url=True) -> str: + """将文本转换为图片""" return await html_renderer.render_t2i(text, return_url=return_url) - - async def html_render(self, tmpl: str, data: dict, return_url = True) -> str: - '''渲染 HTML''' - return await html_renderer.render_custom_template(tmpl, data, return_url=return_url) - + + async def html_render(self, tmpl: str, data: dict, return_url=True) -> str: + """渲染 HTML""" + return await html_renderer.render_custom_template( + tmpl, data, return_url=return_url + ) + async def terminate(self): - '''当插件被禁用、重载插件时会调用这个方法''' + """当插件被禁用、重载插件时会调用这个方法""" pass - -__all__ = [ - 'Star', - 'StarMetadata', - 'PluginManager', - 'Context', - 'Provider' -] \ No newline at end of file + + +__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"] diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index 83bb9f4cb..dc07fe6f5 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -1,17 +1,18 @@ -''' +""" 此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta -''' +""" from typing import Union import os import json + def load_config(namespace: str) -> Union[dict, bool]: - ''' + """ 从配置文件中加载配置。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。 - ''' + """ path = f"data/config/{namespace}.json" if not os.path.exists(path): return False @@ -21,9 +22,10 @@ def load_config(namespace: str) -> Union[dict, bool]: for k in data: ret[k] = data[k]["value"] return ret - + + def put_config(namespace: str, name: str, key: str, value, description: str): - ''' + """ 将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 name: str, 配置项的显示名字。 @@ -32,7 +34,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): description: str, 配置项的描述。 注意:只有当 namespace 为插件名(info 函数中的 name)时,该配置才会显示到可视化面板上。 注意:value一定要是该配置项对应类型的值,否则类型判断会乱。 - ''' + """ if namespace == "": raise ValueError("namespace 不能为空。") if namespace.startswith("internal_"): @@ -47,7 +49,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): f.write("{}") with open(path, "r", encoding="utf-8-sig") as f: d = json.load(f) - assert(isinstance(d, dict)) + assert isinstance(d, dict) if key not in d: d[key] = { "config_type": "item", @@ -55,32 +57,29 @@ def put_config(namespace: str, name: str, key: str, value, description: str): "description": description, "path": key, "value": value, - "val_type": type(value).__name__ + "val_type": type(value).__name__, } with open(path, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=2, ensure_ascii=False) f.flush() - + + def update_config(namespace: str, key: str, value): - ''' + """ 更新配置文件中的配置项。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 key: str, 配置项的键。 value: str, int, float, bool, list, 配置项的值。 - ''' + """ path = f"data/config/{namespace}.json" if not os.path.exists(path): raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。") with open(path, "r", encoding="utf-8-sig") as f: d = json.load(f) - assert(isinstance(d, dict)) + assert isinstance(d, dict) if key not in d: raise KeyError(f"配置项 {key} 不存在。") d[key]["value"] = value with open(path, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=2, ensure_ascii=False) - f.flush() - - - - \ No newline at end of file + f.flush() diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 2ebb20f8e..6d2403321 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -18,38 +18,43 @@ from .filter.regex import RegexFilter from typing import Awaitable from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager from astrbot.core.conversation_mgr import ConversationManager -from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterType, ADAPTER_NAME_2_TYPE +from astrbot.core.star.filter.platform_adapter_type import ( + PlatformAdapterType, + ADAPTER_NAME_2_TYPE, +) class Context: - ''' + """ 暴露给插件的接口上下文。 - ''' + """ + _event_queue: Queue = None - '''事件队列。消息平台通过事件队列传递消息事件。''' - + """事件队列。消息平台通过事件队列传递消息事件。""" + _config: AstrBotConfig = None - '''AstrBot 配置信息''' - + """AstrBot 配置信息""" + _db: BaseDatabase = None - '''AstrBot 数据库''' - + """AstrBot 数据库""" + provider_manager: ProviderManager = None - + platform_manager: PlatformManager = None - + # back compatibility _register_tasks: List[Awaitable] = [] _star_manager = None - def __init__(self, - event_queue: Queue, - config: AstrBotConfig, - db: BaseDatabase, - provider_manager: ProviderManager = None, + def __init__( + self, + event_queue: Queue, + config: AstrBotConfig, + db: BaseDatabase, + provider_manager: ProviderManager = None, platform_manager: PlatformManager = None, conversation_manager: ConversationManager = None, - knowledge_db_manager: KnowledgeDBManager = None + knowledge_db_manager: KnowledgeDBManager = None, ): self._event_queue = event_queue self._config = config @@ -60,122 +65,123 @@ class Context: self.conversation_manager = conversation_manager def get_registered_star(self, star_name: str) -> StarMetadata: - '''根据插件名获取插件的 Metadata''' + """根据插件名获取插件的 Metadata""" for star in star_registry: if star.name == star_name: return star def get_all_stars(self) -> List[StarMetadata]: - '''获取当前载入的所有插件 Metadata 的列表''' + """获取当前载入的所有插件 Metadata 的列表""" return star_registry - + def get_llm_tool_manager(self) -> FuncCall: - '''获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools''' + """获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools""" return self.provider_manager.llm_tools - + def activate_llm_tool(self, name: str) -> bool: - '''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。 - + """激活一个已经注册的函数调用工具。注册的工具默认是激活状态。 + Returns: 如果没找到,会返回 False - ''' + """ func_tool = self.provider_manager.llm_tools.get_func(name) if func_tool is not None: - if func_tool.handler_module_path in star_map: if not star_map[func_tool.handler_module_path].activated: - raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。") - + raise ValueError( + f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。" + ) + func_tool.active = True - + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) if name in inactivated_llm_tools: inactivated_llm_tools.remove(name) sp.put("inactivated_llm_tools", inactivated_llm_tools) - + return True return False - + def deactivate_llm_tool(self, name: str) -> bool: - '''停用一个已经注册的函数调用工具。 - + """停用一个已经注册的函数调用工具。 + Returns: - 如果没找到,会返回 False''' + 如果没找到,会返回 False""" func_tool = self.provider_manager.llm_tools.get_func(name) if func_tool is not None: func_tool.active = False - + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) if name not in inactivated_llm_tools: inactivated_llm_tools.append(name) sp.put("inactivated_llm_tools", inactivated_llm_tools) - + return True return False - + def register_provider(self, provider: Provider): - ''' + """ 注册一个 LLM Provider(Chat_Completion 类型)。 - ''' + """ self.provider_manager.provider_insts.append(provider) - + def get_provider_by_id(self, provider_id: str) -> Provider: - '''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。''' + """通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" for provider in self.provider_manager.provider_insts: if provider.meta().id == provider_id: return provider return None - + def get_all_providers(self) -> List[Provider]: - '''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。''' + """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.provider_insts - + def get_all_tts_providers(self) -> List[TTSProvider]: - '''获取所有用于 TTS 任务的 Provider。''' + """获取所有用于 TTS 任务的 Provider。""" return self.provider_manager.tts_provider_insts - + def get_all_stt_providers(self) -> List[STTProvider]: - '''获取所有用于 STT 任务的 Provider。''' + """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - + def get_using_provider(self) -> Provider: - ''' + """ 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。 - + 通过 /provider 指令切换。 - ''' + """ return self.provider_manager.curr_provider_inst - + def get_using_tts_provider(self) -> TTSProvider: - ''' + """ 获取当前使用的用于 TTS 任务的 Provider。 - ''' + """ return self.provider_manager.curr_tts_provider_inst - + def get_using_stt_provider(self) -> STTProvider: - ''' + """ 获取当前使用的用于 STT 任务的 Provider。 - ''' + """ return self.provider_manager.curr_stt_provider_inst - + def get_config(self) -> AstrBotConfig: - '''获取 AstrBot 的配置。''' + """获取 AstrBot 的配置。""" return self._config - + def get_db(self) -> BaseDatabase: - '''获取 AstrBot 数据库。''' + """获取 AstrBot 数据库。""" return self._db - + def get_event_queue(self) -> Queue: - ''' + """ 获取事件队列。 - ''' + """ return self._event_queue - + def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform: - ''' + """ 获取指定类型的平台适配器。 - ''' + """ for platform in self.platform_manager.platform_insts: if isinstance(platform_type, str): if platform.meta().name == platform_type: @@ -183,21 +189,23 @@ class Context: else: if platform.meta().name == ADAPTER_NAME_2_TYPE[platform_type]: return platform - - async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool: - ''' + + async def send_message( + self, session: Union[str, MessageSesion], message_chain: MessageChain + ) -> bool: + """ 根据 session(unified_msg_origin) 主动发送消息。 - + @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 @param message_chain: 消息链。 - + @return: 是否找到匹配的平台。 - + 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 - + NOTE: qq_official(QQ 官方 API 平台) 不支持此方法 - ''' - + """ + if isinstance(session, str): try: session = MessageSesion.from_str(session) @@ -209,22 +217,24 @@ class Context: await platform.send_by_session(session, message_chain) return True return False - - ''' + + """ 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 - ''' - - def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: - ''' + """ + + def register_llm_tool( + self, name: str, func_args: list, desc: str, func_obj: Awaitable + ) -> None: + """ 为函数调用(function-calling / tools-use)添加工具。 - + @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @param desc: 函数描述 @param func_obj: 异步处理函数。 - + 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 - ''' + """ md = StarHandlerMetadata( event_type=EventType.OnLLMRequestEvent, handler_full_name=func_obj.__module__ + "_" + func_obj.__name__, @@ -232,28 +242,39 @@ class Context: handler_module_path=func_obj.__module__, handler=func_obj, event_filters=[], - desc=desc + desc=desc, ) star_handlers_registry.append(md) - self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj) - + self.provider_manager.llm_tools.add_func( + name, func_args, desc, func_obj, func_obj + ) + def unregister_llm_tool(self, name: str) -> None: - '''删除一个函数调用工具。如果再要启用,需要重新注册。''' + """删除一个函数调用工具。如果再要启用,需要重新注册。""" self.provider_manager.llm_tools.remove_func(name) - - def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False): - ''' + + def register_commands( + self, + star_name: str, + command_name: str, + desc: str, + priority: int, + awaitable: Awaitable, + use_regex=False, + ignore_prefix=False, + ): + """ 注册一个命令。 - + [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 - + @param star_name: 插件(Star)名称。 @param command_name: 命令名称。 @param desc: 命令描述。 @param priority: 优先级。1-10。 @param awaitable: 异步处理函数。 - - ''' + + """ md = StarHandlerMetadata( event_type=EventType.AdapterMessageEvent, handler_full_name=awaitable.__module__ + "_" + awaitable.__name__, @@ -261,21 +282,18 @@ class Context: handler_module_path=awaitable.__module__, handler=awaitable, event_filters=[], - desc=desc + desc=desc, ) if use_regex: - md.event_filters.append(RegexFilter( - regex=command_name - )) + md.event_filters.append(RegexFilter(regex=command_name)) else: - md.event_filters.append(CommandFilter( - command_name=command_name, - handler_md=md - )) + md.event_filters.append( + CommandFilter(command_name=command_name, handler_md=md) + ) star_handlers_registry.append(md) - + def register_task(self, task: Awaitable, desc: str): - ''' + """ 注册一个异步任务。 - ''' - self._register_tasks.append(task) \ No newline at end of file + """ + self._register_tasks.append(task) diff --git a/astrbot/core/star/filter/__init__.py b/astrbot/core/star/filter/__init__.py index bada25d6f..c2f78e275 100644 --- a/astrbot/core/star/filter/__init__.py +++ b/astrbot/core/star/filter/__init__.py @@ -3,8 +3,12 @@ from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig + class HandlerFilter(abc.ABC): @abc.abstractmethod def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: - '''是否应当被过滤''' + """是否应当被过滤""" raise NotImplementedError + + +__all__ = ["HandlerFilter", "MessageType", "AstrMessageEvent", "AstrBotConfig"] diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index ea20be1a5..93746980c 100644 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -1,4 +1,3 @@ - import re import inspect from typing import List, Any, Type, Dict @@ -8,17 +7,25 @@ from astrbot.core.config import AstrBotConfig from .custom_filter import CustomFilter from ..star_handler import StarHandlerMetadata + # 标准指令受到 wake_prefix 的制约。 class CommandFilter(HandlerFilter): - '''标准指令过滤器''' - def __init__(self, command_name: str, alias: set = None, handler_md: StarHandlerMetadata = None, parent_command_names: List[str] = [""]): + """标准指令过滤器""" + + def __init__( + self, + command_name: str, + alias: set = None, + handler_md: StarHandlerMetadata = None, + parent_command_names: List[str] = [""], + ): self.command_name = command_name self.alias = alias if alias else set() self.parent_command_names = parent_command_names if handler_md: self.init_handler_md(handler_md) self.custom_filter_list: List[CustomFilter] = [] - + def print_types(self): result = "" for k, v in self.handler_params.items(): @@ -28,11 +35,11 @@ class CommandFilter(HandlerFilter): result += f"{k}({type(v).__name__})={v}," result = result.rstrip(",") return result - + def init_handler_md(self, handle_md: StarHandlerMetadata): self.handler_md = handle_md signature = inspect.signature(self.handler_md.handler) - self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 + self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 idx = 0 for k, v in signature.parameters.items(): if idx < 2: @@ -43,7 +50,7 @@ class CommandFilter(HandlerFilter): self.handler_params[k] = v.annotation else: self.handler_params[k] = v.default - + def get_handler_md(self) -> StarHandlerMetadata: return self.handler_md @@ -55,16 +62,22 @@ class CommandFilter(HandlerFilter): if not custom_filter.filter(event, cfg): return False return True - - def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]: - '''将参数列表 params 根据 param_type 转换为参数字典。 - ''' + + def validate_and_convert_params( + self, params: List[Any], param_type: Dict[str, Type] + ) -> Dict[str, Any]: + """将参数列表 params 根据 param_type 转换为参数字典。""" result = {} for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()): if i >= len(params): - if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty: + if ( + isinstance(param_type_or_default_val, Type) + or param_type_or_default_val is inspect.Parameter.empty + ): # 是类型 - raise ValueError(f"必要参数缺失。该指令完整参数: {self.print_types()}") + raise ValueError( + f"必要参数缺失。该指令完整参数: {self.print_types()}" + ) else: # 是默认值 result[param_name] = param_type_or_default_val @@ -86,7 +99,9 @@ class CommandFilter(HandlerFilter): else: result[param_name] = param_type_or_default_val(params[i]) except ValueError: - raise ValueError(f"参数 {param_name} 类型错误。完整参数: {self.print_types()}") + raise ValueError( + f"参数 {param_name} 类型错误。完整参数: {self.print_types()}" + ) return result def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: @@ -95,7 +110,7 @@ class CommandFilter(HandlerFilter): if not self.custom_filter_ok(event, cfg): return False - + # 检查是否以指令开头 message_str = re.sub(r"\s+", " ", event.get_message_str().strip()) candidates = [self.command_name] + list(self.alias) @@ -107,7 +122,7 @@ class CommandFilter(HandlerFilter): else: _full = candidate if message_str.startswith(f"{_full} ") or message_str == _full: - message_str = message_str[len(_full):].strip() + message_str = message_str[len(_full) :].strip() ok = True break if not ok: @@ -122,7 +137,7 @@ class CommandFilter(HandlerFilter): params = self.validate_and_convert_params(ls, self.handler_params) except ValueError as e: raise e - + event.set_extra("parsed_params", params) - - return True \ No newline at end of file + + return True diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 0c596feaf..55106ad9b 100644 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -1,49 +1,57 @@ from __future__ import annotations -import re -from typing import List, Union, Tuple +from typing import List, Union from . import HandlerFilter from .command import CommandFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig from .custom_filter import CustomFilter -from ..star_handler import StarHandlerMetadata + # 指令组受到 wake_prefix 的制约。 class CommandGroupFilter(HandlerFilter): - def __init__(self, group_name: str, alias: set = None, parent_group: CommandGroupFilter = None): + def __init__( + self, + group_name: str, + alias: set = None, + parent_group: CommandGroupFilter = None, + ): self.group_name = group_name self.alias = alias if alias else set() self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = [] self.custom_filter_list: List[CustomFilter] = [] self.parent_group = parent_group - - def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]): + + def add_sub_command_filter( + self, sub_command_filter: Union[CommandFilter, CommandGroupFilter] + ): self.sub_command_filters.append(sub_command_filter) def add_custom_filter(self, custom_filter: CustomFilter): self.custom_filter_list.append(custom_filter) def get_complete_command_names(self) -> List[str]: - '''遍历父节点获取完整的指令名。 - - 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。''' - parent_cmd_names = self.parent_group.get_complete_command_names() if self.parent_group else [] - + """遍历父节点获取完整的指令名。 + + 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。""" + parent_cmd_names = ( + self.parent_group.get_complete_command_names() if self.parent_group else [] + ) + if not parent_cmd_names: # 根节点 return [self.group_name] + list(self.alias) - + result = [] candidates = [self.group_name] + list(self.alias) for parent_cmd_name in parent_cmd_names: for candidate in candidates: result.append(parent_cmd_name + " " + candidate) return result - # 以树的形式打印出来 - def print_cmd_tree(self, + def print_cmd_tree( + self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "", event: AstrMessageEvent = None, @@ -62,7 +70,7 @@ class CommandGroupFilter(HandlerFilter): result += f" ({cmd_th})" else: result += " (无参数指令)" - + if sub_filter.handler_md and sub_filter.handler_md.desc: result += f": {sub_filter.handler_md.desc}" @@ -74,7 +82,12 @@ class CommandGroupFilter(HandlerFilter): if custom_filter_pass: result += f"{prefix}├── {sub_filter.group_name}" result += "\n" - result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ", event=event, cfg=cfg) + result += sub_filter.print_cmd_tree( + sub_filter.sub_command_filters, + prefix + "│ ", + event=event, + cfg=cfg, + ) return result @@ -94,9 +107,16 @@ class CommandGroupFilter(HandlerFilter): complete_command_names = self.get_complete_command_names() if event.message_str.strip() in complete_command_names: - tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) - raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree) - + tree = ( + self.group_name + + "\n" + + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) + ) + raise ValueError( + f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n" + + tree + ) + # complete_command_names = [name + " " for name in complete_command_names] # return event.message_str.startswith(tuple(complete_command_names)) - return False \ No newline at end of file + return False diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index 5be1b8dbe..9a76b74f2 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -4,6 +4,7 @@ from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig + class CustomFilterMeta(ABCMeta): def __and__(cls, other): if not issubclass(other, CustomFilter): @@ -15,13 +16,14 @@ class CustomFilterMeta(ABCMeta): raise TypeError("Operands must be subclasses of CustomFilter.") return CustomFilterOr(cls(), other()) + class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): - def __init__(self, raise_error: bool = True, **kwargs): + def __init__(self, raise_error: bool = True, **kwargs): self.raise_error = raise_error @abstractmethod def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: - ''' 一个用于重写的自定义Filter ''' + """一个用于重写的自定义Filter""" raise NotImplementedError def __or__(self, other): @@ -30,22 +32,28 @@ class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): def __and__(self, other): return CustomFilterAnd(self, other) + class CustomFilterOr(CustomFilter): def __init__(self, filter1: CustomFilter, filter2: CustomFilter): super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): - raise ValueError("CustomFilter lass can only operate with other CustomFilter.") + raise ValueError( + "CustomFilter lass can only operate with other CustomFilter." + ) self.filter1 = filter1 self.filter2 = filter2 def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: return self.filter1.filter(event, cfg) or self.filter2.filter(event, cfg) + class CustomFilterAnd(CustomFilter): def __init__(self, filter1: CustomFilter, filter2: CustomFilter): super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): - raise ValueError("CustomFilter lass can only operate with other CustomFilter.") + raise ValueError( + "CustomFilter lass can only operate with other CustomFilter." + ) self.filter1 = filter1 self.filter2 = filter2 diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py index 5e16e2e75..ce36ec9ed 100644 --- a/astrbot/core/star/filter/event_message_type.py +++ b/astrbot/core/star/filter/event_message_type.py @@ -4,25 +4,28 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig from astrbot.core.platform.message_type import MessageType + class EventMessageType(enum.Flag): GROUP_MESSAGE = enum.auto() PRIVATE_MESSAGE = enum.auto() OTHER_MESSAGE = enum.auto() ALL = GROUP_MESSAGE | PRIVATE_MESSAGE | OTHER_MESSAGE - + + MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE = { MessageType.GROUP_MESSAGE: EventMessageType.GROUP_MESSAGE, MessageType.FRIEND_MESSAGE: EventMessageType.PRIVATE_MESSAGE, - MessageType.OTHER_MESSAGE: EventMessageType.OTHER_MESSAGE + MessageType.OTHER_MESSAGE: EventMessageType.OTHER_MESSAGE, } + class EventMessageTypeFilter(HandlerFilter): def __init__(self, event_message_type: EventMessageType): self.event_message_type = event_message_type - + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: message_type = event.get_message_type() if message_type in MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE: event_message_type = MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE[message_type] return bool(event_message_type & self.event_message_type) - return False \ No newline at end of file + return False diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 5e961aa5b..307b492a4 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -3,20 +3,21 @@ from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig + class PermissionType(enum.Flag): - '''权限类型。当选择 MEMBER,ADMIN 也可以通过。 - ''' + """权限类型。当选择 MEMBER,ADMIN 也可以通过。""" + ADMIN = enum.auto() MEMBER = enum.auto() + class PermissionTypeFilter(HandlerFilter): def __init__(self, permission_type: PermissionType, raise_error: bool = True): self.permission_type = permission_type self.raise_error = raise_error - + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: - '''过滤器 - ''' + """过滤器""" if self.permission_type == PermissionType.ADMIN: if not event.is_admin(): # event.stop_event() diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 17bffcf34..0926cc337 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -4,6 +4,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig from typing import Union + class PlatformAdapterType(enum.Flag): AIOCQHTTP = enum.auto() QQOFFICIAL = enum.auto() @@ -13,7 +14,8 @@ class PlatformAdapterType(enum.Flag): WECOM = enum.auto() LARK = enum.auto() ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT | TELEGRAM | WECOM | LARK - + + ADAPTER_NAME_2_TYPE = { "aiocqhttp": PlatformAdapterType.AIOCQHTTP, "qq_official": PlatformAdapterType.QQOFFICIAL, @@ -21,15 +23,16 @@ ADAPTER_NAME_2_TYPE = { "gewechat": PlatformAdapterType.GEWECHAT, "telegram": PlatformAdapterType.TELEGRAM, "wecom": PlatformAdapterType.WECOM, - "lark": PlatformAdapterType.LARK + "lark": PlatformAdapterType.LARK, } + class PlatformAdapterTypeFilter(HandlerFilter): def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]): self.type_or_str = platform_adapter_type_or_str - + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: adapter_name = event.get_platform_name() if adapter_name in ADAPTER_NAME_2_TYPE: return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str - return False \ No newline at end of file + return False diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py index 816b14109..af9cb3a5a 100644 --- a/astrbot/core/star/filter/regex.py +++ b/astrbot/core/star/filter/regex.py @@ -1,15 +1,16 @@ - import re from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig + # 正则表达式过滤器不会受到 wake_prefix 的制约。 class RegexFilter(HandlerFilter): - '''正则表达式过滤器''' + """正则表达式过滤器""" + def __init__(self, regex: str): self.regex_str = regex self.regex = re.compile(regex) - + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: - return bool(self.regex.match(event.get_message_str().strip())) \ No newline at end of file + return bool(self.regex.match(event.get_message_str().strip())) diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index 705d026be..fa6a730ba 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -12,22 +12,22 @@ from .star_handler import ( register_on_llm_response, register_llm_tool, register_on_decorating_result, - register_after_message_sent + register_after_message_sent, ) __all__ = [ - 'register_star', - 'register_command', - 'register_command_group', - 'register_event_message_type', - 'register_platform_adapter_type', - 'register_regex', - 'register_permission_type', - 'register_custom_filter', - 'register_on_astrbot_loaded', - 'register_on_llm_request', - 'register_on_llm_response', - 'register_llm_tool', - 'register_on_decorating_result', - 'register_after_message_sent' -] \ No newline at end of file + "register_star", + "register_command", + "register_command_group", + "register_event_message_type", + "register_platform_adapter_type", + "register_regex", + "register_permission_type", + "register_custom_filter", + "register_on_astrbot_loaded", + "register_on_llm_request", + "register_on_llm_response", + "register_llm_tool", + "register_on_decorating_result", + "register_after_message_sent", +] diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py index c3d098512..01ff9adaa 100644 --- a/astrbot/core/star/register/star.py +++ b/astrbot/core/star/register/star.py @@ -1,24 +1,26 @@ from ..star import star_registry, StarMetadata, star_map + def register_star(name: str, author: str, desc: str, version: str, repo: str = None): - '''注册一个插件(Star)。 - + """注册一个插件(Star)。 + Args: name: 插件名称。 author: 作者。 desc: 插件的简述。 version: 版本号。 repo: 仓库地址。如果没有填写仓库地址,将无法更新这个插件。 - + 如果需要为插件填写帮助信息,请使用如下格式: - + ```python class MyPlugin(star.Star): \'\'\'这是帮助信息\'\'\' ... - + 帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。` - ''' + """ + def decorator(cls): star_metadata = StarMetadata( name=name, @@ -32,5 +34,5 @@ def register_star(name: str, author: str, desc: str, version: str, repo: str = N star_registry.append(star_metadata) star_map[cls.__module__] = star_metadata return cls - + return decorator diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 999fd562a..4e2f9d176 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -5,7 +5,10 @@ from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventTyp from ..filter.command import CommandFilter from ..filter.command_group import CommandGroupFilter from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType -from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType +from ..filter.platform_adapter_type import ( + PlatformAdapterTypeFilter, + PlatformAdapterType, +) from ..filter.permission import PermissionTypeFilter, PermissionType from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.regex import RegexFilter @@ -14,17 +17,16 @@ from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools from astrbot.core import logger + def get_handler_full_name(awaitable: Awaitable) -> str: - '''获取 Handler 的全名''' + """获取 Handler 的全名""" return f"{awaitable.__module__}_{awaitable.__name__}" + def get_handler_or_create( - handler: Awaitable, - event_type: EventType, - dont_add = False, - **kwargs + handler: Awaitable, event_type: EventType, dont_add=False, **kwargs ) -> StarHandlerMetadata: - '''获取 Handler 或者创建一个新的 Handler''' + """获取 Handler 或者创建一个新的 Handler""" handler_full_name = get_handler_full_name(handler) md = star_handlers_registry.get_handler_by_full_name(handler_full_name) if md: @@ -36,40 +38,48 @@ def get_handler_or_create( handler_name=handler.__name__, handler_module_path=handler.__module__, handler=handler, - event_filters=[] + event_filters=[], ) - + # 插件handler的附加额外信息 if handler.__doc__: md.desc = handler.__doc__.strip() - if 'desc' in kwargs: - md.desc = kwargs['desc'] - del kwargs['desc'] + if "desc" in kwargs: + md.desc = kwargs["desc"] + del kwargs["desc"] md.extras_configs = kwargs - + if not dont_add: star_handlers_registry.append(md) return md -def register_command(command_name: str = None, sub_command: str = None, alias: set = None, **kwargs): - '''注册一个 Command. - ''' + +def register_command( + command_name: str = None, sub_command: str = None, alias: set = None, **kwargs +): + """注册一个 Command.""" new_command = None add_to_event_filters = False if isinstance(command_name, RegisteringCommandable): # 子指令 parent_command_names = command_name.parent_group.get_complete_command_names() - new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names) + new_command = CommandFilter( + sub_command, alias, None, parent_command_names=parent_command_names + ) command_name.parent_group.add_sub_command_filter(new_command) else: # 裸指令 new_command = CommandFilter(command_name, alias, None) add_to_event_filters = True - + def decorator(awaitable): if not add_to_event_filters: - kwargs['sub_command'] = True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管) - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + kwargs["sub_command"] = ( + True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管) + ) + handler_md = get_handler_or_create( + awaitable, EventType.AdapterMessageEvent, **kwargs + ) new_command.init_handler_md(handler_md) handler_md.event_filters.append(new_command) return awaitable @@ -78,13 +88,13 @@ def register_command(command_name: str = None, sub_command: str = None, alias: s def register_custom_filter(custom_type_filter, *args, **kwargs): - '''注册一个自定义的 CustomFilter + """注册一个自定义的 CustomFilter Args: custom_type_filter: 在裸指令时为CustomFilter对象 在指令组时为父指令的RegisteringCommandable对象,即self或者command_group的返回 raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True - ''' + """ add_to_event_filters = False raise_error = True @@ -107,56 +117,75 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): def decorator(awaitable): # 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。 - if not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) or \ - (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)): + if ( + not add_to_event_filters + and isinstance(awaitable, RegisteringCommandable) + or (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)) + ): # 指令组 与 根指令组,添加到本层的grouphandle中一起判断 awaitable.parent_group.add_custom_filter(custom_filter) else: - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + handler_md = get_handler_or_create( + awaitable, EventType.AdapterMessageEvent, **kwargs + ) - if not add_to_event_filters and not isinstance(awaitable, RegisteringCommandable): + if not add_to_event_filters and not isinstance( + awaitable, RegisteringCommandable + ): # 底层子指令 handle_full_name = get_handler_full_name(awaitable) - for sub_handle in parent_register_commandable.parent_group.sub_command_filters: + for ( + sub_handle + ) in parent_register_commandable.parent_group.sub_command_filters: # 所有符合fullname一致的子指令handle添加自定义过滤器。 # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? sub_handle_md = sub_handle.get_handler_md() - if sub_handle_md and sub_handle_md.handler_full_name == handle_full_name: + if ( + sub_handle_md + and sub_handle_md.handler_full_name == handle_full_name + ): sub_handle.add_custom_filter(custom_filter) else: # 裸指令 - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + handler_md = get_handler_or_create( + awaitable, EventType.AdapterMessageEvent, **kwargs + ) handler_md.event_filters.append(custom_filter) return awaitable + return decorator + def register_command_group( command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs ): - '''注册一个 CommandGroup - ''' + """注册一个 CommandGroup""" new_group = None if isinstance(command_group_name, RegisteringCommandable): # 子指令组 - new_group = CommandGroupFilter(sub_command, alias, parent_group=command_group_name.parent_group) + new_group = CommandGroupFilter( + sub_command, alias, parent_group=command_group_name.parent_group + ) command_group_name.parent_group.add_sub_command_filter(new_group) else: # 根指令组 new_group = CommandGroupFilter(command_group_name, alias) - + def decorator(obj): # 根指令组 handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs) handler_md.event_filters.append(new_group) - + return RegisteringCommandable(new_group) return decorator -class RegisteringCommandable(): - '''用于指令组级联注册''' + +class RegisteringCommandable: + """用于指令组级联注册""" + group: CommandGroupFilter = register_command_group command: CommandFilter = register_command custom_filter = register_custom_filter @@ -164,164 +193,199 @@ class RegisteringCommandable(): def __init__(self, parent_group: CommandGroupFilter): self.parent_group = parent_group + def register_event_message_type(event_message_type: EventMessageType, **kwargs): - '''注册一个 EventMessageType''' + """注册一个 EventMessageType""" + def decorator(awaitable): - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + handler_md = get_handler_or_create( + awaitable, EventType.AdapterMessageEvent, **kwargs + ) handler_md.event_filters.append(EventMessageTypeFilter(event_message_type)) return awaitable return decorator -def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs): - '''注册一个 PlatformAdapterType''' + +def register_platform_adapter_type( + platform_adapter_type: PlatformAdapterType, **kwargs +): + """注册一个 PlatformAdapterType""" + def decorator(awaitable): handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) - handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type)) + handler_md.event_filters.append( + PlatformAdapterTypeFilter(platform_adapter_type) + ) return awaitable return decorator + def register_regex(regex: str, **kwargs): - '''注册一个 Regex''' + """注册一个 Regex""" + def decorator(awaitable): - handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs) + handler_md = get_handler_or_create( + awaitable, EventType.AdapterMessageEvent, **kwargs + ) handler_md.event_filters.append(RegexFilter(regex)) return awaitable return decorator + def register_permission_type(permission_type: PermissionType, raise_error: bool = True): - '''注册一个 PermissionType - + """注册一个 PermissionType + Args: permission_type: PermissionType raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True - ''' + """ + def decorator(awaitable): handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) - handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error)) + handler_md.event_filters.append( + PermissionTypeFilter(permission_type, raise_error) + ) return awaitable return decorator + def register_on_astrbot_loaded(**kwargs): - '''当 AstrBot 加载完成时 - ''' + """当 AstrBot 加载完成时""" + def decorator(awaitable): _ = get_handler_or_create(awaitable, EventType.OnAstrBotLoadedEvent, **kwargs) return awaitable - + return decorator + def register_on_llm_request(**kwargs): - '''当有 LLM 请求时的事件 - + """当有 LLM 请求时的事件 + Examples: ```py from astrbot.api.provider import ProviderRequest - + @on_llm_request() async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None: request.system_prompt += "你是一个猫娘..." ``` - + 请务必接收两个参数:event, request - ''' + """ + def decorator(awaitable): _ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs) return awaitable - + return decorator + def register_on_llm_response(**kwargs): - '''当有 LLM 请求后的事件 - + """当有 LLM 请求后的事件 + Examples: ```py from astrbot.api.provider import LLMResponse - + @on_llm_response() async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None: ... ``` - + 请务必接收两个参数:event, request - ''' + """ + def decorator(awaitable): _ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs) return awaitable - + return decorator def register_llm_tool(name: str = None): - '''为函数调用(function-calling / tools-use)添加工具。 - + """为函数调用(function-calling / tools-use)添加工具。 + 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) - + ``` @llm_tool(name="get_weather") # 如果 name 不填,将使用函数名 async def get_weather(event: AstrMessageEvent, location: str): \'\'\'获取天气信息。 - + Args: location(string): 地点 \'\'\' # 处理逻辑 ``` - + 可接受的参数类型有:string, number, object, array, boolean。 - + 返回值: - 返回 str:结果会被加入下一次 LLM 请求的 prompt 中,用于让 LLM 总结工具返回的结果 - 返回 None:结果不会被加入下一次 LLM 请求的 prompt 中。 - + 可以使用 yield 发送消息、终止事件。 - + 发送消息:请参考文档。 - + 终止事件: ``` event.stop_event() yield ``` - ''' - + """ + name_ = name - + def decorator(awaitable: Awaitable): llm_tool_name = name_ if name_ else awaitable.__name__ docstring = docstring_parser.parse(awaitable.__doc__) args = [] for arg in docstring.params: if arg.type_name not in SUPPORTED_TYPES: - raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}") - args.append({ - "type": arg.type_name, - "name": arg.arg_name, - "description": arg.description - }) + raise ValueError( + f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}" + ) + args.append( + { + "type": arg.type_name, + "name": arg.arg_name, + "description": arg.description, + } + ) md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler) - + logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") return awaitable - + return decorator + def register_on_decorating_result(**kwargs): - '''在发送消息前的事件''' + """在发送消息前的事件""" + def decorator(awaitable): - _ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs) + _ = get_handler_or_create( + awaitable, EventType.OnDecoratingResultEvent, **kwargs + ) return awaitable - + return decorator + def register_after_message_sent(**kwargs): - '''在消息发送后的事件''' + """在消息发送后的事件""" + def decorator(awaitable): - _ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs) + _ = get_handler_or_create( + awaitable, EventType.OnAfterMessageSentEvent, **kwargs + ) return awaitable - - return decorator \ No newline at end of file + + return decorator diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index eb940b405..e2bec4594 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -7,13 +7,15 @@ from astrbot.core.config import AstrBotConfig star_registry: List[StarMetadata] = [] star_map: Dict[str, StarMetadata] = {} -'''key 是模块路径,__module__''' +"""key 是模块路径,__module__""" + @dataclass class StarMetadata: - ''' + """ 插件的元数据。 - ''' + """ + name: str author: str # 插件作者 desc: str # 插件简介 @@ -21,27 +23,27 @@ class StarMetadata: repo: str = None # 插件仓库地址 star_cls_type: type = None - '''插件的类对象的类型''' + """插件的类对象的类型""" module_path: str = None - '''插件的模块路径''' - + """插件的模块路径""" + star_cls: object = None - '''插件的类对象''' + """插件的类对象""" module: ModuleType = None - '''插件的模块对象''' + """插件的模块对象""" root_dir_name: str = None - '''插件的目录名称''' + """插件的目录名称""" reserved: bool = False - '''是否是 AstrBot 的保留插件''' - + """是否是 AstrBot 的保留插件""" + activated: bool = True - '''是否被激活''' - + """是否被激活""" + config: AstrBotConfig = None - '''插件配置''' - + """插件配置""" + star_handler_full_names: List[str] = field(default_factory=list) - '''注册的 Handler 的全名列表''' + """注册的 Handler 的全名列表""" def __str__(self) -> str: - return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})" \ No newline at end of file + return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})" diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index ee927eac9..7be0e053c 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -6,52 +6,68 @@ from typing import Awaitable, List, Dict, TypeVar, Generic from .filter import HandlerFilter from .star import star_map -T = TypeVar('T', bound='StarHandlerMetadata') +T = TypeVar("T", bound="StarHandlerMetadata") + + class StarHandlerRegistry(Generic[T]): - '''用于存储所有的 Star Handler''' - + """用于存储所有的 Star Handler""" + star_handlers_map: Dict[str, StarHandlerMetadata] = {} - '''用于快速查找。key 是 handler_full_name''' + """用于快速查找。key 是 handler_full_name""" _handlers = [] - + def append(self, handler: StarHandlerMetadata): - '''添加一个 Handler''' - if 'priority' not in handler.extras_configs: - handler.extras_configs['priority'] = 0 - - heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler)) + """添加一个 Handler""" + if "priority" not in handler.extras_configs: + handler.extras_configs["priority"] = 0 + + heapq.heappush(self._handlers, (-handler.extras_configs["priority"], handler)) self.star_handlers_map[handler.handler_full_name] = handler - + def _print_handlers(self): - '''打印所有的 Handler''' + """打印所有的 Handler""" for _, handler in self._handlers: print(handler.handler_full_name) - - def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]: - '''通过事件类型获取 Handler''' + + def get_handlers_by_event_type( + self, event_type: EventType, only_activated=True + ) -> List[StarHandlerMetadata]: + """通过事件类型获取 Handler""" handlers = [ - handler - for _, handler in self._handlers - if handler.event_type == event_type and - (not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated)) + handler + for _, handler in self._handlers + if handler.event_type == event_type + and ( + not only_activated + or ( + star_map[handler.handler_module_path] + and star_map[handler.handler_module_path].activated + ) + ) ] return handlers - + def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata: - '''通过 Handler 的全名获取 Handler''' + """通过 Handler 的全名获取 Handler""" return self.star_handlers_map.get(full_name, None) - - def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]: - '''通过模块名获取 Handler''' - return [handler for _, handler in self._handlers if handler.handler_module_path == module_name] - + + def get_handlers_by_module_name( + self, module_name: str + ) -> List[StarHandlerMetadata]: + """通过模块名获取 Handler""" + return [ + handler + for _, handler in self._handlers + if handler.handler_module_path == module_name + ] + def clear(self): - '''清空所有的 Handler''' + """清空所有的 Handler""" self.star_handlers_map.clear() self._handlers.clear() - + def remove(self, handler: StarHandlerMetadata): - '''删除一个 Handler''' + """删除一个 Handler""" # self._handlers.remove(handler) for i, h in enumerate(self._handlers): if h[1] == handler: @@ -61,59 +77,65 @@ class StarHandlerRegistry(Generic[T]): del self.star_handlers_map[handler.handler_full_name] except KeyError: pass - + def __iter__(self): - '''使 StarHandlerRegistry 支持迭代''' + """使 StarHandlerRegistry 支持迭代""" return (handler for _, handler in self._handlers) - + def __len__(self): - '''返回 Handler 的数量''' + """返回 Handler 的数量""" return len(self._handlers) - + + star_handlers_registry = StarHandlerRegistry() + class EventType(enum.Enum): - '''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等 - + """表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等 + 用于对 Handler 的职能分组。 - ''' - OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成 - - AdapterMessageEvent = enum.auto() # 收到适配器发来的消息 - OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件) - OnLLMResponseEvent = enum.auto() # LLM 响应后 - OnDecoratingResultEvent = enum.auto() # 发送消息前 - OnCallingFuncToolEvent = enum.auto() # 调用函数工具 - OnAfterMessageSentEvent = enum.auto() # 发送消息后 + """ + + OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成 + + AdapterMessageEvent = enum.auto() # 收到适配器发来的消息 + OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件) + OnLLMResponseEvent = enum.auto() # LLM 响应后 + OnDecoratingResultEvent = enum.auto() # 发送消息前 + OnCallingFuncToolEvent = enum.auto() # 调用函数工具 + OnAfterMessageSentEvent = enum.auto() # 发送消息后 + @dataclass -class StarHandlerMetadata(): - '''描述一个 Star 所注册的某一个 Handler。''' - +class StarHandlerMetadata: + """描述一个 Star 所注册的某一个 Handler。""" + event_type: EventType - '''Handler 的事件类型''' - + """Handler 的事件类型""" + handler_full_name: str '''格式为 f"{handler.__module__}_{handler.__name__}"''' - + handler_name: str - '''Handler 的名字,也就是方法名''' - + """Handler 的名字,也就是方法名""" + handler_module_path: str - '''Handler 所在的模块路径。''' - + """Handler 所在的模块路径。""" + handler: Awaitable - '''Handler 的函数对象,应当是一个异步函数''' - + """Handler 的函数对象,应当是一个异步函数""" + event_filters: List[HandlerFilter] - '''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件''' - + """一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件""" + desc: str = "" - '''Handler 的描述信息''' + """Handler 的描述信息""" extras_configs: dict = field(default_factory=dict) - '''插件注册的一些其他的信息, 如 priority 等''' + """插件注册的一些其他的信息, 如 priority 等""" def __lt__(self, other: StarHandlerMetadata): - '''定义小于运算符以支持优先队列''' - return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0) \ No newline at end of file + """定义小于运算符以支持优先队列""" + return self.extras_configs.get("priority", 0) < other.extras_configs.get( + "priority", 0 + ) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 6802f8619..f7c8593a2 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -21,34 +21,43 @@ from astrbot.core.provider.register import llm_tools from .filter.permission import PermissionTypeFilter, PermissionType + class PluginManager: - def __init__( - self, - context: Context, - config: AstrBotConfig - ): - self.updator = PluginUpdator(config['plugin_repo_mirror']) - + def __init__(self, context: Context, config: AstrBotConfig): + self.updator = PluginUpdator(config["plugin_repo_mirror"]) + self.context = context self.context._star_manager = self - + self.config = config - self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins")) - '''存储插件的路径。即 data/plugins''' - self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config")) - '''存储插件配置的路径。data/config''' - self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages")) - '''保留插件的路径。在 packages 目录下''' + self.plugin_store_path = os.path.abspath( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins" + ) + ) + """存储插件的路径。即 data/plugins""" + self.plugin_config_path = os.path.abspath( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../../data/config" + ) + ) + """存储插件配置的路径。data/config""" + self.reserved_plugin_path = os.path.abspath( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../../packages" + ) + ) + """保留插件的路径。在 packages 目录下""" self.conf_schema_fname = "_conf_schema.json" - '''插件配置 Schema 文件名''' - + """插件配置 Schema 文件名""" + self.failed_plugin_info = "" def _get_classes(self, arg: ModuleType): - '''获取指定模块(可以理解为一个 python 文件)下所有的类''' + """获取指定模块(可以理解为一个 python 文件)下所有的类""" classes = [] clsmembers = inspect.getmembers(arg, inspect.isclass) - for (name, _) in clsmembers: + for name, _ in clsmembers: if name.lower().endswith("plugin") or name.lower() == "main": classes.append(name) break @@ -62,20 +71,24 @@ class PluginManager: for d in dirs: if os.path.isdir(os.path.join(path, d)): if os.path.exists(os.path.join(path, d, "main.py")): - module_str = 'main' + module_str = "main" elif os.path.exists(os.path.join(path, d, d + ".py")): module_str = d else: print(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。") continue - if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(os.path.join(path, d, d + ".py")): - modules.append({ - "pname": d, - "module": module_str, - "module_path": os.path.join(path, d, module_str) - }) + if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists( + os.path.join(path, d, d + ".py") + ): + modules.append( + { + "pname": d, + "module": module_str, + "module_path": os.path.join(path, d, module_str), + } + ) return modules - + def _get_plugin_modules(self) -> List[dict]: plugins = [] if os.path.exists(self.plugin_store_path): @@ -83,14 +96,14 @@ class PluginManager: if os.path.exists(self.reserved_plugin_path): _p = self._get_modules(self.reserved_plugin_path) for p in _p: - p['reserved'] = True + p["reserved"] = True plugins.extend(_p) return plugins - + def _check_plugin_dept_update(self, target_plugin: str = None): - '''检查插件的依赖 + """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 - ''' + """ plugin_dir = self.plugin_store_path if not os.path.exists(plugin_dir): return False @@ -110,46 +123,55 @@ class PluginManager: except Exception as e: logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") - def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata: - '''v3.4.0 以前的方式载入插件元数据 - + def _load_plugin_metadata(self, plugin_path: str, plugin_obj=None) -> StarMetadata: + """v3.4.0 以前的方式载入插件元数据 + 先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。 - ''' + """ metadata = None - + if not os.path.exists(plugin_path): raise Exception("插件不存在。") - + if os.path.exists(os.path.join(plugin_path, "metadata.yaml")): - with open(os.path.join(plugin_path, "metadata.yaml"), "r", encoding='utf-8') as f: + with open( + os.path.join(plugin_path, "metadata.yaml"), "r", encoding="utf-8" + ) as f: metadata = yaml.safe_load(f) elif plugin_obj: # 使用 info() 函数 metadata = plugin_obj.info() - + if isinstance(metadata, dict): - if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata: - raise Exception("插件元数据信息不完整。name, desc, version, author 是必须的字段。") + if ( + "name" not in metadata + or "desc" not in metadata + or "version" not in metadata + or "author" not in metadata + ): + raise Exception( + "插件元数据信息不完整。name, desc, version, author 是必须的字段。" + ) metadata = StarMetadata( - name=metadata['name'], - author=metadata['author'], - desc=metadata['desc'], - version=metadata['version'], - repo=metadata['repo'] if 'repo' in metadata else None + name=metadata["name"], + author=metadata["author"], + desc=metadata["desc"], + version=metadata["version"], + repo=metadata["repo"] if "repo" in metadata else None, ) - + return metadata - + async def reload(self, specified_plugin_name=None): - '''扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件''' - + """扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件""" + specified_module_path = None if specified_plugin_name: for smd in star_registry: if smd.name == specified_plugin_name: specified_module_path = smd.module_path break - + # 终止插件 if not specified_module_path: # 重载所有插件 @@ -158,7 +180,9 @@ class PluginManager: await self._terminate_plugin(smd) except Exception as e: logger.warning(traceback.format_exc()) - logger.warning(f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。") + logger.warning( + f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" + ) star_handlers_registry.clear() star_map.clear() @@ -174,42 +198,45 @@ class PluginManager: await self._terminate_plugin(smd) except Exception as e: logger.warning(traceback.format_exc()) - logger.warning(f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。") - + logger.warning( + f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" + ) + await self._unbind_plugin(smd.name, specified_module_path) try: del sys.modules[specified_module_path] except KeyError: logger.warning(f"模块 {specified_module_path} 未载入") - plugin_modules = self._get_plugin_modules() if plugin_modules is None: return False, "未找到任何插件模块" - + fail_rec = "" - + inactivated_plugins: list = sp.get("inactivated_plugins", []) inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - + alter_cmd = sp.get("alter_cmd", {}) - + # 导入插件模块,并尝试实例化插件类 for plugin_module in plugin_modules: try: - module_str = plugin_module['module'] + module_str = plugin_module["module"] # module_path = plugin_module['module_path'] - root_dir_name = plugin_module['pname'] # 插件的目录名 - reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。 - + root_dir_name = plugin_module["pname"] # 插件的目录名 + reserved = plugin_module.get( + "reserved", False + ) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。 + path = "data.plugins." if not reserved else "packages." path += root_dir_name + "." + module_str - + if specified_module_path and path != specified_module_path: continue - + logger.info(f"正在载入插件 {root_dir_name} ...") - + # 尝试导入模块 try: module = __import__(path, fromlist=[module_str]) @@ -221,27 +248,36 @@ class PluginManager: logger.error(traceback.format_exc()) logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}") continue - + # 检查 _conf_schema.json plugin_config = None - plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \ - if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name) - plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname) + plugin_dir_path = ( + os.path.join(self.plugin_store_path, root_dir_name) + if not reserved + else os.path.join(self.reserved_plugin_path, root_dir_name) + ) + plugin_schema_path = os.path.join( + plugin_dir_path, self.conf_schema_fname + ) if os.path.exists(plugin_schema_path): # 加载插件配置 - with open(plugin_schema_path, 'r', encoding='utf-8') as f: + with open(plugin_schema_path, "r", encoding="utf-8") as f: plugin_config = AstrBotConfig( - config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"), - schema=json.loads(f.read()) + config_path=os.path.join( + self.plugin_config_path, f"{root_dir_name}_config.json" + ), + schema=json.loads(f.read()), ) if path in star_map: # 通过装饰器的方式注册插件 metadata = star_map[path] - + try: # yaml 文件的元数据优先 - metadata_yaml = self._load_plugin_metadata(plugin_path=plugin_dir_path) + metadata_yaml = self._load_plugin_metadata( + plugin_path=plugin_dir_path + ) if metadata_yaml: metadata.name = metadata_yaml.name metadata.author = metadata_yaml.author @@ -250,47 +286,69 @@ class PluginManager: metadata.repo = metadata_yaml.repo except Exception: pass - + if plugin_config: metadata.config = plugin_config try: - metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config) + metadata.star_cls = metadata.star_cls_type( + context=self.context, config=plugin_config + ) except TypeError as _: - metadata.star_cls = metadata.star_cls_type(context=self.context) + metadata.star_cls = metadata.star_cls_type( + context=self.context + ) else: metadata.star_cls = metadata.star_cls_type(context=self.context) - + metadata.module = module metadata.root_dir_name = root_dir_name metadata.reserved = reserved - + # 绑定 handler - related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path) + related_handlers = ( + star_handlers_registry.get_handlers_by_module_name( + metadata.module_path + ) + ) for handler in related_handlers: - handler.handler = functools.partial(handler.handler, metadata.star_cls) + handler.handler = functools.partial( + handler.handler, metadata.star_cls + ) # 绑定 llm_tool handler for func_tool in llm_tools.func_list: if func_tool.handler.__module__ == metadata.module_path: func_tool.handler_module_path = metadata.module_path - func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls) + func_tool.handler = functools.partial( + func_tool.handler, metadata.star_cls + ) if func_tool.name in inactivated_llm_tools: func_tool.active = False - + else: # v3.4.0 以前的方式注册插件 - logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。") + logger.debug( + f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。" + ) classes = self._get_classes(module) - + if plugin_config: try: - obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类 + obj = getattr(module, classes[0])( + context=self.context, config=plugin_config + ) # 实例化插件类 except TypeError as _: - obj = getattr(module, classes[0])(context=self.context) # 实例化插件类 + obj = getattr(module, classes[0])( + context=self.context + ) # 实例化插件类 else: - obj = getattr(module, classes[0])(context=self.context) # 实例化插件类 + obj = getattr(module, classes[0])( + context=self.context + ) # 实例化插件类 metadata = None - metadata = self._load_plugin_metadata(plugin_path=plugin_dir_path, plugin_obj=obj) + metadata = self._load_plugin_metadata( + plugin_path=plugin_dir_path, plugin_obj=obj + ) metadata.star_cls = obj metadata.config = plugin_config metadata.module = module @@ -300,18 +358,25 @@ class PluginManager: metadata.module_path = path star_map[path] = metadata star_registry.append(metadata) - + # 禁用/启用插件 if metadata.module_path in inactivated_plugins: metadata.activated = False - + full_names = [] - for handler in star_handlers_registry.get_handlers_by_module_name(metadata.module_path): + for handler in star_handlers_registry.get_handlers_by_module_name( + metadata.module_path + ): full_names.append(handler.handler_full_name) - + # 检查并且植入自定义的权限过滤器(alter_cmd) - if metadata.name in alter_cmd and handler.handler_name in alter_cmd[metadata.name]: - cmd_type = alter_cmd[metadata.name][handler.handler_name].get("permission", "member") + if ( + metadata.name in alter_cmd + and handler.handler_name in alter_cmd[metadata.name] + ): + cmd_type = alter_cmd[metadata.name][handler.handler_name].get( + "permission", "member" + ) found_permission_filter = False for filter_ in handler.event_filters: if isinstance(filter_, PermissionTypeFilter): @@ -322,20 +387,28 @@ class PluginManager: found_permission_filter = True break if not found_permission_filter: - handler.event_filters.append(PermissionTypeFilter(PermissionType.ADMIN if cmd_type == "admin" else PermissionType.MEMBER)) + handler.event_filters.append( + PermissionTypeFilter( + PermissionType.ADMIN + if cmd_type == "admin" + else PermissionType.MEMBER + ) + ) + + logger.debug( + f"插入权限过滤器 {cmd_type} 到 {metadata.name} 的 {handler.handler_name} 方法。" + ) - logger.debug(f"插入权限过滤器 {cmd_type} 到 {metadata.name} 的 {handler.handler_name} 方法。") - metadata.star_handler_full_names = full_names # 执行 initialize() 方法 if hasattr(metadata.star_cls, "initialize"): await metadata.star_cls.initialize() - + except BaseException as e: logger.error(f"----- 插件 {root_dir_name} 载入失败 -----") errors = traceback.format_exc() - for line in errors.split('\n'): + for line in errors.split("\n"): logger.error(f"| {line}") logger.error("----------------------------------") fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}。\n" @@ -349,13 +422,13 @@ class PluginManager: else: self.failed_plugin_info = fail_rec return False, fail_rec - + async def install_plugin(self, repo_url: str, proxy=""): plugin_path = await self.updator.install(repo_url, proxy) # reload the plugin await self.reload() return plugin_path - + async def uninstall_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) if not plugin: @@ -364,88 +437,100 @@ class PluginManager: raise Exception("该插件是 AstrBot 保留插件,无法卸载。") root_dir_name = plugin.root_dir_name ppath = self.plugin_store_path - + # 终止插件 try: await self._terminate_plugin(plugin) except Exception as e: logger.warning(traceback.format_exc()) - logger.warning(f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。") - + logger.warning( + f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。" + ) + # 从 star_registry 和 star_map 中删除 await self._unbind_plugin(plugin_name, plugin.module_path) - + if not remove_dir(os.path.join(ppath, root_dir_name)): - raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") - + raise Exception( + "移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。" + ) + async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): del star_map[plugin_module_path] for i, p in enumerate(star_registry): if p.name == plugin_name: del star_registry[i] break - for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path): + for handler in star_handlers_registry.get_handlers_by_module_name( + plugin_module_path + ): logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}") star_handlers_registry.remove(handler) - keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)] + keys_to_delete = [ + k + for k, v in star_handlers_registry.star_handlers_map.items() + if k.startswith(plugin_module_path) + ] for k in keys_to_delete: v = star_handlers_registry.star_handlers_map[k] logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)") del star_handlers_registry.star_handlers_map[k] - async def update_plugin(self, plugin_name: str, proxy = ""): - '''升级一个插件''' + async def update_plugin(self, plugin_name: str, proxy=""): + """升级一个插件""" plugin = self.context.get_registered_star(plugin_name) if not plugin: raise Exception("插件不存在。") if plugin.reserved: raise Exception("该插件是 AstrBot 保留插件,无法更新。") - + await self.updator.update(plugin, proxy=proxy) await self.reload() - + async def turn_off_plugin(self, plugin_name: str): - ''' + """ 禁用一个插件。 调用插件的 terminate() 方法, 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 并且同时将插件启用的 llm_tool 禁用。 - ''' + """ plugin = self.context.get_registered_star(plugin_name) if not plugin: raise Exception("插件不存在。") - + # 调用插件的终止方法 await self._terminate_plugin(plugin) - + # 加入到 shared_preferences 中 inactivated_plugins: list = sp.get("inactivated_plugins", []) if plugin.module_path not in inactivated_plugins: inactivated_plugins.append(plugin.module_path) - - inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容 - + + inactivated_llm_tools: list = list( + set(sp.get("inactivated_llm_tools", [])) + ) # 后向兼容 + # 禁用插件启用的 llm_tool for func_tool in llm_tools.func_list: if func_tool.handler_module_path == plugin.module_path: func_tool.active = False if func_tool.name not in inactivated_llm_tools: inactivated_llm_tools.append(func_tool.name) - + sp.put("inactivated_plugins", inactivated_plugins) sp.put("inactivated_llm_tools", inactivated_llm_tools) - + plugin.activated = False - + async def _terminate_plugin(self, star_metadata: StarMetadata): - '''终止插件,调用插件的 terminate() 和 __del__() 方法''' + """终止插件,调用插件的 terminate() 和 __del__() 方法""" logging.info(f"正在终止插件 {star_metadata.name} ...") - + if hasattr(star_metadata.star_cls, "__del__"): asyncio.get_event_loop().run_in_executor(star_metadata.star_cls.__del__) else: await star_metadata.star_cls.terminate() - + async def turn_on_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) inactivated_plugins: list = sp.get("inactivated_plugins", []) @@ -453,16 +538,16 @@ class PluginManager: if plugin.module_path in inactivated_plugins: inactivated_plugins.remove(plugin.module_path) sp.put("inactivated_plugins", inactivated_plugins) - + # 启用插件启用的 llm_tool for func_tool in llm_tools.func_list: if func_tool.handler_module_path == plugin.module_path: inactivated_llm_tools.remove(func_tool.name) func_tool.active = True sp.put("inactivated_llm_tools", inactivated_llm_tools) - + plugin.activated = True - + async def install_plugin_from_file(self, zip_file_path: str): dir_name = os.path.basename(zip_file_path).replace(".zip", "") dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower() diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 640fb4abb..5f98af9c1 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -7,51 +7,58 @@ from astrbot.core.utils.io import remove_dir, on_error from ..star.star import StarMetadata from astrbot.core import logger + class PluginUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) - self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins")) - + self.plugin_store_path = os.path.abspath( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins" + ) + ) + def get_plugin_store_path(self) -> str: return self.plugin_store_path - + async def install(self, repo_url: str, proxy="") -> str: repo_name = self.format_repo_name(repo_url) plugin_path = os.path.join(self.plugin_store_path, repo_name) await self.download_from_repo_url(plugin_path, repo_url, proxy) self.unzip_file(plugin_path + ".zip", plugin_path) - + return plugin_path async def update(self, plugin: StarMetadata, proxy="") -> str: repo_url = plugin.repo - + if not repo_url: raise Exception(f"插件 {plugin.name} 没有指定仓库地址。") - + if proxy: proxy = proxy.removesuffix("/") repo_url = f"{proxy}/{repo_url}" - + plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) - + logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}") await self.download_from_repo_url(plugin_path, repo_url) - + try: remove_dir(plugin_path) except BaseException as e: - logger.error(f"删除旧版本插件 {plugin_path} 文件夹失败: {str(e)},使用覆盖安装。") - + logger.error( + f"删除旧版本插件 {plugin_path} 文件夹失败: {str(e)},使用覆盖安装。" + ) + self.unzip_file(plugin_path + ".zip", plugin_path) - + return plugin_path def unzip_file(self, zip_path: str, target_dir: str): os.makedirs(target_dir, exist_ok=True) update_dir = "" logger.info(f"解压文件: {zip_path}") - with zipfile.ZipFile(zip_path, 'r') as z: + with zipfile.ZipFile(zip_path, "r") as z: update_dir = z.namelist()[0] z.extractall(target_dir) @@ -64,11 +71,14 @@ class PluginUpdator(RepoZipUpdator): if os.path.exists(os.path.join(target_dir, f)): os.remove(os.path.join(target_dir, f)) shutil.move(os.path.join(target_dir, update_dir, f), target_dir) - + try: - logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") + logger.info( + f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}" + ) shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) os.remove(zip_path) except BaseException: - logger.warning(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") - + logger.warning( + f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}" + ) diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 62d30cc79..0d9860a60 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -7,10 +7,13 @@ from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.utils.io import download_file + class AstrBotUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) - self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) + self.MAIN_PATH = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") + ) self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" def terminate_child_processes(self): @@ -30,7 +33,7 @@ class AstrBotUpdator(RepoZipUpdator): child.kill() except psutil.NoSuchProcess: pass - + def _reboot(self, delay: int = 3): py = sys.executable time.sleep(delay) @@ -41,28 +44,28 @@ class AstrBotUpdator(RepoZipUpdator): except Exception as e: logger.error(f"重启失败({py}, {e}),请尝试手动重启。") raise e - + async def check_update(self, url: str, current_version: str) -> ReleaseInfo: return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION) - + async def get_releases(self) -> list: return await self.fetch_release_info(self.ASTRBOT_RELEASE_API) - - async def update(self, reboot = False, latest = True, version = None, proxy = ""): + + async def update(self, reboot=False, latest=True, version=None, proxy=""): update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None - + if latest: - latest_version = update_data[0]['tag_name'] + latest_version = update_data[0]["tag_name"] if self.compare_version(VERSION, latest_version) >= 0: raise Exception("当前已经是最新版本。") - file_url = update_data[0]['zipball_url'] + file_url = update_data[0]["zipball_url"] elif str(version).startswith("v"): # 更新到指定版本 logger.info(f"正在更新到指定版本: {version}") for data in update_data: - if data['tag_name'] == version: - file_url = data['zipball_url'] + if data["tag_name"] == version: + file_url = data["zipball_url"] if not file_url: raise Exception(f"未找到版本号为 {version} 的更新文件。") else: @@ -74,12 +77,12 @@ class AstrBotUpdator(RepoZipUpdator): if proxy: proxy = proxy.removesuffix("/") file_url = f"{proxy}/{file_url}" - + try: await download_file(file_url, "temp.zip") self.unzip_file("temp.zip", self.MAIN_PATH) except BaseException as e: raise e - + if reboot: self._reboot() diff --git a/astrbot/core/utils/command_parser.py b/astrbot/core/utils/command_parser.py index f454a00f9..7829140f5 100644 --- a/astrbot/core/utils/command_parser.py +++ b/astrbot/core/utils/command_parser.py @@ -1,22 +1,23 @@ import re -class CommandTokens(): + +class CommandTokens: def __init__(self) -> None: self.tokens = [] self.len = 0 - + def get(self, idx: int): if idx >= self.len: return None return self.tokens[idx].strip() -class CommandParserMixin(): + +class CommandParserMixin: def parse_commands(self, message: str): cmd_tokens = CommandTokens() cmd_tokens.tokens = re.split(r"\s+", message) cmd_tokens.len = len(cmd_tokens.tokens) return cmd_tokens - + def regex_match(self, message: str, command: str) -> bool: return re.search(command, message, re.MULTILINE) is not None - \ No newline at end of file diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index b5326e4b6..80be3fff7 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -34,20 +34,20 @@ class DifyAPIClient: if resp.status != 200: text = await resp.text() raise Exception(f"chat_messages 请求失败:{resp.status}. {text}") - + buffer = "" while True: # 保持原有的8192字节限制,防止数据过大导致高水位报错 chunk = await resp.content.read(8192) if not chunk: break - - buffer += chunk.decode('utf-8') - blocks = buffer.split('\n\n') - + + buffer += chunk.decode("utf-8") + blocks = buffer.split("\n\n") + # 处理完整的数据块 for block in blocks[:-1]: - if block.strip() and block.startswith('data:'): + if block.strip() and block.startswith("data:"): try: json_str = block[5:] # 移除 "data:" 前缀 json_obj = json.loads(json_str) @@ -55,7 +55,7 @@ class DifyAPIClient: except json.JSONDecodeError as e: logger.error(f"JSON解析错误: {str(e)}") logger.error(f"原始数据块: {json_str}") - + # 保留最后一个可能不完整的块 buffer = blocks[-1] if blocks else "" @@ -78,20 +78,20 @@ class DifyAPIClient: if resp.status != 200: text = await resp.text() raise Exception(f"workflow_run 请求失败:{resp.status}. {text}") - + buffer = "" while True: # 保持原有的8192字节限制,防止数据过大导致高水位报错 chunk = await resp.content.read(8192) if not chunk: break - - buffer += chunk.decode('utf-8') - blocks = buffer.split('\n\n') - + + buffer += chunk.decode("utf-8") + blocks = buffer.split("\n\n") + # 处理完整的数据块 for block in blocks[:-1]: - if block.strip() and block.startswith('data:'): + if block.strip() and block.startswith("data:"): try: json_str = block[5:] # 移除 "data:" 前缀 json_obj = json.loads(json_str) @@ -99,7 +99,7 @@ class DifyAPIClient: except json.JSONDecodeError as e: logger.error(f"JSON解析错误: {str(e)}") logger.error(f"原始数据块: {json_str}") - + # 保留最后一个可能不完整的块 buffer = blocks[-1] if blocks else "" @@ -113,51 +113,33 @@ class DifyAPIClient: "user": user, "file": open(file_path, "rb"), } - async with self.session.post( - url, data=payload, headers=self.headers - ) as resp: - return await resp.json() # {"id": "xxx", ...} - + async with self.session.post(url, data=payload, headers=self.headers) as resp: + return await resp.json() # {"id": "xxx", ...} + async def close(self): await self.session.close() - - async def get_chat_convs( - self, - user: str, - limit: int = 20 - ): + + async def get_chat_convs(self, user: str, limit: int = 20): # conversations. GET url = f"{self.api_base}/conversations" payload = { "user": user, "limit": limit, } - async with self.session.get( - url, params=payload, headers=self.headers - ) as resp: + async with self.session.get(url, params=payload, headers=self.headers) as resp: return await resp.json() - - async def delete_chat_conv( - self, - user: str, - conversation_id: str - ): + + async def delete_chat_conv(self, user: str, conversation_id: str): # conversation. DELETE url = f"{self.api_base}/conversations/{conversation_id}" payload = { "user": user, } - async with self.session.delete( - url, json=payload, headers=self.headers - ) as resp: + async with self.session.delete(url, json=payload, headers=self.headers) as resp: return await resp.json() - + async def rename( - self, - conversation_id: str, - name: str, - user: str, - auto_generate: bool = False + self, conversation_id: str, name: str, user: str, auto_generate: bool = False ): # /conversations/:conversation_id/name url = f"{self.api_base}/conversations/{conversation_id}/name" @@ -166,8 +148,5 @@ class DifyAPIClient: "name": name, "auto_generate": auto_generate, } - async with self.session.post( - url, json=payload, headers=self.headers - ) as resp: + async with self.session.post(url, json=payload, headers=self.headers) as resp: return await resp.json() - \ No newline at end of file diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 99b7176dd..318a61835 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -12,26 +12,31 @@ from typing import Union from PIL import Image + def on_error(func, path, exc_info): - ''' + """ a callback of the rmtree function. - ''' + """ print(f"remove {path} failed.") import stat + if not os.access(path, os.W_OK): os.chmod(path, stat.S_IWUSR) func(path) else: raise - + + def remove_dir(file_path) -> bool: - if not os.path.exists(file_path): return True + if not os.path.exists(file_path): + return True try: shutil.rmtree(file_path, onerror=on_error) return True except BaseException: return False - + + def port_checker(port: int, host: str = "localhost"): sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) @@ -42,7 +47,7 @@ def port_checker(port: int, host: str = "localhost"): except Exception: sk.close() return False - + def save_temp_img(img: Union[Image.Image, str]) -> str: os.makedirs("data/temp", exist_ok=True) @@ -52,7 +57,7 @@ def save_temp_img(img: Union[Image.Image, str]) -> str: path = os.path.join("data/temp", f) if os.path.isfile(path): ctime = os.path.getctime(path) - if time.time() - ctime > 3600*12: + if time.time() - ctime > 3600 * 12: os.remove(path) except Exception as e: print(f"清除临时文件失败: {e}") @@ -68,10 +73,13 @@ def save_temp_img(img: Union[Image.Image, str]) -> str: f.write(img) return p -async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str: - ''' + +async def download_image_by_url( + url: str, post: bool = False, post_data: dict = None, path=None +) -> str: + """ 下载图片, 返回 path - ''' + """ try: async with aiohttp.ClientSession(trust_env=True) as session: if post: @@ -93,7 +101,7 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict = except aiohttp.client.ClientConnectorSSLError: # 关闭SSL验证 ssl_context = ssl.create_default_context() - ssl_context.set_ciphers('DEFAULT') + ssl_context.set_ciphers("DEFAULT") async with aiohttp.ClientSession() as session: if post: async with session.get(url, ssl=ssl_context) as resp: @@ -103,22 +111,23 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict = return save_temp_img(await resp.read()) except Exception as e: raise e - + + async def download_file(url: str, path: str, show_progress: bool = False): - ''' + """ 从指定 url 下载文件到指定路径 path - ''' + """ try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get(url, timeout=1800) as resp: if resp.status != 200: raise Exception(f"下载文件失败: {resp.status}") - total_size = int(resp.headers.get('content-length', 0)) + total_size = int(resp.headers.get("content-length", 0)) downloaded_size = 0 start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, 'wb') as f: + with open(path, "wb") as f: while True: chunk = await resp.content.read(8192) if not chunk: @@ -128,19 +137,22 @@ async def download_file(url: str, path: str, show_progress: bool = False): if show_progress: elapsed_time = time.time() - start_time speed = downloaded_size / 1024 / elapsed_time # KB/s - print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='') + print( + f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", + end="", + ) except aiohttp.client.ClientConnectorSSLError: # 关闭SSL验证 ssl_context = ssl.create_default_context() - ssl_context.set_ciphers('DEFAULT') + ssl_context.set_ciphers("DEFAULT") async with aiohttp.ClientSession() as session: async with session.get(url, ssl=ssl_context, timeout=120) as resp: - total_size = int(resp.headers.get('content-length', 0)) + total_size = int(resp.headers.get("content-length", 0)) downloaded_size = 0 start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, 'wb') as f: + with open(path, "wb") as f: while True: chunk = await resp.content.read(8192) if not chunk: @@ -150,11 +162,14 @@ async def download_file(url: str, path: str, show_progress: bool = False): if show_progress: elapsed_time = time.time() - start_time speed = downloaded_size / 1024 / elapsed_time # KB/s - print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='') + print( + f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", + end="", + ) if show_progress: print() - - + + def file_to_base64(file_path: str) -> str: with open(file_path, "rb") as f: data_bytes = f.read() @@ -165,14 +180,15 @@ def file_to_base64(file_path: str) -> str: def get_local_ip_addresses(): net_interfaces = psutil.net_if_addrs() network_ips = [] - + for interface, addrs in net_interfaces.items(): for addr in addrs: if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET network_ips.append(addr.address) - + return network_ips + async def get_dashboard_version(): if os.path.exists("data/dist"): if os.path.exists("data/dist/assets/version"): @@ -181,14 +197,21 @@ async def get_dashboard_version(): return v return None + async def download_dashboard(): - '''下载管理面板文件''' + """下载管理面板文件""" dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip" try: - await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True) + await download_file( + dashboard_release_url, "data/dashboard.zip", show_progress=True + ) except BaseException as _: - dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip" - await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True) + dashboard_release_url = ( + "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip" + ) + await download_file( + dashboard_release_url, "data/dashboard.zip", show_progress=True + ) print("解压管理面板文件中...") with zipfile.ZipFile("data/dashboard.zip", "r") as z: - z.extractall("data") \ No newline at end of file + z.extractall("data") diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 0064f62d3..9fc986cc0 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -1,38 +1,34 @@ import aiohttp import sys -import logging from astrbot.core.config import VERSION from astrbot.core import db_helper, logger -logger = logging.getLogger("astrbot") -class Metric(): +class Metric: @staticmethod async def upload(**kwargs): - ''' + """ 上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 - + Powered by TickStats. - ''' + """ base_url = "https://tickstats.soulter.top/api/metric/90a6c2a1" kwargs["v"] = VERSION kwargs["os"] = sys.platform - payload = { - "metrics_data": kwargs - } + payload = {"metrics_data": kwargs} try: - if 'adapter_name' in kwargs: - db_helper.insert_platform_metrics({kwargs['adapter_name']: 1}) - if 'llm_name' in kwargs: - db_helper.insert_llm_metrics({kwargs['llm_name']: 1}) + if "adapter_name" in kwargs: + db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1}) + if "llm_name" in kwargs: + db_helper.insert_llm_metrics({kwargs["llm_name"]: 1}) except Exception as e: logger.error(f"保存指标到数据库失败: {e}") pass - + try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.post(base_url, json=payload, timeout=3) as response: if response.status != 200: pass except Exception: - pass \ No newline at end of file + pass diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 672df6bc9..2cbdc9229 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -2,33 +2,39 @@ import logging from pip import main as pip_main logger = logging.getLogger("astrbot") -class PipInstaller(): + + +class PipInstaller: def __init__(self, pip_install_arg: str): self.pip_install_arg = pip_install_arg - - def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None): - args = ['install'] + + def install( + self, + package_name: str = None, + requirements_path: str = None, + mirror: str = None, + ): + args = ["install"] if package_name: args.append(package_name) elif requirements_path: - args.extend(['-r', requirements_path]) - + args.extend(["-r", requirements_path]) + if not mirror: - mirror = 'https://mirrors.aliyun.com/pypi/simple/' - - args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror]) - + mirror = "https://mirrors.aliyun.com/pypi/simple/" + + args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", mirror]) + if self.pip_install_arg: args.extend(self.pip_install_arg.split()) - + logger.info(f"Pip 包管理器: pip {' '.join(args)}") - + result_code = pip_main(args) - + # 清除 pip.main 导致的多余的 logging handlers for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) - + if result_code != 0: raise Exception(f"安装失败,错误码:{result_code}") - \ No newline at end of file diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index e9a9ed0d7..b469c9d71 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,6 +1,7 @@ import json import os + class SharedPreferences: def __init__(self, path="data/shared_preferences.json"): self.path = path @@ -30,4 +31,4 @@ class SharedPreferences: def clear(self): self._data.clear() - self._save_preferences() \ No newline at end of file + self._save_preferences() diff --git a/astrbot/core/utils/t2i/__init__.py b/astrbot/core/utils/t2i/__init__.py index 67e843b7a..8ce209ad3 100644 --- a/astrbot/core/utils/t2i/__init__.py +++ b/astrbot/core/utils/t2i/__init__.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod + class RenderStrategy(ABC): @abstractmethod def render(self, text: str, return_url: bool) -> str: pass - + @abstractmethod - def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool) -> str: - pass \ No newline at end of file + def render_custom_template( + self, tmpl_str: str, tmpl_data: dict, return_url: bool + ) -> str: + pass diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index cb0ad7423..23abba6e6 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -6,14 +6,22 @@ from . import RenderStrategy from PIL import ImageFont, Image, ImageDraw from astrbot.core.utils.io import save_temp_img -class LocalRenderStrategy(RenderStrategy): - async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str: +class LocalRenderStrategy(RenderStrategy): + async def render_custom_template( + self, tmpl_str: str, tmpl_data: dict, return_url: bool = True + ) -> str: raise NotImplementedError - + def get_font(self, size: int) -> ImageFont.FreeTypeFont: # common and default fonts on Windows, macOS and Linux - fonts = ["msyh.ttc", "NotoSansCJK-Regular.ttc", "msyhbd.ttc", "PingFang.ttc", "Heiti.ttc"] + fonts = [ + "msyh.ttc", + "NotoSansCJK-Regular.ttc", + "msyhbd.ttc", + "PingFang.ttc", + "Heiti.ttc", + ] for font in fonts: try: font = ImageFont.truetype(font, size) @@ -21,13 +29,13 @@ class LocalRenderStrategy(RenderStrategy): except Exception: pass - async def render(self, text: str, return_url: bool=False) -> str: + async def render(self, text: str, return_url: bool = False) -> str: font_size = 26 image_width = 800 image_height = 600 font_color = (0, 0, 0) bg_color = (255, 255, 255) - + HEADER_MARGIN = 20 HEADER_FONT_STANDARD_SIZE = 42 @@ -71,7 +79,7 @@ class LocalRenderStrategy(RenderStrategy): images: Image = {} # pre_process, get height of each line - pre_lines = text.split('\n') + pre_lines = text.split("\n") height = 0 pre_in_code = False i = -1 @@ -90,12 +98,21 @@ class LocalRenderStrategy(RenderStrategy): # 最大不得超过image_width的50% img_height = image_res.size[1] - if image_res.size[0] > image_width*0.5: + if image_res.size[0] > image_width * 0.5: image_res = image_res.resize( - (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) + ( + int(image_width * 0.5), + int( + image_res.size[1] + * image_width + * 0.5 + / image_res.size[0] + ), + ) + ) img_height = image_res.size[1] - height += img_height + IMAGE_MARGIN*2 + height += img_height + IMAGE_MARGIN * 2 line = re.sub(IMAGE_REGEX, "", line) except Exception as e: @@ -135,11 +152,13 @@ class LocalRenderStrategy(RenderStrategy): continue if line.startswith("#"): header_level = line.count("#") - height += HEADER_FONT_STANDARD_SIZE + HEADER_MARGIN*2 - header_level * 4 + height += ( + HEADER_FONT_STANDARD_SIZE + HEADER_MARGIN * 2 - header_level * 4 + ) elif line.startswith("-"): - height += font_size+LIST_MARGIN*2 + height += font_size + LIST_MARGIN * 2 elif line.startswith(">"): - height += font_size+QUOTE_LEFT_LINE_MARGIN*2 + height += font_size + QUOTE_LEFT_LINE_MARGIN * 2 elif line.startswith("```"): if pre_in_code: pre_in_code = False @@ -149,18 +168,20 @@ class LocalRenderStrategy(RenderStrategy): pre_in_code = True height += CODE_BLOCK_MARGIN elif re.search(r"`(.*?)`", line): - height += font_size+INLINE_CODE_FONT_MARGIN*2+INLINE_CODE_MARGIN*2 + height += ( + font_size + INLINE_CODE_FONT_MARGIN * 2 + INLINE_CODE_MARGIN * 2 + ) else: - height += font_size + TEXT_LINE_MARGIN*2 + height += font_size + TEXT_LINE_MARGIN * 2 - text = '\n'.join(pre_lines) + text = "\n".join(pre_lines) image_height = height if image_height < 100: image_height = 100 image_width += 20 # 创建空白图像 - image = Image.new('RGB', (image_width, image_height), bg_color) + image = Image.new("RGB", (image_width, image_height), bg_color) draw = ImageDraw.Draw(image) # 设置初始位置 @@ -192,19 +213,34 @@ class LocalRenderStrategy(RenderStrategy): y += HEADER_MARGIN # 上边距 # 字间距 draw.text((x, y), line, font=font, fill=font_color) - draw.line((x, y + font_size_header + 8, image_width - 10, - y + font_size_header + 8), fill=(230, 230, 230), width=3) + draw.line( + ( + x, + y + font_size_header + 8, + image_width - 10, + y + font_size_header + 8, + ), + fill=(230, 230, 230), + width=3, + ) y += font_size_header + HEADER_MARGIN elif line.startswith(">"): # 处理引用 quote_text = line.strip(">") y += QUOTE_LEFT_LINE_MARGIN - draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), - fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH) + draw.line( + (x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), + fill=QUOTE_LEFT_LINE_COLOR, + width=QUOTE_LEFT_LINE_WIDTH, + ) font = self.get_font(QUOTE_FONT_SIZE) - draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), - quote_text, font=font, fill=QUOTE_FONT_COLOR) + draw.text( + (x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), + quote_text, + font=font, + fill=QUOTE_FONT_COLOR, + ) y += font_size + QUOTE_LEFT_LINE_HEIGHT + QUOTE_LEFT_LINE_MARGIN elif line.startswith("-"): @@ -212,24 +248,41 @@ class LocalRenderStrategy(RenderStrategy): list_text = line.strip("-").strip() font = self.get_font(LIST_FONT_SIZE) y += LIST_MARGIN - draw.text((x, y), " · " + list_text, - font=font, fill=LIST_FONT_COLOR) + draw.text((x, y), " · " + list_text, font=font, fill=LIST_FONT_COLOR) y += font_size + LIST_MARGIN elif line.startswith("```"): if not in_code_block: - code_block_start_y = y+CODE_BLOCK_MARGIN + code_block_start_y = y + CODE_BLOCK_MARGIN in_code_block = True else: # print(code_block_codes) in_code_block = False codes = "\n".join(code_block_codes) code_block_codes = [] - draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL + - CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2) + draw.rounded_rectangle( + ( + x, + code_block_start_y, + image_width - 10, + y + + CODE_BLOCK_CODES_MARGIN_VERTICAL + + CODE_BLOCK_TEXT_MARGIN, + ), + radius=5, + fill=CODE_BLOCK_BG_COLOR, + width=2, + ) font = self.get_font(CODE_BLOCK_FONT_SIZE) - draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y + - CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color) + draw.text( + ( + x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, + code_block_start_y + CODE_BLOCK_CODES_MARGIN_VERTICAL, + ), + codes, + font=font, + fill=font_color, + ) y += CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_MARGIN # y += font_size+10 elif re.search(r"`(.*?)`", line): @@ -244,16 +297,21 @@ class LocalRenderStrategy(RenderStrategy): if part in parts_inline: font = self.get_font(INLINE_CODE_FONT_SIZE) code_text = part.strip("`") - code_width = font.getsize( - code_text)[0] + INLINE_CODE_FONT_MARGIN*2 + code_width = ( + font.getsize(code_text)[0] + INLINE_CODE_FONT_MARGIN * 2 + ) x += INLINE_CODE_MARGIN - code_box = (x, y, x + code_width, - y + INLINE_CODE_BG_HEIGHT) + code_box = (x, y, x + code_width, y + INLINE_CODE_BG_HEIGHT) draw.rounded_rectangle( - code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景 - draw.text((x+INLINE_CODE_FONT_MARGIN, y), - code_text, font=font, fill=font_color) - x += code_width+INLINE_CODE_MARGIN-INLINE_CODE_FONT_MARGIN + code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2 + ) # 使用灰色填充矩形框作为引用背景 + draw.text( + (x + INLINE_CODE_FONT_MARGIN, y), + code_text, + font=font, + fill=font_color, + ) + x += code_width + INLINE_CODE_MARGIN - INLINE_CODE_FONT_MARGIN else: font = self.get_font(font_size) draw.text((x, y), part, font=font, fill=font_color) @@ -269,15 +327,24 @@ class LocalRenderStrategy(RenderStrategy): font = self.get_font(font_size) draw.text((x, y), line, font=font, fill=font_color) - y += font_size + TEXT_LINE_MARGIN*2 + y += font_size + TEXT_LINE_MARGIN * 2 # 图片特殊处理 if index in images: image_res = images[index] # 最大不得超过image_width的50% - if image_res.size[0] > image_width*0.5: + if image_res.size[0] > image_width * 0.5: image_res = image_res.resize( - (int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0]))) + ( + int(image_width * 0.5), + int( + image_res.size[1] + * image_width + * 0.5 + / image_res.size[0] + ), + ) + ) image.paste(image_res, (IMAGE_MARGIN, y)) - y += image_res.size[1] + IMAGE_MARGIN*2 + y += image_res.size[1] + IMAGE_MARGIN * 2 return save_temp_img(image) diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 63004bec7..b9b9abffe 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -7,6 +7,7 @@ from astrbot.core.utils.io import download_image_by_url ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" + class NetworkRenderStrategy(RenderStrategy): def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None: super().__init__() @@ -14,7 +15,7 @@ class NetworkRenderStrategy(RenderStrategy): base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT self.BASE_RENDER_URL = base_url self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template") - + if self.BASE_RENDER_URL.endswith("/"): self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1] if not self.BASE_RENDER_URL.endswith("text2img"): @@ -24,15 +25,16 @@ class NetworkRenderStrategy(RenderStrategy): if not base_url: base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT self.BASE_RENDER_URL = base_url - + if self.BASE_RENDER_URL.endswith("/"): self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1] if not self.BASE_RENDER_URL.endswith("text2img"): self.BASE_RENDER_URL += "/text2img" - - async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool=True) -> str: - '''使用自定义文转图模板''' + async def render_custom_template( + self, tmpl_str: str, tmpl_data: dict, return_url: bool = True + ) -> str: + """使用自定义文转图模板""" post_data = { "tmpl": tmpl_str, "json": return_url, @@ -41,22 +43,29 @@ class NetworkRenderStrategy(RenderStrategy): "full_page": True, "type": "jpeg", "quality": 40, - } + }, } if return_url: async with aiohttp.ClientSession(trust_env=True) as session: - async with session.post(f"{self.BASE_RENDER_URL}/generate", json=post_data) as resp: + async with session.post( + f"{self.BASE_RENDER_URL}/generate", json=post_data + ) as resp: ret = await resp.json() return f"{self.BASE_RENDER_URL}/{ret['data']['id']}" - return await download_image_by_url(f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data) + return await download_image_by_url( + f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data + ) - - async def render(self, text: str, return_url: bool=False) -> str: - ''' + async def render(self, text: str, return_url: bool = False) -> str: + """ 返回图像的文件路径 - ''' - with open(os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding='utf-8') as f: + """ + with open( + os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding="utf-8" + ) as f: tmpl_str = f.read() - assert(tmpl_str) + assert tmpl_str text = text.replace("`", "\\`") - return await self.render_custom_template(tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url) \ No newline at end of file + return await self.render_custom_template( + tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url + ) diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 5b3e27628..73e8eee66 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -2,40 +2,45 @@ from .network_strategy import NetworkRenderStrategy from .local_strategy import LocalRenderStrategy from astrbot.core.log import LogManager -logger = LogManager.GetLogger(log_name='astrbot') +logger = LogManager.GetLogger(log_name="astrbot") + class HtmlRenderer: def __init__(self, endpoint_url: str = None): self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - + def set_network_endpoint(self, endpoint_url: str): - '''设置 t2i 的网络端点。 - ''' + """设置 t2i 的网络端点。""" logger.info("文本转图像服务接口: " + endpoint_url) self.network_strategy.set_endpoint(endpoint_url) - async def render_custom_template(self, tmpl_str: str, tmpl_data: dict, return_url: bool = False): - '''使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。 + async def render_custom_template( + self, tmpl_str: str, tmpl_data: dict, return_url: bool = False + ): + """使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。 @param tmpl_str: HTML Jinja2 模板。 @param tmpl_data: jinja2 模板数据。 @return: 图片 URL 或者文件路径,取决于 return_url 参数。 @example: 参见 https://astrbot.app 插件开发部分。 - ''' + """ local = locals() - local.pop('self') + local.pop("self") return await self.network_strategy.render_custom_template(**local) - async def render_t2i(self, text: str, use_network: bool = True, return_url: bool = False): - '''使用默认文转图模板。 - ''' + async def render_t2i( + self, text: str, use_network: bool = True, return_url: bool = False + ): + """使用默认文转图模板。""" if use_network: try: return await self.network_strategy.render(text, return_url=return_url) except BaseException as e: - logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.") + logger.error( + f"Failed to render image via AstrBot API: {e}. Falling back to local rendering." + ) return await self.local_strategy.render(text) else: return await self.local_strategy.render(text) diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index 6ecd2698f..f7b2eb5a4 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -1,48 +1,52 @@ import wave from io import BytesIO + async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: import pysilk - + with open(silk_path, "rb") as f: input_data = f.read() - if input_data.startswith(b'\x02'): + if input_data.startswith(b"\x02"): input_data = input_data[1:] input_io = BytesIO(input_data) output_io = BytesIO() pysilk.decode(input_io, output_io, 24000) output_io.seek(0) - with wave.open(output_path, 'wb') as wav: + with wave.open(output_path, "wb") as wav: wav.setnchannels(1) wav.setsampwidth(2) wav.setframerate(24000) wav.writeframes(output_io.read()) - + return output_path + async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int: - '''返回 duration''' + """返回 duration""" try: import pilk except (ImportError, ModuleNotFoundError) as _: - raise Exception("pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库") + raise Exception( + "pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库" + ) # with wave.open(wav_path, 'rb') as wav: # wav_data = wav.readframes(wav.getnframes()) # wav_data = BytesIO(wav_data) # output_io = BytesIO() # pysilk.encode(wav_data, output_io, 24000, 24000) # output_io.seek(0) - + # # 在首字节添加 \x02,去除结尾的\xff\xff # silk_data = output_io.read() # silk_data_with_prefix = b'\x02' + silk_data[:-2] - + # # return BytesIO(silk_data_with_prefix) # with open(output_path, "wb") as f: # f.write(silk_data_with_prefix) - + # return 0 - with wave.open(wav_path, 'rb') as wav: + with wave.open(wav_path, "rb") as wav: rate = wav.getframerate() duration = pilk.encode(wav_path, output_path, pcm_rate=rate, tencent=True) - return duration \ No newline at end of file + return duration diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 19841cc8d..29533ea88 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -5,34 +5,38 @@ import shutil from astrbot.core.utils.io import on_error, download_file from astrbot.core import logger -class ReleaseInfo(): + +class ReleaseInfo: version: str published_at: str body: str - - def __init__(self, version: str = '', published_at: str = '', body: str = '') -> None: + + def __init__( + self, version: str = "", published_at: str = "", body: str = "" + ) -> None: self.version = version self.published_at = published_at self.body = body - + def __str__(self) -> str: return f"新版本: {self.version}, 发布于: {self.published_at}, 详细内容: {self.body}" -class RepoZipUpdator(): + +class RepoZipUpdator: def __init__(self, repo_mirror: str = "") -> None: self.repo_mirror = repo_mirror self.rm_on_error = on_error - + async def fetch_release_info(self, url: str, latest: bool = True) -> list: - ''' + """ 请求版本信息。 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 - ''' + """ try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get(url) as response: result = await response.json() - if not result: + if not result: return [] # if latest: # ret = self.github_api_release_parser([result[0]]) @@ -40,48 +44,52 @@ class RepoZipUpdator(): # ret = self.github_api_release_parser(result) ret = [] for release in result: - ret.append({ - "version": release['name'], - "published_at": release['published_at'], - "body": release['body'], - "tag_name": release['tag_name'], - "zipball_url": release['zipball_url'] - }) + ret.append( + { + "version": release["name"], + "published_at": release["published_at"], + "body": release["body"], + "tag_name": release["tag_name"], + "zipball_url": release["zipball_url"], + } + ) except BaseException: raise Exception("解析版本信息失败") return ret - + def github_api_release_parser(self, releases: list) -> list: - ''' + """ 解析 GitHub API 返回的 releases 信息。 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 - ''' + """ ret = [] for release in releases: - ret.append({ - "version": release['name'], - "published_at": release['published_at'], - "body": release['body'], - "tag_name": release['tag_name'], - "zipball_url": release['zipball_url'] - }) + ret.append( + { + "version": release["name"], + "published_at": release["published_at"], + "body": release["body"], + "tag_name": release["tag_name"], + "zipball_url": release["zipball_url"], + } + ) return ret def unzip(self): raise NotImplementedError() - + async def update(self): raise NotImplementedError() - + def compare_version(self, v1: str, v2: str) -> int: - ''' + """ 比较两个版本号的大小。 返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。 - ''' - v1 = v1.replace('v', '') - v2 = v2.replace('v', '') - v1 = v1.split('.') - v2 = v2.split('.') + """ + v1 = v1.replace("v", "") + v2 = v2.replace("v", "") + v1 = v1.split(".") + v2 = v2.split(".") for i in range(3): if int(v1[i]) > int(v2[i]): @@ -89,19 +97,19 @@ class RepoZipUpdator(): elif int(v1[i]) < int(v2[i]): return -1 return 0 - + async def check_update(self, url: str, current_version: str) -> ReleaseInfo: update_data = await self.fetch_release_info(url) - tag_name = update_data[0]['tag_name'] - + tag_name = update_data[0]["tag_name"] + if self.compare_version(current_version, tag_name) >= 0: return None return ReleaseInfo( version=tag_name, - published_at=update_data[0]['published_at'], - body=update_data[0]['body'] + published_at=update_data[0]["published_at"], + body=update_data[0]["body"], ) - + async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""): repo_namespace = repo_url.split("/")[-2:] author = repo_namespace[0] @@ -111,27 +119,28 @@ class RepoZipUpdator(): release_url = f"https://api.github.com/repos/{author}/{repo}/releases" releases = await self.fetch_release_info(url=release_url) if not releases: - # download from the default branch directly. + # download from the default branch directly. logger.info(f"正在从默认分支下载 {author}/{repo} ") - release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" + release_url = ( + f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" + ) else: - release_url = releases[0]['zipball_url'] - + release_url = releases[0]["zipball_url"] + if proxy: release_url = f"{proxy}/{release_url}" logger.info(f"使用代理下载: {release_url}") await download_file(release_url, target_path + ".zip") - - + def unzip_file(self, zip_path: str, target_dir: str): - ''' + """ 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir - ''' + """ os.makedirs(target_dir, exist_ok=True) update_dir = "" logger.info(f"解压文件: {zip_path}") - with zipfile.ZipFile(zip_path, 'r') as z: + with zipfile.ZipFile(zip_path, "r") as z: update_dir = z.namelist()[0] z.extractall(target_dir) @@ -145,13 +154,17 @@ class RepoZipUpdator(): if os.path.exists(os.path.join(target_dir, f)): os.remove(os.path.join(target_dir, f)) shutil.move(os.path.join(target_dir, update_dir, f), target_dir) - + try: - logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") + logger.info( + f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}" + ) shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) os.remove(zip_path) except BaseException: - logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") + logger.warn( + f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}" + ) def format_repo_name(self, repo_url: str) -> str: if repo_url.endswith("/"): @@ -159,10 +172,10 @@ class RepoZipUpdator(): repo_namespace = repo_url.split("/")[-2:] repo = repo_namespace[1] - + repo = self.format_name(repo) return repo - + def format_name(self, name: str) -> str: - return name.replace("-", "_").lower() \ No newline at end of file + return name.replace("-", "_").lower() diff --git a/astrbot/dashboard/__init__.py b/astrbot/dashboard/__init__.py index 3c420fa41..cf829d4d6 100644 --- a/astrbot/dashboard/__init__.py +++ b/astrbot/dashboard/__init__.py @@ -1 +1,3 @@ -from .dashboard_lifecycle import AstrBotDashBoardLifecycle \ No newline at end of file +from .dashboard_lifecycle import AstrBotDashBoardLifecycle + +__all__ = ["AstrBotDashBoardLifecycle"] diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py index b363ae3a7..930d5089e 100644 --- a/astrbot/dashboard/dashboard_lifecycle.py +++ b/astrbot/dashboard/dashboard_lifecycle.py @@ -1,20 +1,21 @@ import asyncio -import traceback from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .server import AstrBotDashboard from astrbot.core.db import BaseDatabase from astrbot.core import LogBroker + + class AstrBotDashBoardLifecycle: def __init__(self, db: BaseDatabase, log_broker: LogBroker): self.db = db self.logger = logger self.log_broker = log_broker self.dashboard_server = None - + async def start(self): core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) - + core_task = [] try: await core_lifecycle.initialize() @@ -23,12 +24,12 @@ class AstrBotDashBoardLifecycle: logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!") - + self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db) task = asyncio.gather(core_task, self.dashboard_server.run()) - + try: await task except asyncio.CancelledError: logger.info("🌈 正在关闭 AstrBot...") - await core_lifecycle.stop() \ No newline at end of file + await core_lifecycle.stop() diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index b1dee8bed..f4107bdc5 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -18,4 +18,3 @@ __all__ = [ "StaticFileRoute", "ChatRoute", ] - diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 1cdb4eae9..2e525a3ca 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -4,52 +4,56 @@ from .route import Route, Response, RouteContext from quart import request from astrbot.core import WEBUI_SK + class AuthRoute(Route): def __init__(self, context: RouteContext) -> None: super().__init__(context) self.routes = { - '/auth/login': ('POST', self.login), - '/auth/account/edit': ('POST', self.edit_account), + "/auth/login": ("POST", self.login), + "/auth/account/edit": ("POST", self.edit_account), } self.register_routes() - + async def login(self): - username = self.config['dashboard']['username'] - password = self.config['dashboard']['password'] + username = self.config["dashboard"]["username"] + password = self.config["dashboard"]["password"] post_data = await request.json if post_data["username"] == username and post_data["password"] == password: - return Response().ok({ - "token": self.generate_jwt(username), - "username": username - }).__dict__ + return ( + Response() + .ok({"token": self.generate_jwt(username), "username": username}) + .__dict__ + ) else: return Response().error("用户名或密码错误").__dict__ - + async def edit_account(self): - password = self.config['dashboard']['password'] + password = self.config["dashboard"]["password"] post_data = await request.json - + if post_data["password"] != password: return Response().error("原密码错误").__dict__ - - new_pwd = post_data.get('new_password', None) - new_username = post_data.get('new_username', None) + + new_pwd = post_data.get("new_password", None) + new_username = post_data.get("new_username", None) if not new_pwd and not new_username: - return Response().error("新用户名和新密码不能同时为空,你改了个寂寞").__dict__ + return ( + Response().error("新用户名和新密码不能同时为空,你改了个寂寞").__dict__ + ) if new_pwd: - self.config['dashboard']['password'] = new_pwd + self.config["dashboard"]["password"] = new_pwd if new_username: - self.config['dashboard']['username'] = new_username + self.config["dashboard"]["username"] = new_username self.config.save_config() return Response().ok(None, "修改成功").__dict__ - + def generate_jwt(self, username): payload = { "username": username, - "exp": datetime.datetime.utcnow() + datetime.timedelta(days=30) + "exp": datetime.datetime.utcnow() + datetime.timedelta(days=30), } token = jwt.encode(payload, WEBUI_SK, algorithm="HS256") - return token \ No newline at end of file + return token diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 594620004..db1461f59 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -11,115 +11,131 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle class ChatRoute(Route): - def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: + def __init__( + self, + context: RouteContext, + db: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: super().__init__(context) self.routes = { - '/chat/send': ('POST', self.chat), - '/chat/listen': ('GET', self.listener), - '/chat/new_conversation': ('GET', self.new_conversation), - '/chat/conversations': ('GET', self.get_conversations), - '/chat/get_conversation': ('GET', self.get_conversation), - '/chat/delete_conversation': ('GET', self.delete_conversation), - '/chat/get_file': ('GET', self.get_file), - '/chat/post_image': ('POST', self.post_image), - '/chat/post_file': ('POST', self.post_file), - '/chat/status': ('GET', self.status), + "/chat/send": ("POST", self.chat), + "/chat/listen": ("GET", self.listener), + "/chat/new_conversation": ("GET", self.new_conversation), + "/chat/conversations": ("GET", self.get_conversations), + "/chat/get_conversation": ("GET", self.get_conversation), + "/chat/delete_conversation": ("GET", self.delete_conversation), + "/chat/get_file": ("GET", self.get_file), + "/chat/post_image": ("POST", self.post_image), + "/chat/post_file": ("POST", self.post_file), + "/chat/status": ("GET", self.status), } self.db = db self.core_lifecycle = core_lifecycle self.register_routes() self.imgs_dir = "data/webchat/imgs" - - self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp'] - + + self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] + self.curr_user_cid = {} self.curr_chat_sse = {} - + async def status(self): - has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None - has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None - return Response().ok(data={ - 'llm_enabled': has_llm_enabled, - 'stt_enabled': has_stt_enabled - }).__dict__ - + has_llm_enabled = ( + self.core_lifecycle.provider_manager.curr_provider_inst is not None + ) + has_stt_enabled = ( + self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None + ) + return ( + Response() + .ok(data={"llm_enabled": has_llm_enabled, "stt_enabled": has_stt_enabled}) + .__dict__ + ) + async def get_file(self): - filename = request.args.get('filename') + filename = request.args.get("filename") if not filename: return Response().error("Missing key: filename").__dict__ - + try: with open(os.path.join(self.imgs_dir, filename), "rb") as f: if filename.endswith(".wav"): return QuartResponse(f.read(), mimetype="audio/wav") - elif filename.split('.')[-1] in self.supported_imgs: + elif filename.split(".")[-1] in self.supported_imgs: return QuartResponse(f.read(), mimetype="image/jpeg") else: return QuartResponse(f.read()) - + except FileNotFoundError: return Response().error("File not found").__dict__ - + async def post_image(self): post_data = await request.files - if 'file' not in post_data: + if "file" not in post_data: return Response().error("Missing key: file").__dict__ - - file = post_data['file'] + + file = post_data["file"] filename = str(uuid.uuid4()) + ".jpg" path = os.path.join(self.imgs_dir, filename) await file.save(path) - - return Response().ok(data={ - 'filename': filename - }).__dict__ - + + return Response().ok(data={"filename": filename}).__dict__ + async def post_file(self): post_data = await request.files - if 'file' not in post_data: + if "file" not in post_data: return Response().error("Missing key: file").__dict__ - - file = post_data['file'] + + file = post_data["file"] filename = f"{str(uuid.uuid4())}" print(file) # 通过文件格式判断文件类型 - if file.content_type.startswith('audio'): + if file.content_type.startswith("audio"): filename += ".wav" - + path = os.path.join(self.imgs_dir, filename) await file.save(path) - - return Response().ok(data={ - 'filename': filename - }).__dict__ + + return Response().ok(data={"filename": filename}).__dict__ async def chat(self): - username = g.get('username', 'guest') - + username = g.get("username", "guest") + post_data = await request.json - if 'message' not in post_data and 'image_url' not in post_data: + if "message" not in post_data and "image_url" not in post_data: return Response().error("Missing key: message or image_url").__dict__ - - if 'conversation_id' not in post_data: + + if "conversation_id" not in post_data: return Response().error("Missing key: conversation_id").__dict__ - - message = post_data['message'] - conversation_id = post_data['conversation_id'] - image_url = post_data.get('image_url') - audio_url = post_data.get('audio_url') + + message = post_data["message"] + conversation_id = post_data["conversation_id"] + image_url = post_data.get("image_url") + audio_url = post_data.get("audio_url") if not message and not image_url and not audio_url: - return Response().error("Message and image_url and audio_url are empty").__dict__ + return ( + Response() + .error("Message and image_url and audio_url are empty") + .__dict__ + ) if not conversation_id: return Response().error("conversation_id is empty").__dict__ - + self.curr_user_cid[username] = conversation_id - - await web_chat_queue.put((username, conversation_id, { - 'message': message, - 'image_url': image_url, # list - 'audio_url': audio_url - })) - + + await web_chat_queue.put( + ( + username, + conversation_id, + { + "message": message, + "image_url": image_url, # list + "audio_url": audio_url, + }, + ) + ) + # 持久化 conversation = self.db.get_conversation_by_user_id(username, conversation_id) try: @@ -127,59 +143,59 @@ class ChatRoute(Route): except BaseException as e: print(e) history = [] - new_his = { - 'type': 'user', - 'message': message - } + new_his = {"type": "user", "message": message} if image_url: - new_his['image_url'] = image_url + new_his["image_url"] = image_url if audio_url: - new_his['audio_url'] = audio_url + new_his["audio_url"] = audio_url history.append(new_his) - self.db.update_conversation(username, conversation_id, history=json.dumps(history)) - + self.db.update_conversation( + username, conversation_id, history=json.dumps(history) + ) + return Response().ok().__dict__ - + async def listener(self): - '''一直保持长连接''' - - username = g.get('username', 'guest') - + """一直保持长连接""" + + username = g.get("username", "guest") + if username in self.curr_chat_sse: return "[ERROR]\n" - + self.curr_chat_sse[username] = None - + async def stream(): try: - yield '[HB]\n' + yield "[HB]\n" while True: try: - result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=10) # 设置超时时间为5秒 + result = await asyncio.wait_for( + web_chat_back_queue.get(), timeout=10 + ) # 设置超时时间为5秒 except asyncio.TimeoutError: - yield '[HB]\n' # 心跳包 + yield "[HB]\n" # 心跳包 continue - + if not result: continue result_text, cid = result if cid != self.curr_user_cid.get(username): # 丢弃 continue - yield result_text + '\n' - + yield result_text + "\n" + conversation = self.db.get_conversation_by_user_id(username, cid) try: history = json.loads(conversation.history) except BaseException as e: print(e) history = [] - history.append({ - 'type': 'bot', - 'message': result_text - }) - self.db.update_conversation(username, cid, history=json.dumps(history)) - + history.append({"type": "bot", "message": result_text}) + self.db.update_conversation( + username, cid, history=json.dumps(history) + ) + await asyncio.sleep(0.5) except BaseException as _: logger.debug(f"用户 {username} 断开聊天长连接。") @@ -189,45 +205,43 @@ class ChatRoute(Route): response = await make_response( stream(), { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - 'Transfer-Encoding': 'chunked', - 'Connection': 'keep-alive' - } + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Transfer-Encoding": "chunked", + "Connection": "keep-alive", + }, ) response.timeout = None return response - + async def delete_conversation(self): - username = g.get('username', 'guest') - conversation_id = request.args.get('conversation_id') + username = g.get("username", "guest") + conversation_id = request.args.get("conversation_id") if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ - + self.db.delete_conversation(username, conversation_id) return Response().ok().__dict__ - + async def new_conversation(self): - username = g.get('username', 'guest') + username = g.get("username", "guest") conversation_id = str(uuid.uuid4()) self.db.new_conversation(username, conversation_id) - return Response().ok(data={ - 'conversation_id': conversation_id - }).__dict__ - + return Response().ok(data={"conversation_id": conversation_id}).__dict__ + async def get_conversations(self): - username = g.get('username', 'guest') + username = g.get("username", "guest") conversations = self.db.get_conversations(username) return Response().ok(data=conversations).__dict__ - + async def get_conversation(self): - username = g.get('username', 'guest') - conversation_id = request.args.get('conversation_id') + username = g.get("username", "guest") + conversation_id = request.args.get("conversation_id") if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ - + conversation = self.db.get_conversation_by_user_id(username, conversation_id) - + self.curr_user_cid[username] = conversation_id - - return Response().ok(data=conversation).__dict__ \ No newline at end of file + + return Response().ok(data=conversation).__dict__ diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 0e2cdc1c4..c1216ac2e 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -10,17 +10,25 @@ from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry from astrbot.core import logger + def try_cast(value: str, type_: str): if type_ == "int" and value.isdigit(): return int(value) - elif type_ == "float" and isinstance(value, str) \ - and value.replace(".", "", 1).isdigit(): + elif ( + type_ == "float" + and isinstance(value, str) + and value.replace(".", "", 1).isdigit() + ): return float(value) elif type_ == "float" and isinstance(value, int): return float(value) -def validate_config(data, schema: dict, is_core: bool) -> typing.Tuple[typing.List[str], typing.Dict]: + +def validate_config( + data, schema: dict, is_core: bool +) -> typing.Tuple[typing.List[str], typing.Dict]: errors = [] + def validate(data, metadata=schema, path=""): for key, meta in metadata.items(): if key not in data: @@ -40,21 +48,33 @@ def validate_config(data, schema: dict, is_core: bool) -> typing.Tuple[typing.Li if meta["type"] == "int" and not isinstance(value, int): casted = try_cast(value, "int") if casted is None: - errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}") + errors.append( + f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}" + ) data[key] = casted elif meta["type"] == "float" and not isinstance(value, float): casted = try_cast(value, "float") if casted is None: - errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}") + errors.append( + f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}" + ) data[key] = casted elif meta["type"] == "bool" and not isinstance(value, bool): - errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}") + errors.append( + f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}" + ) elif meta["type"] in ["string", "text"] and not isinstance(value, str): - errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}") + errors.append( + f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}" + ) elif meta["type"] == "list" and not isinstance(value, list): - errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}") + errors.append( + f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}" + ) elif meta["type"] == "object" and not isinstance(value, dict): - errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}") + errors.append( + f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}" + ) validate(value, meta["items"], path=f"{path}{key}.") if is_core: @@ -66,15 +86,18 @@ def validate_config(data, schema: dict, is_core: bool) -> typing.Tuple[typing.Li validate(data, group_meta, path=f"{key}.") else: validate(data, schema) - + return errors, data + def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): - '''验证并保存配置''' + """验证并保存配置""" errors = None try: if is_core: - errors, post_config = validate_config(post_config, CONFIG_METADATA_2, is_core) + errors, post_config = validate_config( + post_config, CONFIG_METADATA_2, is_core + ) else: errors, post_config = validate_config(post_config, config.schema, is_core) except BaseException as e: @@ -83,23 +106,24 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) if errors: raise ValueError(f"格式校验未通过: {errors}") config.save_config(post_config) - + + class ConfigRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None: + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.routes = { - '/config/get': ('GET', self.get_configs), - '/config/astrbot/update': ('POST', self.post_astrbot_configs), - '/config/plugin/update': ('POST', self.post_plugin_configs), - - '/config/platform/new': ('POST', self.post_new_platform), - '/config/platform/update': ('POST', self.post_update_platform), - '/config/platform/delete': ('POST', self.post_delete_platform), - - '/config/provider/new': ('POST', self.post_new_provider), - '/config/provider/update': ('POST', self.post_update_provider), - '/config/provider/delete': ('POST', self.post_delete_provider) + "/config/get": ("GET", self.get_configs), + "/config/astrbot/update": ("POST", self.post_astrbot_configs), + "/config/plugin/update": ("POST", self.post_plugin_configs), + "/config/platform/new": ("POST", self.post_new_platform), + "/config/platform/update": ("POST", self.post_update_platform), + "/config/platform/delete": ("POST", self.post_delete_platform), + "/config/provider/new": ("POST", self.post_new_provider), + "/config/provider/update": ("POST", self.post_update_provider), + "/config/provider/delete": ("POST", self.post_delete_provider), } self.register_routes() @@ -119,83 +143,91 @@ class ConfigRoute(Route): except Exception as e: logger.error(e) return Response().error(str(e)).__dict__ - + async def post_plugin_configs(self): post_configs = await request.json plugin_name = request.args.get("plugin_name", "unknown") try: await self._save_plugin_configs(post_configs, plugin_name) - return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__ + return ( + Response() + .ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。") + .__dict__ + ) except Exception as e: return Response().error(str(e)).__dict__ - + async def post_new_platform(self): new_platform_config = await request.json - self.config['platform'].append(new_platform_config) + self.config["platform"].append(new_platform_config) try: save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.load_platform(new_platform_config) + await self.core_lifecycle.platform_manager.load_platform( + new_platform_config + ) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "新增平台配置成功~").__dict__ - + async def post_new_provider(self): new_provider_config = await request.json - self.config['provider'].append(new_provider_config) + self.config["provider"].append(new_provider_config) try: save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.provider_manager.load_provider(new_provider_config) + await self.core_lifecycle.provider_manager.load_provider( + new_provider_config + ) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "新增服务提供商配置成功~").__dict__ - + async def post_update_platform(self): update_platform_config = await request.json platform_id = update_platform_config.get("id", None) new_config = update_platform_config.get("config", None) if not platform_id or not new_config: return Response().error("参数错误").__dict__ - - for i, platform in enumerate(self.config['platform']): - if platform['id'] == platform_id: - self.config['platform'][i] = new_config + + for i, platform in enumerate(self.config["platform"]): + if platform["id"] == platform_id: + self.config["platform"][i] = new_config break else: return Response().error("未找到对应平台").__dict__ - + try: await self._save_astrbot_configs(self.config) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "更新平台配置成功~").__dict__ - + async def post_update_provider(self): update_provider_config = await request.json provider_id = update_provider_config.get("id", None) new_config = update_provider_config.get("config", None) if not provider_id or not new_config: return Response().error("参数错误").__dict__ - - for i, provider in enumerate(self.config['provider']): - if provider['id'] == provider_id: - self.config['provider'][i] = new_config + + for i, provider in enumerate(self.config["provider"]): + if provider["id"] == provider_id: + self.config["provider"][i] = new_config break else: return Response().error("未找到对应服务提供商").__dict__ - + try: save_config(self.config, self.config, is_core=True) await self.core_lifecycle.provider_manager.reload(new_config) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "更新成功,已经实时生效~").__dict__ - + async def post_delete_platform(self): platform_id = await request.json platform_id = platform_id.get("id") - for i, platform in enumerate(self.config['platform']): - if platform['id'] == platform_id: - del self.config['platform'][i] + for i, platform in enumerate(self.config["platform"]): + if platform["id"] == platform_id: + del self.config["platform"][i] break else: return Response().error("未找到对应平台").__dict__ @@ -204,13 +236,13 @@ class ConfigRoute(Route): except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "删除平台配置成功~").__dict__ - + async def post_delete_provider(self): provider_id = await request.json provider_id = provider_id.get("id") - for i, provider in enumerate(self.config['provider']): - if provider['id'] == provider_id: - del self.config['provider'][i] + for i, provider in enumerate(self.config["provider"]): + if provider["id"] == provider_id: + del self.config["provider"][i] break else: return Response().error("未找到对应服务提供商").__dict__ @@ -223,66 +255,66 @@ class ConfigRoute(Route): async def _get_astrbot_config(self): config = self.config - + # 平台适配器的默认配置模板注入 - platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template'] + platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][ + "platform" + ]["config_template"] for platform in platform_registry: if platform.default_config_tmpl: platform_default_tmpl[platform.name] = platform.default_config_tmpl - + # 服务提供商的默认配置模板注入 - provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template'] + provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][ + "provider" + ]["config_template"] for provider in provider_registry: if provider.default_config_tmpl: provider_default_tmpl[provider.type] = provider.default_config_tmpl - - return { - "metadata": CONFIG_METADATA_2, - "config": config - } + + return {"metadata": CONFIG_METADATA_2, "config": config} async def _get_plugin_config(self, plugin_name: str): - ret = { - "metadata": None, - "config": None - } - + ret = {"metadata": None, "config": None} + for plugin_md in star_registry: if plugin_md.name == plugin_name: if not plugin_md.config: break - ret['config'] = plugin_md.config # 这是自定义的 Dict 类(AstrBotConfig) - ret['metadata'] = { + ret["config"] = ( + plugin_md.config + ) # 这是自定义的 Dict 类(AstrBotConfig) + ret["metadata"] = { plugin_name: { "description": f"{plugin_name} 配置", "type": "object", - "items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema + "items": plugin_md.config.schema, # 初始化时通过 __setattr__ 存入了 schema } } break - + return ret - + async def _save_astrbot_configs(self, post_configs: dict): try: save_config(post_configs, self.config, is_core=True) self.core_lifecycle.restart() except Exception as e: raise e - + async def _save_plugin_configs(self, post_configs: dict, plugin_name: str): md = None for plugin_md in star_registry: if plugin_md.name == plugin_name: md = plugin_md - + if not md: raise ValueError(f"插件 {plugin_name} 不存在") if not md.config: raise ValueError(f"插件 {plugin_name} 没有注册配置") - + try: save_config(post_configs, md.config) self.core_lifecycle.restart() except Exception as e: - raise e \ No newline at end of file + raise e diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index 949023d83..e4178a832 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -2,12 +2,15 @@ import asyncio from quart import websocket from astrbot.core import logger, LogBroker from .route import Route, RouteContext - + + class LogRoute(Route): def __init__(self, context: RouteContext, log_broker: LogBroker) -> None: super().__init__(context) self.log_broker = log_broker - self.app.add_url_rule('/api/live-log', view_func=self.log, methods=['GET'], websocket=True) + self.app.add_url_rule( + "/api/live-log", view_func=self.log, methods=["GET"], websocket=True + ) async def log(self): queue = None @@ -22,4 +25,4 @@ class LogRoute(Route): logger.error(f"WebSocket 连接错误: {e}") finally: if queue: - self.log_broker.unregister(queue) \ No newline at end of file + self.log_broker.unregister(queue) diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 545ed3a95..fee6dc6b4 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -12,33 +12,39 @@ from astrbot.core.star.filter.permission import PermissionTypeFilter from astrbot.core.star.filter.regex import RegexFilter from astrbot.core.star.star_handler import EventType + class PluginRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None: + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + plugin_manager: PluginManager, + ) -> None: super().__init__(context) self.routes = { - '/plugin/get': ('GET', self.get_plugins), - '/plugin/install': ('POST', self.install_plugin), - '/plugin/install-upload': ('POST', self.install_plugin_upload), - '/plugin/update': ('POST', self.update_plugin), - '/plugin/uninstall': ('POST', self.uninstall_plugin), - '/plugin/market_list': ('GET', self.get_online_plugins), - '/plugin/off': ('POST', self.off_plugin), - '/plugin/on': ('POST', self.on_plugin), - '/plugin/reload': ('POST', self.reload_plugins), + "/plugin/get": ("GET", self.get_plugins), + "/plugin/install": ("POST", self.install_plugin), + "/plugin/install-upload": ("POST", self.install_plugin_upload), + "/plugin/update": ("POST", self.update_plugin), + "/plugin/uninstall": ("POST", self.uninstall_plugin), + "/plugin/market_list": ("GET", self.get_online_plugins), + "/plugin/off": ("POST", self.off_plugin), + "/plugin/on": ("POST", self.on_plugin), + "/plugin/reload": ("POST", self.reload_plugins), } self.core_lifecycle = core_lifecycle self.plugin_manager = plugin_manager self.register_routes() - + self.translated_event_type = { EventType.AdapterMessageEvent: "平台消息下发时", EventType.OnLLMRequestEvent: "LLM 请求时", EventType.OnLLMResponseEvent: "LLM 响应后", EventType.OnDecoratingResultEvent: "回复消息前", EventType.OnCallingFuncToolEvent: "函数工具", - EventType.OnAfterMessageSentEvent: "发送消息后" + EventType.OnAfterMessageSentEvent: "发送消息后", } - + async def reload_plugins(self): data = await request.json plugin_name = data.get("name", None) @@ -50,17 +56,15 @@ class PluginRoute(Route): except Exception as e: logger.error(f"/api/plugin/reload: {traceback.format_exc()}") return Response().error(str(e)).__dict__ - + async def get_online_plugins(self): custom = request.args.get("custom_registry") - + if custom: urls = [custom] else: - urls = [ - "https://api.soulter.top/astrbot/plugins" - ] - + urls = ["https://api.soulter.top/astrbot/plugins"] + for url in urls: try: async with aiohttp.ClientSession(trust_env=True) as session: @@ -72,58 +76,86 @@ class PluginRoute(Route): logger.error(f"请求 {url} 失败,状态码:{response.status}") except Exception as e: logger.error(f"请求 {url} 失败,错误:{e}") - + return Response().error("获取插件列表失败").__dict__ - + async def get_plugins(self): _plugin_resp = [] for plugin in self.plugin_manager.context.get_all_stars(): _t = { "name": plugin.name, - "repo": '' if plugin.repo is None else plugin.repo, + "repo": "" if plugin.repo is None else plugin.repo, "author": plugin.author, "desc": plugin.desc, "version": plugin.version, "reserved": plugin.reserved, "activated": plugin.activated, "online_vesion": "", - "handlers": await self.get_plugin_handlers_info(plugin.star_handler_full_names), + "handlers": await self.get_plugin_handlers_info( + plugin.star_handler_full_names + ), } _plugin_resp.append(_t) - return Response().ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info).__dict__ - + return ( + Response() + .ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info) + .__dict__ + ) + async def get_plugin_handlers_info(self, handler_full_names: list[str]): - '''解析插件行为''' + """解析插件行为""" handlers = [] - + for handler_full_name in handler_full_names: info = {} - handler = star_handlers_registry.star_handlers_map.get(handler_full_name, None) + handler = star_handlers_registry.star_handlers_map.get( + handler_full_name, None + ) if handler is None: continue info["event_type"] = handler.event_type.name - info["event_type_h"] = self.translated_event_type.get(handler.event_type, handler.event_type.name) + info["event_type_h"] = self.translated_event_type.get( + handler.event_type, handler.event_type.name + ) info["handler_full_name"] = handler.handler_full_name info["desc"] = handler.desc info["handler_name"] = handler.handler_name - + if handler.event_type == EventType.AdapterMessageEvent: # 处理平台适配器消息事件 has_admin = False - for filter in handler.event_filters: # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高 + for filter in ( + handler.event_filters + ): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高 if isinstance(filter, CommandFilter): info["type"] = "指令" - info["cmd"] = f"{filter.parent_command_names[0]} {filter.command_name}" + info["cmd"] = ( + f"{filter.parent_command_names[0]} {filter.command_name}" + ) info["cmd"] = info["cmd"].strip() - if self.core_lifecycle.astrbot_config['wake_prefix'] and len(self.core_lifecycle.astrbot_config['wake_prefix']) > 0: - info["cmd"] = f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" + if ( + self.core_lifecycle.astrbot_config["wake_prefix"] + and len(self.core_lifecycle.astrbot_config["wake_prefix"]) + > 0 + ): + info["cmd"] = ( + f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" + ) elif isinstance(filter, CommandGroupFilter): info["type"] = "指令组" info["cmd"] = filter.get_complete_command_names()[0] info["cmd"] = info["cmd"].strip() - info["sub_command"] = filter.print_cmd_tree(filter.sub_command_filters) - if self.core_lifecycle.astrbot_config['wake_prefix'] and len(self.core_lifecycle.astrbot_config['wake_prefix']) > 0: - info["cmd"] = f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" + info["sub_command"] = filter.print_cmd_tree( + filter.sub_command_filters + ) + if ( + self.core_lifecycle.astrbot_config["wake_prefix"] + and len(self.core_lifecycle.astrbot_config["wake_prefix"]) + > 0 + ): + info["cmd"] = ( + f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" + ) elif isinstance(filter, RegexFilter): info["type"] = "正则匹配" info["cmd"] = filter.regex_str @@ -137,22 +169,22 @@ class PluginRoute(Route): else: info["cmd"] = "自动触发" info["type"] = "无" - + if not info["desc"]: info["desc"] = "无描述" - + handlers.append(info) - + return handlers - + async def install_plugin(self): post_data = await request.json repo_url = post_data["url"] - + proxy: str = post_data.get("proxy", None) if proxy: proxy = proxy.removesuffix("/") - + try: logger.info(f"正在安装插件 {repo_url}") await self.plugin_manager.install_plugin(repo_url, proxy) @@ -162,11 +194,11 @@ class PluginRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - + async def install_plugin_upload(self): try: file = await request.files - file = file['file'] + file = file["file"] logger.info(f"正在安装用户上传的插件 {file.filename}") file_path = f"data/temp/{file.filename}" await file.save(file_path) @@ -177,7 +209,7 @@ class PluginRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - + async def uninstall_plugin(self): post_data = await request.json plugin_name = post_data["name"] @@ -189,7 +221,7 @@ class PluginRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - + async def update_plugin(self): post_data = await request.json plugin_name = post_data["name"] @@ -203,7 +235,7 @@ class PluginRoute(Route): except Exception as e: logger.error(f"/api/plugin/update: {traceback.format_exc()}") return Response().error(str(e)).__dict__ - + async def off_plugin(self): post_data = await request.json plugin_name = post_data["name"] @@ -224,4 +256,4 @@ class PluginRoute(Route): return Response().ok(None, "启用成功。").__dict__ except Exception as e: logger.error(f"/api/plugin/on: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ \ No newline at end of file + return Response().error(str(e)).__dict__ diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index cbec57cdf..bd94d9adf 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -2,22 +2,25 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from dataclasses import dataclass from quart import Quart + @dataclass class RouteContext: config: AstrBotConfig app: Quart -class Route(): + +class Route: def __init__(self, context: RouteContext): self.app = context.app self.config = context.config - + def register_routes(self): for route, (method, func) in self.routes.items(): self.app.add_url_rule(f"/api{route}", view_func=func, methods=[method]) + @dataclass -class Response(): +class Response: status: str = None message: str = None data: dict = None @@ -27,8 +30,8 @@ class Response(): self.message = message return self - def ok(self, data: dict={}, message: str=None): + def ok(self, data: dict = {}, message: str = None): self.status = "ok" self.data = data self.message = message - return self \ No newline at end of file + return self diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 1a7e61bc2..3a21aa2b9 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,7 +1,6 @@ import traceback import psutil import time -import aiohttp from .route import Route, Response, RouteContext from astrbot.core import logger from quart import request @@ -9,71 +8,86 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.core.config import VERSION + class StatRoute(Route): - def __init__(self, context: RouteContext, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: + def __init__( + self, + context: RouteContext, + db_helper: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: super().__init__(context) self.routes = { - '/stat/get': ('GET', self.get_stat), - '/stat/version': ('GET', self.get_version), - '/stat/start-time': ('GET', self.get_start_time), - '/stat/restart-core': ('POST', self.restart_core) + "/stat/get": ("GET", self.get_stat), + "/stat/version": ("GET", self.get_version), + "/stat/start-time": ("GET", self.get_start_time), + "/stat/restart-core": ("POST", self.restart_core), } self.db_helper = db_helper self.register_routes() self.core_lifecycle = core_lifecycle - + async def restart_core(self): self.core_lifecycle.restart() return Response().ok().__dict__ - + def format_sec(self, sec: int): m, s = divmod(sec, 60) h, m = divmod(m, 60) return f"{h}小时{m}分{s}秒" - + async def get_version(self): - return Response().ok({ - "version": VERSION - }).__dict__ - + return Response().ok({"version": VERSION}).__dict__ + async def get_start_time(self): - return Response().ok({ - "start_time": self.core_lifecycle.start_time - }).__dict__ - + return Response().ok({"start_time": self.core_lifecycle.start_time}).__dict__ + async def get_stat(self): - offset_sec = request.args.get('offset_sec', 86400) + offset_sec = request.args.get("offset_sec", 86400) offset_sec = int(offset_sec) try: stat = self.db_helper.get_base_stats(offset_sec) now = int(time.time()) start_time = now - offset_sec message_time_based_stats = [] - + idx = 0 for bucket_end in range(start_time, now, 1800): cnt = 0 - while idx < len(stat.platform) and stat.platform[idx].timestamp < bucket_end: + while ( + idx < len(stat.platform) + and stat.platform[idx].timestamp < bucket_end + ): cnt += stat.platform[idx].count idx += 1 message_time_based_stats.append([bucket_end, cnt]) - + stat_dict = stat.__dict__ - - stat_dict.update({ - "platform": self.db_helper.get_grouped_base_stats(offset_sec).platform, - "message_count": self.db_helper.get_total_message_count() or 0, - "platform_count": len(self.core_lifecycle.platform_manager.get_insts()), - "plugin_count": len(self.core_lifecycle.star_context.get_all_stars()), - "message_time_series": message_time_based_stats, - "running": self.format_sec(int(time.time()) - self.core_lifecycle.start_time), - "memory": { - "process": psutil.Process().memory_info().rss >> 20, - "system": psutil.virtual_memory().total >> 20 + + stat_dict.update( + { + "platform": self.db_helper.get_grouped_base_stats( + offset_sec + ).platform, + "message_count": self.db_helper.get_total_message_count() or 0, + "platform_count": len( + self.core_lifecycle.platform_manager.get_insts() + ), + "plugin_count": len( + self.core_lifecycle.star_context.get_all_stars() + ), + "message_time_series": message_time_based_stats, + "running": self.format_sec( + int(time.time()) - self.core_lifecycle.start_time + ), + "memory": { + "process": psutil.Process().memory_info().rss >> 20, + "system": psutil.virtual_memory().total >> 20, + }, } - }) - + ) + return Response().ok(stat_dict).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(e.__str__()).__dict__ \ No newline at end of file + return Response().error(e.__str__()).__dict__ diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index 630d50d2a..ef2d10634 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -1,5 +1,4 @@ import traceback -import aiohttp from .route import Route, Response, RouteContext from quart import request from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -8,30 +7,37 @@ from astrbot.core import logger, pip_installer from astrbot.core.utils.io import download_dashboard, get_dashboard_version from astrbot.core.config.default import VERSION + class UpdateRoute(Route): - def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator, core_lifecycle: AstrBotCoreLifecycle) -> None: + def __init__( + self, + context: RouteContext, + astrbot_updator: AstrBotUpdator, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: super().__init__(context) self.routes = { - '/update/check': ('GET', self.check_update), - '/update/releases': ('GET', self.get_releases), - '/update/do': ('POST', self.update_project), - '/update/dashboard': ('POST', self.update_dashboard), - '/update/pip-install': ('POST', self.install_pip_package) + "/update/check": ("GET", self.check_update), + "/update/releases": ("GET", self.get_releases), + "/update/do": ("POST", self.update_project), + "/update/dashboard": ("POST", self.update_dashboard), + "/update/pip-install": ("POST", self.install_pip_package), } self.astrbot_updator = astrbot_updator self.core_lifecycle = core_lifecycle self.register_routes() - + async def check_update(self): - type_ = request.args.get('type', None) - + type_ = request.args.get("type", None) + try: dv = await get_dashboard_version() - if type_ == 'dashboard': - return Response().ok({ - "has_new_version": dv != f"v{VERSION}", - "current_version": dv - }).__dict__ + if type_ == "dashboard": + return ( + Response() + .ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv}) + .__dict__ + ) else: ret = await self.astrbot_updator.check_update(None, None) return Response( @@ -41,13 +47,13 @@ class UpdateRoute(Route): "version": f"v{VERSION}", "has_new_version": ret is not None, "dashboard_version": dv, - "dashboard_has_new_version": dv != f"v{VERSION}" - } + "dashboard_has_new_version": dv != f"v{VERSION}", + }, ).__dict__ except Exception as e: logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)") return Response().error(e.__str__()).__dict__ - + async def get_releases(self): try: ret = await self.astrbot_updator.get_releases() @@ -55,14 +61,14 @@ class UpdateRoute(Route): except Exception as e: logger.error(f"/api/update/releases: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ - + async def update_project(self): data = await request.json - version = data.get('version', '') - reboot = data.get('reboot', True) + version = data.get("version", "") + reboot = data.get("reboot", True) if version == "" or version == "latest": latest = True - version = '' + version = "" else: latest = False @@ -71,31 +77,41 @@ class UpdateRoute(Route): proxy = proxy.removesuffix("/") try: - await self.astrbot_updator.update(latest=latest, version=version, proxy=proxy) - + await self.astrbot_updator.update( + latest=latest, version=version, proxy=proxy + ) + if latest: try: await download_dashboard() except Exception as e: logger.error(f"下载管理面板文件失败: {e}。") - + # pip 更新依赖 logger.info("更新依赖中...") try: pip_installer.install(requirements_path="requirements.txt") except Exception as e: logger.error(f"更新依赖失败: {e}") - + if reboot: # threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start() self.core_lifecycle.restart() - return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__ + return ( + Response() + .ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。") + .__dict__ + ) else: - return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__ + return ( + Response() + .ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。") + .__dict__ + ) except Exception as e: logger.error(f"/api/update_project: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ - + async def update_dashboard(self): try: try: @@ -103,14 +119,16 @@ class UpdateRoute(Route): except Exception as e: logger.error(f"下载管理面板文件失败: {e}。") return Response().error(f"下载管理面板文件失败: {e}").__dict__ - return Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__ + return ( + Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__ + ) except Exception as e: logger.error(f"/api/update_dashboard: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ - + async def install_pip_package(self): data = await request.json - package = data.get('package', '') + package = data.get("package", "") if not package: return Response().error("缺少参数 package 或不合法。").__dict__ try: @@ -118,4 +136,4 @@ class UpdateRoute(Route): return Response().ok(None, "安装成功。").__dict__ except Exception as e: logger.error(f"/api/update_pip: {traceback.format_exc()}") - return Response().error(e.__str__()).__dict__ \ No newline at end of file + return Response().error(e.__str__()).__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index e3f28cb55..0a5839533 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -3,7 +3,6 @@ import jwt import asyncio import os import socket -import sys import psutil from astrbot.core.config.default import VERSION from quart import Quart, request, jsonify, g @@ -15,9 +14,12 @@ from astrbot.core import logger, WEBUI_SK from astrbot.core.db import BaseDatabase from astrbot.core.utils.io import get_local_ip_addresses -DATAPATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data")) +DATAPATH = os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data") +) -class AstrBotDashboard(): + +class AstrBotDashboard: def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config @@ -28,15 +30,19 @@ class AstrBotDashboard(): # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) self.context = RouteContext(self.config, self.app) - self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator, core_lifecycle) + self.ur = UpdateRoute( + self.context, core_lifecycle.astrbot_updator, core_lifecycle + ) self.sr = StatRoute(self.context, db, core_lifecycle) - self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager) + self.pr = PluginRoute( + self.context, core_lifecycle, core_lifecycle.plugin_manager + ) self.cr = ConfigRoute(self.context, core_lifecycle) self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.sfr = StaticFileRoute(self.context) self.ar = AuthRoute(self.context) self.chat_route = ChatRoute(self.context, db, core_lifecycle) - + async def auth_middleware(self): if not request.path.startswith("/api"): return @@ -63,13 +69,12 @@ class AstrBotDashboard(): r = jsonify(Response().error("Token 无效").__dict__) r.status_code = 401 return r - - + async def shutdown_trigger_placeholder(self): while not self.core_lifecycle.event_queue.closed: await asyncio.sleep(1) logger.info("管理面板已关闭。") - + def check_port_in_use(self, port: int) -> bool: """ 跨平台检测端口是否被占用 @@ -80,7 +85,7 @@ class AstrBotDashboard(): sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # 设置超时时间 sock.settimeout(2) - result = sock.connect_ex(('127.0.0.1', port)) + result = sock.connect_ex(("127.0.0.1", port)) sock.close() # result 为 0 表示端口被占用 return result == 0 @@ -88,11 +93,11 @@ class AstrBotDashboard(): logger.warning(f"检查端口 {port} 时发生错误: {str(e)}") # 如果出现异常,保守起见认为端口可能被占用 return True - + def get_process_using_port(self, port: int) -> str: """获取占用端口的进程详细信息""" try: - for conn in psutil.net_connections(kind='inet'): + for conn in psutil.net_connections(kind="inet"): if conn.laddr.port == port: try: process = psutil.Process(conn.pid) @@ -102,7 +107,7 @@ class AstrBotDashboard(): f"PID: {process.pid}", f"执行路径: {process.exe()}", f"工作目录: {process.cwd()}", - f"启动命令: {' '.join(process.cmdline())}" + f"启动命令: {' '.join(process.cmdline())}", ] return "\n ".join(proc_info) except (psutil.NoSuchProcess, psutil.AccessDenied) as e: @@ -110,34 +115,39 @@ class AstrBotDashboard(): return "未找到占用进程" except Exception as e: return f"获取进程信息失败: {str(e)}" - + def run(self): try: ip_addr = get_local_ip_addresses() - except Exception as e: + except Exception as _: ip_addr = [] - - port = self.core_lifecycle.astrbot_config['dashboard'].get("port", 6185) + + port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185) if isinstance(port, str): port = int(port) if self.check_port_in_use(port): process_info = self.get_process_using_port(port) - logger.error(f"错误:端口 {port} 已被占用\n" - f"占用信息: \n {process_info}\n" - f"请确保:\n" - f"1. 没有其他 AstrBot 实例正在运行\n" - f"2. 端口 {port} 没有被其他程序占用\n" - f"3. 如需使用其他端口,请修改配置文件") - + logger.error( + f"错误:端口 {port} 已被占用\n" + f"占用信息: \n {process_info}\n" + f"请确保:\n" + f"1. 没有其他 AstrBot 实例正在运行\n" + f"2. 端口 {port} 没有被其他程序占用\n" + f"3. 如需使用其他端口,请修改配置文件" + ) + raise Exception(f"端口 {port} 已被占用") - + display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n" display += f" ➜ 本地: http://localhost:{port}\n" for ip in ip_addr: display += f" ➜ 网络: http://{ip}:{port}\n" display += " ➜ 默认用户名和密码: astrbot\n ✨✨✨\n" logger.info(display) - - return self.app.run_task(host="0.0.0.0", port=port, shutdown_trigger=self.shutdown_trigger_placeholder) \ No newline at end of file + return self.app.run_task( + host="0.0.0.0", + port=port, + shutdown_trigger=self.shutdown_trigger_placeholder, + ) diff --git a/main.py b/main.py index d7de4e04e..55f785d86 100644 --- a/main.py +++ b/main.py @@ -21,22 +21,24 @@ logo_tmpl = r""" """ + def check_env(): if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): logger.error("请使用 Python3.10+ 运行本项目。") exit() - + os.makedirs("data/config", exist_ok=True) os.makedirs("data/plugins", exist_ok=True) os.makedirs("data/temp", exist_ok=True) # workaround for issue #181 - mimetypes.add_type("text/javascript", ".js") + mimetypes.add_type("text/javascript", ".js") mimetypes.add_type("text/javascript", ".mjs") mimetypes.add_type("application/json", ".json") - + + async def check_dashboard_files(): - '''下载管理面板文件''' + """下载管理面板文件""" v = await get_dashboard_version() if v is not None: @@ -44,11 +46,15 @@ async def check_dashboard_files(): if v == f"v{VERSION}": logger.info("管理面板文件已是最新。") else: - logger.warning("检测到管理面板有更新。可以使用 /dashboard_update 命令更新。") + logger.warning( + "检测到管理面板有更新。可以使用 /dashboard_update 命令更新。" + ) return - - logger.info("开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/Soulter/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。") - + + logger.info( + "开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/Soulter/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。" + ) + try: await download_dashboard() except Exception as e: @@ -57,20 +63,21 @@ async def check_dashboard_files(): logger.info("管理面板下载完成。") + if __name__ == "__main__": check_env() - + # start log broker log_broker = LogBroker() LogManager.set_queue_handler(logger, log_broker) - + # check dashboard files asyncio.run(check_dashboard_files()) - + db = db_helper - + # print logo logger.info(logo_tmpl) - + dashboard_lifecycle = AstrBotDashBoardLifecycle(db, log_broker) - asyncio.run(dashboard_lifecycle.start()) \ No newline at end of file + asyncio.run(dashboard_lifecycle.start()) diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index 6d5e0f63f..2e7b20a80 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -9,9 +9,11 @@ from astrbot.api.message_components import Plain, Image from astrbot import logger from collections import defaultdict -''' +""" 聊天记忆增强 -''' +""" + + class LongTermMemory: def __init__(self, config: dict, context: star.Context): self.config = config @@ -26,16 +28,16 @@ class LongTermMemory: self.image_caption = self.config["image_caption"] self.image_caption_prompt = self.config["image_caption_prompt"] self.image_caption_provider_id = self.config["image_caption_provider_id"] - + self.active_reply = self.config["active_reply"] self.enable_active_reply = self.active_reply.get("enable", False) self.ar_method = self.active_reply["method"] self.ar_possibility = self.active_reply["possibility_reply"] self.ar_prompt = self.active_reply.get("prompt", "") self.ar_whitelist = self.active_reply.get("whitelist", []) - + # self.put_history_to_prompt = self.config["put_history_to_prompt"] - + async def remove_session(self, event: AstrMessageEvent) -> int: cnt = 0 if event.unified_msg_origin in self.session_chats: @@ -44,13 +46,14 @@ class LongTermMemory: return cnt async def get_image_caption(self, image_url: str) -> str: - if not self.image_caption_provider_id: provider = self.context.get_using_provider() else: provider = self.context.get_provider_by_id(self.image_caption_provider_id) if not provider: - raise Exception(f"没有找到 ID 为 {self.image_caption_provider_id} 的提供商") + raise Exception( + f"没有找到 ID 为 {self.image_caption_provider_id} 的提供商" + ) response = await provider.text_chat( prompt=self.image_caption_prompt, session_id=uuid.uuid4().hex, @@ -58,17 +61,17 @@ class LongTermMemory: persist=False, ) return response.completion_text - + async def need_active_reply(self, event: AstrMessageEvent) -> bool: if not self.enable_active_reply: return False if event.get_message_type() != MessageType.GROUP_MESSAGE: return False - + if event.is_at_or_wake_command: # if the message is a command, let it pass return False - + if self.ar_whitelist and ( event.unified_msg_origin not in self.ar_whitelist and (event.get_group_id() and event.get_group_id() not in self.ar_whitelist) @@ -79,12 +82,11 @@ class LongTermMemory: case "possibility_reply": trig = random.random() < self.ar_possibility return trig - + return False - async def handle_message(self, event: AstrMessageEvent): - '''仅支持群聊''' + """仅支持群聊""" if event.get_message_type() == MessageType.GROUP_MESSAGE: datetime_str = datetime.datetime.now().strftime("%H:%M:%S") @@ -111,21 +113,23 @@ class LongTermMemory: self.session_chats[event.unified_msg_origin].pop(0) async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): - '''当触发 LLM 请求前,调用此方法修改 req''' + """当触发 LLM 请求前,调用此方法修改 req""" if event.unified_msg_origin not in self.session_chats: return - - chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin]) - + + chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) + if self.enable_active_reply: prompt = req.prompt req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}" req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information." - req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 + req.contexts = [] # 清空上下文,当使用了主动回复,所有聊天记录都在一个prompt中。 else: - req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n" + req.system_prompt += ( + "You are now in a chatroom. The chat history is as follows: \n" + ) req.system_prompt += chats_str - + async def after_req_llm(self, event: AstrMessageEvent): if event.unified_msg_origin not in self.session_chats: return diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 1882c07ce..34712fbd8 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -21,39 +21,54 @@ from astrbot.api.message_components import Plain, Image from typing import Union -@star.register(name="astrbot", desc="AstrBot 基础指令结合 + 拓展功能", author="Soulter", version="4.0.0") + +@star.register( + name="astrbot", + desc="AstrBot 基础指令结合 + 拓展功能", + author="Soulter", + version="4.0.0", +) class Main(star.Star): def __init__(self, context: star.Context) -> None: self.context = context cfg = context.get_config() - self.prompt_prefix = cfg['provider_settings']['prompt_prefix'] - self.identifier = cfg['provider_settings']['identifier'] - self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"] - + self.prompt_prefix = cfg["provider_settings"]["prompt_prefix"] + self.identifier = cfg["provider_settings"]["identifier"] + self.enable_datetime = cfg["provider_settings"]["datetime_system_prompt"] + self.ltm = None - if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']: + if ( + self.context.get_config()["provider_ltm_settings"]["group_icl_enable"] + or self.context.get_config()["provider_ltm_settings"]["active_reply"][ + "enable" + ] + ): try: - self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context) + self.ltm = LongTermMemory( + self.context.get_config()["provider_ltm_settings"], self.context + ) except BaseException as e: logger.error(f"聊天增强 err: {e}") - + async def _query_astrbot_notice(self): try: async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get("https://astrbot.app/notice.json", timeout=2) as resp: + async with session.get( + "https://astrbot.app/notice.json", timeout=2 + ) as resp: return (await resp.json())["notice"] except BaseException: return "" - + @filter.command("help") async def help(self, event: AstrMessageEvent): - '''查看帮助''' + """查看帮助""" notice = "" try: notice = await self._query_astrbot_notice() except BaseException: pass - + dashboard_version = await get_dashboard_version() msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version}) @@ -87,82 +102,108 @@ AstrBot 指令: {notice}""" event.set_result(MessageEventResult().message(msg).use_t2i(False)) - + @filter.command_group("tool") def tool(self): pass - + @tool.command("ls") async def tool_ls(self, event: AstrMessageEvent): - '''查看函数工具列表''' + """查看函数工具列表""" tm = self.context.get_llm_tool_manager() msg = "函数工具:\n" for tool in tm.func_list: active = " (启用)" if tool.active else "(停用)" msg += f"- {tool.name}: {tool.description} {active}\n" - + msg += "\n使用 /tool on/off <工具名> 激活或者停用函数工具。/tool off_all 停用所有函数工具。" event.set_result(MessageEventResult().message(msg).use_t2i(False)) - + @tool.command("on") async def tool_on(self, event: AstrMessageEvent, tool_name: str): - '''启用一个函数工具''' + """启用一个函数工具""" if self.context.activate_llm_tool(tool_name): - event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 成功。")) + event.set_result( + MessageEventResult().message(f"激活工具 {tool_name} 成功。") + ) else: - event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 失败,未找到此工具。")) - + event.set_result( + MessageEventResult().message( + f"激活工具 {tool_name} 失败,未找到此工具。" + ) + ) + @tool.command("off") async def tool_off(self, event: AstrMessageEvent, tool_name: str): - '''停用一个函数工具''' + """停用一个函数工具""" if self.context.deactivate_llm_tool(tool_name): - event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。")) + event.set_result( + MessageEventResult().message(f"停用工具 {tool_name} 成功。") + ) else: - event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。")) - + event.set_result( + MessageEventResult().message( + f"停用工具 {tool_name} 失败,未找到此工具。" + ) + ) + @tool.command("off_all") async def tool_all_off(self, event: AstrMessageEvent): - '''停用所有函数工具''' + """停用所有函数工具""" tm = self.context.get_llm_tool_manager() for tool in tm.func_list: self.context.deactivate_llm_tool(tool.name) event.set_result(MessageEventResult().message("停用所有工具成功。")) @filter.command("plugin") - async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None): - '''插件管理''' + async def plugin( + self, event: AstrMessageEvent, oper1: str = None, oper2: str = None + ): + """插件管理""" if oper1 is None: plugin_list_info = "已加载的插件:\n" for plugin in self.context.get_all_stars(): - plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n" + plugin_list_info += ( + f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n" + ) if plugin_list_info.strip() == "": plugin_list_info = "没有加载任何插件。" - + plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。" - event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False)) + event.set_result( + MessageEventResult().message(f"{plugin_list_info}").use_t2i(False) + ) else: if oper1 == "off": # 禁用插件 if oper2 is None: - event.set_result(MessageEventResult().message("/plugin off <插件名> 禁用插件。")) + event.set_result( + MessageEventResult().message("/plugin off <插件名> 禁用插件。") + ) return await self.context._star_manager.turn_off_plugin(oper2) event.set_result(MessageEventResult().message(f"插件 {oper2} 已禁用。")) elif oper1 == "on": # 启用插件 if oper2 is None: - event.set_result(MessageEventResult().message("/plugin on <插件名> 启用插件。")) + event.set_result( + MessageEventResult().message("/plugin on <插件名> 启用插件。") + ) return await self.context._star_manager.turn_on_plugin(oper2) event.set_result(MessageEventResult().message(f"插件 {oper2} 已启用。")) - + else: # 获取插件帮助 plugin = self.context.get_registered_star(oper1) if plugin is None: event.set_result(MessageEventResult().message("未找到此插件。")) return - help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "帮助信息: 未提供" + help_msg = ( + plugin.star_cls.__doc__ + if plugin.star_cls.__doc__ + else "帮助信息: 未提供" + ) help_msg += f"\n\n作者: {plugin.author}\n版本: {plugin.version}" command_handlers = [] command_names = [] @@ -178,47 +219,47 @@ AstrBot 指令: elif isinstance(filter_, CommandGroupFilter): command_handlers.append(handler) command_names.append(filter_.group_name) - + if len(command_handlers) > 0: help_msg += "\n\n指令列表:\n" for i in range(len(command_handlers)): help_msg += f"{command_names[i]}: {command_handlers[i].desc}\n" - + help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。" - + ret = f"插件 {oper1} 帮助信息:\n" + help_msg ret += "更多帮助信息请查看插件仓库 README。" event.set_result(MessageEventResult().message(ret).use_t2i(False)) @filter.command("t2i") async def t2i(self, event: AstrMessageEvent): - '''开关文本转图片''' + """开关文本转图片""" config = self.context.get_config() - if config['t2i']: - config['t2i'] = False + if config["t2i"]: + config["t2i"] = False config.save_config() event.set_result(MessageEventResult().message("已关闭文本转图片模式。")) return - config['t2i'] = True + config["t2i"] = True config.save_config() event.set_result(MessageEventResult().message("已开启文本转图片模式。")) @filter.command("tts") async def tts(self, event: AstrMessageEvent): - '''开关文本转语音''' + """开关文本转语音""" config = self.context.get_config() - if config['provider_tts_settings']['enable']: - config['provider_tts_settings']['enable'] = False + if config["provider_tts_settings"]["enable"]: + config["provider_tts_settings"]["enable"] = False config.save_config() event.set_result(MessageEventResult().message("已关闭文本转语音。")) return - config['provider_tts_settings']['enable'] = True + config["provider_tts_settings"]["enable"] = True config.save_config() event.set_result(MessageEventResult().message("已开启文本转语音。")) - + @filter.command("sid") async def sid(self, event: AstrMessageEvent): - '''获取会话 ID 和 管理员 ID''' + """获取会话 ID 和 管理员 ID""" sid = event.unified_msg_origin user_id = str(event.get_sender_id()) ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl 添加白名单, /dwl 删除白名单。 @@ -228,50 +269,56 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("op") async def op(self, event: AstrMessageEvent, admin_id: str): - '''授权管理员。op ''' - self.context.get_config()['admins_id'].append(admin_id) + """授权管理员。op """ + self.context.get_config()["admins_id"].append(admin_id) self.context.get_config().save_config() event.set_result(MessageEventResult().message("授权成功。")) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("deop") async def deop(self, event: AstrMessageEvent, admin_id: str): - '''取消授权管理员。deop ''' + """取消授权管理员。deop """ try: - self.context.get_config()['admins_id'].remove(admin_id) + self.context.get_config()["admins_id"].remove(admin_id) self.context.get_config().save_config() event.set_result(MessageEventResult().message("取消授权成功。")) except ValueError: - event.set_result(MessageEventResult().message("此用户 ID 不在管理员名单内。")) + event.set_result( + MessageEventResult().message("此用户 ID 不在管理员名单内。") + ) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("wl") async def wl(self, event: AstrMessageEvent, sid: str): - '''添加白名单。wl ''' - self.context.get_config()['platform_settings']['id_whitelist'].append(sid) + """添加白名单。wl """ + self.context.get_config()["platform_settings"]["id_whitelist"].append(sid) self.context.get_config().save_config() event.set_result(MessageEventResult().message("添加白名单成功。")) - + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dwl") async def dwl(self, event: AstrMessageEvent, sid: str): - '''删除白名单。dwl ''' + """删除白名单。dwl """ try: - self.context.get_config()['platform_settings']['id_whitelist'].remove(sid) + self.context.get_config()["platform_settings"]["id_whitelist"].remove(sid) self.context.get_config().save_config() event.set_result(MessageEventResult().message("删除白名单成功。")) except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) @filter.command("provider") - async def provider(self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None): - '''查看或者切换 LLM Provider''' - + async def provider( + self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None + ): + """查看或者切换 LLM Provider""" + if not self.context.get_using_provider(): - event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + event.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) return - - if idx is None: + + if idx is None: ret = "## 载入的 LLM 提供商\n" for idx, llm in enumerate(self.context.get_all_providers()): id_ = llm.meta().id @@ -279,7 +326,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo if self.context.get_using_provider().meta().id == id_: ret += " (当前使用)" ret += "\n" - + tts_providers = self.context.get_all_tts_providers() if tts_providers: ret += "\n## 载入的 TTS 提供商\n" @@ -290,7 +337,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo if tts_using and tts_using.meta().id == id_: ret += " (当前使用)" ret += "\n" - + stt_providers = self.context.get_all_stt_providers() if stt_providers: ret += "\n## 载入的 STT 提供商\n" @@ -303,12 +350,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo ret += "\n" ret += "\n使用 /provider <序号> 切换 LLM 提供商。" - + if tts_providers: ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" if stt_providers: - ret += "\n使用 /provider stt <切换> STT 提供商。" - + ret += "\n使用 /provider stt <切换> STT 提供商。" + event.set_result(MessageEventResult().message(ret)) else: if idx == "tts": @@ -324,7 +371,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo sp.put("curr_provider_tts", id_) if not self.context.provider_manager.tts_enabled: self.context.provider_manager.tts_enabled = True - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + event.set_result( + MessageEventResult().message(f"成功切换到 {id_}。") + ) elif idx == "stt": if idx2 is None: event.set_result(MessageEventResult().message("请输入序号。")) @@ -338,7 +387,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo sp.put("curr_provider_stt", id_) if not self.context.provider_manager.stt_enabled: self.context.provider_manager.stt_enabled = True - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + event.set_result( + MessageEventResult().message(f"成功切换到 {id_}。") + ) elif isinstance(idx, int): if idx > len(self.context.get_all_providers()) or idx < 1: event.set_result(MessageEventResult().message("无效的序号。")) @@ -355,65 +406,90 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("reset") async def reset(self, message: AstrMessageEvent): - '''重置 LLM 会话''' - is_unique_session = self.context.get_config()['platform_settings']['unique_session'] + """重置 LLM 会话""" + is_unique_session = self.context.get_config()["platform_settings"][ + "unique_session" + ] if message.get_group_id() and not is_unique_session and message.role != "admin": # 群聊,没开独立会话,发送人不是管理员 - message.set_result(MessageEventResult().message(f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限重置当前对话。")) + message.set_result( + MessageEventResult().message( + f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限重置当前对话。" + ) + ) return - + if not self.context.get_using_provider(): - message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) return - + provider = self.context.get_using_provider() - if provider and provider.meta().type == 'dify': + if provider and provider.meta().type == "dify": assert isinstance(provider, ProviderDify) await provider.forget(message.unified_msg_origin) - message.set_result(MessageEventResult().message("已重置当前 Dify 会话,新聊天将更换到新的会话。")) + message.set_result( + MessageEventResult().message( + "已重置当前 Dify 会话,新聊天将更换到新的会话。" + ) + ) return - - cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) - + + cid = await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) + if not cid: - message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。")) + message.set_result( + MessageEventResult().message( + "当前未处于对话状态,请 /switch 切换或者 /new 创建。" + ) + ) return - + await self.context.conversation_manager.update_conversation( message.unified_msg_origin, cid, [] ) - + ret = "清除会话 LLM 聊天历史成功。" if self.ltm: cnt = await self.ltm.remove_session(event=message) ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。" - + message.set_result(MessageEventResult().message(ret)) @filter.command("model") - async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None): - '''查看或者切换模型''' + async def model_ls( + self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None + ): + """查看或者切换模型""" if not self.context.get_using_provider(): - message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) return - - + if idx_or_name is None: models = [] try: models = await self.context.get_using_provider().get_models() except BaseException as e: - message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)).use_t2i(False)) + message.set_result( + MessageEventResult() + .message("获取模型列表失败: " + str(e)) + .use_t2i(False) + ) return i = 1 ret = "下面列出了此服务提供商可用模型:" for model in models: ret += f"\n{i}. {model}" i += 1 - + curr_model = self.context.get_using_provider().get_model() or "无" ret += f"\n当前模型: [{curr_model}]" - + ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" message.set_result(MessageEventResult().message(ret).use_t2i(False)) else: @@ -422,38 +498,57 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo try: models = await self.context.get_using_provider().get_models() except BaseException as e: - message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e))) + message.set_result( + MessageEventResult().message("获取模型列表失败: " + str(e)) + ) return if idx_or_name > len(models) or idx_or_name < 1: message.set_result(MessageEventResult().message("模型序号错误。")) else: try: - new_model = models[idx_or_name-1] + new_model = models[idx_or_name - 1] self.context.get_using_provider().set_model(new_model) except BaseException as e: message.set_result( - MessageEventResult().message("切换模型未知错误: "+str(e))) + MessageEventResult().message("切换模型未知错误: " + str(e)) + ) message.set_result(MessageEventResult().message("切换模型成功。")) else: self.context.get_using_provider().set_model(idx_or_name) message.set_result( - MessageEventResult().message(f"切换模型到 {self.context.get_using_provider().get_model()}。")) - + MessageEventResult().message( + f"切换模型到 {self.context.get_using_provider().get_model()}。" + ) + ) + @filter.command("history") async def his(self, message: AstrMessageEvent, page: int = 1): - '''查看对话记录''' + """查看对话记录""" if not self.context.get_using_provider(): - message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) return - + size_per_page = 6 - session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) - + session_curr_cid = ( + await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) + ) + if not session_curr_cid: - message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 序号 切换或者 /new 创建。")) + message.set_result( + MessageEventResult().message( + "当前未处于对话状态,请 /switch 序号 切换或者 /new 创建。" + ) + ) return - - contexts, total_pages = await self.context.conversation_manager.get_human_readable_context( + + ( + contexts, + total_pages, + ) = await self.context.conversation_manager.get_human_readable_context( message.unified_msg_origin, session_curr_cid, page, size_per_page ) @@ -462,7 +557,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo if len(context) > 150: context = context[:150] + "..." history += f"{context}\n" - + ret = f"""当前对话历史记录: {history} 第 {page} 页 | 共 {total_pages} 页 @@ -474,17 +569,19 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo @filter.command("ls") async def convs(self, message: AstrMessageEvent, page: int = 1): - '''查看对话列表''' - + """查看对话列表""" + provider = self.context.get_using_provider() - if provider and provider.meta().type == 'dify': + if provider and provider.meta().type == "dify": """原有的Dify处理逻辑保持不变""" ret = "Dify 对话列表:\n" assert isinstance(provider, ProviderDify) data = await provider.api_client.get_chat_convs(message.unified_msg_origin) idx = 1 - for conv in data['data']: - ts_h = datetime.datetime.fromtimestamp(conv['updated_at']).strftime('%m-%d %H:%M') + for conv in data["data"]: + ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime( + "%m-%d %H:%M" + ) ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" idx += 1 if idx == 1: @@ -496,7 +593,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo size_per_page = 6 """获取所有对话列表""" - conversations_all = await self.context.conversation_manager.get_conversations(message.unified_msg_origin) + conversations_all = await self.context.conversation_manager.get_conversations( + message.unified_msg_origin + ) """计算总页数""" total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page """确保页码有效""" @@ -505,110 +604,144 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo start_idx = (page - 1) * size_per_page end_idx = start_idx + size_per_page conversations_paged = conversations_all[start_idx:end_idx] - + ret = "对话列表:\n---\n" """全局序号从当前页的第一个开始""" - global_index = start_idx + 1 - + global_index = start_idx + 1 + """生成所有对话的标题字典""" _titles = {} for conv in conversations_all: persona_id = conv.persona_id if not persona_id or persona_id == "[%None]": - persona_id = self.context.provider_manager.selected_default_persona['name'] + persona_id = self.context.provider_manager.selected_default_persona[ + "name" + ] title = conv.title if conv.title else "新对话" _titles[conv.cid] = title - + """遍历分页后的对话生成列表显示""" for conv in conversations_paged: persona_id = conv.persona_id if not persona_id or persona_id == "[%None]": - persona_id = self.context.provider_manager.selected_default_persona['name'] + persona_id = self.context.provider_manager.selected_default_persona[ + "name" + ] title = _titles.get(conv.cid, "新对话") ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n" global_index += 1 - + ret += "---\n" - curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) + curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) if curr_cid: """从所有对话的标题字典中获取标题""" title = _titles.get(curr_cid, "新对话") ret += f"\n当前对话: {title}({curr_cid[:4]})" else: ret += "\n当前对话: 无" - - unique_session = self.context.get_config()['platform_settings']['unique_session'] + + unique_session = self.context.get_config()["platform_settings"][ + "unique_session" + ] if unique_session: ret += "\n会话隔离粒度: 个人" else: ret += "\n会话隔离粒度: 群聊" - + ret += f"\n第 {page} 页 | 共 {total_pages} 页" ret += "\n*输入 /ls 2 跳转到第 2 页" - + message.set_result(MessageEventResult().message(ret).use_t2i(False)) return - + @filter.command("new") async def new_conv(self, message: AstrMessageEvent): - '''创建新对话''' + """创建新对话""" provider = self.context.get_using_provider() - if provider and provider.meta().type == 'dify': + if provider and provider.meta().type == "dify": assert isinstance(provider, ProviderDify) await provider.forget(message.unified_msg_origin) - message.set_result(MessageEventResult().message("成功,下次聊天将是新对话。")) + message.set_result( + MessageEventResult().message("成功,下次聊天将是新对话。") + ) return - cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin) - message.set_result(MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。")) - + cid = await self.context.conversation_manager.new_conversation( + message.unified_msg_origin + ) + message.set_result( + MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。") + ) + @filter.command("switch") async def switch_conv(self, message: AstrMessageEvent, index: int = None): - '''通过 /ls 前面的序号切换对话''' + """通过 /ls 前面的序号切换对话""" if not isinstance(index, int): - message.set_result(MessageEventResult().message("类型错误,请输入数字对话序号。")) + message.set_result( + MessageEventResult().message("类型错误,请输入数字对话序号。") + ) return provider = self.context.get_using_provider() - if provider and provider.meta().type == 'dify': + if provider and provider.meta().type == "dify": assert isinstance(provider, ProviderDify) data = await provider.api_client.get_chat_convs(message.unified_msg_origin) - if not data['data']: + if not data["data"]: message.set_result(MessageEventResult().message("未找到任何对话。")) return selected_conv = None if index is not None: try: - selected_conv = data['data'][index-1] + selected_conv = data["data"][index - 1] except IndexError: - message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看")) + message.set_result( + MessageEventResult().message("对话序号错误,请使用 /ls 查看") + ) return else: - selected_conv = data['data'][0] - ret = f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。" - provider.conversation_ids[message.unified_msg_origin] = selected_conv['id'] + selected_conv = data["data"][0] + ret = ( + f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。" + ) + provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"] message.set_result(MessageEventResult().message(ret)) return - + if index is None: - message.set_result(MessageEventResult().message("请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话")) + message.set_result( + MessageEventResult().message( + "请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话" + ) + ) return - conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin) + conversations = await self.context.conversation_manager.get_conversations( + message.unified_msg_origin + ) if index > len(conversations) or index < 1: - message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看")) + message.set_result( + MessageEventResult().message("对话序号错误,请使用 /ls 查看") + ) else: - conversation = conversations[index-1] + conversation = conversations[index - 1] title = conversation.title if conversation.title else "新对话" - await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid) - message.set_result(MessageEventResult().message(f"切换到对话: {title}({conversation.cid[:4]})。")) - + await self.context.conversation_manager.switch_conversation( + message.unified_msg_origin, conversation.cid + ) + message.set_result( + MessageEventResult().message( + f"切换到对话: {title}({conversation.cid[:4]})。" + ) + ) + @filter.command("rename") async def rename_conv(self, message: AstrMessageEvent, new_name: str): - '''重命名对话''' + """重命名对话""" provider = self.context.get_using_provider() - - if provider and provider.meta().type == 'dify': + + if provider and provider.meta().type == "dify": assert isinstance(provider, ProviderDify) cid = provider.conversation_ids.get(message.unified_msg_origin, None) if not cid: @@ -617,50 +750,77 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo await provider.api_client.rename(cid, new_name, message.unified_msg_origin) message.set_result(MessageEventResult().message("重命名对话成功。")) return - - await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name) + + await self.context.conversation_manager.update_conversation_title( + message.unified_msg_origin, new_name + ) message.set_result(MessageEventResult().message("重命名对话成功。")) - + @filter.command("del") async def del_conv(self, message: AstrMessageEvent): - '''删除当前对话''' - is_unique_session = self.context.get_config()['platform_settings']['unique_session'] + """删除当前对话""" + is_unique_session = self.context.get_config()["platform_settings"][ + "unique_session" + ] if message.get_group_id() and not is_unique_session and message.role != "admin": # 群聊,没开独立会话,发送人不是管理员 - message.set_result(MessageEventResult().message(f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。")) + message.set_result( + MessageEventResult().message( + f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。" + ) + ) return - + provider = self.context.get_using_provider() - if provider and provider.meta().type == 'dify': + if provider and provider.meta().type == "dify": assert isinstance(provider, ProviderDify) await provider.api_client.delete_chat_conv(message.unified_msg_origin) provider.conversation_ids.pop(message.unified_msg_origin, None) - message.set_result(MessageEventResult().message("删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。")) + message.set_result( + MessageEventResult().message( + "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" + ) + ) return - - session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) - + + session_curr_cid = ( + await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) + ) + if not session_curr_cid: - message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 序号 切换或 /new 创建。")) + message.set_result( + MessageEventResult().message( + "当前未处于对话状态,请 /switch 序号 切换或 /new 创建。" + ) + ) return - - await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid) - message.set_result(MessageEventResult().message("删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。")) - + + await self.context.conversation_manager.delete_conversation( + message.unified_msg_origin, session_curr_cid + ) + message.set_result( + MessageEventResult().message( + "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" + ) + ) + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") - async def key(self, message: AstrMessageEvent, index: int=None): - + async def key(self, message: AstrMessageEvent, index: int = None): if not self.context.get_using_provider(): - message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")) + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) return - + if index is None: keys_data = self.context.get_using_provider().get_keys() curr_key = self.context.get_using_provider().get_current_key() ret = "Key:" for i, k in enumerate(keys_data): - ret += f"\n{i+1}. {k[:8]}" + ret += f"\n{i + 1}. {k[:8]}" ret += f"\n当前 Key: {curr_key[:8]}" ret += "\n当前模型: " + self.context.get_using_provider().get_model() @@ -673,60 +833,73 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo message.set_result(MessageEventResult().message("Key 序号错误。")) else: try: - new_key = keys_data[index-1] + new_key = keys_data[index - 1] self.context.get_using_provider().set_key(new_key) except BaseException as e: message.set_result( - MessageEventResult().message("切换 Key 未知错误: "+str(e))) + MessageEventResult().message("切换 Key 未知错误: " + str(e)) + ) message.set_result(MessageEventResult().message("切换 Key 成功。")) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") async def persona(self, message: AstrMessageEvent): - l = message.message_str.split(" ") - + l = message.message_str.split(" ") # noqa: E741 + curr_persona_name = "无" - cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin) + cid = await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) curr_cid_title = "无" if cid: - conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid) + conversation = await self.context.conversation_manager.get_conversation( + message.unified_msg_origin, cid + ) if not conversation.persona_id and not conversation.persona_id == "[%None]": - curr_persona_name = self.context.provider_manager.selected_default_persona['name'] + curr_persona_name = ( + self.context.provider_manager.selected_default_persona["name"] + ) else: curr_persona_name = conversation.persona_id - + curr_cid_title = conversation.title if conversation.title else "新对话" curr_cid_title += f"({cid[:4]})" - + if len(l) == 1: message.set_result( - MessageEventResult().message(f"""[Persona] + MessageEventResult() + .message(f"""[Persona] - 人格情景列表: `/persona list` - 设置人格情景: `/persona 人格` - 人格情景详细信息: `/persona view 人格` - 取消人格: `/persona unset` -默认人格情景: {self.context.provider_manager.selected_default_persona['name']} +默认人格情景: {self.context.provider_manager.selected_default_persona["name"]} 当前对话 {curr_cid_title} 的人格情景: {curr_persona_name} 配置人格情景请前往管理面板-配置页 -""").use_t2i(False)) +""") + .use_t2i(False) + ) elif l[1] == "list": msg = "人格列表:\n" for persona in self.context.provider_manager.personas: msg += f"- {persona['name']}\n" - msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息' + msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息" message.set_result(MessageEventResult().message(msg)) elif l[1] == "view": if len(l) == 2: message.set_result(MessageEventResult().message("请输入人格情景名")) return ps = l[2].strip() - if persona := next(builtins.filter( - lambda persona: persona['name'] == ps, - self.context.provider_manager.personas - ), None): + if persona := next( + builtins.filter( + lambda persona: persona["name"] == ps, + self.context.provider_manager.personas, + ), + None, + ): msg = f"人格{ps}的详细信息:\n" msg += f"{persona['prompt']}\n" else: @@ -734,21 +907,38 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo message.set_result(MessageEventResult().message(msg)) elif l[1] == "unset": if not cid: - message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。")) + message.set_result( + MessageEventResult().message("当前没有对话,无法取消人格。") + ) return - await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]") + await self.context.conversation_manager.update_conversation_persona_id( + message.unified_msg_origin, "[%None]" + ) message.set_result(MessageEventResult().message("取消人格成功。")) else: ps = "".join(l[1:]).strip() - if persona := next(builtins.filter( - lambda persona: persona['name'] == ps, - self.context.provider_manager.personas - ), None): - await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps) - message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。")) + if persona := next( + builtins.filter( + lambda persona: persona["name"] == ps, + self.context.provider_manager.personas, + ), + None, + ): + await self.context.conversation_manager.update_conversation_persona_id( + message.unified_msg_origin, ps + ) + message.set_result( + MessageEventResult().message( + "设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。" + ) + ) else: - message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。")) - + message.set_result( + MessageEventResult().message( + "不存在该人格情景。使用 /persona list 查看所有。" + ) + ) + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") async def update_dashboard(self, event: AstrMessageEvent): @@ -761,30 +951,30 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo # session_id = event.get_session_id() uid = event.unified_msg_origin session_vars = sp.get("session_variables", {}) - + session_var = session_vars.get(uid, {}) session_var[key] = value - + session_vars[uid] = session_var - + sp.put("session_variables", session_vars) - + yield event.plain_result(f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。") - + @filter.command("unset") async def unset_variable(self, event: AstrMessageEvent, key: str): uid = event.unified_msg_origin session_vars = sp.get("session_variables", {}) - + session_var = session_vars.get(uid, {}) - + if key not in session_var: yield event.plain_result("没有那个变量名。格式 /unset 变量名。") else: del session_var[key] sp.put("session_variables", session_vars) yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。") - + @filter.command("gewe_logout") async def gewe_logout(self, event: AstrMessageEvent): platforms = self.context.platform_manager.platform_insts @@ -794,38 +984,39 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo await platform.logout() yield event.plain_result("已登出 gewechat,请重启 AstrBot") return - - + @filter.command("gewe_code") async def gewe_code(self, event: AstrMessageEvent, code: str): - '''保存 gewechat 验证码''' - with open("data/temp/gewe_code", "w", encoding='utf-8') as f: + """保存 gewechat 验证码""" + with open("data/temp/gewe_code", "w", encoding="utf-8") as f: f.write(code) yield event.plain_result("验证码已保存。") - + @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) async def on_message(self, event: AstrMessageEvent): - '''群聊记忆增强''' - + """群聊记忆增强""" + has_image_or_plain = False for comp in event.message_obj.message: if isinstance(comp, Plain) or isinstance(comp, Image): has_image_or_plain = True break - + if self.ltm and has_image_or_plain: need_active = await self.ltm.need_active_reply(event) - - group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable'] + + group_icl_enable = self.context.get_config()["provider_ltm_settings"][ + "group_icl_enable" + ] if group_icl_enable: - '''记录对话''' + """记录对话""" try: await self.ltm.handle_message(event) except BaseException as e: logger.error(e) - + if need_active: - '''主动回复''' + """主动回复""" provider = self.context.get_using_provider() if not provider: logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复") @@ -833,32 +1024,39 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo try: conv = None history = [] - if provider.meta().type != 'dify': + if provider.meta().type != "dify": # Dify 自己有维护对话,不需要 bot 端维护。 - session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin) - + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + event.unified_msg_origin + ) + if not session_curr_cid: - logger.error("当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。") + logger.error( + "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。" + ) return - + conv = await self.context.conversation_manager.get_conversation( - event.unified_msg_origin, - session_curr_cid + event.unified_msg_origin, session_curr_cid ) history = [] if conv: history = json.loads(conv.history) else: assert isinstance(provider, ProviderDify) - cid = provider.conversation_ids.get(event.unified_msg_origin, None) + cid = provider.conversation_ids.get( + event.unified_msg_origin, None + ) if cid is None: - logger.error("[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。") + logger.error( + "[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。" + ) return - + prompt = self.ltm.ar_prompt if not prompt: prompt = event.message_str - + yield event.request_llm( prompt=prompt, func_tool_manager=self.context.get_llm_tool_manager(), @@ -868,75 +1066,80 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo ) except BaseException as e: logger.error(f"主动回复失败: {e}") - - + @filter.on_llm_request() async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): - '''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt''' + """在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt""" if self.prompt_prefix: req.prompt = self.prompt_prefix + req.prompt - + if self.identifier: user_id = event.message_obj.sender.user_id user_nickname = event.message_obj.sender.nickname user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n" req.prompt = user_info + req.prompt - + if self.enable_datetime: tz_offset = datetime.timedelta(hours=8) tz = datetime.timezone(tz_offset) - current_time = datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M') + current_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M") req.system_prompt += f"\nCurrent datetime: {current_time}\n" - + if req.conversation: persona_id = req.conversation.persona_id - if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格 - persona_id = self.context.provider_manager.selected_default_persona['name'] - persona = next(builtins.filter( - lambda persona: persona['name'] == persona_id, - self.context.provider_manager.personas - ), None) + if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格 + persona_id = self.context.provider_manager.selected_default_persona[ + "name" + ] + persona = next( + builtins.filter( + lambda persona: persona["name"] == persona_id, + self.context.provider_manager.personas, + ), + None, + ) if persona: - if prompt := persona['prompt']: + if prompt := persona["prompt"]: req.system_prompt += prompt - if mood_dialogs := persona['_mood_imitation_dialogs_processed']: + if mood_dialogs := persona["_mood_imitation_dialogs_processed"]: req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" req.system_prompt += mood_dialogs if begin_dialogs := persona["_begin_dialogs_processed"]: req.contexts[:0] = begin_dialogs - + if self.ltm: try: await self.ltm.on_req_llm(event, req) except BaseException as e: logger.error(f"ltm: {e}") - @filter.after_message_sent() async def after_llm_req(self, event: AstrMessageEvent): - '''在 LLM 请求后记录对话''' + """在 LLM 请求后记录对话""" if self.ltm: try: await self.ltm.after_req_llm(event) except BaseException as e: logger.error(f"ltm: {e}") - + @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("alter_cmd") async def alter_cmd(self, event: AstrMessageEvent): # token = event.message_str.split(" ") token = self.parse_commands(event.message_str) if token.len < 2: - yield event.plain_result("可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd \n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令") + yield event.plain_result( + "可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd \n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令" + ) return - + cmd_name = token.get(1) cmd_type = token.get(2) - + if cmd_type not in ["admin", "member"]: yield event.plain_result("指令类型错误,可选类型有 admin, member") return - + # 查找指令 found_command = None for handler in star_handlers_registry: @@ -950,22 +1153,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo if cmd_name == filter_.group_name: found_command = handler break - + if not found_command: yield event.plain_result("未找到该指令") return - + found_plugin = star_map[found_command.handler_module_path] - + alter_cmd_cfg = sp.get("alter_cmd", {}) plugin_ = alter_cmd_cfg.get(found_plugin.name, {}) cfg = plugin_.get(found_command.handler_name, {}) cfg["permission"] = cmd_type plugin_[found_command.handler_name] = cfg alter_cmd_cfg[found_plugin.name] = plugin_ - + sp.put("alter_cmd", alter_cmd_cfg) - + # 注入权限过滤器 found_permission_filter = False for filter_ in found_command.event_filters: @@ -977,6 +1180,13 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo found_permission_filter = True break if not found_permission_filter: - found_command.event_filters.insert(0, PermissionTypeFilter(filter.PermissionType.ADMIN if cmd_type == "admin" else filter.PermissionType.MEMBER)) + found_command.event_filters.insert( + 0, + PermissionTypeFilter( + filter.PermissionType.ADMIN + if cmd_type == "admin" + else filter.PermissionType.MEMBER + ), + ) yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令") diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 42f394867..aaeba473f 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -86,19 +86,26 @@ send_text("If you need more information, please let me know :)") DEFAULT_CONFIG = { "sandbox": { "image": "soulter/astrbot-code-interpreter-sandbox", - "docker_mirror": "", # cjie.eu.org + "docker_mirror": "", # cjie.eu.org }, - "docker_host_astrbot_abs_path": "" + "docker_host_astrbot_abs_path": "", } PATH = "data/config/python_interpreter.json" -@star.register(name="astrbot-python-interpreter", desc="Python 代码执行器", author="Soulter", version="0.0.1") + +@star.register( + name="astrbot-python-interpreter", + desc="Python 代码执行器", + author="Soulter", + version="0.0.1", +) class Main(star.Star): - '''基于 Docker 沙箱的 Python 代码执行器''' + """基于 Docker 沙箱的 Python 代码执行器""" + def __init__(self, context: star.Context) -> None: self.context = context self.curr_dir = os.path.dirname(os.path.abspath(__file__)) - + self.shared_path = os.path.join("data", "py_interpreter_shared") if not os.path.exists(self.shared_path): # 复制 api.py 到 shared 目录 @@ -107,12 +114,12 @@ class Main(star.Star): shutil.copy(shared_api_file, self.shared_path) self.workplace_path = os.path.join("data", "py_interpreter_workplace") os.makedirs(self.workplace_path, exist_ok=True) - + self.user_file_msg_buffer = defaultdict(list) - '''存放用户上传的文件和图片''' + """存放用户上传的文件和图片""" self.user_waiting = {} - '''正在等待用户的文件或图片''' - + """正在等待用户的文件或图片""" + # 加载配置 if not os.path.exists(PATH): self.config = DEFAULT_CONFIG @@ -120,33 +127,38 @@ class Main(star.Star): else: with open(PATH, "r") as f: self.config = json.load(f) - + async def initialize(self): ok = await self.is_docker_available() if not ok: - logger.info("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。") - await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter") - + logger.info( + "Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。" + ) + await self.context._star_manager.turn_off_plugin( + "astrbot-python-interpreter" + ) + async def file_upload(self, file_path: str): - ''' + """ 上传图像文件到 S3 - ''' + """ ext = os.path.splitext(file_path)[1] S3_URL = "https://s3.neko.soulter.top/astrbot-s3" with open(file_path, "rb") as f: file = f.read() - + s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}" - async with aiohttp.ClientSession(headers = {"Accept": "application/json"}, trust_env=True) as session: + async with aiohttp.ClientSession( + headers={"Accept": "application/json"}, trust_env=True + ) as session: async with session.put(s3_file_url, data=file) as resp: if resp.status != 200: raise Exception(f"Failed to upload image: {resp.status}") return s3_file_url - - + async def is_docker_available(self) -> bool: - '''Check if docker is available''' + """Check if docker is available""" try: docker = aiodocker.Docker() await docker.version() @@ -155,42 +167,44 @@ class Main(star.Star): except BaseException as e: logger.info(f"检查 Docker 可用性: {e}") return False - + async def get_image_name(self) -> str: - '''Get the image name''' + """Get the image name""" if self.config["sandbox"]["docker_mirror"]: return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" return self.config["sandbox"]["image"] - + def _save_config(self): with open(PATH, "w") as f: json.dump(self.config, f) - + async def gen_magic_code(self) -> str: return uuid.uuid4().hex[:8] - - async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str: - '''Download image from url to workplace_path''' + + async def download_image( + self, image_url: str, workplace_path: str, filename: str + ) -> str: + """Download image from url to workplace_path""" async with aiohttp.ClientSession(trust_env=True) as session: async with session.get(image_url) as resp: if resp.status != 200: return "" image_path = os.path.join(workplace_path, f"{filename}.jpg") - with open(image_path, 'wb') as f: + with open(image_path, "wb") as f: f.write(await resp.read()) return f"{filename}.jpg" - + async def tidy_code(self, code: str) -> str: - '''Tidy the code''' + """Tidy the code""" pattern = r"```(?:py|python)?\n(.*?)\n```" match = re.search(pattern, code, re.DOTALL) if match is None: raise ValueError("The code is not in the code block.") return match.group(1) - + @filter.event_message_type(filter.EventMessageType.ALL) async def on_message(self, event: AstrMessageEvent): - '''处理消息''' + """处理消息""" uid = event.get_sender_id() if uid not in self.user_waiting: return @@ -220,23 +234,24 @@ class Main(star.Star): yield event.plain_result(f"代码执行器: 图片已经上传: {image_path}") if uid in self.user_waiting: del self.user_waiting[uid] - + @filter.on_llm_request() async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): if event.get_session_id() in self.user_file_msg_buffer: files = self.user_file_msg_buffer[event.get_session_id()] request.prompt += f"\nUser provided files: {files}" - @filter.command_group("pi") def pi(self): pass - + @pi.command("absdir") async def pi_absdir(self, event: AstrMessageEvent, path: str = ""): - '''设置 Docker 宿主机绝对路径''' + """设置 Docker 宿主机绝对路径""" if not path: - yield event.plain_result(f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}") + yield event.plain_result( + f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}" + ) else: self.config["docker_host_astrbot_abs_path"] = path self._save_config() @@ -244,9 +259,9 @@ class Main(star.Star): @pi.command("mirror") async def pi_mirror(self, event: AstrMessageEvent, url: str = ""): - '''Docker 镜像地址''' + """Docker 镜像地址""" if not url: - yield event.plain_result(f"""当前 Docker 镜像地址: {self.config['sandbox']['docker_mirror']}。 + yield event.plain_result(f"""当前 Docker 镜像地址: {self.config["sandbox"]["docker_mirror"]}。 使用 `pi mirror ` 来设置 Docker 镜像地址。 您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。 """) @@ -257,7 +272,7 @@ class Main(star.Star): @pi.command("repull") async def pi_repull(self, event: AstrMessageEvent): - '''重新拉取沙箱镜像''' + """重新拉取沙箱镜像""" docker = aiodocker.Docker() image_name = await self.get_image_name() try: @@ -267,28 +282,29 @@ class Main(star.Star): pass await docker.images.pull(image_name) yield event.plain_result("重新拉取沙箱镜像成功。") - + @pi.command("file") async def pi_file(self, event: AstrMessageEvent): - '''在规定秒数(60s)内上传一个文件''' + """在规定秒数(60s)内上传一个文件""" uid = event.get_sender_id() self.user_waiting[uid] = time.time() tip = "文件" yield event.plain_result(f"代码执行器: 请在 60s 内上传一个{tip}。") await asyncio.sleep(60) if uid in self.user_waiting: - yield event.plain_result(f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。") + yield event.plain_result( + f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。" + ) self.user_waiting.pop(uid) - @llm_tool("python_interpreter") async def python_interpreter(self, event: AstrMessageEvent): - '''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. + """Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code. For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc. - ''' + """ if not await self.is_docker_available(): yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。") - + plain_text = event.message_str # 创建必要的工作目录和幻术码 @@ -297,7 +313,7 @@ class Main(star.Star): output_path = os.path.join(workplace_path, "output") os.makedirs(workplace_path, exist_ok=True) os.makedirs(output_path, exist_ok=True) - + files = [] # 文件 for file_path in self.user_file_msg_buffer[event.get_session_id()]: @@ -310,9 +326,9 @@ class Main(star.Star): file_name = os.path.basename(file_path) shutil.copy(file_path, os.path.join(workplace_path, file_name)) files.append(file_name) - + logger.debug(f"user query: {plain_text}, files: {files}") - + # 整理额外输入 extra_inputs = "" if files: @@ -320,30 +336,33 @@ class Main(star.Star): obs = "" n = 5 - + for i in range(n): if i > 0: - logger.info(f"Try {i+1}/{n}") - + logger.info(f"Try {i + 1}/{n}") + PROMPT_ = PROMPT.format( - prompt=plain_text, + prompt=plain_text, extra_input=extra_inputs, extra_prompt=obs, ) provider = self.context.get_using_provider() - llm_response = await provider.text_chat(prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}") - - logger.debug("code interpreter llm gened code:" + llm_response.completion_text) - + llm_response = await provider.text_chat( + prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}" + ) + + logger.debug( + "code interpreter llm gened code:" + llm_response.completion_text + ) + # 整理代码并保存 code_clean = await self.tidy_code(llm_response.completion_text) with open(os.path.join(workplace_path, "exec.py"), "w") as f: f.write(code_clean) - + # 启动容器 docker = aiodocker.Docker() - - + # 检查有没有image image_name = await self.get_image_name() try: @@ -352,47 +371,58 @@ class Main(star.Star): # 拉取镜像 logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...") await docker.images.pull(image_name) - - yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})") - - - self.docker_host_astrbot_abs_path = self.config.get("docker_host_astrbot_abs_path", "") + + yield event.plain_result( + f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})" + ) + + self.docker_host_astrbot_abs_path = self.config.get( + "docker_host_astrbot_abs_path", "" + ) if self.docker_host_astrbot_abs_path: - host_shared = os.path.join(self.docker_host_astrbot_abs_path, self.shared_path) - host_output = os.path.join(self.docker_host_astrbot_abs_path, output_path) - host_workplace = os.path.join(self.docker_host_astrbot_abs_path, workplace_path) - + host_shared = os.path.join( + self.docker_host_astrbot_abs_path, self.shared_path + ) + host_output = os.path.join( + self.docker_host_astrbot_abs_path, output_path + ) + host_workplace = os.path.join( + self.docker_host_astrbot_abs_path, workplace_path + ) + else: host_shared = os.path.abspath(self.shared_path) host_output = os.path.abspath(output_path) host_workplace = os.path.abspath(workplace_path) - - logger.debug(f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}") - - container = await docker.containers.run({ - "Image": image_name, - "Cmd": ["python", "exec.py"], - "Memory": 512 * 1024 * 1024, - "NanoCPUs": 1000000000, - "HostConfig": { - "Binds": [ - f"{host_shared}:/astrbot_sandbox/shared:ro", - f"{host_output}:/astrbot_sandbox/output:rw", - f"{host_workplace}:/astrbot_sandbox:rw", - ] - }, - "Env": [ - f"MAGIC_CODE={magic_code}" - ], - "AutoRemove": True - }) - + + logger.debug( + f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}" + ) + + container = await docker.containers.run( + { + "Image": image_name, + "Cmd": ["python", "exec.py"], + "Memory": 512 * 1024 * 1024, + "NanoCPUs": 1000000000, + "HostConfig": { + "Binds": [ + f"{host_shared}:/astrbot_sandbox/shared:ro", + f"{host_output}:/astrbot_sandbox/output:rw", + f"{host_workplace}:/astrbot_sandbox:rw", + ] + }, + "Env": [f"MAGIC_CODE={magic_code}"], + "AutoRemove": True, + } + ) + logger.debug(f"Container {container.id} created.") logs = await self.run_container(container) - + logger.debug(f"Container {container.id} finished.") logger.debug(f"Container {container.id} logs: {logs}") - + # 发送结果 pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)" ok = False @@ -415,39 +445,43 @@ class Main(star.Star): file_name = os.path.basename(file_path) chain = [File(name=file_name, file=file_s3_url)] yield event.set_result(MessageEventResult(chain=chain)) - - elif "Traceback (most recent call last)" in log \ - or "[Error]: " in log: + + elif "Traceback (most recent call last)" in log or "[Error]: " in log: traceback = "\n".join(logs[idx:]) - + if not ok: if traceback: obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code." else: - logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}") + logger.warning( + f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}" + ) break else: # 成功了 self.user_file_msg_buffer.pop(event.get_session_id()) return - - yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。") - + + yield event.plain_result( + "经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。" + ) + @pi.command("cleanfile") async def pi_cleanfile(self, event: AstrMessageEvent): - '''清理用户上传的文件''' + """清理用户上传的文件""" for file in self.user_file_msg_buffer[event.get_session_id()]: try: os.remove(file) except BaseException as e: logger.error(f"删除文件 {file} 失败: {e}") - + self.user_file_msg_buffer.pop(event.get_session_id()) yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。") - - async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]: - '''Run the container and get the output''' + async def run_container( + self, container: aiodocker.docker.DockerContainer, timeout: int = 20 + ) -> list[str]: + """Run the container and get the output""" try: await container.wait(timeout=timeout) logs = await container.log(stdout=True, stderr=True) diff --git a/packages/python_interpreter/shared/api.py b/packages/python_interpreter/shared/api.py index 9fe27ce67..287773fb0 100644 --- a/packages/python_interpreter/shared/api.py +++ b/packages/python_interpreter/shared/api.py @@ -1,18 +1,22 @@ import os + def _get_magic_code(): - '''防止注入攻击''' + """防止注入攻击""" return os.getenv("MAGIC_CODE") + def send_text(text: str): print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}") - + + def send_image(image_path: str): if not os.path.exists(image_path): raise Exception(f"Image file not found: {image_path}") print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") - + + def send_file(file_path: str): if not os.path.exists(file_path): raise Exception(f"File not found: {file_path}") - print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") \ No newline at end of file + print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") diff --git a/packages/reminder/main.py b/packages/reminder/main.py index 8d5e836e4..e3f6c0a97 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -8,63 +8,72 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api import llm_tool, logger -@star.register(name="astrbot-reminder", desc="使用 LLM 待办提醒", author="Soulter", version="0.0.1") + +@star.register( + name="astrbot-reminder", desc="使用 LLM 待办提醒", author="Soulter", version="0.0.1" +) class Main(star.Star): - '''使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`''' + """使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`""" + def __init__(self, context: star.Context) -> None: self.context = context - self.scheduler = AsyncIOScheduler(timezone='Asia/Shanghai') - + self.scheduler = AsyncIOScheduler(timezone="Asia/Shanghai") + # set and load config if not os.path.exists("data/astrbot-reminder.json"): - with open("data/astrbot-reminder.json", "w", encoding='utf-8') as f: + with open("data/astrbot-reminder.json", "w", encoding="utf-8") as f: f.write("{}") - with open("data/astrbot-reminder.json", "r", encoding='utf-8') as f: + with open("data/astrbot-reminder.json", "r", encoding="utf-8") as f: self.reminder_data = json.load(f) - + self._init_scheduler() self.scheduler.start() def _init_scheduler(self): - '''Initialize the scheduler.''' + """Initialize the scheduler.""" for group in self.reminder_data: for reminder in self.reminder_data[group]: - if 'id' not in reminder: + if "id" not in reminder: id_ = str(uuid.uuid4()) - reminder['id'] = id_ + reminder["id"] = id_ else: - id_ = reminder['id'] - + id_ = reminder["id"] + if "datetime" in reminder: if self.check_is_outdated(reminder): continue self.scheduler.add_job( - self._reminder_callback, + self._reminder_callback, id=id_, - trigger='date', - args=[group, reminder], - run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"), - misfire_grace_time=60 + trigger="date", + args=[group, reminder], + run_date=datetime.datetime.strptime( + reminder["datetime"], "%Y-%m-%d %H:%M" + ), + misfire_grace_time=60, ) elif "cron" in reminder: self.scheduler.add_job( - self._reminder_callback, - trigger='cron', + self._reminder_callback, + trigger="cron", id=id_, - args=[group, reminder], + args=[group, reminder], misfire_grace_time=60, - **self._parse_cron_expr(reminder["cron"]) + **self._parse_cron_expr(reminder["cron"]), ) - + def check_is_outdated(self, reminder: dict): - '''Check if the reminder is outdated.''' + """Check if the reminder is outdated.""" if "datetime" in reminder: - return datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") < datetime.datetime.now() + return ( + datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") + < datetime.datetime.now() + ) return False - + async def _save_data(self): - '''Save the reminder data.''' - with open("data/astrbot-reminder.json", "w", encoding='utf-8') as f: + """Save the reminder data.""" + with open("data/astrbot-reminder.json", "w", encoding="utf-8") as f: json.dump(self.reminder_data, f, ensure_ascii=False) def _parse_cron_expr(self, cron_expr: str): @@ -76,79 +85,104 @@ class Main(star.Star): "month": fields[3], "day_of_week": fields[4], } - + @llm_tool("reminder") - async def reminder_tool(self, event: AstrMessageEvent, text: str=None, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None): - '''Call this function when user is asking for setting a reminder. - + async def reminder_tool( + self, + event: AstrMessageEvent, + text: str = None, + datetime_str: str = None, + cron_expression: str = None, + human_readable_cron: str = None, + ): + """Call this function when user is asking for setting a reminder. + Args: text(string): Must Required. The content of the reminder. datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder. human_readable_cron(string): Optional. The human readable cron expression of the reminder. - ''' - if event.get_platform_name() == 'qq_official': + """ + if event.get_platform_name() == "qq_official": yield event.plain_result("reminder 暂不支持 QQ 官方机器人。") return - + if event.unified_msg_origin not in self.reminder_data: self.reminder_data[event.unified_msg_origin] = [] - + if not cron_expression and not datetime_str: - raise ValueError("The cron_expression and datetime_str cannot be both None.") + raise ValueError( + "The cron_expression and datetime_str cannot be both None." + ) reminder_time = "" - + if not text: text = "未命名待办事项" - + if cron_expression: - d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron, "id": str(uuid.uuid4()) } + d = { + "text": text, + "cron": cron_expression, + "cron_h": human_readable_cron, + "id": str(uuid.uuid4()), + } self.reminder_data[event.unified_msg_origin].append(d) self.scheduler.add_job( - self._reminder_callback, - 'cron', + self._reminder_callback, + "cron", id=d["id"], misfire_grace_time=60, - **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d] + **self._parse_cron_expr(cron_expression), + args=[event.unified_msg_origin, d], ) if human_readable_cron: reminder_time = f"{human_readable_cron}(Cron: {cron_expression})" else: - d = { "text": text, "datetime": datetime_str, "id": str(uuid.uuid4()) } + d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())} self.reminder_data[event.unified_msg_origin].append(d) - datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M") + datetime_scheduled = datetime.datetime.strptime( + datetime_str, "%Y-%m-%d %H:%M" + ) self.scheduler.add_job( - self._reminder_callback, - 'date', + self._reminder_callback, + "date", id=d["id"], - args=[event.unified_msg_origin, d], + args=[event.unified_msg_origin, d], run_date=datetime_scheduled, - misfire_grace_time=60 + misfire_grace_time=60, ) reminder_time = datetime_str await self._save_data() - yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。\n使用 /tool off reminder 关闭此功能。") - + yield event.plain_result( + "成功设置待办事项。\n内容: " + + text + + "\n时间: " + + reminder_time + + "\n\n使用 /reminder ls 查看所有待办事项。\n使用 /tool off reminder 关闭此功能。" + ) + @filter.command_group("reminder") def reminder(self): - '''The command group of the reminder.''' + """The command group of the reminder.""" pass - + async def get_upcoming_reminders(self, unified_msg_origin: str): - '''Get upcoming reminders.''' + """Get upcoming reminders.""" reminders = self.reminder_data.get(unified_msg_origin, []) if not reminders: return [] now = datetime.datetime.now() upcoming_reminders = [ - reminder for reminder in reminders - if "datetime" not in reminder or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now + reminder + for reminder in reminders + if "datetime" not in reminder + or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now ] return upcoming_reminders - + @reminder.command("ls") async def reminder_ls(self, event: AstrMessageEvent): - '''List upcoming reminders.''' + """List upcoming reminders.""" reminders = await self.get_upcoming_reminders(event.unified_msg_origin) if not reminders: yield event.plain_result("没有正在进行的待办事项。") @@ -162,12 +196,12 @@ class Main(star.Star): reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n" reminder_str += "\n使用 /reminder rm 删除待办事项。\n" yield event.plain_result(reminder_str) - + @reminder.command("rm") async def reminder_rm(self, event: AstrMessageEvent, index: int): - '''Remove a reminder by index.''' + """Remove a reminder by index.""" reminders = await self.get_upcoming_reminders(event.unified_msg_origin) - + if not reminders: yield event.plain_result("没有待办事项。") elif index < 1 or index > len(reminders): @@ -175,26 +209,37 @@ class Main(star.Star): else: reminder = reminders.pop(index - 1) job_id = reminder.get("id") - + # self.reminder_data[event.unified_msg_origin] = reminder users_reminders = self.reminder_data.get(event.unified_msg_origin, []) for i, r in enumerate(users_reminders): if r.get("id") == job_id: users_reminders.pop(i) - + try: self.scheduler.remove_job(job_id) except Exception as e: logger.error(f"Remove job error: {e}") - yield event.plain_result(f"成功移除对应的待办事项。删除定时任务失败: {str(e)} 可能需要重启 AstrBot 以取消该提醒任务。") + yield event.plain_result( + f"成功移除对应的待办事项。删除定时任务失败: {str(e)} 可能需要重启 AstrBot 以取消该提醒任务。" + ) await self._save_data() yield event.plain_result("成功删除待办事项:\n" + reminder["text"]) - + async def _reminder_callback(self, unified_msg_origin: str, d: dict): - '''The callback function of the reminder.''' + """The callback function of the reminder.""" logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}") - await self.context.send_message(unified_msg_origin, MessageEventResult().message("待办提醒: \n\n" + d['text'] + "\n时间: " + d.get("datetime", "") + d.get("cron_h", ""))) - + await self.context.send_message( + unified_msg_origin, + MessageEventResult().message( + "待办提醒: \n\n" + + d["text"] + + "\n时间: " + + d.get("datetime", "") + + d.get("cron_h", "") + ), + ) + async def terminate(self): self.scheduler.shutdown() await self._save_data() diff --git a/packages/web_searcher/engines/__init__.py b/packages/web_searcher/engines/__init__.py index 9dad0a766..38b3ede10 100644 --- a/packages/web_searcher/engines/__init__.py +++ b/packages/web_searcher/engines/__init__.py @@ -6,29 +6,29 @@ from typing import List import urllib.parse HEADERS = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0', - 'Accept': '*/*', - 'Connection': 'keep-alive', - 'Accept-Language': 'en-GB,en;q=0.5' + "User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0", + "Accept": "*/*", + "Connection": "keep-alive", + "Accept-Language": "en-GB,en;q=0.5", } -USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0' +USER_AGENT_BING = "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0" USER_AGENTS = [ - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36', - 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0', - 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0' + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36", + "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0", + "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0", ] @dataclass -class SearchResult(): +class SearchResult: title: str url: str snippet: str @@ -36,10 +36,11 @@ class SearchResult(): def __str__(self) -> str: return f"{self.title} - {self.url}\n{self.snippet}" -class SearchEngine(): - ''' + +class SearchEngine: + """ 搜索引擎爬虫基类 - ''' + """ def __init__(self) -> None: self.TIMEOUT = 10 @@ -47,7 +48,7 @@ class SearchEngine(): self.headers = HEADERS def _set_selector(self, selector: str) -> None: - raise NotImplementedError() + raise NotImplementedError() def _get_next_page(self): raise NotImplementedError() @@ -58,37 +59,41 @@ class SearchEngine(): headers["User-Agent"] = random.choice(USER_AGENTS) if data: async with ClientSession() as session: - async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp: + async with session.post( + url, headers=headers, data=data, timeout=self.TIMEOUT + ) as resp: ret = await resp.text(encoding="utf-8") return ret else: async with ClientSession() as session: - async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp: + async with session.get( + url, headers=headers, timeout=self.TIMEOUT + ) as resp: ret = await resp.text(encoding="utf-8") return ret - - - def tidy_text(self, text: str) -> str: - ''' - 清理文本,去除空格、换行符等 - ''' - return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + def tidy_text(self, text: str) -> str: + """ + 清理文本,去除空格、换行符等 + """ + return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") async def search(self, query: str, num_results: int) -> List[SearchResult]: query = urllib.parse.quote(query) - + try: resp = await self._get_next_page(query) - soup = BeautifulSoup(resp, 'html.parser') - links = soup.select(self._set_selector('links')) + soup = BeautifulSoup(resp, "html.parser") + links = soup.select(self._set_selector("links")) results = [] for link in links: - title = self.tidy_text(link.select_one(self._set_selector('title')).text) - url = link.select_one(self._set_selector('url')) - snippet = '' + title = self.tidy_text( + link.select_one(self._set_selector("title")).text + ) + url = link.select_one(self._set_selector("url")) + snippet = "" if title and url: results.append(SearchResult(title=title, url=url, snippet=snippet)) return results[:num_results] if len(results) > num_results else results except Exception as e: - raise e \ No newline at end of file + raise e diff --git a/packages/web_searcher/engines/bing.py b/packages/web_searcher/engines/bing.py index 624e3a0fb..01bec4d45 100644 --- a/packages/web_searcher/engines/bing.py +++ b/packages/web_searcher/engines/bing.py @@ -2,19 +2,20 @@ from typing import List from . import SearchEngine, SearchResult from . import USER_AGENT_BING + class Bing(SearchEngine): def __init__(self) -> None: super().__init__() self.base_urls = ["https://cn.bing.com", "https://www.bing.com"] - self.headers.update({'User-Agent': USER_AGENT_BING}) + self.headers.update({"User-Agent": USER_AGENT_BING}) def _set_selector(self, selector: str): selectors = { - 'url': 'div.b_attribution cite', - 'title': 'h2', - 'text': 'p', - 'links': 'ol#b_results > li.b_algo', - 'next': 'div#b_content nav[role="navigation"] a.sb_pagN' + "url": "div.b_attribution cite", + "title": "h2", + "text": "p", + "links": "ol#b_results > li.b_algo", + "next": 'div#b_content nav[role="navigation"] a.sb_pagN', } return selectors[selector] @@ -23,7 +24,7 @@ class Bing(SearchEngine): # await self._get_html(self.base_url) for base_url in self.base_urls: try: - url = f'{base_url}/search?q={query}' + url = f"{base_url}/search?q={query}" return await self._get_html(url, None) except Exception as _: self.base_url = base_url @@ -35,5 +36,5 @@ class Bing(SearchEngine): for result in results: if not isinstance(result.url, str): result.url = result.url.text - - return results \ No newline at end of file + + return results diff --git a/packages/web_searcher/engines/google.py b/packages/web_searcher/engines/google.py index 62ca5f3ee..ac66f7d72 100644 --- a/packages/web_searcher/engines/google.py +++ b/packages/web_searcher/engines/google.py @@ -5,17 +5,26 @@ from . import SearchEngine, SearchResult from typing import List + class Google(SearchEngine): def __init__(self) -> None: super().__init__() self.proxy = os.environ.get("https_proxy") - + async def search(self, query: str, num_results: int) -> List[SearchResult]: results = [] try: - ls = search(query, advanced=True, num_results=num_results, timeout=3, proxy=self.proxy) + ls = search( + query, + advanced=True, + num_results=num_results, + timeout=3, + proxy=self.proxy, + ) for i in ls: - results.append(SearchResult(title=i.title, url=i.url, snippet=i.description)) + results.append( + SearchResult(title=i.title, url=i.url, snippet=i.description) + ) except Exception as e: raise e - return results \ No newline at end of file + return results diff --git a/packages/web_searcher/engines/sogo.py b/packages/web_searcher/engines/sogo.py index 17058f250..9a505782f 100644 --- a/packages/web_searcher/engines/sogo.py +++ b/packages/web_searcher/engines/sogo.py @@ -6,27 +6,27 @@ from . import USER_AGENTS from typing import List + class Sogo(SearchEngine): def __init__(self) -> None: super().__init__() self.base_url = "https://www.sogou.com" - self.headers['User-Agent'] = random.choice(USER_AGENTS) - + self.headers["User-Agent"] = random.choice(USER_AGENTS) def _set_selector(self, selector: str): selectors = { - 'url': 'h3 > a', - 'title': 'h3', - 'text': '', - 'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)', - 'next': '' + "url": "h3 > a", + "title": "h3", + "text": "", + "links": "div.results > div.vrwrap:not(.middle-better-hintBox)", + "next": "", } return selectors[selector] async def _get_next_page(self, query) -> str: - url = f'{self.base_url}/web?query={query}' + url = f"{self.base_url}/web?query={query}" return await self._get_html(url, None) - + async def search(self, query: str, num_results: int) -> List[SearchResult]: results = await super().search(query, num_results) for result in results: @@ -35,11 +35,13 @@ class Sogo(SearchEngine): result.url = self.base_url + result.url result.url = await self._parse_url(result.url) return results - + async def _parse_url(self, url) -> str: html = await self._get_html(url) - soup = BeautifulSoup(html, 'html.parser') + soup = BeautifulSoup(html, "html.parser") script = soup.find("script") if script: - url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1) - return url \ No newline at end of file + url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group( + 1 + ) + return url diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 747813b93..fa65ce1e5 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -12,74 +12,92 @@ from bs4 import BeautifulSoup from .engines import HEADERS, USER_AGENTS -@star.register(name="astrbot-web-searcher", desc="让 LLM 具有网页检索能力", author="Soulter", version="1.14.514") +@star.register( + name="astrbot-web-searcher", + desc="让 LLM 具有网页检索能力", + author="Soulter", + version="1.14.514", +) class Main(star.Star): - '''使用 /websearch on 或者 off 开启或者关闭网页搜索功能''' + """使用 /websearch on 或者 off 开启或者关闭网页搜索功能""" + def __init__(self, context: star.Context) -> None: self.context = context - + self.bing_search = Bing() self.sogo_search = Sogo() self.google = Google() - - self.websearch_link = self.context.get_config()['provider_settings'].get('web_search_link', False) - + + self.websearch_link = self.context.get_config()["provider_settings"].get( + "web_search_link", False + ) + async def initialize(self): - websearch = self.context.get_config()['provider_settings']['web_search'] + websearch = self.context.get_config()["provider_settings"]["web_search"] if websearch: self.context.activate_llm_tool("web_search") self.context.activate_llm_tool("fetch_url") else: self.context.deactivate_llm_tool("web_search") self.context.deactivate_llm_tool("fetch_url") - + async def _tidy_text(self, text: str) -> str: - '''清理文本,去除空格、换行符等''' + """清理文本,去除空格、换行符等""" return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") - + async def _get_from_url(self, url: str) -> str: - '''获取网页内容''' + """获取网页内容""" header = HEADERS - header.update({'User-Agent': random.choice(USER_AGENTS)}) + header.update({"User-Agent": random.choice(USER_AGENTS)}) async with aiohttp.ClientSession(trust_env=True) as session: async with session.get(url, headers=header, timeout=6) as response: html = await response.text(encoding="utf-8") doc = Document(html) ret = doc.summary(html_partial=True) - soup = BeautifulSoup(ret, 'html.parser') + soup = BeautifulSoup(ret, "html.parser") ret = await self._tidy_text(soup.get_text()) return ret @filter.command("websearch") async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str: - websearch = self.context.get_config()['provider_settings']['web_search'] + websearch = self.context.get_config()["provider_settings"]["web_search"] if oper is None: status = "开启" if websearch else "关闭" - event.set_result(MessageEventResult().message("当前网页搜索功能状态:" + status + "。使用 /websearch on 或者 off 启用或者关闭。")) + event.set_result( + MessageEventResult().message( + "当前网页搜索功能状态:" + + status + + "。使用 /websearch on 或者 off 启用或者关闭。" + ) + ) return - + if oper == "on": - self.context.get_config()['provider_settings']['web_search'] = True + self.context.get_config()["provider_settings"]["web_search"] = True self.context.get_config().save_config() self.context.activate_llm_tool("web_search") self.context.activate_llm_tool("fetch_url") event.set_result(MessageEventResult().message("已开启网页搜索功能")) elif oper == "off": - self.context.get_config()['provider_settings']['web_search'] = False + self.context.get_config()["provider_settings"]["web_search"] = False self.context.get_config().save_config() self.context.deactivate_llm_tool("web_search") self.context.deactivate_llm_tool("fetch_url") event.set_result(MessageEventResult().message("已关闭网页搜索功能")) else: - event.set_result(MessageEventResult().message("操作参数错误,应为 on 或 off")) - + event.set_result( + MessageEventResult().message("操作参数错误,应为 on 或 off") + ) + @llm_tool("web_search") - async def search_from_search_engine(self, event: AstrMessageEvent, query: str) -> str: - '''搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。 - + async def search_from_search_engine( + self, event: AstrMessageEvent, query: str + ) -> str: + """搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。 + Args: query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索。 - ''' + """ logger.info("web_searcher - search_from_search_engine: " + query) results = [] RESULT_NUM = 5 @@ -110,27 +128,29 @@ class Main(star.Star): site_result = await self._get_from_url(i.url) except BaseException: site_result = "" - site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result - + site_result = ( + site_result[:700] + "..." if len(site_result) > 700 else site_result + ) + header = f"{idx}. {i.title} " - + if self.websearch_link and i.url: header += i.url - + ret += f"{header}\n{i.snippet}\n{site_result}\n\n" idx += 1 - + if self.websearch_link: ret += "针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。" - + return ret @llm_tool("fetch_url") async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str: - '''fetch the content of a website with the given web url - + """fetch the content of a website with the given web url + Args: url(string): The url of the website to fetch content from - ''' + """ resp = await self._get_from_url(url) - return resp \ No newline at end of file + return resp diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..e61f1a087 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.ruff] +lint.ignore = ["F403", "F405"] +exclude = [ + "astrbot/core/utils/t2i/local_strategy.py", + "astrbot/api/all.py" +] \ No newline at end of file diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index d4a38ad9b..9c0658465 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -16,93 +16,107 @@ def core_lifecycle_td(): core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db) return core_lifecycle_td + @pytest.fixture(scope="module") def app(core_lifecycle_td): db = SQLiteDatabase("data/data_v3.db") server = AstrBotDashboard(core_lifecycle_td, db) return server.app + @pytest.fixture(scope="module") def header(): return {} + @pytest.mark.asyncio async def test_init_core_lifecycle_td(core_lifecycle_td): await core_lifecycle_td.initialize() assert core_lifecycle_td is not None + @pytest.mark.asyncio -async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict): +async def test_auth_login( + app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict +): test_client = app.test_client() - response = await test_client.post('/api/auth/login', json={ - "username": "wrong", - "password": "password" - }) + response = await test_client.post( + "/api/auth/login", json={"username": "wrong", "password": "password"} + ) data = await response.get_json() - assert data['status'] == 'error' - - response = await test_client.post('/api/auth/login', json={ - "username": core_lifecycle_td.astrbot_config['dashboard']['username'], - "password": core_lifecycle_td.astrbot_config['dashboard']['password'] - }) + assert data["status"] == "error" + + response = await test_client.post( + "/api/auth/login", + json={ + "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], + "password": core_lifecycle_td.astrbot_config["dashboard"]["password"], + }, + ) data = await response.get_json() - assert data['status'] == 'ok' and 'token' in data['data'] - header['Authorization'] = f"Bearer {data['data']['token']}" - + assert data["status"] == "ok" and "token" in data["data"] + header["Authorization"] = f"Bearer {data['data']['token']}" + + @pytest.mark.asyncio async def test_get_stat(app: Quart, header: dict): test_client = app.test_client() - response = await test_client.get('/api/stat/get') + response = await test_client.get("/api/stat/get") assert response.status_code == 401 - response = await test_client.get('/api/stat/get', headers=header) + response = await test_client.get("/api/stat/get", headers=header) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' and 'platform' in data['data'] + assert data["status"] == "ok" and "platform" in data["data"] + @pytest.mark.asyncio async def test_plugins(app: Quart, header: dict): test_client = app.test_client() # 已经安装的插件 - response = await test_client.get('/api/plugin/get', headers=header) + response = await test_client.get("/api/plugin/get", headers=header) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' - + assert data["status"] == "ok" + # 插件市场 - response = await test_client.get('/api/plugin/market_list', headers=header) + response = await test_client.get("/api/plugin/market_list", headers=header) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' - + assert data["status"] == "ok" + # 插件安装 - response = await test_client.post('/api/plugin/install', json={ - "url": "https://github.com/Soulter/astrbot_plugin_essential" - }, headers=header) + response = await test_client.post( + "/api/plugin/install", + json={"url": "https://github.com/Soulter/astrbot_plugin_essential"}, + headers=header, + ) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' + assert data["status"] == "ok" exists = False for md in star_registry: if md.name == "astrbot_plugin_essential": exists = True break assert exists is True, "插件 astrbot_plugin_essential 未成功载入" - + # 插件更新 - response = await test_client.post('/api/plugin/update', json={ - "name": "astrbot_plugin_essential" - }, headers=header) + response = await test_client.post( + "/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header + ) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' - + assert data["status"] == "ok" + # 插件卸载 - response = await test_client.post('/api/plugin/uninstall', json={ - "name": "astrbot_plugin_essential" - }, headers=header) + response = await test_client.post( + "/api/plugin/uninstall", + json={"name": "astrbot_plugin_essential"}, + headers=header, + ) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' + assert data["status"] == "ok" exists = False for md in star_registry: if md.name == "astrbot_plugin_essential": @@ -115,34 +129,37 @@ async def test_plugins(app: Quart, header: dict): exists = True break assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" - + + @pytest.mark.asyncio async def test_check_update(app: Quart, header: dict): test_client = app.test_client() - response = await test_client.get('/api/update/check', headers=header) + response = await test_client.get("/api/update/check", headers=header) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'success' - + assert data["status"] == "success" + + @pytest.mark.asyncio -async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle): +async def test_do_update( + app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle +): global VERSION test_client = app.test_client() os.makedirs("data/astrbot_release", exist_ok=True) core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release" VERSION = "114.514.1919810" - response = await test_client.post('/api/update/do', headers=header, json={ - "version": "latest" - }) + response = await test_client.post( + "/api/update/do", headers=header, json={"version": "latest"} + ) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'error' # 已经是最新版本 - - response = await test_client.post('/api/update/do', headers=header, json={ - "version": "v3.4.0", - "reboot": False - }) + assert data["status"] == "error" # 已经是最新版本 + + response = await test_client.post( + "/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False} + ) assert response.status_code == 200 data = await response.get_json() - assert data['status'] == 'ok' - assert os.path.exists("data/astrbot_release/astrbot") \ No newline at end of file + assert data["status"] == "ok" + assert os.path.exists("data/astrbot_release/astrbot") diff --git a/tests/test_main.py b/tests/test_main.py index d2201e448..0f5e51d13 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,45 +4,52 @@ import pytest from unittest import mock from main import check_env, check_dashboard_files -class _version_info(): + +class _version_info: def __init__(self, major, minor): self.major = major self.minor = minor - + + def test_check_env(monkeypatch): version_info_correct = _version_info(3, 10) version_info_wrong = _version_info(3, 9) - monkeypatch.setattr(sys, 'version_info', version_info_correct) - with mock.patch('os.makedirs') as mock_makedirs: + monkeypatch.setattr(sys, "version_info", version_info_correct) + with mock.patch("os.makedirs") as mock_makedirs: check_env() mock_makedirs.assert_any_call("data/config", exist_ok=True) mock_makedirs.assert_any_call("data/plugins", exist_ok=True) mock_makedirs.assert_any_call("data/temp", exist_ok=True) - - monkeypatch.setattr(sys, 'version_info', version_info_wrong) + + monkeypatch.setattr(sys, "version_info", version_info_wrong) with pytest.raises(SystemExit): check_env() + @pytest.mark.asyncio async def test_check_dashboard_files(monkeypatch): - monkeypatch.setattr(os.path, 'exists', lambda x: False) + monkeypatch.setattr(os.path, "exists", lambda x: False) + async def mock_get(*args, **kwargs): class MockResponse: status = 200 + async def read(self): - return b'content' + return b"content" + return MockResponse() - - with mock.patch('aiohttp.ClientSession.get', new=mock_get): - with mock.patch('builtins.open', mock.mock_open()) as mock_file: - with mock.patch('zipfile.ZipFile.extractall') as mock_extractall: + + with mock.patch("aiohttp.ClientSession.get", new=mock_get): + with mock.patch("builtins.open", mock.mock_open()) as mock_file: + with mock.patch("zipfile.ZipFile.extractall") as mock_extractall: + async def mock_aenter(_): await check_dashboard_files() mock_file.assert_called_once_with("data/dashboard.zip", "wb") mock_extractall.assert_called_once() - + async def mock_aexit(obj, exc_type, exc, tb): return mock_extractall.__aenter__ = mock_aenter - mock_extractall.__aexit__ = mock_aexit \ No newline at end of file + mock_extractall.__aexit__ = mock_aexit diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 4eb278d81..7afedaefa 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -5,9 +5,12 @@ import asyncio from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext from astrbot.core.star import PluginManager from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.config.default import CONFIG_METADATA_2 from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType +from astrbot.core.platform.astrbot_message import ( + AstrBotMessage, + MessageMember, + MessageType, +) from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core.message.components import Plain, At from astrbot.core.platform.platform_metadata import PlatformMetadata @@ -47,9 +50,10 @@ TEST_COMMANDS = [ # ["model", "查看、切换提供商模型列表"], ["history", "历史记录:"], ["key", "当前 Key"], - ["persona", "[Persona]"] + ["persona", "[Persona]"], ] + class FakeAstrMessageEvent(AstrMessageEvent): def __init__(self, abm: AstrBotMessage = None): meta = PlatformMetadata("test_platform", "test") @@ -57,19 +61,19 @@ class FakeAstrMessageEvent(AstrMessageEvent): message_str=abm.message_str, message_obj=abm, platform_meta=meta, - session_id=abm.session_id + session_id=abm.session_id, ) - + async def send(self, message: MessageChain): await super().send(message) - + @staticmethod def create_fake_event( - message_str: str, - session_id: str = "test_sid", + message_str: str, + session_id: str = "test_sid", is_at: bool = False, is_group: bool = False, - sender_id: str = "123456" + sender_id: str = "123456", ): abm = AstrBotMessage() abm.message_str = message_str @@ -88,36 +92,46 @@ class FakeAstrMessageEvent(AstrMessageEvent): abm.type = MessageType.FRIEND_MESSAGE return FakeAstrMessageEvent(abm) + @pytest.fixture(scope="module") def event_queue(): return Queue() + @pytest.fixture(scope="module") def config(): cfg = AstrBotConfig() - cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"] - cfg['admins_id'] = ["123456"] - cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"] - cfg['provider'] = [TEST_LLM_PROVIDER] + cfg["platform_settings"]["id_whitelist"] = [ + "test_platform:FriendMessage:test_sid_wl", + "test_platform:GroupMessage:test_sid_wl", + ] + cfg["admins_id"] = ["123456"] + cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"] + cfg["provider"] = [TEST_LLM_PROVIDER] return cfg + @pytest.fixture(scope="module") def db(): return SQLiteDatabase("data/data_v3.db") + @pytest.fixture(scope="module") def platform_manager(event_queue, config): return PlatformManager(config, event_queue) - + + @pytest.fixture(scope="module") def provider_manager(config, db): return ProviderManager(config, db) + @pytest.fixture(scope="module") def star_context(event_queue, config, db, platform_manager, provider_manager): star_context = Context(event_queue, config, db, provider_manager, platform_manager) return star_context + @pytest.fixture(scope="module") def plugin_manager(star_context, config): plugin_manager = PluginManager(star_context, config) @@ -125,72 +139,105 @@ def plugin_manager(star_context, config): asyncio.run(plugin_manager.reload()) return plugin_manager + @pytest.fixture(scope="module") def pipeline_context(config, plugin_manager): return PipelineContext(config, plugin_manager) + @pytest.fixture(scope="module") def pipeline_scheduler(pipeline_context): return PipelineScheduler(pipeline_context) + @pytest.mark.asyncio async def test_platform_initialization(platform_manager: PlatformManager): await platform_manager.initialize() - + + @pytest.mark.asyncio async def test_provider_initialization(provider_manager: ProviderManager): await provider_manager.initialize() + @pytest.mark.asyncio async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler): await pipeline_scheduler.initialize() + @pytest.mark.asyncio async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog): - '''测试唤醒''' + """测试唤醒""" # 群聊无 @ 无指令 caplog.clear() mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) - assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages) + assert any( + "执行阶段 WhitelistCheckStage" not in message for message in caplog.messages + ) # 群聊有 @ 无指令 - mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True) + mock_event = FakeAstrMessageEvent.create_fake_event( + "test", is_group=True, is_at=True + ) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages) # 群聊有指令 - mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST) + mock_event = FakeAstrMessageEvent.create_fake_event( + "/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST + ) await pipeline_scheduler.execute(mock_event) assert mock_event._has_send_oper is True + @pytest.mark.asyncio -async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog): +async def test_pipeline_wl( + pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog +): caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123") + mock_event = FakeAstrMessageEvent.create_fake_event( + "test", SESSION_ID_IN_WHITELIST, sender_id="123" + ) with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) - assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息" + assert any( + "不在会话白名单中,已终止事件传播。" not in message + for message in caplog.messages + ), "日志中未找到预期的消息" mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123") with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) - assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息" - + assert any( + "不在会话白名单中,已终止事件传播。" in message for message in caplog.messages + ), "日志中未找到预期的消息" + + @pytest.mark.asyncio async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog): # 测试默认屏蔽词 caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。 + mock_event = FakeAstrMessageEvent.create_fake_event( + "色情", session_id=SESSION_ID_IN_WHITELIST + ) # 测试需要。 with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) - assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息" + assert any("内容安全检查不通过" in message for message in caplog.messages), ( + "日志中未找到预期的消息" + ) # 测试额外屏蔽词 - mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST) + mock_event = FakeAstrMessageEvent.create_fake_event( + "TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST + ) with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) - assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息" - mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST) + assert any("内容安全检查不通过" in message for message in caplog.messages), ( + "日志中未找到预期的消息" + ) + mock_event = FakeAstrMessageEvent.create_fake_event( + "_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST + ) with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) assert any("内容安全检查不通过" not in message for message in caplog.messages) @@ -200,28 +247,39 @@ async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, ca @pytest.mark.asyncio async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog): caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST) + mock_event = FakeAstrMessageEvent.create_fake_event( + "just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST + ) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) assert any("请求 LLM" in message for message in caplog.messages) assert mock_event.get_result() is not None assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT - + + @pytest.mark.asyncio async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog): caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST) + mock_event = FakeAstrMessageEvent.create_fake_event( + "help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST + ) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) assert any("请求 LLM" in message for message in caplog.messages) - assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages) - + assert any( + "web_searcher - search_from_search_engine" in message + for message in caplog.messages + ) + + @pytest.mark.asyncio async def test_commands(pipeline_scheduler: PipelineScheduler, caplog): for command in TEST_COMMANDS: caplog.clear() - mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST) + mock_event = FakeAstrMessageEvent.create_fake_event( + command[0], session_id=SESSION_ID_IN_WHITELIST + ) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) # assert any("执行阶段 ProcessStage" in message for message in caplog.messages) - assert any(command[1] in message for message in caplog.messages) \ No newline at end of file + assert any(command[1] in message for message in caplog.messages) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 86a99c6c1..1a7831536 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -16,24 +16,28 @@ db = SQLiteDatabase("data/data_v3.db") star_context = Context(event_queue, config, db) + @pytest.fixture def plugin_manager_pm(): return PluginManager(star_context, config) + def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): assert plugin_manager_pm is not None assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None + @pytest.mark.asyncio async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): success, err_message = await plugin_manager_pm.reload() assert success is True assert err_message is None - + + @pytest.mark.asyncio async def test_plugin_crud(plugin_manager_pm: PluginManager): - '''测试插件安装和重载''' + """测试插件安装和重载""" os.makedirs("data/plugins", exist_ok=True) test_repo = "https://github.com/Soulter/astrbot_plugin_essential" plugin_path = await plugin_manager_pm.install_plugin(test_repo) @@ -44,19 +48,19 @@ async def test_plugin_crud(plugin_manager_pm: PluginManager): break assert plugin_path is not None assert os.path.exists(plugin_path) - assert exists is True, "插件 astrbot_plugin_essential 未成功载入" + assert exists is True, "插件 astrbot_plugin_essential 未成功载入" # shutil.rmtree(plugin_path) - + # install plugin which is not exists with pytest.raises(Exception): plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha") - + # update await plugin_manager_pm.update_plugin("astrbot_plugin_essential") - + with pytest.raises(Exception): await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha") - + # uninstall await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential") assert not os.path.exists(plugin_path) @@ -72,7 +76,7 @@ async def test_plugin_crud(plugin_manager_pm: PluginManager): exists = True break assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" - + with pytest.raises(Exception): await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")