diff --git a/.gitignore b/.gitignore index 52b57f486..a863e36ec 100644 --- a/.gitignore +++ b/.gitignore @@ -17,10 +17,12 @@ addons/plugins tests/astrbot_plugin_openai chroma -node_modules/ +dashboard/node_modules/ +dashboard/dist/ .DS_Store package-lock.json package.json venv/* packages/python_interpreter/workplace -.venv/* \ No newline at end of file +.venv/* +.conda/ diff --git a/LICENSE b/LICENSE index 0ad25db4b..fb36daa15 100644 --- a/LICENSE +++ b/LICENSE @@ -629,8 +629,8 @@ to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. - - Copyright (C) + AstrBot is a llm-powered chatbot and develop framework. + Copyright (C) 2022-2099 Soulter This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published diff --git a/README.md b/README.md index bf4748809..9d84175f2 100644 --- a/README.md +++ b/README.md @@ -148,10 +148,6 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_ -## Sponsors - -[](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==) - ## Disclaimer 1. The project is protected under the `AGPL-v3` opensource license. diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d2e5dcee5..ca42825d3 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,7 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ -VERSION = "3.4.30" +VERSION = "3.4.31" DB_PATH = "data/data_v3.db" # 默认配置 @@ -16,7 +16,7 @@ DEFAULT_CONFIG = { "strategy": "stall", # stall, discard }, "reply_prefix": "", - "forward_threshold": 200, + "forward_threshold": 1500, "enable_id_white_list": True, "id_whitelist": [], "id_whitelist_log": True, @@ -67,17 +67,15 @@ DEFAULT_CONFIG = { "method": "possibility_reply", "possibility_reply": 0.1, "prompt": "", - "whitelist": [] - } + "whitelist": [], + }, }, "content_safety": { "also_use_in_response": False, "internal_keywords": {"enable": True, "extra_keywords": []}, "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, }, - "admins_id": [ - "astrbot" - ], + "admins_id": ["astrbot"], "t2i": False, "t2i_word_threshold": 150, "http_proxy": "", @@ -85,7 +83,7 @@ DEFAULT_CONFIG = { "enable": True, "username": "astrbot", "password": "77b90590a8945a7d36c963981a307dc9", - "port": 6185 + "port": 6185, }, "platform": [], "wake_prefix": ["/"], @@ -122,9 +120,9 @@ CONFIG_METADATA_2 = { "enable": False, "appid": "", "secret": "", - "port": 6196 + "port": 6196, }, - "aiocqhtp(QQ)": { + "aiocqhttp(OneBotv11)": { "id": "default", "type": "aiocqhttp", "enable": False, @@ -140,6 +138,14 @@ CONFIG_METADATA_2 = { "host": "这里填写你的局域网IP或者公网服务器IP", "port": 11451, }, + "wecom(企业微信)": { + "corpid": "", + "secret": "", + "port": 6195, + "token": "", + "encoding_aes_key": "", + "api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/", + }, "lark(飞书)": { "id": "lark", "type": "lark", @@ -147,14 +153,28 @@ CONFIG_METADATA_2 = { "lark_bot_name": "", "app_id": "", "app_secret": "", - "domain": "https://open.feishu.cn" + "domain": "https://open.feishu.cn", + }, + "telegram": { + "id": "telegram", + "type": "telegram", + "enable": False, + "telegram_token": "your_bot_token", + "start_message": "Hello, I'm AstrBot!", + "telegram_api_base_url": "https://api.telegram.org/bot", }, }, "items": { + "telegram_token": { + "description": "Bot Token", + "type": "string", + "hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。", + }, "id": { "description": "ID", "type": "string", - "hint": "用于在多实例下方便管理和识别。自定义,ID 不能重复。", + "obvious_hint": True, + "hint": "ID 不能和其它的平台适配器重复,否则将发生严重冲突。", }, "type": { "description": "适配器类型", @@ -200,8 +220,8 @@ CONFIG_METADATA_2 = { "description": "飞书机器人的名字", "type": "string", "hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", - "obvious_hint": True - } + "obvious_hint": True, + }, }, }, "platform_settings": { @@ -249,7 +269,7 @@ CONFIG_METADATA_2 = { "description": "间隔时间计算方法", "type": "string", "options": ["random", "log"], - "hint": "分段回复的间隔时间计算方法。random 为随机时间,log 为根据消息长度计算,$y=log_{log\_base}(x)$,x为字数,y的单位为秒。", + "hint": "分段回复的间隔时间计算方法。random 为随机时间,log 为根据消息长度计算,$y=log_(x)$,x为字数,y的单位为秒。", }, "interval": { "description": "随机间隔时间(秒)", @@ -386,7 +406,7 @@ CONFIG_METADATA_2 = { "description": "服务提供商配置", "type": "list", "config_template": { - "openai": { + "OpenAI": { "id": "openai", "type": "openai_chat_completion", "enable": True, @@ -397,7 +417,7 @@ CONFIG_METADATA_2 = { "model": "gpt-4o-mini", }, }, - "azure_openai": { + "Azure_OpenAI": { "id": "azure", "type": "openai_chat_completion", "enable": True, @@ -409,7 +429,7 @@ CONFIG_METADATA_2 = { "model": "gpt-4o-mini", }, }, - "xAI": { + "xAI(grok)": { "id": "xai", "type": "openai_chat_completion", "enable": True, @@ -420,7 +440,19 @@ CONFIG_METADATA_2 = { "model": "grok-2-latest", }, }, - "ollama": { + "Anthropic(claude)": { + "id": "claude", + "type": "anthropic_chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.anthropic.com/v1", + "timeout": 120, + "model_config": { + "model": "claude-3-5-sonnet-latest", + "max_tokens": 4096, + }, + }, + "Ollama": { "id": "ollama_default", "type": "openai_chat_completion", "enable": True, @@ -430,7 +462,7 @@ CONFIG_METADATA_2 = { "model": "llama3.1-8b", }, }, - "gemini(OpenAI兼容)": { + "Gemini(OpenAI兼容)": { "id": "gemini_default", "type": "openai_chat_completion", "enable": True, @@ -441,7 +473,7 @@ CONFIG_METADATA_2 = { "model": "gemini-1.5-flash", }, }, - "gemini(googlegenai原生)": { + "Gemini(googlegenai原生)": { "id": "gemini_default", "type": "googlegenai_chat_completion", "enable": True, @@ -452,7 +484,7 @@ CONFIG_METADATA_2 = { "model": "gemini-1.5-flash", }, }, - "deepseek": { + "DeepSeek": { "id": "deepseek_default", "type": "openai_chat_completion", "enable": True, @@ -463,7 +495,7 @@ CONFIG_METADATA_2 = { "model": "deepseek-chat", }, }, - "zhipu": { + "Zhipu(智谱)": { "id": "zhipu_default", "type": "zhipu_chat_completion", "enable": True, @@ -474,7 +506,7 @@ CONFIG_METADATA_2 = { "model": "glm-4-flash", }, }, - "siliconflow": { + "SiliconFlow(硅基流动)": { "id": "siliconflow", "type": "openai_chat_completion", "enable": True, @@ -485,7 +517,7 @@ CONFIG_METADATA_2 = { "model": "deepseek-ai/DeepSeek-V3", }, }, - "moonshot(kimi)": { + "MoonShot(Kimi)": { "id": "moonshot", "type": "openai_chat_completion", "enable": True, @@ -496,7 +528,7 @@ CONFIG_METADATA_2 = { "model": "moonshot-v1-8k", }, }, - "llmtuner": { + "LLMTuner": { "id": "llmtuner_default", "type": "llm_tuner", "enable": True, @@ -506,7 +538,7 @@ CONFIG_METADATA_2 = { "finetuning_type": "lora", "quantization_bit": 4, }, - "dify": { + "Dify": { "id": "dify_app_default", "type": "dify", "enable": True, @@ -515,9 +547,28 @@ CONFIG_METADATA_2 = { "dify_api_base": "https://api.dify.ai/v1", "dify_workflow_output_key": "", "dify_query_input_key": "astrbot_text_query", + "variables": {}, "timeout": 60, }, - "whisper(API)": { + "Dashscope(阿里云百炼应用)": { + "id": "dashscope", + "type": "dashscope", + "enable": True, + "dashscope_app_type": "agent", + "dashscope_api_key": "", + "dashscope_app_id": "", + "variables": {}, + "timeout": 60, + }, + "FastGPT": { + "id": "fastgpt", + "type": "openai_chat_completion", + "enable": True, + "key": [], + "api_base": "https://api.fastgpt.in/api/v1", + "timeout": 60, + }, + "Whisper(API)": { "id": "whisper", "type": "openai_whisper_api", "enable": False, @@ -525,7 +576,7 @@ CONFIG_METADATA_2 = { "api_base": "", "model": "whisper-1", }, - "whisper(本地加载)": { + "Whisper(本地加载)": { "whisper_hint": "(不用修改我)", "enable": False, "id": "whisper", @@ -540,7 +591,7 @@ CONFIG_METADATA_2 = { "stt_model": "icc/SenseVoiceSmall", "is_emotion": False, }, - "openai_tts(API)": { + "OpenAI_TTS(API)": { "id": "openai_tts", "type": "openai_tts_api", "enable": False, @@ -550,7 +601,7 @@ CONFIG_METADATA_2 = { "openai-tts-voice": "alloy", "timeout": "20", }, - "fishaudio_tts(API)": { + "FishAudio_TTS(API)": { "id": "fishaudio_tts", "type": "fishaudio_tts_api", "enable": False, @@ -577,6 +628,31 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。", }, + # "variables": { + # "description": "工作流固定输入变量", + # "type": "object", + # "obvious_hint": True, + # "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", + # }, + # "fastgpt_app_type": { + # "description": "应用类型", + # "type": "string", + # "hint": "FastGPT 应用的应用类型。", + # "options": ["agent", "workflow", "plugin"], + # "obvious_hint": True, + # }, + "dashscope_app_type": { + "description": "应用类型", + "type": "string", + "hint": "阿里云百炼应用的应用类型。", + "options": [ + "agent", + "agent-arrange", + "dialog-workflow", + "task-workflow", + ], + "obvious_hint": True, + }, "timeout": { "description": "超时时间", "type": "int", @@ -603,7 +679,8 @@ CONFIG_METADATA_2 = { "id": { "description": "ID", "type": "string", - "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。", + "obvious_hint": True, + "hint": "ID 不能和其它的服务提供商重复,否则将发生严重冲突。", }, "type": { "description": "模型提供商类型", @@ -692,10 +769,10 @@ CONFIG_METADATA_2 = { }, "dify_query_input_key": { "description": "Prompt 输入变量名", - "type": "string", + "type": "string", "hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。", "obvious": True, - } + }, }, }, "provider_settings": { diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 0289c4d1a..c2a5f4838 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -27,7 +27,7 @@ class AstrBotCoreLifecycle: os.environ['https_proxy'] = self.astrbot_config['http_proxy'] os.environ['http_proxy'] = self.astrbot_config['http_proxy'] - os.environ['no_proxy'] = 'localhost,127.0.0.1' + os.environ['no_proxy'] = 'localhost' async def initialize(self): logger.info("AstrBot v"+ VERSION) @@ -63,9 +63,6 @@ class AstrBotCoreLifecycle: await self.provider_manager.initialize() '''根据配置实例化各个 Provider''' - await self.platform_manager.initialize() - '''根据配置实例化各个平台适配器''' - self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager)) await self.pipeline_scheduler.initialize() '''初始化消息事件流水线调度器''' @@ -74,19 +71,18 @@ class AstrBotCoreLifecycle: self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) self.start_time = int(time.time()) self.curr_tasks: List[asyncio.Task] = [] - + + await self.platform_manager.initialize() + '''根据配置实例化各个平台适配器''' + def _load(self): - - platform_tasks = self.load_platform() event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus") extra_tasks = [] for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) - # self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks] - - tasks_ = [event_bus_task, *platform_tasks, *extra_tasks] + tasks_ = [event_bus_task, *extra_tasks] for task in tasks_: self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name())) diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 58cd86837..0ab5fe852 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -7,7 +7,7 @@ from typing import List CACHED_SIZE = 200 log_color_config = { - 'DEBUG': 'bold_blue', 'INFO': 'bold_cyan', + 'DEBUG': 'green', 'INFO': 'bold_cyan', 'WARNING': 'bold_yellow', 'ERROR': 'red', 'CRITICAL': 'bold_red', 'RESET': 'reset', 'asctime': 'green' @@ -58,7 +58,7 @@ class LogManager: console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) console_formatter = colorlog.ColoredFormatter( - fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s', + fmt='%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s', datefmt='%H:%M:%S', log_colors=log_color_config ) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 16ba0c1e4..477e7eeec 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -30,11 +30,19 @@ from enum import Enum from pydantic.v1 import BaseModel class ComponentType(Enum): - Plain = "Plain" - Face = "Face" - Record = "Record" - Video = "Video" - At = "At" + Plain = "Plain" # 纯文本消息 + Face = "Face" # QQ表情 + Record = "Record" # 语音 + Video = "Video" # 视频 + At = "At" # At + Node = "Node" # 转发消息的一个节点 + Nodes = "Nodes" # 转发消息的多个节点 + Poke = "Poke" # QQ 戳一戳 + Image = "Image" # 图片 + Reply = "Reply" # 回复 + Forward = "Forward" # 转发消息 + File = "File" # 文件 + RPS = "RPS" # TODO Dice = "Dice" # TODO Shake = "Shake" # TODO @@ -43,18 +51,12 @@ class ComponentType(Enum): Contact = "Contact" # TODO Location = "Location" # TODO Music = "Music" - Image = "Image" - Reply = "Reply" RedBag = "RedBag" - Poke = "Poke" - Forward = "Forward" - Node = "Node" Xml = "Xml" Json = "Json" CardImage = "CardImage" TTS = "TTS" Unknown = "Unknown" - File = "File" class BaseMessageComponent(BaseModel): @@ -362,6 +364,18 @@ class Node(BaseMessageComponent): def toString(self): # logger.warn("Protocol: node doesn't support stringify") return "" + +class Nodes(BaseMessageComponent): + type: ComponentType = "Nodes" + nodes: T.List[Node] + + def __init__(self, nodes: T.List[Node], **_): + super().__init__(nodes=nodes, **_) + + def toDict(self): + return { + "messages": [node.toDict() for node in self.nodes] + } class Xml(BaseMessageComponent): @@ -451,6 +465,7 @@ ComponentTypes = { "poke": Poke, "forward": Forward, "node": Node, + "nodes": Nodes, "xml": Xml, "json": Json, "cardimage": CardImage, diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index ffd7689a7..7f781a1e7 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -28,4 +28,3 @@ class ContentSafetyCheckStage(Stage): event.stop_event() logger.info(f"内容安全检查不通过,原因:{info}") return - event.continue_event() diff --git a/astrbot/core/pipeline/process_stage/method/dify_request.py b/astrbot/core/pipeline/process_stage/method/dify_request.py deleted file mode 100644 index 3dca3792d..000000000 --- a/astrbot/core/pipeline/process_stage/method/dify_request.py +++ /dev/null @@ -1,90 +0,0 @@ -''' -Dify 调用 Stage -''' -import traceback -from typing import Union, AsyncGenerator -from ...context import PipelineContext -from ..stage import Stage -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType -from astrbot.core.message.components import Image -from astrbot.core import logger -from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ProviderRequest -from astrbot.core.star.star_handler import star_handlers_registry, EventType - -class DifyRequestSubStage(Stage): - - async def initialize(self, ctx: PipelineContext) -> None: - self.ctx = ctx - - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: - req: ProviderRequest = None - - provider = self.ctx.plugin_manager.context.get_using_provider() - - if not provider: - return - - if provider.meta().type != "dify": - return - - if event.get_extra("provider_request"): - req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。" - else: - req = ProviderRequest(prompt="", image_urls=[]) - if self.ctx.astrbot_config['provider_settings']['wake_prefix']: - if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']): - return - req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):] - for comp in event.message_obj.message: - if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - req.image_urls.append(image_url) - req.session_id = event.session_id - event.set_extra("provider_request", req) - - if not req.prompt: - return - - req.session_id = event.unified_msg_origin - - # 执行请求 LLM 前事件钩子。 - # 装饰 system_prompt 等功能 - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent) - for handler in handlers: - try: - await handler.handler(event, req) - except BaseException: - logger.error(traceback.format_exc()) - - try: - logger.debug(f"Dify 请求 Payload: {req.__dict__}") - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM - await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) - - # 执行 LLM 响应后的事件。 - handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent) - for handler in handlers: - try: - await handler.handler(event, llm_response) - except BaseException: - logger.error(traceback.format_exc()) - - if llm_response.role == 'assistant': - # text completion - event.set_result(MessageEventResult().message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT)) - return - elif llm_response.role == 'err': - event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}")) - return - elif llm_response.role == 'tool': - event.set_result(MessageEventResult().message(f"Dify 暂不支持工具调用。")) - yield - - except BaseException as e: - logger.error(traceback.format_exc()) - event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e))) - return \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fd50d4423..9b5617340 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -13,6 +13,7 @@ from astrbot.core import logger from astrbot.core.utils.metrics import Metric from astrbot.core.provider.entites import ProviderRequest, LLMResponse from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star import star_map class LLMRequestSubStage(Stage): @@ -54,7 +55,7 @@ class LLMRequestSubStage(Stage): conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin) if not conversation_id: conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin) - req.session_id = conversation_id + req.session_id = event.unified_msg_origin conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id) req.conversation = conversation req.contexts = json.loads(conversation.history) @@ -69,6 +70,7 @@ class LLMRequestSubStage(Stage): handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent) for handler in handlers: try: + logger.debug(f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") await handler.handler(event, req) except BaseException: logger.error(traceback.format_exc()) @@ -82,10 +84,11 @@ class LLMRequestSubStage(Stage): req.func_tool = None # 暂时不支持递归工具调用 llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM - # 执行 LLM 响应后的事件。 + # 执行 LLM 响应后的事件钩子。 handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent) for handler in handlers: try: + logger.debug(f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") await handler.handler(event, llm_response) except BaseException: logger.error(traceback.format_exc()) @@ -154,6 +157,6 @@ class LLMRequestSubStage(Stage): contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts)) await self.conv_manager.update_conversation( event.unified_msg_origin, - req.session_id, + req.conversation.cid, history=contexts_to_save ) \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index d2a4c8382..703f9469e 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -28,10 +28,8 @@ class StarRequestSubStage(Stage): params = handlers_parsed_params.get(handler.handler_full_name, {}) try: if handler.handler_module_path not in star_map: - # 孤立无援的 star handler continue - - logger.debug(f"执行插件 handler {handler.handler_full_name}") + logger.debug(f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}") wrapper = self._call_handler(self.ctx, event, handler.handler, **params) async for ret in wrapper: yield ret diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 0026e9d6f..c22ab4d92 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -3,7 +3,6 @@ from ..stage import Stage, register_stage from ..context import PipelineContext from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage -from .method.dify_request import DifyRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.provider.entites import ProviderRequest @@ -21,9 +20,6 @@ class ProcessStage(Stage): self.star_request_sub_stage = StarRequestSubStage() await self.star_request_sub_stage.initialize(ctx) - - self.dify_request_sub_stage = DifyRequestSubStage() - await self.dify_request_sub_stage.initialize(ctx) async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: '''处理事件 @@ -49,7 +45,7 @@ class ProcessStage(Stage): if not self.ctx.astrbot_config['provider_settings'].get('enable', True): return - if not event._has_send_oper and event.is_at_or_wake_command: + if not event._has_send_oper and event.is_at_or_wake_command and not event.call_llm: # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): # 事件没有终止传播 @@ -59,10 +55,5 @@ class ProcessStage(Stage): logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。") return - match provider.meta().type: - case "dify": - async for _ in self.dify_request_sub_stage.process(event): - yield - case _: - async for _ in self.llm_request_sub_stage.process(event): - yield \ No newline at end of file + async for _ in self.llm_request_sub_stage.process(event): + yield \ No newline at end of file diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index 4b4ac5b3e..2d0437e34 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -73,8 +73,6 @@ class RateLimitStage(Stage): timestamps.append(now) - return event.continue_event() - def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None: """ 移除时间窗口外的时间戳。 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index a833ac76c..9980d46c3 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -10,6 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain from astrbot.core import logger from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star import star_map from astrbot.core.message.components import Plain, Reply, At @register_stage class RespondStage(Stage): @@ -90,6 +91,7 @@ class RespondStage(Stage): handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent) for handler in handlers: try: + logger.debug(f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") await handler.handler(event) except BaseException: logger.error(traceback.format_exc()) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index aa0ce02b7..c7f0bf72c 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -10,6 +10,7 @@ from astrbot.core import logger from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node from astrbot.core import html_renderer from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star import star_map @register_stage class ResultDecorateStage(Stage): @@ -47,7 +48,7 @@ class ResultDecorateStage(Stage): async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: result = event.get_result() - if result is None: + if result is None or not result.chain: return # 回复时检查内容安全 @@ -63,7 +64,10 @@ class ResultDecorateStage(Stage): handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent) for handler in handlers: try: + logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}") await handler.handler(event) + if event.get_result() is None or not event.get_result().chain: + logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。") except BaseException: logger.error(traceback.format_exc()) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index d650ba840..0500d0279 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -57,7 +57,8 @@ class AstrMessageEvent(abc.ABC): self._has_send_oper = False '''是否有过至少一次发送操作''' - + self.call_llm = False + '''是否在此消息事件中禁止默认的 LLM 请求''' # back_compability self.platform = platform_meta @@ -242,7 +243,15 @@ class AstrMessageEvent(abc.ABC): ''' if self._result is None: return False # 默认是继续传播 - return self._result.is_stopped() + return self._result.is_stopped() + + def should_call_llm(self, call_llm: bool): + ''' + 是否在此消息事件中禁止默认的 LLM 请求。 + + 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 + ''' + self.call_llm = call_llm def get_result(self) -> MessageEventResult: ''' diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 46118b554..a21b0672e 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -1,3 +1,5 @@ +import traceback +import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .platform import Platform from typing import List @@ -11,43 +13,105 @@ class PlatformManager(): self.platform_insts: List[Platform] = [] '''加载的 Platform 的实例''' + self._inst_map = {} + self.platforms_config = config['platform'] self.settings = config['platform_settings'] self.event_queue = event_queue - - try: - for platform in self.platforms_config: - if not platform['enable']: - continue - match platform['type']: - case "aiocqhttp": - from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401 - case "qq_official": - from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401 - case "qq_official_webhook": - from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401 - case "gewechat": - from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401 - case "lark": - from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 - except (ImportError, ModuleNotFoundError) as e: - logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。") - except Exception as e: - logger.error(f"加载平台适配器 {platform['type']} 失败,原因:{e}。") async def initialize(self): + '''初始化所有平台适配器''' for platform in self.platforms_config: - if not platform['enable']: - continue - if platform['type'] not in platform_cls_map: - logger.error(f"未找到适用于 {platform['type']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。") - continue - cls_type = platform_cls_map[platform['type']] - logger.debug(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...") - inst = cls_type(platform, self.settings, self.event_queue) - self.platform_insts.append(inst) + await self.load_platform(platform) - self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue)) + # 网页聊天 + webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) + self.platform_insts.append(webchat_inst) + asyncio.create_task(self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat"))) + + async def load_platform(self, platform_config: dict): + '''实例化一个平台''' + if not platform_config['enable']: + return + + logger.info(f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...") + + # 动态导入 + try: + match platform_config['type']: + case "aiocqhttp": + from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter # noqa: F401 + case "qq_official": + from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401 + case "qq_official_webhook": + from .sources.qqofficial_webhook.qo_webhook_adapter import QQOfficialWebhookPlatformAdapter # noqa: F401 + case "gewechat": + from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401 + case "lark": + from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401 + case "telegram": + from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401 + except (ImportError, ModuleNotFoundError) as e: + logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。") + except Exception as e: + logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") + + + if platform_config['type'] not in platform_cls_map: + logger.error(f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误") + return + cls_type = platform_cls_map[platform_config['type']] + inst = cls_type(platform_config, self.settings, self.event_queue) + self._inst_map[platform_config['id']] = inst + self.platform_insts.append(inst) + + asyncio.create_task(self._task_wrapper(asyncio.create_task(inst.run(), name=platform_config['id'] + "_platform"))) + + async def _task_wrapper(self, task: asyncio.Task): + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + + logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") + for line in traceback.format_exc().split("\n"): + logger.error(f"| {line}") + logger.error("-------") + + async def reload(self, platform_config: dict): + # 还未实现完成,不要调用此方法 + + if platform_config['id'] in self._inst_map: + # 正在运行 + if getattr(self._inst_map[platform_config['id']], 'terminate', None): + logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...") + await self._inst_map[platform_config['id']].terminate() + logger.info(f"{platform_config['id']} 平台适配器已终止。") + del self._inst_map[platform_config['id']] + self.platform_insts.remove(self._inst_map[platform_config['id']]) + else: + logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。") + + # 再启动新的实例 + await self.load_platform(platform_config) + + else: + # 先将 _inst_map 中在 platform_config 中不存在的实例删除 + config_ids = [platform['id'] for platform in self.platforms_config] + for key in list(self._inst_map.keys()): + if key not in config_ids: + if getattr(self._inst_map[key], 'terminate', None): + logger.info(f"正在尝试终止 {key} 平台适配器 ...") + await self._inst_map[key].terminate() + logger.info(f"{key} 平台适配器已终止。") + del self._inst_map[key] + self.platform_insts.remove(self._inst_map[key]) + else: + logger.warning(f"可能无法正常终止 {key} 平台适配器。") + + # 再启动新的实例 + await self.load_platform(platform_config) def get_insts(self): return self.platform_insts \ No newline at end of file diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 1dd356e89..3526d2802 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -20,6 +20,12 @@ class Platform(abc.ABC): ''' raise NotImplementedError + async def terminate(self): + ''' + 终止一个平台的运行实例。 + ''' + pass + @abc.abstractmethod def meta(self) -> PlatformMetadata: ''' diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index b91e44227..5f54f69a2 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,7 +1,7 @@ import asyncio from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video +from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from aiocqhttp import CQHttp from astrbot.core.utils.io import file_to_base64, download_image_by_url @@ -45,15 +45,25 @@ class AiocqhttpMessageEvent(AstrMessageEvent): send_one_by_one = False for seg in message.chain: - if isinstance(seg, (Node, Music)): + if isinstance(seg, (Node, Nodes)): # 转发消息不能和普通消息混在一起发送 send_one_by_one = True break if send_one_by_one: for seg in message.chain: - await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg]))) - await asyncio.sleep(0.5) + if isinstance(seg, Nodes): + # 带有多个节点的合并转发消息 + payload = seg.toDict() + if self.get_group_id(): + payload['group_id'] = self.get_group_id() + await self.bot.call_action('send_group_forward_msg', **payload) + else: + payload['user_id'] = self.get_sender_id() + await self.bot.call_action('send_private_forward_msg', **payload) + else: + await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg]))) + await asyncio.sleep(0.5) else: await self.bot.send(self.message_obj.raw_message, ret) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index ce4b5700d..94b10e3d2 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -16,7 +16,7 @@ from ...register import register_platform_adapter from aiocqhttp.exceptions import ActionFailed from astrbot.core.utils.io import download_file -@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。") +@register_platform_adapter("aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。") class AiocqhttpAdapter(Platform): def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: super().__init__(event_queue) @@ -32,6 +32,8 @@ class AiocqhttpAdapter(Platform): "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", ) + self.stop = False + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): ret = await AiocqhttpMessageEvent._parse_onebot_json(message_chain) match session.message_type.value: @@ -146,8 +148,11 @@ class AiocqhttpAdapter(Platform): a = None if t == 'text': message_str += m['data']['text'].strip() + a = ComponentTypes[t](**m['data']) # noqa: F405 + abm.message.append(a) + elif t == 'file': - if m['data']['url'] and m['data']['url'].startswith("http"): + if m['data'].get('url') and m['data'].get('url').startswith("http"): # Lagrange logger.info("guessing lagrange") @@ -159,6 +164,8 @@ class AiocqhttpAdapter(Platform): "file": path, "name": file_name } + a = ComponentTypes[t](**m['data']) # noqa: F405 + abm.message.append(a) else: try: @@ -173,13 +180,17 @@ class AiocqhttpAdapter(Platform): "file": ret['file'], "name": ret['file_name'] } + a = ComponentTypes[t](**m['data']) # noqa: F405 + abm.message.append(a) except ActionFailed as e: logger.error(f"获取文件失败: {e},此消息段将被忽略。") except BaseException as e: logger.error(f"获取文件失败: {e},此消息段将被忽略。") - - a = ComponentTypes[t](**m['data']) # noqa: F405 - abm.message.append(a) + + else: + a = ComponentTypes[t](**m['data']) # noqa: F405 + abm.message.append(a) + abm.timestamp = int(time.time()) abm.message_str = message_str abm.raw_message = event @@ -220,7 +231,7 @@ class AiocqhttpAdapter(Platform): @self.bot.on_websocket_connection def on_websocket_connection(_): - logger.info("aiocqhttp 适配器已连接。") + logger.info("aiocqhttp(OneBot v11) 适配器已连接。") bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) @@ -230,11 +241,15 @@ class AiocqhttpAdapter(Platform): return bot + async def terminate(self): + self.stop = True + await asyncio.sleep(1) + def meta(self) -> PlatformMetadata: return self.metadata async def shutdown_trigger_placeholder(self): - while not self._event_queue.closed: + while not self._event_queue.closed and not self.stop: await asyncio.sleep(1) logger.info("aiocqhttp 适配器已关闭。") @@ -248,4 +263,4 @@ class AiocqhttpAdapter(Platform): bot=self.bot ) - self.commit_event(message_event) \ No newline at end of file + self.commit_event(message_event) diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index fe0bc6bfb..79f61c774 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -5,6 +5,7 @@ import quart import base64 import datetime import re +import os from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType from astrbot.api.message_components import Plain, Image, At, Record from astrbot.api import logger, sp @@ -53,6 +54,8 @@ class SimpleGewechatClient(): self.multimedia_downloader = None self.userrealnames = {} + + self.stop = False async def get_token_id(self): async with aiohttp.ClientSession() as session: @@ -230,7 +233,7 @@ class SimpleGewechatClient(): ) async def shutdown_trigger_placeholder(self): - while not self.event_queue.closed: + while not self.event_queue.closed and not self.stop: await asyncio.sleep(1) logger.info("gewechat 适配器已关闭。") @@ -304,6 +307,22 @@ class SimpleGewechatClient(): }) while retry_cnt > 0: retry_cnt -= 1 + + # 需要验证码 + if os.path.exists("data/temp/gewe_code"): + with open("data/temp/gewe_code", "r") as f: + code = f.read().strip() + if not code: + logger.warning("未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456") + await asyncio.sleep(5) + continue + payload['captchCode'] = code + logger.info(f"使用验证码: {code}") + try: + os.remove("data/temp/gewe_code") + except: + logger.warning("删除验证码文件 data/temp/gewe_code 失败。") + async with aiohttp.ClientSession() as session: async with session.post( f"{self.base_url}/login/checkLogin", @@ -312,17 +331,25 @@ class SimpleGewechatClient(): ) as resp: json_blob = await resp.json() logger.info(f"检查登录状态: {json_blob}") - status = json_blob['data']['status'] - nickname = json_blob['data'].get('nickName', '') - if status == 1: - logger.info(f"等待确认...{nickname}") - elif status == 2: - logger.info(f"绿泡泡平台登录成功: {nickname}") - break - elif status == 0: - logger.info("等待扫码...") + + ret = json_blob['ret'] + msg = '' + if json_blob['data'] and 'msg' in json_blob['data']: + msg = json_blob['data']['msg'] + if ret == 500 and '安全验证码' in msg: + logger.warning("此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456") else: - logger.warning(f"未知状态: {status}") + status = json_blob['data']['status'] + nickname = json_blob['data'].get('nickName', '') + if status == 1: + logger.info(f"等待确认...{nickname}") + elif status == 2: + logger.info(f"绿泡泡平台登录成功: {nickname}") + break + elif status == 0: + logger.info("等待扫码...") + else: + logger.warning(f"未知状态: {status}") await asyncio.sleep(5) if appid: diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py index 1ca47391e..7e325ce8b 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py @@ -47,6 +47,10 @@ class GewechatPlatformAdapter(Platform): "基于 gewechat 的 Wechat 适配器", ) + async def terminate(self): + self.client.stop = True + await asyncio.sleep(1) + @override def run(self): self.client = SimpleGewechatClient( diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py new file mode 100644 index 000000000..bbb4d6cf5 --- /dev/null +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -0,0 +1,128 @@ +import sys +import uuid +import asyncio + +from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, PlatformMetadata, MessageType +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain, Image, Record, File as AstrBotFile, Video, At +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.api.platform import register_platform_adapter + +from telegram import Update +from telegram.ext import ApplicationBuilder, ContextTypes, filters +from telegram.constants import ChatType +from telegram.ext import MessageHandler as TelegramMessageHandler +from .tg_event import TelegramPlatformEvent +from astrbot.api import logger + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + +@register_platform_adapter("telegram", "telegram 适配器") +class TelegramPlatformAdapter(Platform): + + def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + super().__init__(event_queue) + self.config = platform_config + self.settings = platform_settings + self.client_self_id = uuid.uuid4().hex[:8] + + @override + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + from_username = session.session_id + await TelegramPlatformEvent.send_with_client(self.client, message_chain, from_username) + await super().send_by_session(session, message_chain) + + @override + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + "telegram", + "telegram 适配器", + ) + + @override + async def run(self): + base_url = self.config.get("telegram_api_base_url", "https://api.telegram.org/bot") + if not base_url: + base_url = "https://api.telegram.org/bot" + + self.application = ApplicationBuilder().token(self.config['telegram_token']).base_url(base_url).build() + message_handler = TelegramMessageHandler( + filters=filters.ALL, # receive all messages + callback=self.convert_message + ) + self.application.add_handler(message_handler) + await self.application.initialize() + await self.application.start() + queue = self.application.updater.start_polling() + self.client = self.application.bot + logger.info("Telegram Platform Adapter is running.") + + await queue + + async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + await context.bot.send_message(chat_id=update.effective_chat.id, text=self.config["start_message"]) + + async def convert_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> AstrBotMessage: + message = AstrBotMessage() + # 获得是群聊还是私聊 + if update.effective_chat.type == ChatType.PRIVATE: + message.type = MessageType.FRIEND_MESSAGE + else: + message.type = MessageType.GROUP_MESSAGE + message.group_id = update.effective_chat.id + message.message_id = str(update.message.message_id) + message.session_id = str(update.effective_chat.id) + message.sender = MessageMember(str(update.effective_user.id), update.effective_user.username) + message.self_id = str(context.bot.id) + message.raw_message = update + message.message_str = "" + message.message = [] + + logger.debug(f"Telegram message: {update.message}") + + if update.message.text: + plain_text = update.message.text + + if update.message.entities: + for entity in update.message.entities: + if entity.type == "mention": + name = plain_text[entity.offset:entity.offset+entity.length] + message.message.append(At(qq=message.self_id, name=name)) + plain_text = plain_text[:entity.offset] + plain_text[entity.offset+entity.length:] + + message.message.append(Plain(plain_text)) + message.message_str = plain_text + + + elif update.message.voice: + file = await update.message.voice.get_file() + message.message = [Record(file=file.file_path, url=file.file_path),] + + elif update.message.photo: + for photo in update.message.photo: + file = await photo.get_file() + message.message.append(Image(file=file.file_path, url=file.file_path)) + + elif update.message.document: + file = await update.message.document.get_file() + message.message = [AstrBotFile(file=file.file_path, name="file"),] + + elif update.message.video: + file = await update.message.video.get_file() + message.message = [Video(file=file.file_path, path=file.file_path),] + + + await self.handle_msg(message) + + async def handle_msg(self, message: AstrBotMessage): + message_event = TelegramPlatformEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + client=self.client + ) + self.commit_event(message_event) \ No newline at end of file diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py new file mode 100644 index 000000000..04b82bdea --- /dev/null +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -0,0 +1,50 @@ +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType +from astrbot.api.message_components import Plain, Image, Reply, At +from telegram.ext import ExtBot + +class TelegramPlatformEvent(AstrMessageEvent): + def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: ExtBot): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.client = client + + @staticmethod + async def send_with_client(client: ExtBot, message: MessageChain, user_name: str): + image_path = None + + has_reply = False + reply_message_id = None + at_user_id = None + for i in message.chain: + if isinstance(i, Reply): + has_reply = True + reply_message_id = i.id + if isinstance(i, At): + at_user_id = i.name + + at_flag = False + for i in message.chain: + payload = { + "chat_id": user_name, + } + if has_reply: + payload["reply_to_message_id"] = reply_message_id + + if isinstance(i, Plain): + if at_user_id and not at_flag: + i.text = f"@{at_user_id} " + i.text + at_flag = True + await client.send_message(text=i.text, **payload) + elif isinstance(i, Image): + if i.path: + image_path = i.path + else: + image_path = i.file + await client.send_photo(photo=image_path, **payload) + + async def send(self, message: MessageChain): + if self.get_message_type() == MessageType.GROUP_MESSAGE: + await self.send_with_client(self.client, message, self.message_obj.group_id) + else: + await self.send_with_client(self.client, message, self.get_sender_id()) + await super().send(message) \ No newline at end of file diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index fe30151b2..afcda5e1c 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,5 +1,6 @@ import os import uuid +import base64 from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image @@ -31,6 +32,11 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(ph, "rb") as f2: f.write(f2.read()) + elif comp.file.startswith("base64://"): + base64_str = comp.file[9:] + image_data = base64.b64decode(base64_str) + with open(path, "wb") as f: + f.write(image_data) elif comp.file and comp.file.startswith("http"): await download_image_by_url(comp.file, path=path) else: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py new file mode 100644 index 000000000..0127342fe --- /dev/null +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -0,0 +1,230 @@ +import sys +import uuid +import asyncio +import quart + +from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, PlatformMetadata, MessageType +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain, Image, Record +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.api.platform import register_platform_adapter +from astrbot.core import logger +from requests import Response + +from wechatpy.enterprise.crypto import WeChatCrypto +from wechatpy.enterprise import WeChatClient +from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage +from wechatpy.exceptions import InvalidSignatureException +from wechatpy.enterprise import parse_message +from .wecom_event import WecomPlatformEvent + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + +class WecomServer(): + def __init__( + self, + event_queue: asyncio.Queue, + config: dict + ): + self.server = quart.Quart(__name__) + self.port = int(config.get("port")) + self.server.add_url_rule('/callback/command', view_func=self.verify, methods=['GET']) + self.server.add_url_rule('/callback/command', view_func=self.callback_command, methods=['POST']) + self.event_queue = event_queue + + self.crypto = WeChatCrypto( + config['token'].strip(), + config['encoding_aes_key'].strip(), + config['corpid'].strip() + ) + + self.callback = None + + async def verify(self): + logger.info(f"验证请求有效性: {quart.request.args}") + args = quart.request.args + try: + echo_str = self.crypto.check_signature( + args.get('msg_signature'), + args.get('timestamp'), + args.get('nonce'), + args.get('echostr') + ) + logger.info("验证请求有效性成功。") + return echo_str + except InvalidSignatureException: + logger.error("验证请求有效性失败,签名异常,请检查配置。") + raise + + async def callback_command(self): + data = await quart.request.get_data() + msg_signature = quart.request.args.get('msg_signature') + timestamp = quart.request.args.get('timestamp') + nonce = quart.request.args.get('nonce') + try: + xml = self.crypto.decrypt_message( + data, + msg_signature, + timestamp, + nonce + ) + except InvalidSignatureException: + logger.error("解密失败,签名异常,请检查配置。") + raise + else: + msg = parse_message(xml) + logger.info(f"解析成功: {msg}") + + if self.callback: + await self.callback(msg) + + return "success" + + async def start_polling(self): + logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。") + await self.server.run_task( + host='0.0.0.0', + port=self.port, + shutdown_trigger=self.shutdown_trigger_placeholder + ) + + async def shutdown_trigger_placeholder(self): + while not self.event_queue.closed: + await asyncio.sleep(1) + logger.info("企业微信 适配器已关闭。") + + +@register_platform_adapter("wecom", "wecom 适配器") +class WecomPlatformAdapter(Platform): + def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: + super().__init__(event_queue) + self.config = platform_config + self.settingss = platform_settings + self.client_self_id = uuid.uuid4().hex[:8] + self.api_base_url = platform_config.get("api_base_url", "https://qyapi.weixin.qq.com/cgi-bin/") + + if not self.api_base_url: + self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/" + + if self.api_base_url.endswith("/"): + self.api_base_url = self.api_base_url[:-1] + if not self.api_base_url.endswith("/cgi-bin"): + self.api_base_url += "/cgi-bin" + + if not self.api_base_url.endswith("/"): + self.api_base_url += "/" + + @override + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + await super().send_by_session(session, message_chain) + + @override + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + "wecom", + "wecom 适配器", + ) + + @override + async def run(self): + self.server = WecomServer( + self._event_queue, + self.config + ) + + self.client = WeChatClient( + self.config['corpid'].strip(), + self.config['secret'].strip(), + ) + self.client.API_BASE_URL = self.api_base_url + + async def callback(msg): + await self.convert_message(msg) + + self.server.callback = callback + + await self.server.start_polling() + + async def convert_message(self, msg): + abm = AstrBotMessage() + if msg.type == 'text': + assert isinstance(msg, TextMessage) + abm.message_str = msg.content + abm.self_id = str(msg.agent) + abm.message = [Plain(msg.content)] + abm.type = MessageType.FRIEND_MESSAGE + abm.sender = MessageMember( + msg.source, + msg.source, + ) + abm.message_id = msg.id + abm.timestamp = msg.time + abm.session_id = abm.sender.user_id + abm.raw_message = msg + elif msg.type == 'image': + assert isinstance(msg, ImageMessage) + abm.message_str = "[图片]" + abm.self_id = str(msg.agent) + abm.message = [Image(file=msg.image, url=msg.image)] + abm.type = MessageType.FRIEND_MESSAGE + abm.sender = MessageMember( + msg.source, + msg.source, + ) + abm.message_id = msg.id + abm.timestamp = msg.time + abm.session_id = abm.sender.user_id + abm.raw_message = msg + elif msg.type == 'voice': + assert isinstance(msg, VoiceMessage) + + resp: Response = await asyncio.get_event_loop().run_in_executor( + None, + self.client.media.download, + msg.media_id + ) + path = f"data/temp/wecom_{msg.media_id}.amr" + with open(path, 'wb') as f: + f.write(resp.content) + + try: + from pydub import AudioSegment + + path_wav = f"data/temp/wecom_{msg.media_id}.wav" + audio = AudioSegment.from_file(path) + audio.export(path_wav, format="wav") + except Exception as e: + logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。") + path_wav = path + return + + abm.message_str = "" + abm.self_id = str(msg.agent) + abm.message = [Record(file=path_wav, url=path_wav)] + abm.type = MessageType.FRIEND_MESSAGE + abm.sender = MessageMember( + msg.source, + msg.source, + ) + abm.message_id = msg.id + abm.timestamp = msg.time + abm.session_id = abm.sender.user_id + abm.raw_message = msg + + + + logger.info(f"abm: {abm}") + await self.handle_msg(abm) + + async def handle_msg(self, message: AstrBotMessage): + message_event = WecomPlatformEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + client=self.client + ) + self.commit_event(message_event) \ No newline at end of file diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py new file mode 100644 index 000000000..83e99b5c4 --- /dev/null +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -0,0 +1,103 @@ +import uuid +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.api.message_components import Plain, Image, Record +from wechatpy.enterprise import WeChatClient +from astrbot.core.utils.io import download_image_by_url, download_file + +from astrbot.api import logger + +try: + import pydub +except Exception: + logger.warning( + "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。" + ) + pass + + +class WecomPlatformEvent(AstrMessageEvent): + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + client: WeChatClient, + ): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.client = client + + @staticmethod + async def send_with_client( + client: WeChatClient, message: MessageChain, user_name: str + ): + pass + + async def send(self, message: MessageChain): + message_obj = self.message_obj + + for comp in message.chain: + if isinstance(comp, Plain): + self.client.message.send_text( + message_obj.self_id, message_obj.session_id, comp.text + ) + elif isinstance(comp, Image): + img_url = comp.file + img_path = "" + if img_url.startswith("file:///"): + img_path = img_url[8:] + elif comp.file and comp.file.startswith("http"): + img_path = await download_image_by_url(comp.file) + else: + img_path = img_url + + with open(img_path, "rb") as f: + try: + response = self.client.media.upload("image", f) + except Exception as e: + logger.error(f"企业微信上传图片失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传图片失败: {e}") + ) + return + logger.info(f"企业微信上传图片返回: {response}") + self.client.message.send_image( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) + elif isinstance(comp, Record): + record_url = comp.file + record_path = "" + + if record_url.startswith("file:///"): + record_path = record_url[8:] + elif record_url.startswith("http"): + await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") + else: + record_path = record_url + + # 转成amr + record_path_amr = f"data/temp/{uuid.uuid4()}.amr" + pydub.AudioSegment.from_wav(record_path).export( + record_path_amr, format="amr" + ) + + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"企业微信上传语音失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传语音失败: {e}") + ) + return + logger.info(f"企业微信上传语音返回: {response}") + self.client.message.send_voice( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) + + await super().send(message) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index b82ec6db7..8ddc21b75 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -102,6 +102,29 @@ class FuncCall: ) return _l + def get_func_desc_anthropic_style(self) -> list: + """ + 获得 Anthropic API 风格的**已经激活**的工具描述 + """ + tools = [] + for f in self.func_list: + if not f.active: + continue + + # Convert internal format to Anthropic style + tool = { + "name": f.name, + "description": f.description, + "input_schema": { + "type": "object", + "properties": f.parameters.get("properties", {}), + # Keep the required field from the original parameters if it exists + "required": f.parameters.get("required", []) + } + } + tools.append(tool) + return tools + def get_func_desc_google_genai_style(self) -> Dict: declarations = {} tools = [] diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7187cc100..1e64a1d9d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,11 +1,9 @@ import traceback -import uuid from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider, STTProvider, TTSProvider, Personality from .entites import ProviderType from typing import List from astrbot.core.db import BaseDatabase -from collections import defaultdict from .register import provider_cls_map, llm_tools from astrbot.core import logger, sp @@ -16,6 +14,14 @@ class ProviderManager(): self.provider_stt_settings: dict = config.get('provider_stt_settings', {}) self.provider_tts_settings: dict = config.get('provider_tts_settings', {}) self.persona_configs: list = config.get('persona', []) + self.astrbot_config = config + + self.selected_provider_id = sp.get("curr_provider") + self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") + self.selected_tts_provider_id = self.provider_settings.get("provider_id") + self.provider_enabled = self.provider_settings.get("enable", False) + self.stt_enabled = self.provider_stt_settings.get("enable", False) + self.tts_enabled = self.provider_tts_settings.get("enable", False) # 人格情景管理 # 目前没有拆成独立的模块 @@ -75,14 +81,15 @@ class ProviderManager(): _mood_imitation_dialogs_processed="" ) self.personas.append(self.selected_default_persona) - - + self.provider_insts: List[Provider] = [] '''加载的 Provider 的实例''' self.stt_provider_insts: List[STTProvider] = [] '''加载的 Speech To Text Provider 的实例''' self.tts_provider_insts: List[TTSProvider] = [] '''加载的 Text To Speech Provider 的实例''' + self.inst_map = {} + '''Provider 实例映射. key: provider_id, value: Provider 实例''' self.llm_tools = llm_tools self.curr_provider_inst: Provider = None '''当前使用的 Provider 实例''' @@ -90,7 +97,6 @@ class ProviderManager(): '''当前使用的 Speech To Text Provider 实例''' self.curr_tts_provider_inst: TTSProvider = None '''当前使用的 Text To Speech Provider 实例''' - self.loaded_ids = defaultdict(bool) self.db_helper = db_helper # kdb(experimental) @@ -98,144 +104,168 @@ class ProviderManager(): kdb_cfg = config.get("knowledge_db", {}) if kdb_cfg and len(kdb_cfg): self.curr_kdb_name = list(kdb_cfg.keys())[0] - - changed = False - for provider_cfg in self.providers_config: - if not provider_cfg['enable']: - continue - - if provider_cfg['id'] in self.loaded_ids: - new_id = f"{provider_cfg['id']}_{str(uuid.uuid4())[:8]}" - logger.info(f"Provider ID 重复:{provider_cfg['id']}。已自动更改为 {new_id}。") - provider_cfg['id'] = new_id - changed = True - self.loaded_ids[provider_cfg['id']] = True - - try: - match provider_cfg['type']: - case "openai_chat_completion": - from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial - case "zhipu_chat_completion": - from .sources.zhipu_source import ProviderZhipu as ProviderZhipu - case "llm_tuner": - logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader - case "dify": - from .sources.dify_source import ProviderDify as ProviderDify - case "googlegenai_chat_completion": - from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI - case "openai_whisper_api": - from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI - case "openai_whisper_selfhost": - from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost - case "sensevoice_stt_selfhost": - from .sources.sensevoice_selfhosted_source import ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost - case "openai_tts_api": - from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI - case "fishaudio_tts_api": - from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI - except (ImportError, ModuleNotFoundError) as e: - logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") - continue - except Exception as e: - logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因") - continue - - if changed: - try: - config.save_config() - except Exception as e: - logger.warning(f"保存配置文件失败:{e}") + async def initialize(self): - - selected_provider_id = sp.get("curr_provider") - selected_stt_provider_id = self.provider_stt_settings.get("provider_id") - selected_tts_provider_id = self.provider_settings.get("provider_id") - provider_enabled = self.provider_settings.get("enable", False) - stt_enabled = self.provider_stt_settings.get("enable", False) - tts_enabled = self.provider_tts_settings.get("enable", False) - for provider_config in self.providers_config: - if not provider_config['enable']: - continue - if provider_config['type'] not in provider_cls_map: - logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") - continue + await self.load_provider(provider_config) - provider_metadata = provider_cls_map[provider_config['type']] - logger.debug(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...") - try: - # 按任务实例化提供商 - - if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: - # STT 任务 - inst = provider_metadata.cls_type(provider_config, self.provider_settings) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.stt_provider_insts.append(inst) - if selected_stt_provider_id == provider_config['id'] and stt_enabled: - self.curr_stt_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") - - elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: - # TTS 任务 - inst = provider_metadata.cls_type(provider_config, self.provider_settings) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.tts_provider_insts.append(inst) - if selected_tts_provider_id == provider_config['id'] and tts_enabled: - self.curr_tts_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。") - - elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: - # 文本生成任务 - inst = provider_metadata.cls_type( - provider_config, - self.provider_settings, - self.db_helper, - self.provider_settings.get('persistant_history', True), - self.selected_default_persona - ) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.provider_insts.append(inst) - if selected_provider_id == provider_config['id'] and provider_enabled: - self.curr_provider_inst = inst - logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") - - except Exception as e: - traceback.print_exc() - logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") - - if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled: - self.curr_provider_inst = self.provider_insts[0] - - if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled: - self.curr_stt_provider_inst = self.stt_provider_insts[0] - - if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled: - self.curr_tts_provider_inst = self.tts_provider_insts[0] - if not self.curr_provider_inst: logger.warning("未启用任何用于 文本生成 的提供商适配器。") - if stt_enabled and not self.curr_stt_provider_inst: + if self.stt_enabled and not self.curr_stt_provider_inst: logger.warning("未启用任何用于 语音转文本 的提供商适配器。") - if tts_enabled and not self.curr_tts_provider_inst: + if self.tts_enabled and not self.curr_tts_provider_inst: logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + + async def load_provider(self, provider_config: dict): + if not provider_config['enable']: + return + logger.info(f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ...") + logger.debug(f"Provider Config: {provider_config}") + + # 动态导入 + try: + match provider_config['type']: + case "openai_chat_completion": + from .sources.openai_source import ProviderOpenAIOfficial as ProviderOpenAIOfficial + case "zhipu_chat_completion": + from .sources.zhipu_source import ProviderZhipu as ProviderZhipu + case "anthropic_chat_completion": + from .sources.anthropic_source import ProviderAnthropic as ProviderAnthropic + case "llm_tuner": + logger.info("加载 LLM Tuner 工具 ...") + from .sources.llmtuner_source import LLMTunerModelLoader as LLMTunerModelLoader + case "dify": + from .sources.dify_source import ProviderDify as ProviderDify + case "dashscope": + from .sources.dashscope_source import ProviderDashscope as ProviderDashscope + case "googlegenai_chat_completion": + from .sources.gemini_source import ProviderGoogleGenAI as ProviderGoogleGenAI + case "openai_whisper_api": + from .sources.whisper_api_source import ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI + case "openai_whisper_selfhost": + from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost + case "openai_tts_api": + from .sources.openai_tts_api_source import ProviderOpenAITTSAPI as ProviderOpenAITTSAPI + case "fishaudio_tts_api": + from .sources.fishaudio_tts_api_source import ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI + except (ImportError, ModuleNotFoundError) as e: + logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。") + return + except Exception as e: + logger.critical(f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因") + return + + if provider_config['type'] not in provider_cls_map: + logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") + return + + provider_metadata = provider_cls_map[provider_config['type']] + try: + # 按任务实例化提供商 + + if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: + # STT 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.stt_provider_insts.append(inst) + if self.selected_stt_provider_id == provider_config['id'] and self.stt_enabled: + self.curr_stt_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。") + if not self.curr_stt_provider_inst and self.stt_enabled: + self.curr_stt_provider_inst = inst + + elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: + # TTS 任务 + inst = provider_metadata.cls_type(provider_config, self.provider_settings) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.tts_provider_insts.append(inst) + if self.selected_tts_provider_id == provider_config['id'] and self.tts_enabled: + self.curr_tts_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。") + if not self.curr_tts_provider_inst and self.tts_enabled: + self.curr_tts_provider_inst = inst + + elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: + # 文本生成任务 + inst = provider_metadata.cls_type( + provider_config, + self.provider_settings, + self.db_helper, + self.provider_settings.get('persistant_history', True), + self.selected_default_persona + ) + + if getattr(inst, "initialize", None): + await inst.initialize() + + self.provider_insts.append(inst) + if self.selected_provider_id == provider_config['id'] and self.provider_enabled: + self.curr_provider_inst = inst + logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。") + if not self.curr_provider_inst and self.provider_enabled: + self.curr_provider_inst = inst + + self.inst_map[provider_config['id']] = inst + except Exception as e: + traceback.print_exc() + logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}") + + async def reload(self, provider_config: dict): + await self.terminate_provider(provider_config['id']) + if provider_config['enable']: + await self.load_provider(provider_config) + + # 和配置文件保持同步 + config_ids = [provider['id'] for provider in self.providers_config] + for key in list(self.inst_map.keys()): + if key not in config_ids: + await self.terminate_provider(key) + + if len(self.provider_insts) == 0: + self.curr_provider_inst = None + if len(self.stt_provider_insts) == 0: + self.curr_stt_provider_inst = None + if len(self.tts_provider_insts) == 0: + self.curr_tts_provider_inst = None def get_insts(self): return self.provider_insts + async def terminate_provider(self, provider_id: str): + if provider_id in self.inst_map: + + logger.info(f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...") + + if self.inst_map[provider_id] in self.provider_insts: + self.provider_insts.remove(self.inst_map[provider_id]) + if self.inst_map[provider_id] in self.stt_provider_insts: + self.stt_provider_insts.remove(self.inst_map[provider_id]) + if self.inst_map[provider_id] in self.tts_provider_insts: + self.tts_provider_insts.remove(self.inst_map[provider_id]) + + if self.inst_map[provider_id] == self.curr_provider_inst: + self.curr_provider_inst = None + if self.inst_map[provider_id] == self.curr_stt_provider_inst: + self.curr_stt_provider_inst = None + if self.inst_map[provider_id] == self.curr_tts_provider_inst: + self.curr_tts_provider_inst = None + + if getattr(self.inst_map[provider_id], 'terminate', None): + await self.inst_map[provider_id].terminate() + + logger.info(f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})") + del self.inst_map[provider_id] + async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py new file mode 100644 index 000000000..35200bf23 --- /dev/null +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -0,0 +1,189 @@ +from typing import List +from mimetypes import guess_type + +from anthropic import AsyncAnthropic +from anthropic.types import Message + +from astrbot.core.utils.io import download_image_by_url +from astrbot.core.db import BaseDatabase +from astrbot.api.provider import Provider, Personality +from astrbot import logger +from astrbot.core.provider.func_tool_manager import FuncCall +from ..register import register_provider_adapter +from astrbot.core.provider.entites import LLMResponse +from .openai_source import ProviderOpenAIOfficial + +@register_provider_adapter("anthropic_chat_completion", "Anthropic Claude API 提供商适配器") +class ProviderAnthropic(ProviderOpenAIOfficial): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + db_helper: BaseDatabase, + persistant_history = True, + default_persona: Personality = None + ) -> None: + # Skip OpenAI's __init__ and call Provider's __init__ directly + Provider.__init__(self, provider_config, provider_settings, persistant_history, db_helper, default_persona) + + self.chosen_api_key = None + self.api_keys: List = provider_config.get("key", []) + self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None + self.base_url = provider_config.get("api_base", "https://api.anthropic.com") + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + self.client = AsyncAnthropic( + api_key=self.chosen_api_key, + timeout=self.timeout, + base_url=self.base_url + ) + + self.set_model(provider_config['model_config']['model']) + + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + if tools: + tool_list = tools.get_func_desc_anthropic_style() + if tool_list: + payloads['tools'] = tool_list + + completion = await self.client.messages.create( + **payloads, + stream=False + ) + + assert isinstance(completion, Message) + logger.debug(f"completion: {completion}") + + if len(completion.content) == 0: + raise Exception("API 返回的 completion 为空。") + # TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容 + # 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求 + content = completion.content[-1] + + llm_response = LLMResponse("assistant") + + if content.type == "text": + # text completion + completion_text = str(content.text).strip() + llm_response.completion_text = completion_text + + # Anthropic每次只返回一个函数调用 + if completion.stop_reason == "tool_use": + # tools call (function calling) + args_ls = [] + func_name_ls = [] + func_name_ls.append(content.name) + args_ls.append(content.input) + llm_response.role = "tool" + llm_response.tools_call_args = args_ls + llm_response.tools_call_name = func_name_ls + + if not llm_response.completion_text and not llm_response.tools_call_args: + logger.error(f"API 返回的 completion 无法解析:{completion}。") + raise Exception(f"API 返回的 completion 无法解析:{completion}。") + + llm_response.raw_completion = completion + + return llm_response + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + **kwargs + ) -> LLMResponse: + + if not prompt: + prompt = "" + + new_record = await self.assemble_context(prompt, image_urls) + context_query = [*contexts, new_record] + + for part in context_query: + if '_no_save' in part: + del part['_no_save'] + + model_config = self.provider_config.get("model_config", {}) + + payloads = { + "messages": context_query, + **model_config + } + # Anthropic has a different way of handling system prompts + if system_prompt: + payloads['system'] = system_prompt + + llm_response = None + try: + llm_response = await self._query(payloads, func_tool) + + except Exception as e: + if "maximum context length" in str(e): + retry_cnt = 20 + while retry_cnt > 0: + logger.warning(f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}") + try: + await self.pop_record(context_query) + response = await self.client.messages.create( + messages=context_query, + **model_config + ) + llm_response = LLMResponse("assistant") + llm_response.completion_text = response.content[0].text + llm_response.raw_completion = response + return llm_response + except Exception as e: + if "maximum context length" in str(e): + retry_cnt -= 1 + else: + raise e + return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。") + else: + logger.error(f"发生了错误。Provider 配置如下: {model_config}") + raise e + + return llm_response + + async def assemble_context(self, text: str, image_urls: List[str] = None): + '''组装上下文,支持文本和图片''' + if not image_urls: + return {"role": "user", "content": text} + + content = [] + content.append({"type": "text", "text": text}) + + for image_url in image_urls: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + image_data = await self.encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self.encode_image_bs64(image_path) + else: + image_data = await self.encode_image_bs64(image_url) + + if not image_data: + logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") + continue + + # Get mime type for the image + mime_type, _ = guess_type(image_url) + if not mime_type: + mime_type = "image/jpeg" # Default to JPEG if can't determine + + content.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": image_data.split("base64,")[1] if "base64," in image_data else image_data + } + }) + + return {"role": "user", "content": content} \ No newline at end of file diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py new file mode 100644 index 000000000..9fd26c90b --- /dev/null +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -0,0 +1,128 @@ +import asyncio +import functools +from typing import List +from .. import Provider, Personality +from ..entites import LLMResponse +from ..func_tool_manager import FuncCall +from astrbot.core.db import BaseDatabase +from ..register import register_provider_adapter +from .openai_source import ProviderOpenAIOfficial +from astrbot.core import logger, sp +from dashscope import Application + + +@register_provider_adapter("dashscope", "Dashscope APP 适配器。") +class ProviderDashscope(ProviderOpenAIOfficial): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + db_helper: BaseDatabase, + persistant_history=False, + default_persona: Personality = None, + ) -> None: + Provider.__init__( + self, + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, + ) + self.api_key = provider_config.get("dashscope_api_key", "") + if not self.api_key: + raise Exception("阿里云百炼 API Key 不能为空。") + self.app_id = provider_config.get("dashscope_app_id", "") + if not self.app_id: + raise Exception("阿里云百炼 APP ID 不能为空。") + self.dashscope_app_type = provider_config.get("dashscope_app_type", "") + if not self.dashscope_app_type: + raise Exception("阿里云百炼 APP 类型不能为空。") + self.model_name = "dashscope" + self.variables: dict = provider_config.get("variables", {}) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + **kwargs, + ) -> LLMResponse: + # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 + session_vars = sp.get("session_variables", {}) + session_var = session_vars.get(session_id, {}) + payload_vars.update(session_var) + + if self.dashscope_app_type in ["agent", "dialog-workflow"]: + # 支持多轮对话的 + new_record = {"role": "user", "content": prompt} + if image_urls: + logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") + contexts_no_img = await self._remove_image_from_context(contexts) + context_query = [*contexts_no_img, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + # 调用阿里云百炼 API + partial = functools.partial( + Application.call, + app_id=self.app_id, + api_key=self.api_key, + messages=context_query, + biz_params=payload_vars or None, + ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + else: + # 不支持多轮对话的 + # 调用阿里云百炼 API + partial = functools.partial( + Application.call, + app_id=self.app_id, + promtp=prompt, + api_key=self.api_key, + biz_params=payload_vars or None, + ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + + logger.debug(f"dashscope resp: {response}") + + if response.status_code != 200: + logger.error( + f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code" + ) + return LLMResponse( + role="err", + completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + ) + + output_text = response.output.get("text", "") + return LLMResponse(role="assistant", completion_text=output_text) + + async def forget(self, session_id): + return True + + async def get_current_key(self): + return self.api_key + + async def set_key(self, key): + raise Exception("阿里云百炼 适配器不支持设置 API Key。") + + async def get_models(self): + return [self.get_model()] + + async def get_human_readable_context(self, session_id, page, page_size): + raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。") + + async def terminate(self): + pass \ No newline at end of file diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 807520e0e..9e8e344f7 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -32,6 +32,7 @@ class ProviderDify(Provider): self.model_name = "dify" self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output") self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query") + self.variables: dict = provider_config.get("variables", {}) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" self.timeout = provider_config.get("timeout", 120) @@ -72,15 +73,18 @@ class ProviderDify(Provider): logger.warning(f"未知的图片链接:{image_url},图片将忽略。") # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 session_vars = sp.get("session_variables", {}) session_var = session_vars.get(session_id, {}) + payload_vars.update(session_var) try: match self.api_type: case "chat" | "agent": async for chunk in self.api_client.chat_messages( inputs={ - **session_var + **payload_vars, }, query=prompt, user=session_id, @@ -95,13 +99,19 @@ class ProviderDify(Provider): if not conversation_id: self.conversation_ids[session_id] = chunk['conversation_id'] conversation_id = chunk['conversation_id'] + elif chunk['event'] == 'message_end': + logger.debug("Dify message end") + break + elif chunk['event'] == 'error': + logger.error(f"Dify 出现错误:{chunk}") + raise Exception(f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}") case "workflow": async for chunk in self.api_client.workflow_run( inputs={ self.dify_query_input_key: prompt, "astrbot_session_id": session_id, - **session_var + **payload_vars, }, user=session_id, files=files_payload, @@ -126,6 +136,9 @@ class ProviderDify(Provider): logger.error(f"Dify 请求失败:{str(e)}") return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}") + if not result: + logger.warning("Dify 请求结果为空,请查看 Debug 日志。") + return LLMResponse(role="assistant", completion_text=result) async def forget(self, session_id): diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index aa440ab8c..c6357bf7a 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -225,6 +225,7 @@ class ProviderGoogleGenAI(Provider): if 'tools' in payloads: del payloads['tools'] llm_response = await self._query(payloads, None) + break elif "429" in str(e) or "API key not valid" in str(e): keys.remove(chosen_key) if len(keys) > 0: @@ -232,7 +233,7 @@ class ProviderGoogleGenAI(Provider): logger.info(f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}...") continue else: - logger.error(f"A检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...") + logger.error(f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}...") raise Exception("API 资源已耗尽,且没有可用的 Key 重试...") else: logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}") @@ -281,4 +282,8 @@ class ProviderGoogleGenAI(Provider): with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode('utf-8') return "data:image/jpeg;base64," + image_bs64 - return '' \ No newline at end of file + return '' + + async def terminate(self): + await self.client.client.close() + logger.info("Google GenAI 适配器已终止。") \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index a3114a6c1..d43f00e99 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,6 +1,7 @@ import base64 import json import os +import inspect from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat.chat_completion import ChatCompletion @@ -48,8 +49,12 @@ class ProviderOpenAIOfficial(Provider): base_url=provider_config.get("api_base", None), timeout=self.timeout ) - - self.set_model(provider_config['model_config']['model']) + + self.default_params = inspect.signature(self.client.chat.completions.create).parameters.keys() + + model_config = provider_config.get("model_config", {}) + model = model_config.get("model", "unknown") + self.set_model(model) async def get_models(self): try: @@ -67,13 +72,26 @@ class ProviderOpenAIOfficial(Provider): tool_list = tools.get_func_desc_openai_style() if tool_list: payloads['tools'] = tool_list - + + # 不在默认参数中的参数放在 extra_body 中 + extra_body = {} + to_del = [] + for key in payloads.keys(): + if key not in self.default_params: + extra_body[key] = payloads[key] + to_del.append(key) + for key in to_del: + del payloads[key] + completion = await self.client.chat.completions.create( **payloads, - stream=False + stream=False, + extra_body=extra_body ) - assert isinstance(completion, ChatCompletion) + if not isinstance(completion, ChatCompletion): + raise Exception(f"API 返回的 completion 类型错误:{type(completion)}: {completion}。") + logger.debug(f"completion: {completion}") if len(completion.choices) == 0: diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 139fb5a39..17bffcf34 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -9,13 +9,19 @@ class PlatformAdapterType(enum.Flag): QQOFFICIAL = enum.auto() VCHAT = enum.auto() GEWECHAT = enum.auto() - ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT + TELEGRAM = enum.auto() + WECOM = enum.auto() + LARK = enum.auto() + ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT | TELEGRAM | WECOM | LARK ADAPTER_NAME_2_TYPE = { "aiocqhttp": PlatformAdapterType.AIOCQHTTP, "qq_official": PlatformAdapterType.QQOFFICIAL, "vchat": PlatformAdapterType.VCHAT, - "gewechat": PlatformAdapterType.GEWECHAT + "gewechat": PlatformAdapterType.GEWECHAT, + "telegram": PlatformAdapterType.TELEGRAM, + "wecom": PlatformAdapterType.WECOM, + "lark": PlatformAdapterType.LARK } class PlatformAdapterTypeFilter(HandlerFilter): diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index ad378ce85..2cf5aaeeb 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -59,7 +59,6 @@ def register_command(command_name: str = None, sub_command: str = None, alias: s if isinstance(command_name, RegisteringCommandable): # 子指令 parent_command_names = command_name.parent_group.get_complete_command_names() - logger.debug(f"parent_command_names: {parent_command_names}") new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names) command_name.parent_group.add_sub_command_filter(new_command) else: diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 66322f4ec..62d30cc79 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -32,9 +32,6 @@ class AstrBotUpdator(RepoZipUpdator): pass def _reboot(self, delay: int = 3): - if os.environ.get('TEST_MODE', 'off') == 'on': - logger.info("测试模式下不会重启。") - return py = sys.executable time.sleep(delay) self.terminate_child_processes() @@ -47,8 +44,11 @@ class AstrBotUpdator(RepoZipUpdator): async def check_update(self, url: str, current_version: str) -> ReleaseInfo: return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION) + + async def get_releases(self) -> list: + return await self.fetch_release_info(self.ASTRBOT_RELEASE_API) - async def update(self, reboot = False, latest = True, version = None): + async def update(self, reboot = False, latest = True, version = None, proxy = ""): update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None @@ -70,6 +70,10 @@ class AstrBotUpdator(RepoZipUpdator): raise Exception("commit hash 长度不正确,应为 40") logger.info(f"正在尝试更新到指定 commit: {version}") file_url = "https://github.com/Soulter/AstrBot/archive/" + version + ".zip" + + if proxy: + proxy = proxy.removesuffix("/") + file_url = f"{proxy}/{file_url}" try: await download_file(file_url, "temp.zip") diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 67003a1bf..19841cc8d 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -34,10 +34,19 @@ class RepoZipUpdator(): result = await response.json() if not result: return [] - if latest: - ret = self.github_api_release_parser([result[0]]) - else: - ret = self.github_api_release_parser(result) + # if latest: + # ret = self.github_api_release_parser([result[0]]) + # else: + # ret = self.github_api_release_parser(result) + ret = [] + for release in result: + ret.append({ + "version": release['name'], + "published_at": release['published_at'], + "body": release['body'], + "tag_name": release['tag_name'], + "zipball_url": release['zipball_url'] + }) except BaseException: raise Exception("解析版本信息失败") return ret @@ -49,17 +58,10 @@ class RepoZipUpdator(): ''' ret = [] for release in releases: - version = release['name'] - commit_hash = '' - # 规范是: v3.0.7.xxxxxx,其中xxxxxx为 commit hash - _t = version.split(".") - if len(_t) == 4: - commit_hash = _t[3] ret.append({ "version": release['name'], "published_at": release['published_at'], "body": release['body'], - "commit_hash": commit_hash, "tag_name": release['tag_name'], "zipball_url": release['zipball_url'] }) @@ -114,15 +116,6 @@ class RepoZipUpdator(): release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" else: release_url = releases[0]['zipball_url'] - - # 镜像站点 - # match self.repo_mirror: - # case 'https://github-mirror.us.kg/': - # release_url = self.repo_mirror + release_url - # case "https://ghp.ci/": - # release_url = self.repo_mirror + release_url - # case _: - # pass if proxy: release_url = f"{proxy}/{release_url}" diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 329ff6f36..594620004 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -181,8 +181,8 @@ class ChatRoute(Route): self.db.update_conversation(username, cid, history=json.dumps(history)) await asyncio.sleep(0.5) - except BaseException as e: - logger.debug(f"用户 {username} 断开聊天长连接: {str(e)}。") + except BaseException as _: + logger.debug(f"用户 {username} 断开聊天长连接。") self.curr_chat_sse.pop(username) return diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 61810705c..0e2cdc1c4 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,4 +1,5 @@ import typing +import traceback from .route import Route, Response, RouteContext from quart import request from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP @@ -61,7 +62,7 @@ def validate_config(data, schema: dict, is_core: bool) -> typing.Tuple[typing.Li group_meta = group.get("metadata") if not group_meta: continue - logger.info(f"验证配置: 组 {key} ...") + # logger.info(f"验证配置: 组 {key} ...") validate(data, group_meta, path=f"{key}.") else: validate(data, schema) @@ -77,6 +78,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) else: errors, post_config = validate_config(post_config, config.schema, is_core) except BaseException as e: + logger.error(traceback.format_exc()) logger.warning(f"验证配置时出现异常: {e}") if errors: raise ValueError(f"格式校验未通过: {errors}") @@ -90,6 +92,14 @@ class ConfigRoute(Route): '/config/get': ('GET', self.get_configs), '/config/astrbot/update': ('POST', self.post_astrbot_configs), '/config/plugin/update': ('POST', self.post_plugin_configs), + + '/config/platform/new': ('POST', self.post_new_platform), + '/config/platform/update': ('POST', self.post_update_platform), + '/config/platform/delete': ('POST', self.post_delete_platform), + + '/config/provider/new': ('POST', self.post_new_provider), + '/config/provider/update': ('POST', self.post_update_provider), + '/config/provider/delete': ('POST', self.post_delete_provider) } self.register_routes() @@ -118,7 +128,99 @@ class ConfigRoute(Route): return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__ except Exception as e: return Response().error(str(e)).__dict__ - + + async def post_new_platform(self): + new_platform_config = await request.json + self.config['platform'].append(new_platform_config) + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.platform_manager.load_platform(new_platform_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "新增平台配置成功~").__dict__ + + async def post_new_provider(self): + new_provider_config = await request.json + self.config['provider'].append(new_provider_config) + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.provider_manager.load_provider(new_provider_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "新增服务提供商配置成功~").__dict__ + + async def post_update_platform(self): + update_platform_config = await request.json + platform_id = update_platform_config.get("id", None) + new_config = update_platform_config.get("config", None) + if not platform_id or not new_config: + return Response().error("参数错误").__dict__ + + for i, platform in enumerate(self.config['platform']): + if platform['id'] == platform_id: + self.config['platform'][i] = new_config + break + else: + return Response().error("未找到对应平台").__dict__ + + try: + await self._save_astrbot_configs(self.config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "更新平台配置成功~").__dict__ + + async def post_update_provider(self): + update_provider_config = await request.json + provider_id = update_provider_config.get("id", None) + new_config = update_provider_config.get("config", None) + if not provider_id or not new_config: + return Response().error("参数错误").__dict__ + + for i, provider in enumerate(self.config['provider']): + if provider['id'] == provider_id: + self.config['provider'][i] = new_config + break + else: + return Response().error("未找到对应服务提供商").__dict__ + + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.provider_manager.reload(new_config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "更新成功,已经实时生效~").__dict__ + + async def post_delete_platform(self): + platform_id = await request.json + platform_id = platform_id.get("id") + for i, platform in enumerate(self.config['platform']): + if platform['id'] == platform_id: + del self.config['platform'][i] + break + else: + return Response().error("未找到对应平台").__dict__ + try: + await self._save_astrbot_configs(self.config) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "删除平台配置成功~").__dict__ + + async def post_delete_provider(self): + provider_id = await request.json + provider_id = provider_id.get("id") + for i, provider in enumerate(self.config['provider']): + if provider['id'] == provider_id: + del self.config['provider'][i] + break + else: + return Response().error("未找到对应服务提供商").__dict__ + try: + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.provider_manager.terminate_provider(provider_id) + except Exception as e: + return Response().error(str(e)).__dict__ + return Response().ok(None, "删除成功,已经实时生效~").__dict__ + async def _get_astrbot_config(self): config = self.config diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 4e5d68ecd..1a7e61bc2 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -16,7 +16,7 @@ class StatRoute(Route): '/stat/get': ('GET', self.get_stat), '/stat/version': ('GET', self.get_version), '/stat/start-time': ('GET', self.get_start_time), - '/stat/restart-core': ('GET', self.restart_core) + '/stat/restart-core': ('POST', self.restart_core) } self.db_helper = db_helper self.register_routes() diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 841efca82..19e146fbe 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -1,15 +1,31 @@ from .route import Route, RouteContext + + class StaticFileRoute(Route): def __init__(self, context: RouteContext) -> None: super().__init__(context) - - index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat'] + + index_ = [ + "/", + "/auth/login", + "/config", + "/logs", + "/extension", + "/dashboard/default", + "/project-atri", + "/console", + "/chat", + "/settings", + "/platforms", + "/providers", + "/about", + ] for i in index_: self.app.add_url_rule(i, view_func=self.index) - + @self.app.errorhandler(404) async def page_not_found(e): - return "404 Not found。如果你初次使用打开面板发现 404,请参考文档: https://astrbot.app/deploy/dashboard-404.html" - + return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。" + async def index(self): - return await self.app.send_static_file('index.html') \ No newline at end of file + return await self.app.send_static_file("index.html") diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index d5e06e652..630d50d2a 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -13,6 +13,7 @@ class UpdateRoute(Route): super().__init__(context) self.routes = { '/update/check': ('GET', self.check_update), + '/update/releases': ('GET', self.get_releases), '/update/do': ('POST', self.update_project), '/update/dashboard': ('POST', self.update_dashboard), '/update/pip-install': ('POST', self.install_pip_package) @@ -46,6 +47,14 @@ class UpdateRoute(Route): except Exception as e: logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)") return Response().error(e.__str__()).__dict__ + + async def get_releases(self): + try: + ret = await self.astrbot_updator.get_releases() + return Response().ok(ret).__dict__ + except Exception as e: + logger.error(f"/api/update/releases: {traceback.format_exc()}") + return Response().error(e.__str__()).__dict__ async def update_project(self): data = await request.json @@ -56,8 +65,13 @@ class UpdateRoute(Route): version = '' else: latest = False + + proxy: str = data.get("proxy", None) + if proxy: + proxy = proxy.removesuffix("/") + try: - await self.astrbot_updator.update(latest=latest, version=version) + await self.astrbot_updator.update(latest=latest, version=version, proxy=proxy) if latest: try: diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 2656ba2a9..e3f28cb55 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -2,6 +2,9 @@ import logging import jwt import asyncio import os +import socket +import sys +import psutil from astrbot.core.config.default import VERSION from quart import Quart, request, jsonify, g from quart.logging import default_handler @@ -67,6 +70,47 @@ class AstrBotDashboard(): await asyncio.sleep(1) logger.info("管理面板已关闭。") + def check_port_in_use(self, port: int) -> bool: + """ + 跨平台检测端口是否被占用 + """ + try: + # 创建 IPv4 TCP Socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # 设置超时时间 + sock.settimeout(2) + result = sock.connect_ex(('127.0.0.1', port)) + sock.close() + # result 为 0 表示端口被占用 + return result == 0 + except Exception as e: + logger.warning(f"检查端口 {port} 时发生错误: {str(e)}") + # 如果出现异常,保守起见认为端口可能被占用 + return True + + def get_process_using_port(self, port: int) -> str: + """获取占用端口的进程详细信息""" + try: + for conn in psutil.net_connections(kind='inet'): + if conn.laddr.port == port: + try: + process = psutil.Process(conn.pid) + # 获取详细信息 + proc_info = [ + f"进程名: {process.name()}", + f"PID: {process.pid}", + f"执行路径: {process.exe()}", + f"工作目录: {process.cwd()}", + f"启动命令: {' '.join(process.cmdline())}" + ] + return "\n ".join(proc_info) + except (psutil.NoSuchProcess, psutil.AccessDenied) as e: + return f"无法获取进程详细信息(可能需要管理员权限): {str(e)}" + return "未找到占用进程" + except Exception as e: + return f"获取进程信息失败: {str(e)}" + def run(self): try: ip_addr = get_local_ip_addresses() @@ -76,6 +120,17 @@ class AstrBotDashboard(): port = self.core_lifecycle.astrbot_config['dashboard'].get("port", 6185) if isinstance(port, str): port = int(port) + + if self.check_port_in_use(port): + process_info = self.get_process_using_port(port) + logger.error(f"错误:端口 {port} 已被占用\n" + f"占用信息: \n {process_info}\n" + f"请确保:\n" + f"1. 没有其他 AstrBot 实例正在运行\n" + f"2. 端口 {port} 没有被其他程序占用\n" + f"3. 如需使用其他端口,请修改配置文件") + + raise Exception(f"端口 {port} 已被占用") display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n" display += f" ➜ 本地: http://localhost:{port}\n" diff --git a/changelogs/v3.4.31.md b/changelogs/v3.4.31.md new file mode 100644 index 000000000..6047421d8 --- /dev/null +++ b/changelogs/v3.4.31.md @@ -0,0 +1,18 @@ +# What's Changed + +> 提示:改动范围较大 + +1. ✨ 新增: 添加对 Anthropic Claude 的支持 by @Rt39 +2. ✨ 新增: 支持阿里云百炼应用(dashscope)智能体、工作流 #552 by @Soulter +3. ✨ 新增: 支持 AstrBot 更新使用 Github 加速地址 by @Fridemn +4. ✨ 新增: 适配多节点的转发消息,添加新的消息段 `Nodes` +5. ✨ 新增: 支持在管理面板重启(设置页) +6. ✨ 新增: 前端支持以列表展示正式版和开发版的列表 +7. ✨ 新增: 支持插件禁止默认的llm调用(event.should_call_llm())#579 +8. 🍺 重构: 支持更大范围的热重载以及管理面板将平台和提供商配置独立化 by @Soulter +9. ⚡ 优化: 启动时检查端口占用 by @Fridemn +10. ⚡ 优化: 添加控制台关闭自动滚动按钮 by @Fridemn +11. ⚡ 优化: 在聊天页面添加粘贴图片的快捷键提示 #557 +12. 🐛 修复: 修复 webchat 未处理 base64 的问题 by @Raven95676 +13. 🐛 修复: 修复 aiocqhttp_platform_adapter 文件相关判断逻辑 by @Raven95676 +14. ‼️🐛 修复: 修复 gemini 请求时出现多次不支持函数工具调用最后 429 的问题 \ No newline at end of file diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 37efe7f6a..6ace849ef 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -6,8 +6,8 @@
+ style="margin-bottom: 8px" :text="metadata[metadataKey].items[key]?.hint" + :title="'💡 ' + metadata[metadataKey].items[key]?.description" type="info" variant="tonal" color="primary">
@@ -66,8 +66,8 @@
+ style="margin-bottom: 8px" :text="metadata[metadataKey]?.hint" + :title="'💡 ' + metadata[metadataKey]?.description" type="info" variant="tonal" color="primary">
diff --git a/dashboard/src/components/shared/ConsoleDisplayer.vue b/dashboard/src/components/shared/ConsoleDisplayer.vue index c0f4447ff..a24460e37 100644 --- a/dashboard/src/components/shared/ConsoleDisplayer.vue +++ b/dashboard/src/components/shared/ConsoleDisplayer.vue @@ -4,7 +4,7 @@ import { useCommonStore } from '@/stores/common'; @@ -13,6 +13,7 @@ export default { name: 'ConsoleDisplayer', data() { return { + autoScroll: true, // 默认开启自动滚动 logColorAnsiMap: { '\u001b[1;34m': 'color: #0000FF; font-weight: bold;', // bold_blue '\u001b[1;36m': 'color: #00FFFF; font-weight: bold;', // bold_cyan @@ -54,6 +55,9 @@ export default { } }, methods: { + toggleAutoScroll() { + this.autoScroll = !this.autoScroll; + }, printLog(log) { // append 一个 span 标签到 term,block 的方式 let ele = document.getElementById('term') @@ -66,11 +70,13 @@ export default { break } } - span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace;' + span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap;' span.classList.add('fade-in') span.innerText = log ele.appendChild(span) - ele.scrollTop = ele.scrollHeight + if (this.autoScroll) { + ele.scrollTop = ele.scrollHeight + } } }, } diff --git a/dashboard/src/components/shared/ListConfigItem.vue b/dashboard/src/components/shared/ListConfigItem.vue index be23f5d32..74ba445f2 100644 --- a/dashboard/src/components/shared/ListConfigItem.vue +++ b/dashboard/src/components/shared/ListConfigItem.vue @@ -1,7 +1,7 @@ + \ No newline at end of file diff --git a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue index 3a7ee4f65..c98b4d18c 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue +++ b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue @@ -1,16 +1,69 @@ + + \ No newline at end of file diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue new file mode 100644 index 000000000..17cd6ef38 --- /dev/null +++ b/dashboard/src/views/ProviderPage.vue @@ -0,0 +1,240 @@ + + + + \ No newline at end of file diff --git a/dashboard/src/views/Settings.vue b/dashboard/src/views/Settings.vue index fb8a20425..852825dad 100644 --- a/dashboard/src/views/Settings.vue +++ b/dashboard/src/views/Settings.vue @@ -5,23 +5,39 @@ 网络 - + + 系统 + + + 重启 + + + +
+ +