diff --git a/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml b/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml index e5aaaaf78..7eb5ae15c 100644 --- a/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml +++ b/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml @@ -6,7 +6,7 @@ body: - type: markdown attributes: value: | - 欢迎发布插件到插件市场! + 欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。 - type: textarea attributes: @@ -22,9 +22,10 @@ body: 插件名: 插件作者: 插件简介: - 标签: (可选) - 社交链接: (可选, 将会在插件市场作者名称上作为可点击的链接) - description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。 + 支持的消息平台:(必填,如 QQ、微信、飞书) + 标签:(可选) + 社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接) + description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。 - type: checkboxes attributes: diff --git a/.gitignore b/.gitignore index 865b0596d..a3b2aad90 100644 --- a/.gitignore +++ b/.gitignore @@ -26,5 +26,5 @@ venv/* packages/python_interpreter/workplace .venv/* .conda/ -.idea/ +.idea pytest.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c38647e7..4dece7145 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ ci: autoupdate_commit_msg: ":balloon: pre-commit autoupdate" repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.10 + rev: v0.11.0 hooks: - id: ruff - id: ruff-format diff --git a/README.md b/README.md index b4d97fbbb..9d9589384 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,13 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ -[](https://github.com/Soulter/AstrBot/releases/latest) - - - -[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) - -[](https://codecov.io/gh/Soulter/AstrBot) -[](https://gitcode.com/Soulter/AstrBot) +[](https://github.com/Soulter/AstrBot/releases/latest) + + + +[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) + +[](https://codecov.io/gh/Soulter/AstrBot) English | 日本語 | @@ -27,6 +26,8 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。 +[](https://gitcode.com/Soulter/AstrBot) + ## ✨ 主要功能 1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。 @@ -51,15 +52,19 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用 需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。 -#### Replit 部署 +#### 宝塔面板部署 -[](https://repl.it/github/Soulter/AstrBot) +请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。 #### CasaOS 部署 社区贡献的部署方式。 -请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。 +请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。 + +#### Replit 部署 + +[](https://repl.it/github/Soulter/AstrBot) #### 手动部署 @@ -106,6 +111,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用 | Whisper | ✔ | 语音转文本 | 支持 API、本地部署 | | SenseVoice | ✔ | 语音转文本 | 本地部署 | | OpenAI TTS API | ✔ | 文本转语音 | | +| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference | | Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 | | Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS | diff --git a/astrbot/api/platform/__init__.py b/astrbot/api/platform/__init__.py index dcc02bb49..5a98c5903 100644 --- a/astrbot/api/platform/__init__.py +++ b/astrbot/api/platform/__init__.py @@ -5,6 +5,7 @@ from astrbot.core.platform import ( MessageMember, MessageType, PlatformMetadata, + Group, ) from astrbot.core.platform.register import register_platform_adapter @@ -18,4 +19,5 @@ __all__ = [ "MessageType", "PlatformMetadata", "register_platform_adapter", + "Group", ] diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 280c841d0..7db7ff5f7 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,7 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ -VERSION = "3.4.37" +VERSION = "3.4.39" DB_PATH = "data/data_v3.db" # 默认配置 @@ -85,7 +85,7 @@ DEFAULT_CONFIG = { "enable": True, "username": "astrbot", "password": "77b90590a8945a7d36c963981a307dc9", - "host": "127.0.0.1", + "host": "0.0.0.0", "port": 6185, }, "platform": [], @@ -223,7 +223,7 @@ CONFIG_METADATA_2 = { "hint": "启用后,机器人可以接收到频道的私聊消息。", }, "ws_reverse_host": { - "description": "反向 Websocket 主机地址", + "description": "反向 Websocket 主机地址(AstrBot 为服务器端)", "type": "string", "hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。", }, @@ -581,7 +581,7 @@ CONFIG_METADATA_2 = { "dify_api_type": "chat", "dify_api_key": "", "dify_api_base": "https://api.dify.ai/v1", - "dify_workflow_output_key": "", + "dify_workflow_output_key": "astrbot_wf_output", "dify_query_input_key": "astrbot_text_query", "variables": {}, "timeout": 60, @@ -593,6 +593,11 @@ CONFIG_METADATA_2 = { "dashscope_app_type": "agent", "dashscope_api_key": "", "dashscope_app_id": "", + "rag_options": { + "pipeline_ids": [], + "file_ids": [], + "output_reference": False, + }, "variables": {}, "timeout": 60, }, @@ -665,6 +670,30 @@ CONFIG_METADATA_2 = { }, }, "items": { + "rag_options": { + "description": "RAG 选项", + "type": "object", + "hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)", + "items": { + "pipeline_ids": { + "description": "知识库 ID 列表", + "type": "list", + "items": {"type": "string"}, + "hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。", + }, + "file_ids": { + "description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。", + "type": "list", + "items": {"type": "string"}, + "hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。", + }, + "output_reference": { + "description": "是否输出知识库/文档的引用", + "type": "bool", + "hint": "在每次回答尾部加上引用源。默认为 False。", + }, + }, + }, "sensevoice_hint": { "description": "部署SenseVoice", "type": "string", @@ -681,12 +710,14 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。", }, - # "variables": { - # "description": "工作流固定输入变量", - # "type": "object", - # "obvious_hint": True, - # "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", - # }, + "variables": { + "description": "工作流固定输入变量", + "type": "object", + "obvious_hint": True, + "items": {}, + "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", + "invisible": True, + }, # "fastgpt_app_type": { # "description": "应用类型", # "type": "string", @@ -697,7 +728,7 @@ CONFIG_METADATA_2 = { "dashscope_app_type": { "description": "应用类型", "type": "string", - "hint": "阿里云百炼应用的应用类型。", + "hint": "百炼应用的应用类型。", "options": [ "agent", "agent-arrange", diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index e5485edb4..e52d94674 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -40,7 +40,6 @@ class AstrBotCoreLifecycle: else: logger.setLevel(self.astrbot_config["log_level"]) self.event_queue = Queue() - self.event_queue.closed = False self.provider_manager = ProviderManager(self.astrbot_config, self.db) @@ -81,6 +80,8 @@ class AstrBotCoreLifecycle: await self.platform_manager.initialize() """根据配置实例化各个平台适配器""" + self.dashboard_shutdown_event = asyncio.Event() + def _load(self): event_bus_task = asyncio.create_task( self.event_bus.dispatch(), name="event_bus" @@ -129,11 +130,12 @@ class AstrBotCoreLifecycle: await asyncio.gather(*self.curr_tasks, return_exceptions=True) async def stop(self): - self.event_queue.closed = True for task in self.curr_tasks: task.cancel() await self.provider_manager.terminate() + await self.platform_manager.terminate() + self.dashboard_shutdown_event.set() for task in self.curr_tasks: try: @@ -143,8 +145,10 @@ class AstrBotCoreLifecycle: except Exception as e: logger.error(f"任务 {task.get_name()} 发生错误: {e}") - def restart(self): - self.event_queue.closed = True + async def restart(self): + await self.provider_manager.terminate() + await self.platform_manager.terminate() + self.dashboard_shutdown_event.set() threading.Thread( target=self.astrbot_updator._reboot, name="restart", daemon=True ).start() diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/core/initial_loader.py similarity index 82% rename from astrbot/dashboard/dashboard_lifecycle.py rename to astrbot/core/initial_loader.py index 9c5c9138d..f91a71da3 100644 --- a/astrbot/dashboard/dashboard_lifecycle.py +++ b/astrbot/core/initial_loader.py @@ -2,17 +2,16 @@ import asyncio import traceback from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from .server import AstrBotDashboard from astrbot.core.db import BaseDatabase from astrbot.core import LogBroker +from astrbot.dashboard.server import AstrBotDashboard -class AstrBotDashBoardLifecycle: +class InitialLoader: def __init__(self, db: BaseDatabase, log_broker: LogBroker): self.db = db self.logger = logger self.log_broker = log_broker - self.dashboard_server = None async def start(self): core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) @@ -25,7 +24,9 @@ class AstrBotDashBoardLifecycle: logger.critical(traceback.format_exc()) logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!") - self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db) + self.dashboard_server = AstrBotDashboard( + core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event + ) task = asyncio.gather(core_task, self.dashboard_server.run()) try: diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 7f73e5d91..64c324a9e 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -25,9 +25,11 @@ SOFTWARE. import base64 import json import os +import uuid import typing as T from enum import Enum from pydantic.v1 import BaseModel +from astrbot.core.utils.io import download_image_by_url, file_to_base64 class ComponentType(Enum): @@ -146,6 +148,51 @@ class Record(BaseMessageComponent): return Record(file=url, **_) raise Exception("not a valid url") + async def convert_to_file_path(self) -> str: + """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 语音的本地路径,以绝对路径表示。 + """ + if self.file and self.file.startswith("file:///"): + file_path = self.file[8:] + return file_path + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + return os.path.abspath(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(file_path) + elif os.path.exists(self.file): + file_path = self.file + return os.path.abspath(file_path) + else: + raise Exception(f"not a valid file: {self.file}") + + async def convert_to_base64(self) -> str: + """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 + + Returns: + str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + if self.file and self.file.startswith("file:///"): + bs64_data = file_to_base64(self.file[8:]) + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + bs64_data = file_to_base64(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file + elif os.path.exists(self.file): + bs64_data = file_to_base64(self.file) + else: + raise Exception(f"not a valid file: {self.file}") + return bs64_data + class Video(BaseMessageComponent): type: ComponentType = "Video" @@ -279,10 +326,6 @@ class Image(BaseMessageComponent): file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: T.Optional[str], **_): - # for k in _.keys(): - # if (k == "_type" and _[k] not in ["flash", "show", None]) or \ - # (k == "c" and _[k] not in [2, 3]): - # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(file=file, **_) @staticmethod @@ -307,6 +350,53 @@ class Image(BaseMessageComponent): def fromIO(IO): return Image.fromBytes(IO.read()) + async def convert_to_file_path(self) -> str: + """将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 图片的本地路径,以绝对路径表示。 + """ + url = self.url if self.url else self.file + if url and url.startswith("file:///"): + image_file_path = url[8:] + return image_file_path + elif url and url.startswith("http"): + image_file_path = await download_image_by_url(url) + return os.path.abspath(image_file_path) + elif url and url.startswith("base64://"): + bs64_data = url.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + image_file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(image_file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(image_file_path) + elif os.path.exists(url): + image_file_path = url + return os.path.abspath(image_file_path) + else: + raise Exception(f"not a valid file: {url}") + + async def convert_to_base64(self) -> str: + """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 + + Returns: + str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + url = self.url if self.url else self.file + if url and url.startswith("file:///"): + bs64_data = file_to_base64(url[8:]) + elif url and url.startswith("http"): + image_file_path = await download_image_by_url(url) + bs64_data = file_to_base64(image_file_path) + elif url and url.startswith("base64://"): + bs64_data = url + elif os.path.exists(url): + bs64_data = file_to_base64(url) + else: + raise Exception(f"not a valid file: {url}") + return bs64_data + class Reply(BaseMessageComponent): type: ComponentType = "Reply" diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 89aff17a8..4cc7fb842 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -77,6 +77,10 @@ class MessageChain: self.use_t2i_ = use_t2i return self + def get_plain_text(self) -> str: + """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" + return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + class EventResultType(enum.Enum): """用于描述事件处理的结果类型。 @@ -147,9 +151,5 @@ class MessageEventResult(MessageChain): """是否为 LLM 结果。""" return self.result_content_type == ResultContentType.LLM_RESULT - def get_plain_text(self) -> str: - """获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" - return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 2050f37d6..7d7c4516f 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -64,8 +64,8 @@ class LLMRequestSubStage(Stage): req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - req.image_urls.append(image_url) + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) # 获取对话上下文 conversation_id = await self.conv_manager.get_curr_conversation_id( @@ -250,8 +250,7 @@ class LLMRequestSubStage(Stage): if llm_response.role == "assistant": # 文本回复 contexts = req.contexts - new_record = {"role": "user", "content": req.prompt} - contexts.append(new_record) + contexts.append(await req.assemble_context()) contexts.append( {"role": "assistant", "content": llm_response.completion_text} ) diff --git a/astrbot/core/platform/__init__.py b/astrbot/core/platform/__init__.py index 48ea57b8a..4007b2d90 100644 --- a/astrbot/core/platform/__init__.py +++ b/astrbot/core/platform/__init__.py @@ -1,7 +1,7 @@ from .platform import Platform from .astr_message_event import AstrMessageEvent from .platform_metadata import PlatformMetadata -from .astrbot_message import AstrBotMessage, MessageMember, MessageType +from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group __all__ = [ "Platform", @@ -10,4 +10,5 @@ __all__ = [ "AstrBotMessage", "MessageMember", "MessageType", + "Group", ] diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index fceb63ce7..3e1b14ee6 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,11 +1,9 @@ import abc import asyncio from dataclasses import dataclass -from .astrbot_message import AstrBotMessage -from .platform_metadata import PlatformMetadata -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain -from astrbot.core.platform.message_type import MessageType -from typing import List, Union +from typing import List, Union, Optional + +from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( Plain, Image, @@ -16,9 +14,12 @@ from astrbot.core.message.components import ( Forward, Reply, ) -from astrbot.core.utils.metrics import Metric +from astrbot.core.message.message_event_result import MessageEventResult, MessageChain +from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entites import ProviderRequest -from astrbot.core.db.po import Conversation +from astrbot.core.utils.metrics import Metric +from .astrbot_message import AstrBotMessage, Group +from .platform_metadata import PlatformMetadata @dataclass @@ -201,15 +202,6 @@ class AstrMessageEvent(abc.ABC): """ return self.role == "admin" - async def send(self, message: MessageChain): - """ - 发送消息到消息平台。 - """ - asyncio.create_task( - Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) - ) - self._has_send_oper = True - async def _pre_send(self): """调度器会在执行 send() 前调用该方法""" @@ -371,3 +363,26 @@ class AstrMessageEvent(abc.ABC): system_prompt=system_prompt, conversation=conversation, ) + + """平台适配器""" + + async def send(self, message: MessageChain): + """发送消息到消息平台。 + + Args: + message (MessageChain): 消息链,具体使用方式请参考文档。 + """ + asyncio.create_task( + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) + ) + self._has_send_oper = True + + async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]: + """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 + + 适配情况: + + - gewechat + - aiocqhttp(OneBotv11) + """ + ... diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index ea55eaf4b..e7bd4bd9c 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -10,6 +10,41 @@ class MessageMember: user_id: str # 发送者id nickname: str = None + def __str__(self): + # 使用 f-string 来构建返回的字符串表示形式 + return ( + f"User ID: {self.user_id}," + f"Nickname: {self.nickname if self.nickname else 'N/A'}" + ) + + +@dataclass +class Group: + group_id: str + """群号""" + group_name: str = None + """群名称""" + group_avatar: str = None + """群头像""" + group_owner: str = None + """群主 id""" + group_admins: List[str] = None + """群管理员 id""" + members: List[MessageMember] = None + """所有群成员""" + + def __str__(self): + # 使用 f-string 来构建返回的字符串表示形式 + return ( + f"Group ID: {self.group_id}\n" + f"Name: {self.group_name if self.group_name else 'N/A'}\n" + f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n" + f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n" + f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n" + f"Members Len: {len(self.members) if self.members else 0}\n" + f"First Member: {self.members[0] if self.members else 'N/A'}\n" + ) + class AstrBotMessage: """ diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 9ca3b82d1..22a06b739 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -85,14 +85,18 @@ class PlatformManager: ) return cls_type = platform_cls_map[platform_config["type"]] - inst = cls_type(platform_config, self.settings, self.event_queue) - self._inst_map[platform_config["id"]] = inst + inst: Platform = cls_type(platform_config, self.settings, self.event_queue) + self._inst_map[platform_config["id"]] = { + "inst": inst, + "client_id": inst.client_self_id, + } self.platform_insts.append(inst) asyncio.create_task( self._task_wrapper( asyncio.create_task( - inst.run(), name=platform_config["id"] + "_platform" + inst.run(), + name=f"platform_{platform_config['type']}_{platform_config['id']}", ) ) ) @@ -109,38 +113,42 @@ class PlatformManager: logger.error("-------") async def reload(self, platform_config: dict): - # 还未实现完成,不要调用此方法 - - if platform_config["id"] in self._inst_map: - # 正在运行 - if getattr(self._inst_map[platform_config["id"]], "terminate", None): - logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...") - await self._inst_map[platform_config["id"]].terminate() - logger.info(f"{platform_config['id']} 平台适配器已终止。") - del self._inst_map[platform_config["id"]] - self.platform_insts.remove(self._inst_map[platform_config["id"]]) - else: - logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。") - - # 再启动新的实例 + await self.terminate_platform(platform_config["id"]) + if platform_config["enable"]: await self.load_platform(platform_config) - else: - # 先将 _inst_map 中在 platform_config 中不存在的实例删除 - config_ids = [platform["id"] for platform in self.platforms_config] - for key in list(self._inst_map.keys()): - if key not in config_ids: - if getattr(self._inst_map[key], "terminate", None): - logger.info(f"正在尝试终止 {key} 平台适配器 ...") - await self._inst_map[key].terminate() - logger.info(f"{key} 平台适配器已终止。") - del self._inst_map[key] - self.platform_insts.remove(self._inst_map[key]) - else: - logger.warning(f"可能无法正常终止 {key} 平台适配器。") + # 和配置文件保持同步 + config_ids = [provider["id"] for provider in self.platforms_config] + for key in list(self._inst_map.keys()): + if key not in config_ids: + await self.terminate_platform(key) - # 再启动新的实例 - await self.load_platform(platform_config) + async def terminate_platform(self, platform_id: str): + if platform_id in self._inst_map: + logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") + + # client_id = self._inst_map.pop(platform_id, None) + info = self._inst_map.pop(platform_id, None) + client_id = info["client_id"] + inst = info["inst"] + try: + self.platform_insts.remove( + next( + inst + for inst in self.platform_insts + if inst.client_self_id == client_id + ) + ) + except Exception: + logger.warning(f"可能未完全移除 {platform_id} 平台适配器") + + if getattr(inst, "terminate", None): + await inst.terminate() + + async def terminate(self): + for inst in self.platform_insts: + if getattr(inst, "terminate", None): + await inst.terminate() def get_insts(self): return self.platform_insts diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 8ed0be039..6ed53fe0e 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -1,4 +1,5 @@ import abc +import uuid from typing import Awaitable, Any from asyncio import Queue from .platform_metadata import PlatformMetadata @@ -13,6 +14,7 @@ class Platform(abc.ABC): super().__init__() # 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。 self._event_queue = event_queue + self.client_self_id = uuid.uuid4().hex @abc.abstractmethod def run(self) -> Awaitable[Any]: @@ -25,7 +27,7 @@ class Platform(abc.ABC): """ 终止一个平台的运行实例。 """ - pass + ... @abc.abstractmethod def meta(self) -> PlatformMetadata: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index dbe9a3ce0..c7aede7d1 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,9 +1,9 @@ import asyncio - +import typing from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import Group, MessageMember from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from aiocqhttp import CQHttp -from astrbot.core.utils.io import file_to_base64, download_image_by_url class AiocqhttpMessageEvent(AstrMessageEvent): @@ -24,18 +24,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent): d["data"]["text"] = segment.text.strip() elif isinstance(segment, (Image, Record)): # convert to base64 - if segment.file and segment.file.startswith("file:///"): - bs64_data = file_to_base64(segment.file[8:]) - image_file_path = segment.file[8:] - elif segment.file and segment.file.startswith("http"): - image_file_path = await download_image_by_url(segment.file) - bs64_data = file_to_base64(image_file_path) - elif segment.file and segment.file.startswith("base64://"): - bs64_data = segment.file - else: - bs64_data = file_to_base64(segment.file) + bs64 = await segment.convert_to_base64() d["data"] = { - "file": bs64_data, + "file": bs64, } elif isinstance(segment, At): d["data"] = { @@ -56,8 +47,13 @@ class AiocqhttpMessageEvent(AstrMessageEvent): if send_one_by_one: for seg in message.chain: - if isinstance(seg, Nodes): - # 带有多个节点的合并转发消息 + if isinstance(seg, (Node, Nodes)): + # 合并转发消息 + + if isinstance(seg, Node): + nodes = Nodes([seg]) + seg = nodes + payload = seg.toDict() if self.get_group_id(): payload["group_id"] = self.get_group_id() @@ -79,3 +75,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await self.bot.send(self.message_obj.raw_message, ret) await super().send(message) + + async def get_group(self, group_id=None, **kwargs): + if isinstance(group_id, str) and group_id.isdigit(): + group_id = int(group_id) + elif self.get_group_id(): + group_id = int(self.get_group_id()) + else: + return None + + info: dict = await self.bot.call_action( + "get_group_info", + group_id=group_id, + ) + + members: typing.List[typing.Dict] = await self.bot.call_action( + "get_group_member_list", + group_id=group_id, + ) + + owner_id = None + admin_ids = [] + for member in members: + if member["role"] == "owner": + owner_id = member["user_id"] + if member["role"] == "admin": + admin_ids.append(member["user_id"]) + + group = Group( + group_id=str(group_id), + group_name=info.get("group_name"), + group_avatar="", + group_admins=admin_ids, + group_owner=str(owner_id), + members=[ + MessageMember( + user_id=member["user_id"], + nickname=member.get("nickname") or member.get("card"), + ) + for member in members + ], + ) + + return group diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 0d11e3c0b..e41071a56 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -43,8 +43,6 @@ class AiocqhttpAdapter(Platform): "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", ) - self.stop = False - self.bot = CQHttp( use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180 ) @@ -303,22 +301,19 @@ class AiocqhttpAdapter(Platform): for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) logging.getLogger("aiocqhttp").setLevel(logging.ERROR) - + self.shutdown_event = asyncio.Event() return coro async def terminate(self): - self.stop = True - await asyncio.sleep(1) + self.shutdown_event.set() + + async def shutdown_trigger_placeholder(self): + await self.shutdown_event.wait() + logger.info("aiocqhttp 适配器已被优雅地关闭") def meta(self) -> PlatformMetadata: return self.metadata - async def shutdown_trigger_placeholder(self): - # TODO: use asyncio.Event - while not self._event_queue.closed and not self.stop: # noqa: ASYNC110 - await asyncio.sleep(1) - logger.info("aiocqhttp 适配器已关闭。") - async def handle_msg(self, message: AstrBotMessage): message_event = AiocqhttpMessageEvent( message_str=message.message_str, diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 507b3bb50..95347172b 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,6 +2,7 @@ import asyncio import uuid import aiohttp import dingtalk_stream +import threading from astrbot.api.platform import ( Platform, @@ -196,7 +197,31 @@ class DingtalkPlatformAdapter(Platform): self._event_queue.put_nowait(event) async def run(self): - await self.client_.start() + # await self.client_.start() + # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 + def start_client(loop: asyncio.AbstractEventLoop): + try: + self._shutdown_event = threading.Event() + task = loop.create_task(self.client_.start()) + self._shutdown_event.wait() + if task.done(): + task.result() + except Exception as e: + if "Graceful shutdown" in str(e): + logger.info("钉钉适配器已被优雅地关闭") + return + logger.error(f"钉钉机器人启动失败: {e}") + + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, start_client, loop) + + async def terminate(self): + def monkey_patch_close(): + raise Exception("Graceful shutdown") + + self.client_.open_connection = monkey_patch_close + await self.client_.websocket.close(code=1000, reason="Graceful shutdown") + self._shutdown_event.set() def get_client(self): return self.client diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index f0e64e698..53ee1878e 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -1,17 +1,19 @@ -import threading import asyncio -import aiohttp -import quart import base64 import datetime -import re import os +import re +import threading + +import aiohttp import anyio -from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType -from astrbot.api.message_components import Plain, Image, At, Record +import quart + from astrbot.api import logger, sp -from .downloader import GeweDownloader +from astrbot.api.message_components import Plain, Image, At, Record +from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType from astrbot.core.utils.io import download_image_by_url +from .downloader import GeweDownloader class SimpleGewechatClient: @@ -51,11 +53,11 @@ class SimpleGewechatClient: self.server = quart.Quart(__name__) self.server.add_url_rule( - "/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"] + "/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"] ) self.server.add_url_rule( "/astrbot-gewechat/file/", - view_func=self.handle_file, + view_func=self._handle_file, methods=["GET"], ) @@ -70,9 +72,8 @@ class SimpleGewechatClient: self.userrealnames = {} - self.stop = False - async def get_token_id(self): + """获取 Gewechat Token。""" async with aiohttp.ClientSession() as session: async with session.post(f"{self.base_url}/tools/getTokenId") as resp: json_blob = await resp.json() @@ -192,6 +193,11 @@ class SimpleGewechatClient: abm.sender = MessageMember(user_id, user_real_name) abm.raw_message = d abm.message_str = "" + + if user_id == "weixin": + # 忽略微信团队消息 + return + # 不同消息类型 match d["MsgType"]: case 1: @@ -253,7 +259,7 @@ class SimpleGewechatClient: logger.debug(f"abm: {abm}") return abm - async def callback(self): + async def _callback(self): data = await quart.request.json logger.debug(f"收到 gewechat 回调: {data}") @@ -275,7 +281,7 @@ class SimpleGewechatClient: return quart.jsonify({"r": "AstrBot ACK"}) - async def handle_file(self, file_id): + async def _handle_file(self, file_id): file_path = f"data/temp/{file_id}" return await quart.send_file(file_path) @@ -298,20 +304,10 @@ class SimpleGewechatClient: async def start_polling(self): threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start() - await self.server.run_task( - host="0.0.0.0", - port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder, - ) - - async def shutdown_trigger_placeholder(self): - # TODO: use asyncio.Event - while not self.event_queue.closed and not self.stop: # noqa: ASYNC110 - await asyncio.sleep(1) - logger.info("gewechat 适配器已关闭。") + await self.server.run_task(host="0.0.0.0", port=self.port) async def check_online(self, appid: str): - # /login/checkOnline + """检查 APPID 对应的设备是否在线。""" async with aiohttp.ClientSession() as session: async with session.post( f"{self.base_url}/login/checkOnline", @@ -322,6 +318,7 @@ class SimpleGewechatClient: return json_blob["data"] async def logout(self): + """登出 gewechat。""" if self.appid: online = await self.check_online(self.appid) if online: @@ -335,6 +332,7 @@ class SimpleGewechatClient: logger.info(f"登出结果: {json_blob}") async def login(self): + """登录 gewechat。一般来说插件用不到这个方法。""" if self.token is None: await self.get_token_id() @@ -446,9 +444,18 @@ class SimpleGewechatClient: self.appid = appid logger.info(f"已保存 APPID: {appid}") - """API""" + """API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1 + """ - async def get_chatroom_member_list(self, chatroom_wxid: str): + async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict: + """获取群成员列表。 + + Args: + chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。 + + Returns: + dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。 + """ payload = {"appId": self.appid, "chatroomId": chatroom_wxid} async with aiohttp.ClientSession() as session: @@ -461,6 +468,7 @@ class SimpleGewechatClient: return json_blob["data"] async def post_text(self, to_wxid, content: str, ats: str = ""): + """发送纯文本消息""" payload = { "appId": self.appid, "toWxid": to_wxid, @@ -477,6 +485,7 @@ class SimpleGewechatClient: logger.debug(f"发送消息结果: {json_blob}") async def post_image(self, to_wxid, image_url: str): + """发送图片消息""" payload = { "appId": self.appid, "toWxid": to_wxid, @@ -491,6 +500,12 @@ class SimpleGewechatClient: logger.debug(f"发送图片结果: {json_blob}") async def post_voice(self, to_wxid, voice_url: str, voice_duration: int): + """发送语音信息 + + Args: + voice_url (str): 语音文件的网络链接 + voice_duration (int): 语音时长,毫秒 + """ payload = { "appId": self.appid, "toWxid": to_wxid, @@ -508,6 +523,13 @@ class SimpleGewechatClient: logger.debug(f"发送语音结果: {json_blob}") async def post_file(self, to_wxid, file_url: str, file_name: str): + """发送文件 + + Args: + to_wxid (string): 微信ID + file_url (str): 文件的网络链接 + file_name (str): 文件名 + """ payload = { "appId": self.appid, "toWxid": to_wxid, @@ -521,3 +543,114 @@ class SimpleGewechatClient: ) as resp: json_blob = await resp.json() logger.debug(f"发送文件结果: {json_blob}") + + async def add_friend(self, v3: str, v4: str, content: str): + """申请添加好友""" + payload = { + "appId": self.appid, + "scene": 3, + "content": content, + "v4": v4, + "v3": v3, + "option": 2, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/contacts/addContacts", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"申请添加好友结果: {json_blob}") + return json_blob + + async def get_group(self, group_id: str): + payload = { + "appId": self.appid, + "chatroomId": group_id, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/getChatroomInfo", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def get_group_member(self, group_id: str): + payload = { + "appId": self.appid, + "chatroomId": group_id, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/getChatroomMemberList", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def accept_group_invite(self, url: str): + """同意进群""" + payload = {"appId": self.appid, "url": url} + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/agreeJoinRoom", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def add_group_member_to_friend( + self, group_id: str, to_wxid: str, content: str + ): + payload = { + "appId": self.appid, + "chatroomId": group_id, + "content": content, + "memberWxid": to_wxid, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/group/addGroupMemberAsFriend", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob + + async def get_user_or_group_info(self, *ids): + """ + 获取用户或群组信息。 + + :param ids: 可变数量的 wxid 参数 + """ + + wxids_str = list(ids) + + payload = { + "appId": self.appid, + "wxids": wxids_str, # 使用逗号分隔的字符串 + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/contacts/getDetailInfo", + headers=self.headers, + json=payload, + ) as resp: + json_blob = await resp.json() + logger.debug(f"获取群信息结果: {json_blob}") + return json_blob diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 7668663fb..3aca64bab 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -2,11 +2,11 @@ import wave import uuid import traceback import os -from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file +from astrbot.core.utils.io import save_temp_img, download_file from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember from astrbot.api.message_components import Plain, Image, Record, At, File from .client import SimpleGewechatClient @@ -70,18 +70,10 @@ class GewechatPlatformEvent(AstrMessageEvent): await client.post_text(**payload) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() - # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 + # 检查 record_path 是否在 data/temp 目录中 temp_directory = os.path.abspath("data/temp") - img_path = os.path.abspath(img_path) if os.path.commonpath([temp_directory, img_path]) != temp_directory: with open(img_path, "rb") as f: img_path = save_temp_img(f.read()) @@ -93,14 +85,7 @@ class GewechatPlatformEvent(AstrMessageEvent): elif isinstance(comp, Record): # 默认已经存在 data/temp 中 record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url + record_path = await comp.convert_to_file_path() silk_path = f"data/temp/{uuid.uuid4()}.silk" try: @@ -138,3 +123,30 @@ class GewechatPlatformEvent(AstrMessageEvent): to_wxid = self.message_obj.raw_message.get("to_wxid", None) await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client) await super().send(message) + + async def get_group(self, group_id=None, **kwargs): + # 确定有效的 group_id + if group_id is None: + group_id = self.get_group_id() + + if not group_id: + return None + + res = await self.client.get_group(group_id) + data: dict = res["data"] + + if not data["chatroomId"]: + return None + + members = [ + MessageMember(user_id=member["wxid"], nickname=member["nickName"]) + for member in data.get("memberList", []) + ] + + return Group( + group_id=data["chatroomId"], + group_name=data.get("nickName"), + group_avatar=data.get("smallHeadImgUrl"), + group_owner=data.get("chatRoomOwner"), + members=members, + ) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py index 3dbdbba27..9c8c4f5ed 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py @@ -64,8 +64,7 @@ class GewechatPlatformAdapter(Platform): ) async def terminate(self): - self.client.stop = True - await asyncio.sleep(1) + await self.client.server.shutdown() async def logout(self): await self.client.logout() diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index fd29b3602..cbc3a45bb 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -2,6 +2,7 @@ import base64 import asyncio import json import re +import astrbot.api.message_components as Comp from astrbot.api.platform import ( Platform, @@ -11,7 +12,6 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .lark_event import LarkMessageEvent from ...register import register_platform_adapter @@ -92,7 +92,7 @@ class LarkPlatformAdapter(Platform): at_list = {} if message.mentions: for m in message.mentions: - at_list[m.key] = At(qq=m.id.open_id, name=m.name) + at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.name == self.bot_name: abm.self_id = m.id.open_id @@ -111,7 +111,7 @@ class LarkPlatformAdapter(Platform): if s in at_list: abm.message.append(at_list[s]) else: - abm.message.append(Plain(parts[i].strip())) + abm.message.append(Comp.Plain(parts[i].strip())) elif message.message_type == "post": _ls = [] @@ -132,7 +132,7 @@ class LarkPlatformAdapter(Platform): if comp["tag"] == "at": abm.message.append(at_list[comp["user_id"]]) elif comp["tag"] == "text" and comp["text"].strip(): - abm.message.append(Plain(comp["text"].strip())) + abm.message.append(Comp.Plain(comp["text"].strip())) elif comp["tag"] == "img": image_key = comp["image_key"] request = ( @@ -147,10 +147,10 @@ class LarkPlatformAdapter(Platform): logger.error(f"无法下载飞书图片: {image_key}") image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() - abm.message.append(Image.fromBase64(image_base64)) + abm.message.append(Comp.Image.fromBase64(image_base64)) for comp in abm.message: - if isinstance(comp, Plain): + if isinstance(comp, Comp.Plain): abm.message_str += comp.text abm.message_id = message.message_id abm.raw_message = message @@ -185,5 +185,9 @@ class LarkPlatformAdapter(Platform): # self.client.start() await self.client._connect() + async def terminate(self): + await self.client._disconnect() + logger.info("飞书(Lark) 适配器已被优雅地关闭") + def get_client(self) -> lark.Client: return self.client diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index ae1dc2563..57bc8683f 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -17,6 +17,7 @@ from astrbot.api.platform import ( MessageType, PlatformMetadata, ) +from astrbot import logger from astrbot.api.event import MessageChain from typing import Union, List from astrbot.api.message_components import Image, Plain, At @@ -204,3 +205,7 @@ class QQOfficialPlatformAdapter(Platform): def get_client(self) -> botClient: return self.client + + async def terminate(self): + await self.client.close() + logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 542233591..6ad59c67e 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -13,6 +13,7 @@ from .qo_webhook_event import QQOfficialWebhookMessageEvent from ...register import register_platform_adapter from .qo_webhook_server import QQOfficialWebhook from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter +from astrbot import logger # remove logger handler for handler in logging.root.handlers[:]: @@ -111,3 +112,8 @@ class QQOfficialWebhookPlatformAdapter(Platform): def get_client(self) -> botClient: return self.client + + async def terminate(self): + await self.client.close() + await self.webhook_helper.server.shutdown() + logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index a219e2492..681999cf0 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -99,13 +99,4 @@ class QQOfficialWebhook: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。" ) - await self.server.run_task( - host=self.callback_server_host, - port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder, - ) - - async def shutdown_trigger_placeholder(self): - while not self.event_queue.closed: # noqa: ASYNC110 - await asyncio.sleep(1) - logger.info("qq_official_webhook 适配器已关闭。") + await self.server.run_task(host=self.callback_server_host, port=self.port) diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 3dde1802b..93ed6feb6 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -1,6 +1,7 @@ import sys import uuid import asyncio +import astrbot.api.message_components as Comp from astrbot.api.platform import ( Platform, @@ -10,15 +11,6 @@ from astrbot.api.platform import ( MessageType, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import ( - Plain, - Image, - Record, - File as AstrBotFile, - Video, - At, - Reply, -) from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.api.platform import register_platform_adapter @@ -120,6 +112,7 @@ class TelegramPlatformAdapter(Platform): @param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。 """ message = AstrBotMessage() + message.session_id = str(update.message.chat.id) # 获得是群聊还是私聊 if update.message.chat.type == ChatType.PRIVATE: message.type = MessageType.FRIEND_MESSAGE @@ -129,9 +122,9 @@ class TelegramPlatformAdapter(Platform): if update.message.message_thread_id: # Topic Group message.group_id += "#" + str(update.message.message_thread_id) + message.session_id = message.group_id message.message_id = str(update.message.message_id) - message.session_id = str(update.message.chat.id) message.sender = MessageMember( str(update.message.from_user.id), update.message.from_user.username ) @@ -149,7 +142,7 @@ class TelegramPlatformAdapter(Platform): reply_abm = await self.convert_message(reply_update, context, False) message.message.append( - Reply( + Comp.Reply( id=reply_abm.message_id, chain=reply_abm.message, sender_id=reply_abm.sender.user_id, @@ -171,14 +164,14 @@ class TelegramPlatformAdapter(Platform): name = plain_text[ entity.offset + 1 : entity.offset + entity.length ] - message.message.append(At(qq=name, name=name)) + message.message.append(Comp.At(qq=name, name=name)) plain_text = ( plain_text[: entity.offset] + plain_text[entity.offset + entity.length :] ) if plain_text: - message.message.append(Plain(plain_text)) + message.message.append(Comp.Plain(plain_text)) message.message_str = plain_text if message.message_str == "/start": @@ -188,26 +181,34 @@ class TelegramPlatformAdapter(Platform): elif update.message.voice: file = await update.message.voice.get_file() message.message = [ - Record(file=file.file_path, url=file.file_path), + Comp.Record(file=file.file_path, url=file.file_path), ] elif update.message.photo: photo = update.message.photo[-1] # get the largest photo file = await photo.get_file() - message.message.append(Image(file=file.file_path, url=file.file_path)) + message.message.append(Comp.Image(file=file.file_path, url=file.file_path)) + if update.message.caption: + message.message_str = update.message.caption + message.message.append(Comp.Plain(message.message_str)) + if update.message.caption_entities: + for entity in update.message.caption_entities: + if entity.type == "mention": + name = message.message_str[ + entity.offset + 1 : entity.offset + entity.length + ] + message.message.append(Comp.At(qq=name, name=name)) elif update.message.document: file = await update.message.document.get_file() message.message = [ - AstrBotFile( - file=file.file_path, name=update.message.document.file_name - ), + Comp.File(file=file.file_path, name=update.message.document.file_name), ] elif update.message.video: file = await update.message.video.get_file() message.message = [ - Video(file=file.file_path, path=file.file_path), + Comp.Video(file=file.file_path, path=file.file_path), ] return message @@ -224,3 +225,7 @@ class TelegramPlatformAdapter(Platform): def get_client(self) -> ExtBot: return self.client + + async def terminate(self): + await self.application.stop() + logger.info("Telegram 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index a8a04e2e1..87fea26c6 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -43,7 +43,7 @@ class TelegramPlatformEvent(AstrMessageEvent): if has_reply: payload["reply_to_message_id"] = reply_message_id if message_thread_id: - payload["reply_to_message_id"] = message_thread_id + payload["message_thread_id"] = message_thread_id if isinstance(i, Plain): if at_user_id and not at_flag: @@ -51,19 +51,8 @@ class TelegramPlatformEvent(AstrMessageEvent): at_flag = True await client.send_message(text=i.text, **payload) elif isinstance(i, Image): - if i.path: - image_path = i.path - else: - image_path = i.file - - if image_path.startswith("base64://"): - import base64 - - base64_data = image_path[9:] - image_bytes = base64.b64decode(base64_data) - await client.send_photo(photo=image_bytes, **payload) - else: - await client.send_photo(photo=image_path, **payload) + image_path = await i.convert_to_file_path() + await client.send_photo(photo=image_path, **payload) elif isinstance(i, File): if i.file.startswith("https://"): path = "data/temp/" + i.name @@ -72,7 +61,8 @@ class TelegramPlatformEvent(AstrMessageEvent): await client.send_document(document=i.file, filename=i.name, **payload) elif isinstance(i, Record): - await client.send_voice(voice=i.file, **payload) + path = await i.convert_to_file_path() + await client.send_voice(voice=path, **payload) async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 12b193f53..6fa3d5c59 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -119,3 +119,7 @@ class WebChatAdapter(Platform): ) self.commit_event(message_event) + + async def terminate(self): + # Do nothing + pass diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index cef83b030..8b6598e0d 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -93,14 +93,8 @@ class WecomServer: await self.server.run_task( host=self.callback_server_host, port=self.port, - shutdown_trigger=self.shutdown_trigger_placeholder, ) - async def shutdown_trigger_placeholder(self): - while not self.event_queue.closed: # noqa: ASYNC110 - await asyncio.sleep(1) - logger.info("企业微信 适配器已关闭。") - @register_platform_adapter("wecom", "wecom 适配器") class WecomPlatformAdapter(Platform): @@ -235,3 +229,7 @@ class WecomPlatformAdapter(Platform): def get_client(self) -> WeChatClient: return self.client + + async def terminate(self): + await self.server.server.shutdown() + logger.info("企业微信 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 83e99b5c4..470b7b1f8 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -3,7 +3,6 @@ from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image, Record from wechatpy.enterprise import WeChatClient -from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.api import logger @@ -43,14 +42,7 @@ class WecomPlatformEvent(AstrMessageEvent): message_obj.self_id, message_obj.session_id, comp.text ) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() with open(img_path, "rb") as f: try: @@ -68,16 +60,7 @@ class WecomPlatformEvent(AstrMessageEvent): response["media_id"], ) elif isinstance(comp, Record): - record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url - + record_path = await comp.convert_to_file_path() # 转成amr record_path_amr = f"data/temp/{uuid.uuid4()}.amr" pydub.AudioSegment.from_wav(record_path).export( diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 17b6edd46..4236214d4 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,10 +1,14 @@ import enum +import base64 +from astrbot.core.utils.io import download_image_by_url +from astrbot import logger from dataclasses import dataclass, field from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain +import astrbot.core.message.components as Comp class ProviderType(enum.Enum): @@ -47,11 +51,81 @@ class ProviderRequest: conversation: Conversation = None def __repr__(self): - return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt.strip()})" + return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()})" def __str__(self): return self.__repr__() + def _print_friendly_context(self): + """打印友好的消息上下文。将 image_url 的值替换为 """ + if not self.contexts: + return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}" + + result_parts = [] + + for ctx in self.contexts: + role = ctx.get("role", "unknown") + content = ctx.get("content", "") + + if isinstance(content, str): + result_parts.append(f"{role}: {content}") + elif isinstance(content, list): + msg_parts = [] + image_count = 0 + + for item in content: + item_type = item.get("type", "") + + if item_type == "text": + msg_parts.append(item.get("text", "")) + elif item_type == "image_url": + image_count += 1 + + if image_count > 0: + if msg_parts: + msg_parts.append(f"[+{image_count} images]") + else: + msg_parts.append(f"[{image_count} images]") + + result_parts.append(f"{role}: {''.join(msg_parts)}") + + return result_parts + + async def assemble_context(self) -> Dict: + """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" + if self.image_urls: + user_content = { + "role": "user", + "content": [{"type": "text", "text": self.prompt}], + } + for image_url in self.image_urls: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + image_data = await self._encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self._encode_image_bs64(image_path) + else: + image_data = await self._encode_image_bs64(image_url) + if not image_data: + logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") + continue + user_content["content"].append( + {"type": "image_url", "image_url": {"url": image_data}} + ) + return user_content + else: + return {"role": "user", "content": self.prompt} + + async def _encode_image_bs64(self, image_url: str) -> str: + """将图片转换为 base64""" + if image_url.startswith("base64://"): + return image_url.replace("base64://", "data:image/jpeg;base64,") + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 + return "" + @dataclass class LLMResponse: @@ -59,8 +133,6 @@ class LLMResponse: """角色, assistant, tool, err""" result_chain: MessageChain = None """返回的消息链""" - completion_text: str = "" - """LLM 返回的文本, 已经废弃但仍然兼容。使用 result_chain 替代""" tools_call_args: List[Dict[str, any]] = field(default_factory=list) """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) @@ -68,3 +140,51 @@ class LLMResponse: raw_completion: ChatCompletion = None _new_record: Dict[str, any] = None + + _completion_text: str = "" + + def __init__( + self, + role: str, + completion_text: str = "", + result_chain: MessageChain = None, + tools_call_args: List[Dict[str, any]] = None, + tools_call_name: List[str] = None, + raw_completion: ChatCompletion = None, + _new_record: Dict[str, any] = None, + ): + """初始化 LLMResponse + + Args: + role (str): 角色, assistant, tool, err + completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "". + result_chain (MessageChain, optional): 返回的消息链. Defaults to None. + tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None. + tools_call_name (List[str], optional): 工具调用名称. Defaults to None. + raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. + """ + self.role = role + self.completion_text = completion_text + self.result_chain = result_chain + self.tools_call_args = tools_call_args + self.tools_call_name = tools_call_name + self.raw_completion = raw_completion + self._new_record = _new_record + + @property + def completion_text(self): + if self.result_chain: + return self.result_chain.get_plain_text() + return self._completion_text + + @completion_text.setter + def completion_text(self, value): + if self.result_chain: + self.result_chain.chain = [ + comp + for comp in self.result_chain.chain + if not isinstance(comp, Comp.Plain) + ] # 清空 Plain 组件 + self.result_chain.chain.insert(0, Comp.Plain(value)) + else: + self._completion_text = value diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 9647b41c0..7158d57b9 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -1,3 +1,4 @@ +import re import asyncio import functools from typing import List @@ -40,11 +41,24 @@ class ProviderDashscope(ProviderOpenAIOfficial): raise Exception("阿里云百炼 APP 类型不能为空。") self.model_name = "dashscope" self.variables: dict = provider_config.get("variables", {}) + self.rag_options: dict = provider_config.get("rag_options", {}) + self.output_reference = self.rag_options.get("output_reference", False) + self.rag_options = self.rag_options.copy() + self.rag_options.pop("output_reference", None) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) + def has_rag_options(self): + if ( + self.rag_options + and self.rag_options.get("pipeline_ids", None) + and self.rag_options.get("file_ids", None) + ): + return True + return False + async def text_chat( self, prompt: str, @@ -62,7 +76,10 @@ class ProviderDashscope(ProviderOpenAIOfficial): session_var = session_vars.get(session_id, {}) payload_vars.update(session_var) - if self.dashscope_app_type in ["agent", "dialog-workflow"]: + if ( + self.dashscope_app_type in ["agent", "dialog-workflow"] + and self.has_rag_options() + ): # 支持多轮对话的 new_record = {"role": "user", "content": prompt} if image_urls: @@ -86,12 +103,17 @@ class ProviderDashscope(ProviderOpenAIOfficial): else: # 不支持多轮对话的 # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "prompt": prompt, + "api_key": self.api_key, + "biz_params": payload_vars or None, + } + if self.rag_options: + payload["rag_options"] = self.rag_options partial = functools.partial( Application.call, - app_id=self.app_id, - promtp=prompt, - api_key=self.api_key, - biz_params=payload_vars or None, + **payload, ) response = await asyncio.get_event_loop().run_in_executor(None, partial) @@ -107,6 +129,14 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) output_text = response.output.get("text", "") + # RAG 引用脚标格式化 + output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) + if self.output_reference and response.output.get("doc_references", None): + ref_str = "" + for ref in response.output.get("doc_references", []): + ref_str += f"{ref['index_id']}. {ref['title']}\n" + output_text += f"\n\n回答来源:\n{ref_str}" + return LLMResponse(role="assistant", completion_text=output_text) async def forget(self, session_id): diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 9af198aa7..8b5890c28 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -33,7 +33,6 @@ class ProviderDify(Provider): if not self.api_key: raise Exception("Dify API Key 不能为空。") api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_client = DifyAPIClient(self.api_key, api_base) self.api_type = provider_config.get("dify_api_type", "") if not self.api_type: raise Exception("Dify API 类型不能为空。") @@ -44,15 +43,19 @@ class ProviderDify(Provider): self.dify_query_input_key = provider_config.get( "dify_query_input_key", "astrbot_text_query" ) - self.variables: dict = provider_config.get("variables", {}) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" + if not self.workflow_output_key: + self.workflow_output_key = "astrbot_wf_output" + self.variables: dict = provider_config.get("variables", {}) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.conversation_ids = {} """记录当前 session id 的对话 ID""" + self.api_client = DifyAPIClient(self.api_key, api_base) + async def text_chat( self, prompt: str, @@ -68,26 +71,27 @@ class ProviderDify(Provider): files_payload = [] for image_url in image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - file_response = await self.api_client.file_upload( - image_path, user=session_id + image_path = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + file_response = await self.api_client.file_upload( + image_path, user=session_id + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - } - ) - else: - # TODO: 处理更多情况 - logger.warning(f"未知的图片链接:{image_url},图片将忽略。") + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) # 获得会话变量 payload_vars = self.variables.copy() @@ -99,6 +103,9 @@ class ProviderDify(Provider): try: match self.api_type: case "chat" | "agent": + if not prompt: + prompt = "请描述这张图片。" + async for chunk in self.api_client.chat_messages( inputs={ **payload_vars, diff --git a/astrbot/dashboard/__init__.py b/astrbot/dashboard/__init__.py deleted file mode 100644 index cf829d4d6..000000000 --- a/astrbot/dashboard/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .dashboard_lifecycle import AstrBotDashBoardLifecycle - -__all__ = ["AstrBotDashBoardLifecycle"] diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 2e525a3ca..34ba8fd3f 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -3,6 +3,7 @@ import datetime from .route import Route, Response, RouteContext from quart import request from astrbot.core import WEBUI_SK +from astrbot import logger class AuthRoute(Route): @@ -19,9 +20,20 @@ class AuthRoute(Route): password = self.config["dashboard"]["password"] post_data = await request.json if post_data["username"] == username and post_data["password"] == password: + change_pwd_hint = False + if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9": + change_pwd_hint = True + logger.warning("为了保证安全,请尽快修改默认密码。") + return ( Response() - .ok({"token": self.generate_jwt(username), "username": username}) + .ok( + { + "token": self.generate_jwt(username), + "username": username, + "change_pwd_hint": change_pwd_hint, + } + ) .__dict__ ) else: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 54140d92a..14af21bbc 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -29,11 +29,21 @@ def validate_config( ) -> typing.Tuple[typing.List[str], typing.Dict]: errors = [] - def validate(data, metadata=schema, path=""): - for key, meta in metadata.items(): - if key not in data: + def validate(data: dict, metadata: dict = schema, path=""): + for key, value in data.items(): + if key not in metadata: + # 无 schema 的配置项,执行类型猜测 + if isinstance(value, str): + if value.isdigit(): + data[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + data[key] = float(value) + elif value == "true": + data[key] = True + elif value == "false": + data[key] = False continue - value = data[key] + meta = metadata[key] # null 转换 if value is None: data[key] = DEFAULT_VALUE_MAP[meta["type"]] @@ -43,6 +53,16 @@ def validate_config( errors.append( f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}" ) + elif ( + meta["type"] == "list" + and isinstance(value, list) + and value + and "items" in meta + and isinstance(value[0], dict) + ): + # 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置 + for item in value: + validate(item, meta["items"], path=f"{path}{key}.") elif meta["type"] == "object" and isinstance(value, dict): validate(value, meta["items"], path=f"{path}{key}.") @@ -199,7 +219,8 @@ class ConfigRoute(Route): return Response().error("未找到对应平台").__dict__ try: - await self._save_astrbot_configs(self.config) + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.platform_manager.reload(new_config) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "更新平台配置成功~").__dict__ @@ -235,7 +256,8 @@ class ConfigRoute(Route): else: return Response().error("未找到对应平台").__dict__ try: - await self._save_astrbot_configs(self.config) + save_config(self.config, self.config, is_core=True) + await self.core_lifecycle.platform_manager.terminate_platform(platform_id) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "删除平台配置成功~").__dict__ @@ -301,7 +323,7 @@ class ConfigRoute(Route): async def _save_astrbot_configs(self, post_configs: dict): try: save_config(post_configs, self.config, is_core=True) - self.core_lifecycle.restart() + await self.core_lifecycle.restart() except Exception as e: raise e diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 3a21aa2b9..b20e9dca4 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -28,7 +28,7 @@ class StatRoute(Route): self.core_lifecycle = core_lifecycle async def restart_core(self): - self.core_lifecycle.restart() + await self.core_lifecycle.restart() return Response().ok().__dict__ def format_sec(self, sec: int): diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index ef2d10634..e9ada18f5 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -95,8 +95,7 @@ class UpdateRoute(Route): logger.error(f"更新依赖失败: {e}") if reboot: - # threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start() - self.core_lifecycle.restart() + await self.core_lifecycle.restart() return ( Response() .ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。") diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 072ded4ae..45aac3cd6 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -20,7 +20,12 @@ DATAPATH = os.path.abspath( class AstrBotDashboard: - def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None: + def __init__( + self, + core_lifecycle: AstrBotCoreLifecycle, + db: BaseDatabase, + shutdown_event: asyncio.Event, + ) -> None: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist")) @@ -46,6 +51,8 @@ class AstrBotDashboard: self.ar = AuthRoute(self.context) self.chat_route = ChatRoute(self.context, db, core_lifecycle) + self.shutdown_event = shutdown_event + async def auth_middleware(self): if not request.path.startswith("/api"): return @@ -73,11 +80,6 @@ class AstrBotDashboard: r.status_code = 401 return r - async def shutdown_trigger_placeholder(self): - while not self.core_lifecycle.event_queue.closed: # noqa: ASYNC110 - await asyncio.sleep(1) - logger.info("管理面板已关闭。") - def check_port_in_use(self, port: int) -> bool: """ 跨平台检测端口是否被占用 @@ -122,7 +124,15 @@ class AstrBotDashboard: def run(self): ip_addr = [] port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185) - host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "127.0.0.1") + host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0") + + logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}") + + if host == "0.0.0.0": + logger.info( + "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)" + ) + if host not in ["localhost", "127.0.0.1"]: try: ip_addr = get_local_ip_addresses() @@ -144,7 +154,7 @@ class AstrBotDashboard: raise Exception(f"端口 {port} 已被占用") - display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n" + display = f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n" display += f" ➜ 本地: http://localhost:{port}\n" for ip in ip_addr: display += f" ➜ 网络: http://{ip}:{port}\n" @@ -158,7 +168,9 @@ class AstrBotDashboard: logger.info(display) return self.app.run_task( - host=host, - port=port, - shutdown_trigger=self.shutdown_trigger_placeholder, + host=host, port=port, shutdown_trigger=self.shutdown_trigger ) + + async def shutdown_trigger(self): + await self.shutdown_event.wait() + logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/changelogs/v3.4.38.md b/changelogs/v3.4.38.md new file mode 100644 index 000000000..2a3504fcf --- /dev/null +++ b/changelogs/v3.4.38.md @@ -0,0 +1,57 @@ +# What's Changed + +> Special thanks for all contributors and plugin developers and users who love AstrBot. 💖 + +## ✨ 新增的功能 + +1. 支持解析回复消息,支持 LLM 对所引用消息具有感知 #783 +2. 支持 Dify 的文件、图片、视频、音频输出 #819 +3. QQ 下支持嵌套转发(napcat) @zouyonghe +4. 配置页样式重写,更紧凑的 WebUI 配置 + +## 🎈 功能性优化 + +1. 使用系统时间而不是 UTC+8 时间作为默认时间以适应海外用户需求 @roeseth +2. 在对话隔离情况下也可以将整个群聊加入白名单 #746 +3. 在调用插件异常时更完整的报错输出 +4. gewechat 下对已知且没有业务处理的事件类型不显示详细日志 @diudiu62 +5. 优化 WebUI 悬浮文档 @IGCrystal +6. 支持自定义 WebUI、Wecom Webhook Server, QQ Official Webhook Server 的 host #821 +7. Dify 下当只有图片输入时的默认 prompt 防止一些报错 #837 + +## 🐛 修复的 Bug + +1. fishaudio 默认 baseurl 不可用 +2. gewechat 下重复登录后提示设备不存在导致无法重新登陆 @beat4ocean +3. gewechat 下用户本人发消息会触发消息回复 @beat4ocean +4. 钉钉 WebUI 文档不显示 +5. 更新插件后插件热重载不完全、函数工具重复添加 +6. OpenAI TTS API TypeError 报错 #755 +7. EdgeTTS 部分情况下无法使用 @Soulter @需要哦 +8. QQ 官方机器人平台下发送 base64 图片消息段报错 @Soulter @shuiping233 +9. QQ 官方机器人平台下命令参数报错信息无法正常发送 @shuiping233 +10. WebUI 错误地显示未知更新 +11. 部分情况下文件无法上传到 Telegram 群组 #601 +12. 插件管理的插件简介太长导致 “帮助”“操作”图标不显示 #790 +13. LLOnebot 合并消息转发错误 #842 +14. model_config 中自定义的配置项(如温度)类型自动变回 string #854 + +## 🧩 新增的插件 + +1. astrbot_plugin_image_understanding_Janus-Pro - 使用deepseek-ai/Janus-Pro系列模型为本地模型提供的图片理解补充 @xiewoc +2. astrbot_plugin_moyurenpro - 摸鱼人日历,支持自定义时间时区,自定义api,支持立即发送,工作日定时发送。 @quirrel-zh @DuBwTf +3. astrbot_plugin_wechat_manager - 微信关键字好友自动审核、关键字邀请进群。@diudiu62 +4. astrbot_plugin_qwq_filter - qwq 思考过滤工具 @beat4ocean +5. astrbot_plugin_chatsummary - 一个通过拉取历史聊天记录,调用LLM大模型接口实现消息总结功能。@laopanmemz +6. astrBot_PGR_Dialogue - 检测到部分战双角色的名称(或别称)时,有概率发送一条语音文本 @KurisuRee7 +7. astrbot_plugin_bv - 解析群内https://www.bilibili.com/video/BV号/ 的链接并获取视频数据与视频文件,以合并转发方式发送 @haliludaxuanfeng +8. astrbot_plugin_gemini_exp - 让你在AstrBot调用Gemini2.0-flash-exp来生成图片或者p图。Gemini2.0-flash-exp为原生多模态模型,其既是语言模型,也是生图模型,因此能够对图像使用简单的自然语言命令进行处理。@Elen123bot +9. astrbot_plugin_sjzb - 随机生成绝地潜兵2游戏中一组4个战备配置 @tenno1174 +10. astrbot_plugin_picture_manager - 图片管理插件,允许用户通过自定义触发指令从API或直接URL获取图片。@bigshabei +11. astrbot_plugin_bilibiliParse - 解析哔哩哔哩视频,并以图片的形式发送给用户 @7Hello12 +12. astrbot_plugin_sensoji - 这是一个模拟日本浅草寺抽签功能的插件。用户可以通过发送 /抽签 命令随机抽取一个签文,获取运势提示。签文包含吉凶结果(如“大吉”、“凶”等)以及对应的运势描述。 @Shouugou +13. astrbot_plugin_videosummary - 使用 bibigpt 实现视频总结 @kterna +14. astrbot_plugin_InitiativeDialogue - 使 bot 在用户长时间未发送消息时主动与用户对话的插件 @advent259141 +15. astrbot_plugin_emoji - 基于达莉娅综合群娱插件的表情包制作插件,仅保留了@其他群员制作表情包的部分。由桑帛云API提供表情包制作。@KurisuRee7 +16. astrbot_plugin_videos_analysis - 聚合视频分享链接解析(仅测试过napcat) @miaoxutao123 +17. astrbot_plugin_daily_news - 每日 60 秒新闻推送插件 - 自动推送每日热点新闻 @anka-afk \ No newline at end of file diff --git a/changelogs/v3.4.39.md b/changelogs/v3.4.39.md new file mode 100644 index 000000000..d80b4e86d --- /dev/null +++ b/changelogs/v3.4.39.md @@ -0,0 +1,4 @@ +# What's Changed + +1. 默认账户密码登录成功后弹出修改警告 +2. 将 WebUI 默认 host 改变回 v3.4.38 之前的版本以减少兼容性问题。 \ No newline at end of file diff --git a/compose.yml b/compose.yml index 805d30c11..3bab93fc3 100644 --- a/compose.yml +++ b/compose.yml @@ -1,16 +1,21 @@ version: '3.8' +# 当接入 QQ NapCat 时,请使用这个 compose 文件一件部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml + services: astrbot: image: soulter/astrbot:latest container_name: astrbot + restart: always ports: # mappings description: https://github.com/Soulter/AstrBot/issues/497 - - "6185:6185" - - "6195:6195" # optional, wecom default port - - "6199:6199" # optional, aiocqhttp default port - - "6196:6196" # optional, qq official webhook default port - - "11451:11451" # optional, gewechat default port + - "6185:6185" # 必选,AstrBot WebUI 端口 + - "6195:6195" # 可选, 企业微信 Webhook 端口 + - "6199:6199" # 可选, QQ 个人号 WebSocket 端口 + - "6196:6196" # 可选, QQ 官方接口 Webhook 端口 + - "11451:11451" # 可选, 微信个人号 Webhook 端口 + environment: + - TZ=Asia/Shanghai volumes: - ./data:/AstrBot/data - - /etc/timezone:/etc/timezone:ro - - /etc/localtime:/etc/localtime:ro + # - /etc/timezone:/etc/timezone:ro + # - /etc/localtime:/etc/localtime:ro diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 39d5812dc..2796f95da 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -1,9 +1,6 @@ - - object - {{ metadata[metadataKey]?.description }} ({{ metadataKey }}) @@ -13,27 +10,24 @@ - + + style="border: 1px solid #e0e0e0; padding: 8px; margin-bottom: 16px; border-radius: 10px; margin-top: 16px"> - - + + - {{ - metadata[metadataKey].items[key]?.type }} - - {{ metadata[metadataKey].items[key]?.description + '(' + key + ')' }} + {{ metadata[metadataKey].items[key]?.description + '(' + key + ')' }} + {{ key }} @@ -45,7 +39,14 @@ - + + {{ + metadata[metadataKey].items[key]?.type || 'string' }} + + + + - + - + + + - - + + - {{ - metadata[metadataKey]?.type }} - {{ metadata[metadataKey]?.description + '(' + metadataKey + ')' }} - @@ -101,23 +99,35 @@ - + + + {{ + metadata[metadataKey]?.type }} + + + + + dense :disabled="metadata[metadataKey]?.readonly" density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" variant="outlined" dense density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" variant="outlined" dense density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" variant="outlined" dense density="compact" flat hide-details + single-line> + v-model="iterable[metadataKey]" color="primary" hide-details> @@ -125,7 +135,7 @@ - + diff --git a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue index 68e6d38ee..3d08c9cd6 100644 --- a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue +++ b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue @@ -8,6 +8,7 @@ import { useCommonStore } from '@/stores/common'; const customizer = useCustomizerStore(); let dialog = ref(false); +let accountWarning = ref(false) let updateStatusDialog = ref(false); let password = ref(''); let newPassword = ref(''); @@ -177,6 +178,14 @@ checkUpdate(); const commonStore = useCommonStore(); commonStore.createWebSocket(); commonStore.getStartTime(); + + +if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('change_pwd_hint') == 'true') { + dialog.value = true; + accountWarning.value = true; + localStorage.removeItem('change_pwd_hint'); +} + @@ -339,6 +348,11 @@ commonStore.getStartTime(); + + + 为了安全,请尽快修改默认密码。 + + diff --git a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue index 17cec39fa..9b2abed91 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue +++ b/dashboard/src/layouts/full/vertical-sidebar/VerticalSidebar.vue @@ -1,5 +1,6 @@ - + @@ -73,7 +189,6 @@ function onMouseUp() { {{ version }} - @@ -87,75 +202,47 @@ function onMouseUp() { 面板有更新 - AGPL-3.0 - - + - - - + :style="iframeStyle" + > + + 拖拽 - - - - - + + + + + + + + + + - - + - - diff --git a/dashboard/src/stores/auth.ts b/dashboard/src/stores/auth.ts index 7eece8eb5..4503a1219 100644 --- a/dashboard/src/stores/auth.ts +++ b/dashboard/src/stores/auth.ts @@ -24,6 +24,7 @@ export const useAuthStore = defineStore({ this.username = res.data.data.username localStorage.setItem('user', this.username); localStorage.setItem('token', res.data.data.token); + localStorage.setItem('change_pwd_hint', res.data.data?.change_pwd_hint); router.push(this.returnUrl || '/dashboard/default'); } catch (error) { return Promise.reject(error); diff --git a/dashboard/src/views/ConfigPage.vue b/dashboard/src/views/ConfigPage.vue index 55d6e8e33..37d740a3c 100644 --- a/dashboard/src/views/ConfigPage.vue +++ b/dashboard/src/views/ConfigPage.vue @@ -44,11 +44,8 @@ import config from '@/config'; - + {{ metadata[key]['metadata'][key2]['description'] }} ({{ key2 }}) - - object - + @click="configExistingPlatform(platform)"> 配置 @@ -99,7 +99,6 @@ {{ save_message }} -