diff --git a/.gitignore b/.gitignore index 4b683ee8b..a3b2aad90 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,5 @@ venv/* packages/python_interpreter/workplace .venv/* .conda/ -.idea \ No newline at end of file +.idea +pytest.ini diff --git a/README.md b/README.md index 166af8a5e..b4d97fbbb 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ Docker pull Static Badge [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) -![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600) +![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=60) [![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot) [![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg)](https://gitcode.com/Soulter/AstrBot) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 4e8db43f3..70c9ffa43 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,7 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ -VERSION = "3.4.37" +VERSION = "3.4.39" DB_PATH = "data/data_v3.db" # 默认配置 @@ -85,6 +85,7 @@ DEFAULT_CONFIG = { "enable": True, "username": "astrbot", "password": "77b90590a8945a7d36c963981a307dc9", + "host": "0.0.0.0", "port": 6185, }, "platform": [], @@ -122,6 +123,7 @@ CONFIG_METADATA_2 = { "enable": False, "appid": "", "secret": "", + "callback_server_host": "0.0.0.0", "port": 6196, }, "aiocqhttp(OneBotv11)": { @@ -146,10 +148,11 @@ CONFIG_METADATA_2 = { "enable": False, "corpid": "", "secret": "", - "port": 6195, "token": "", "encoding_aes_key": "", "api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/", + "callback_server_host": "0.0.0.0", + "port": 6195, }, "lark(飞书)": { "id": "lark", @@ -342,7 +345,7 @@ CONFIG_METADATA_2 = { "type": "list", "items": {"type": "string"}, "obvious_hint": True, - "hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单", + "hint": "只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单", }, "id_whitelist_log": { "description": "打印白名单日志", @@ -578,7 +581,7 @@ CONFIG_METADATA_2 = { "dify_api_type": "chat", "dify_api_key": "", "dify_api_base": "https://api.dify.ai/v1", - "dify_workflow_output_key": "", + "dify_workflow_output_key": "astrbot_wf_output", "dify_query_input_key": "astrbot_text_query", "variables": {}, "timeout": 60, @@ -590,6 +593,11 @@ CONFIG_METADATA_2 = { "dashscope_app_type": "agent", "dashscope_api_key": "", "dashscope_app_id": "", + "rag_options": { + "pipeline_ids": [], + "file_ids": [], + "output_reference": False, + }, "variables": {}, "timeout": 60, }, @@ -662,6 +670,30 @@ CONFIG_METADATA_2 = { }, }, "items": { + "rag_options": { + "description": "RAG 选项", + "type": "object", + "hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)", + "items": { + "pipeline_ids": { + "description": "知识库 ID 列表", + "type": "list", + "items": {"type": "string"}, + "hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。", + }, + "file_ids": { + "description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。", + "type": "list", + "items": {"type": "string"}, + "hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。", + }, + "output_reference": { + "description": "是否输出知识库/文档的引用", + "type": "bool", + "hint": "在每次回答尾部加上引用源。默认为 False。", + }, + }, + }, "sensevoice_hint": { "description": "部署SenseVoice", "type": "string", @@ -678,12 +710,14 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。", }, - # "variables": { - # "description": "工作流固定输入变量", - # "type": "object", - # "obvious_hint": True, - # "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", - # }, + "variables": { + "description": "工作流固定输入变量", + "type": "object", + "obvious_hint": True, + "items": {}, + "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", + "invisible": True, + }, # "fastgpt_app_type": { # "description": "应用类型", # "type": "string", @@ -694,7 +728,7 @@ CONFIG_METADATA_2 = { "dashscope_app_type": { "description": "应用类型", "type": "string", - "hint": "阿里云百炼应用的应用类型。", + "hint": "百炼应用的应用类型。", "options": [ "agent", "agent-arrange", diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 44caa2075..64c324a9e 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -25,9 +25,11 @@ SOFTWARE. import base64 import json import os +import uuid import typing as T from enum import Enum from pydantic.v1 import BaseModel +from astrbot.core.utils.io import download_image_by_url, file_to_base64 class ComponentType(Enum): @@ -146,6 +148,51 @@ class Record(BaseMessageComponent): return Record(file=url, **_) raise Exception("not a valid url") + async def convert_to_file_path(self) -> str: + """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 语音的本地路径,以绝对路径表示。 + """ + if self.file and self.file.startswith("file:///"): + file_path = self.file[8:] + return file_path + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + return os.path.abspath(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(file_path) + elif os.path.exists(self.file): + file_path = self.file + return os.path.abspath(file_path) + else: + raise Exception(f"not a valid file: {self.file}") + + async def convert_to_base64(self) -> str: + """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 + + Returns: + str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + if self.file and self.file.startswith("file:///"): + bs64_data = file_to_base64(self.file[8:]) + elif self.file and self.file.startswith("http"): + file_path = await download_image_by_url(self.file) + bs64_data = file_to_base64(file_path) + elif self.file and self.file.startswith("base64://"): + bs64_data = self.file + elif os.path.exists(self.file): + bs64_data = file_to_base64(self.file) + else: + raise Exception(f"not a valid file: {self.file}") + return bs64_data + class Video(BaseMessageComponent): type: ComponentType = "Video" @@ -279,10 +326,6 @@ class Image(BaseMessageComponent): file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: T.Optional[str], **_): - # for k in _.keys(): - # if (k == "_type" and _[k] not in ["flash", "show", None]) or \ - # (k == "c" and _[k] not in [2, 3]): - # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(file=file, **_) @staticmethod @@ -307,14 +350,75 @@ class Image(BaseMessageComponent): def fromIO(IO): return Image.fromBytes(IO.read()) + async def convert_to_file_path(self) -> str: + """将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。 + + Returns: + str: 图片的本地路径,以绝对路径表示。 + """ + url = self.url if self.url else self.file + if url and url.startswith("file:///"): + image_file_path = url[8:] + return image_file_path + elif url and url.startswith("http"): + image_file_path = await download_image_by_url(url) + return os.path.abspath(image_file_path) + elif url and url.startswith("base64://"): + bs64_data = url.removeprefix("base64://") + image_bytes = base64.b64decode(bs64_data) + image_file_path = f"data/temp/{uuid.uuid4()}.jpg" + with open(image_file_path, "wb") as f: + f.write(image_bytes) + return os.path.abspath(image_file_path) + elif os.path.exists(url): + image_file_path = url + return os.path.abspath(image_file_path) + else: + raise Exception(f"not a valid file: {url}") + + async def convert_to_base64(self) -> str: + """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 + + Returns: + str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ + # convert to base64 + url = self.url if self.url else self.file + if url and url.startswith("file:///"): + bs64_data = file_to_base64(url[8:]) + elif url and url.startswith("http"): + image_file_path = await download_image_by_url(url) + bs64_data = file_to_base64(image_file_path) + elif url and url.startswith("base64://"): + bs64_data = url + elif os.path.exists(url): + bs64_data = file_to_base64(url) + else: + raise Exception(f"not a valid file: {url}") + return bs64_data + class Reply(BaseMessageComponent): type: ComponentType = "Reply" id: T.Union[str, int] - text: T.Optional[str] = "" - qq: T.Optional[int] = 0 + """所引用的消息 ID""" + chain: T.Optional[T.List["BaseMessageComponent"]] = [] + """引用的消息段列表""" + sender_id: T.Optional[int] | T.Optional[str] = 0 + """引用的消息发送者 ID""" + sender_nickname: T.Optional[str] = "" + """引用的消息发送者昵称""" time: T.Optional[int] = 0 + """引用的消息发送时间""" + message_str: T.Optional[str] = "" + """解析后的纯文本消息字符串""" + + text: T.Optional[str] = "" + """deprecated""" + qq: T.Optional[int] = 0 + """deprecated""" seq: T.Optional[int] = 0 + """deprecated""" def __init__(self, **_): super().__init__(**_) @@ -353,16 +457,22 @@ class Node(BaseMessageComponent): id: T.Optional[int] = 0 # 忽略 name: T.Optional[str] = "" # qq昵称 uin: T.Optional[int] = 0 # qq号 - content: T.Optional[T.Union[str, list]] = "" # 子消息段列表 + content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表 seq: T.Optional[T.Union[str, list]] = "" # 忽略 time: T.Optional[int] = 0 - def __init__(self, content: T.Union[str, list], **_): + def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_): if isinstance(content, list): - _content = "" - for chain in content: - _content += chain.toString() + _content = None + if all(isinstance(item, Node) for item in content): + _content = [node.toDict() for node in content] + else: + _content = "" + for chain in content: + _content += chain.toString() content = _content + elif isinstance(content, Node): + content = content.toDict() super().__init__(content=content, **_) def toString(self): diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 89aff17a8..48d0b18c9 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -77,6 +77,10 @@ class MessageChain: self.use_t2i_ = use_t2i return self + def get_plain_text(self) -> str: + """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" + return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + class EventResultType(enum.Enum): """用于描述事件处理的结果类型。 @@ -147,9 +151,4 @@ class MessageEventResult(MessageChain): """是否为 LLM 结果。""" return self.result_content_type == ResultContentType.LLM_RESULT - def get_plain_text(self) -> str: - """获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" - return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - - CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 22d67d322..210e62a7c 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -64,8 +64,8 @@ class LLMRequestSubStage(Stage): req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - req.image_urls.append(image_url) + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) # 获取对话上下文 conversation_id = await self.conv_manager.get_curr_conversation_id( @@ -148,11 +148,18 @@ class LLMRequestSubStage(Stage): if llm_response.role == "assistant": # text completion - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.LLM_RESULT) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT) + ) elif llm_response.role == "err": event.set_result( MessageEventResult().message( diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 95f70fd4d..9b2b20155 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -36,7 +36,7 @@ class WakingCheckStage(Stage): # 设置 sender 身份 event.message_str = event.message_str.strip() for admin_id in self.ctx.astrbot_config["admins_id"]: - if event.get_sender_id() == admin_id: + if str(event.get_sender_id()) == admin_id: event.role = "admin" break @@ -106,6 +106,7 @@ class WakingCheckStage(Stage): f"插件 {star_map[handler.handler_module_path].name}: {e}" ) ) + await event._post_send() event.stop_event() passed = False break @@ -117,6 +118,7 @@ class WakingCheckStage(Stage): f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。" ) ) + await event._post_send() event.stop_event() return diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 02d600768..35a2c4179 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -12,6 +12,7 @@ from astrbot.core.message.components import ( At, AtAll, Forward, + Reply, ) from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.platform.message_type import MessageType @@ -102,8 +103,15 @@ class AstrMessageEvent(abc.ABC): elif isinstance(i, Forward): # 转发消息 outline += "[转发消息]" + elif isinstance(i, Reply): + # 引用回复 + if i.message_str: + outline += f"[引用消息({i.sender_nickname}: {i.message_str})]" + else: + outline += "[引用消息]" else: outline += f"[{i.type}]" + outline += " " return outline def get_message_outline(self) -> str: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index ce38296e6..08990015e 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -3,8 +3,6 @@ import asyncio from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes from aiocqhttp import CQHttp -from astrbot.core.utils.io import file_to_base64, download_image_by_url - class AiocqhttpMessageEvent(AstrMessageEvent): def __init__( @@ -21,20 +19,12 @@ class AiocqhttpMessageEvent(AstrMessageEvent): d = segment.toDict() if isinstance(segment, Plain): d["type"] = "text" + d["data"]["text"] = segment.text.strip() elif isinstance(segment, (Image, Record)): # convert to base64 - if segment.file and segment.file.startswith("file:///"): - bs64_data = file_to_base64(segment.file[8:]) - image_file_path = segment.file[8:] - elif segment.file and segment.file.startswith("http"): - image_file_path = await download_image_by_url(segment.file) - bs64_data = file_to_base64(image_file_path) - elif segment.file and segment.file.startswith("base64://"): - bs64_data = segment.file - else: - bs64_data = file_to_base64(segment.file) + bs64 = await segment.convert_to_base64() d["data"] = { - "file": bs64_data, + "file": bs64, } elif isinstance(segment, At): d["data"] = { @@ -55,8 +45,13 @@ class AiocqhttpMessageEvent(AstrMessageEvent): if send_one_by_one: for seg in message.chain: - if isinstance(seg, Nodes): - # 带有多个节点的合并转发消息 + if isinstance(seg, (Node, Nodes)): + # 合并转发消息 + + if isinstance(seg, Node): + nodes = Nodes([seg]) + seg = nodes + payload = seg.toDict() if self.get_group_id(): payload["group_id"] = self.get_group_id() diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 0a2e8d430..0d11e3c0b 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -160,8 +160,14 @@ class AiocqhttpAdapter(Platform): return abm - async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage: - """OneBot V11 消息类事件""" + async def _convert_handle_message_event( + self, event: Event, get_reply=True + ) -> AstrBotMessage: + """OneBot V11 消息类事件 + + @param event: 事件对象 + @param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。 + """ abm = AstrBotMessage() abm.self_id = str(event.self_id) abm.sender = MessageMember( @@ -240,6 +246,36 @@ class AiocqhttpAdapter(Platform): except BaseException as e: logger.error(f"获取文件失败: {e},此消息段将被忽略。") + elif t == "reply": + if not get_reply: + a = ComponentTypes[t](**m["data"]) # noqa: F405 + abm.message.append(a) + else: + try: + reply_event_data = await self.bot.call_action( + action="get_msg", + message_id=int(m["data"]["id"]), + ) + abm_reply = await self._convert_handle_message_event( + Event.from_payload(reply_event_data), get_reply=False + ) + + reply_seg = Reply( + id=abm_reply.message_id, + chain=abm_reply.message, + sender_id=abm_reply.sender.user_id, + sender_nickname=abm_reply.sender.nickname, + time=abm_reply.timestamp, + message_str=abm_reply.message_str, + text=abm_reply.message_str, # for compatibility + qq=abm_reply.sender.user_id, # for compatibility + ) + + abm.message.append(reply_seg) + except BaseException as e: + logger.error(f"获取引用消息失败: {e}。") + a = ComponentTypes[t](**m["data"]) # noqa: F405 + abm.message.append(a) else: a = ComponentTypes[t](**m["data"]) # noqa: F405 abm.message.append(a) diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index 7a07afe9c..90e60a255 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -89,6 +89,15 @@ class SimpleGewechatClient: type_name = data["type_name"] else: raise Exception("无法识别的消息类型") + + # 以下没有业务处理,只是避免控制台打印太多的日志 + if type_name == "ModContacts": + logger.info("gewechat下发:ModContacts消息通知。") + return + if type_name == "DelContacts": + logger.info("gewechat下发:DelContacts消息通知。") + return + if type_name == "Offline": logger.critical("收到 gewechat 下线通知。") return @@ -152,6 +161,11 @@ class SimpleGewechatClient: abm.type = MessageType.FRIEND_MESSAGE user_id = from_user_name + # 检查消息是否由自己发送,若是则忽略 + if user_id == abm.self_id: + logger.info("忽略自己发送的消息") + return None + abm.message = [] if at_me: abm.message.insert(0, At(qq=abm.self_id)) @@ -183,6 +197,11 @@ class SimpleGewechatClient: abm.sender = MessageMember(user_id, user_real_name) abm.raw_message = d abm.message_str = "" + + if user_id == "weixin": + # 忽略微信团队消息 + return + # 不同消息类型 match d["MsgType"]: case 1: @@ -191,17 +210,12 @@ class SimpleGewechatClient: abm.message_str = content case 3: # 图片消息 - # 先看看 base64 数据 - if "ImgBuf" in d and "buffer" in d["ImgBuf"]: - logger.debug("发现图片消息包含 base64 数据,使用。") - abm.message.append(Image.fromBase64(d["ImgBuf"]["buffer"])) - else: - file_url = await self.multimedia_downloader.download_image( - self.appid, content - ) - logger.debug(f"下载图片: {file_url}") - file_path = await download_image_by_url(file_url) - abm.message.append(Image(file=file_path, url=file_path)) + file_url = await self.multimedia_downloader.download_image( + self.appid, content + ) + logger.debug(f"下载图片: {file_url}") + file_path = await download_image_by_url(file_url) + abm.message.append(Image(file=file_path, url=file_path)) case 34: # 语音消息 @@ -217,6 +231,31 @@ class SimpleGewechatClient: async with await anyio.open_file(file_path, "wb") as f: await f.write(voice_data) abm.message.append(Record(file=file_path, url=file_path)) + + # 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志 + case 37: # 好友申请 + logger.info("消息类型(37):好友申请") + case 42: # 名片 + logger.info("消息类型(42):名片") + case 43: # 视频 + logger.info("消息类型(43):视频") + case 47: # emoji + logger.info("消息类型(47):emoji") + case 48: # 地理位置 + logger.info("消息类型(48):地理位置") + case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请 + logger.info( + "消息类型(49):公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请" + ) + case 51: # 帐号消息同步? + logger.info("消息类型(51):帐号消息同步?") + case 10000: # 被踢出群聊/更换群主/修改群名称 + logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称") + case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办 + logger.info( + "消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办" + ) + case _: logger.info(f"未实现的消息类型: {d['MsgType']}") abm.raw_message = d diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 7668663fb..247a2a6a4 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -70,18 +70,10 @@ class GewechatPlatformEvent(AstrMessageEvent): await client.post_text(**payload) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() - # 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径 + # 检查 record_path 是否在 data/temp 目录中 temp_directory = os.path.abspath("data/temp") - img_path = os.path.abspath(img_path) if os.path.commonpath([temp_directory, img_path]) != temp_directory: with open(img_path, "rb") as f: img_path = save_temp_img(f.read()) @@ -93,14 +85,7 @@ class GewechatPlatformEvent(AstrMessageEvent): elif isinstance(comp, Record): # 默认已经存在 data/temp 中 record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url + record_path = await comp.convert_to_file_path() silk_path = f"data/temp/{uuid.uuid4()}.silk" try: diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index fd29b3602..1ee30c482 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -2,6 +2,7 @@ import base64 import asyncio import json import re +import astrbot.api.message_components as Comp from astrbot.api.platform import ( Platform, @@ -11,7 +12,6 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .lark_event import LarkMessageEvent from ...register import register_platform_adapter @@ -92,7 +92,7 @@ class LarkPlatformAdapter(Platform): at_list = {} if message.mentions: for m in message.mentions: - at_list[m.key] = At(qq=m.id.open_id, name=m.name) + at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.name == self.bot_name: abm.self_id = m.id.open_id @@ -111,7 +111,7 @@ class LarkPlatformAdapter(Platform): if s in at_list: abm.message.append(at_list[s]) else: - abm.message.append(Plain(parts[i].strip())) + abm.message.append(Comp.Plain(parts[i].strip())) elif message.message_type == "post": _ls = [] @@ -132,7 +132,7 @@ class LarkPlatformAdapter(Platform): if comp["tag"] == "at": abm.message.append(at_list[comp["user_id"]]) elif comp["tag"] == "text" and comp["text"].strip(): - abm.message.append(Plain(comp["text"].strip())) + abm.message.append(Comp.Plain(comp["text"].strip())) elif comp["tag"] == "img": image_key = comp["image_key"] request = ( @@ -147,10 +147,10 @@ class LarkPlatformAdapter(Platform): logger.error(f"无法下载飞书图片: {image_key}") image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() - abm.message.append(Image.fromBase64(image_base64)) + abm.message.append(Comp.Image.fromBase64(image_base64)) for comp in abm.message: - if isinstance(comp, Plain): + if isinstance(comp, Comp.Plain): abm.message_str += comp.text abm.message_id = message.message_id abm.raw_message = message diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index ec7eaa8fc..d31006618 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -122,16 +122,16 @@ class QQOfficialMessageEvent(AstrMessageEvent): plain_text += i.text elif isinstance(i, Image) and not image_base64: if i.file and i.file.startswith("file:///"): - image_base64 = file_to_base64(i.file[8:]).replace("base64://", "") + image_base64 = file_to_base64(i.file[8:]) image_file_path = i.file[8:] elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) - image_base64 = file_to_base64(image_file_path).replace( - "base64://", "" - ) + image_base64 = file_to_base64(image_file_path) + elif i.file and i.file.startswith("base64://"): + image_base64 = i.file else: - image_base64 = file_to_base64(i.file).replace("base64://", "") - image_file_path = i.file + image_base64 = file_to_base64(i.file) + image_base64 = image_base64.removeprefix("base64://") else: logger.debug(f"qq_official 忽略 {i.type}") return plain_text, image_base64, image_file_path diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 562574204..a219e2492 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -15,6 +15,7 @@ class QQOfficialWebhook: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) + self.callback_server_host = config.get("callback_server_host", "0.0.0.0") if isinstance(self.port, str): self.port = int(self.port) @@ -95,8 +96,11 @@ class QQOfficialWebhook: return {"opcode": 12} async def start_polling(self): + logger.info( + f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。" + ) await self.server.run_task( - host="0.0.0.0", + host=self.callback_server_host, port=self.port, shutdown_trigger=self.shutdown_trigger_placeholder, ) diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index dfb882328..3dde1802b 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -17,6 +17,7 @@ from astrbot.api.message_components import ( File as AstrBotFile, Video, At, + Reply, ) from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.api.platform import register_platform_adapter @@ -68,7 +69,7 @@ class TelegramPlatformAdapter(Platform): ) message_handler = TelegramMessageHandler( filters=filters.ALL, # receive all messages - callback=self.convert_message, + callback=self.message_handler, ) self.application.add_handler(message_handler) self.client = self.application.bot @@ -104,29 +105,64 @@ class TelegramPlatformAdapter(Platform): chat_id=update.effective_chat.id, text=self.config["start_message"] ) + async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + logger.debug(f"Telegram message: {update.message}") + abm = await self.convert_message(update, context) + await self.handle_msg(abm) + async def convert_message( - self, update: Update, context: ContextTypes.DEFAULT_TYPE + self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True ) -> AstrBotMessage: + """转换 Telegram 的消息对象为 AstrBotMessage 对象。 + + @param update: Telegram 的 Update 对象。 + @param context: Telegram 的 Context 对象。 + @param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。 + """ message = AstrBotMessage() # 获得是群聊还是私聊 - if update.effective_chat.type == ChatType.PRIVATE: + if update.message.chat.type == ChatType.PRIVATE: message.type = MessageType.FRIEND_MESSAGE else: message.type = MessageType.GROUP_MESSAGE - message.group_id = update.effective_chat.id + message.group_id = str(update.message.chat.id) + if update.message.message_thread_id: + # Topic Group + message.group_id += "#" + str(update.message.message_thread_id) + message.message_id = str(update.message.message_id) - message.session_id = str(update.effective_chat.id) + message.session_id = str(update.message.chat.id) message.sender = MessageMember( - str(update.effective_user.id), update.effective_user.username + str(update.message.from_user.id), update.message.from_user.username ) message.self_id = str(context.bot.username) message.raw_message = update message.message_str = "" message.message = [] - logger.debug(f"Telegram message: {update.message}") + if update.message.reply_to_message: + # 获取回复消息 + reply_update = Update( + update_id=1, + message=update.message.reply_to_message, + ) + reply_abm = await self.convert_message(reply_update, context, False) + + message.message.append( + Reply( + id=reply_abm.message_id, + chain=reply_abm.message, + sender_id=reply_abm.sender.user_id, + sender_nickname=reply_abm.sender.nickname, + time=reply_abm.timestamp, + message_str=reply_abm.message_str, + text=reply_abm.message_str, + qq=reply_abm.sender.user_id, + ) + ) if update.message.text: + # 处理文本消息 plain_text = update.message.text if update.message.entities: @@ -174,7 +210,7 @@ class TelegramPlatformAdapter(Platform): Video(file=file.file_path, path=file.file_path), ] - await self.handle_msg(message) + return message async def handle_msg(self, message: AstrBotMessage): message_event = TelegramPlatformEvent( diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index de0f9a58e..d19017a4f 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -2,6 +2,7 @@ from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType from astrbot.api.message_components import Plain, Image, Reply, At, File, Record from telegram.ext import ExtBot +from astrbot.core.utils.io import download_file class TelegramPlatformEvent(AstrMessageEvent): @@ -31,12 +32,18 @@ class TelegramPlatformEvent(AstrMessageEvent): at_user_id = i.name at_flag = False + message_thread_id = None + if "#" in user_name: + # it's a supergroup chat with message_thread_id + user_name, message_thread_id = user_name.split("#") for i in message.chain: payload = { "chat_id": user_name, } if has_reply: payload["reply_to_message_id"] = reply_message_id + if message_thread_id: + payload["reply_to_message_id"] = message_thread_id if isinstance(i, Plain): if at_user_id and not at_flag: @@ -44,23 +51,18 @@ class TelegramPlatformEvent(AstrMessageEvent): at_flag = True await client.send_message(text=i.text, **payload) elif isinstance(i, Image): - if i.path: - image_path = i.path - else: - image_path = i.file - - if image_path.startswith("base64://"): - import base64 - - base64_data = image_path[9:] - image_bytes = base64.b64decode(base64_data) - await client.send_photo(photo=image_bytes, **payload) - else: - await client.send_photo(photo=image_path, **payload) + image_path = await i.convert_to_file_path() + await client.send_photo(photo=image_path, **payload) elif isinstance(i, File): + if i.file.startswith("https://"): + path = "data/temp/" + i.name + await download_file(i.file, path) + i.file = path + await client.send_document(document=i.file, filename=i.name, **payload) elif isinstance(i, Record): - await client.send_voice(voice=i.file, **payload) + path = await i.convert_to_file_path() + await client.send_voice(voice=path, **payload) async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 77eae03d6..cef83b030 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -34,6 +34,7 @@ class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) self.port = int(config.get("port")) + self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.server.add_url_rule( "/callback/command", view_func=self.verify, methods=["GET"] ) @@ -86,9 +87,11 @@ class WecomServer: return "success" async def start_polling(self): - logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。") + logger.info( + f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。" + ) await self.server.run_task( - host="0.0.0.0", + host=self.callback_server_host, port=self.port, shutdown_trigger=self.shutdown_trigger_placeholder, ) diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 83e99b5c4..c6f8d6ef6 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -43,14 +43,7 @@ class WecomPlatformEvent(AstrMessageEvent): message_obj.self_id, message_obj.session_id, comp.text ) elif isinstance(comp, Image): - img_url = comp.file - img_path = "" - if img_url.startswith("file:///"): - img_path = img_url[8:] - elif comp.file and comp.file.startswith("http"): - img_path = await download_image_by_url(comp.file) - else: - img_path = img_url + img_path = await comp.convert_to_file_path() with open(img_path, "rb") as f: try: @@ -68,16 +61,7 @@ class WecomPlatformEvent(AstrMessageEvent): response["media_id"], ) elif isinstance(comp, Record): - record_url = comp.file - record_path = "" - - if record_url.startswith("file:///"): - record_path = record_url[8:] - elif record_url.startswith("http"): - await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav") - else: - record_path = record_url - + record_path = await comp.convert_to_file_path() # 转成amr record_path_amr = f"data/temp/{uuid.uuid4()}.amr" pydub.AudioSegment.from_wav(record_path).export( diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index 3180b4955..c51b860f0 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -4,6 +4,8 @@ from typing import List, Dict, Type from .func_tool_manager import FuncCall from openai.types.chat.chat_completion import ChatCompletion from astrbot.core.db.po import Conversation +from astrbot.core.message.message_event_result import MessageChain +import astrbot.core.message.components as Comp class ProviderType(enum.Enum): @@ -56,8 +58,8 @@ class ProviderRequest: class LLMResponse: role: str """角色, assistant, tool, err""" - completion_text: str = "" - """LLM 返回的文本""" + result_chain: MessageChain = None + """返回的消息链""" tools_call_args: List[Dict[str, any]] = field(default_factory=list) """工具调用参数""" tools_call_name: List[str] = field(default_factory=list) @@ -65,3 +67,51 @@ class LLMResponse: raw_completion: ChatCompletion = None _new_record: Dict[str, any] = None + + _completion_text: str = "" + + def __init__( + self, + role: str, + completion_text: str = "", + result_chain: MessageChain = None, + tools_call_args: List[Dict[str, any]] = None, + tools_call_name: List[str] = None, + raw_completion: ChatCompletion = None, + _new_record: Dict[str, any] = None, + ): + """初始化 LLMResponse + + Args: + role (str): 角色, assistant, tool, err + completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "". + result_chain (MessageChain, optional): 返回的消息链. Defaults to None. + tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None. + tools_call_name (List[str], optional): 工具调用名称. Defaults to None. + raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. + """ + self.role = role + self.completion_text = completion_text + self.result_chain = result_chain + self.tools_call_args = tools_call_args + self.tools_call_name = tools_call_name + self.raw_completion = raw_completion + self._new_record = _new_record + + @property + def completion_text(self): + if self.result_chain: + return self.result_chain.get_plain_text() + return self._completion_text + + @completion_text.setter + def completion_text(self, value): + if self.result_chain: + self.result_chain.chain = [ + comp + for comp in self.result_chain.chain + if not isinstance(comp, Comp.Plain) + ] # 清空 Plain 组件 + self.result_chain.chain.insert(0, Comp.Plain(value)) + else: + self._completion_text = value diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 0022d2773..0f04628f7 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -2,6 +2,7 @@ import json import textwrap from typing import Dict, List, Awaitable from dataclasses import dataclass +from astrbot import logger @dataclass @@ -46,14 +47,16 @@ class FuncCall: desc: str, handler: Awaitable, ) -> None: - """ - 为函数调用(function-calling / tools-use)添加工具。 + """添加函数调用工具 @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @param desc: 函数描述 @param func_obj: 处理函数 """ + # check if the tool has been added before + self.remove_func(name) + params = { "type": "object", # hard-coded here "properties": {}, @@ -70,13 +73,14 @@ class FuncCall: handler=handler, ) self.func_list.append(_func) + logger.info(f"添加函数调用工具: {name}") def remove_func(self, name: str) -> None: """ 删除一个函数调用工具。 """ for i, f in enumerate(self.func_list): - if f["name"] == name: + if f.name == name: self.func_list.pop(i) break diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 9647b41c0..7158d57b9 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -1,3 +1,4 @@ +import re import asyncio import functools from typing import List @@ -40,11 +41,24 @@ class ProviderDashscope(ProviderOpenAIOfficial): raise Exception("阿里云百炼 APP 类型不能为空。") self.model_name = "dashscope" self.variables: dict = provider_config.get("variables", {}) + self.rag_options: dict = provider_config.get("rag_options", {}) + self.output_reference = self.rag_options.get("output_reference", False) + self.rag_options = self.rag_options.copy() + self.rag_options.pop("output_reference", None) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) + def has_rag_options(self): + if ( + self.rag_options + and self.rag_options.get("pipeline_ids", None) + and self.rag_options.get("file_ids", None) + ): + return True + return False + async def text_chat( self, prompt: str, @@ -62,7 +76,10 @@ class ProviderDashscope(ProviderOpenAIOfficial): session_var = session_vars.get(session_id, {}) payload_vars.update(session_var) - if self.dashscope_app_type in ["agent", "dialog-workflow"]: + if ( + self.dashscope_app_type in ["agent", "dialog-workflow"] + and self.has_rag_options() + ): # 支持多轮对话的 new_record = {"role": "user", "content": prompt} if image_urls: @@ -86,12 +103,17 @@ class ProviderDashscope(ProviderOpenAIOfficial): else: # 不支持多轮对话的 # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "prompt": prompt, + "api_key": self.api_key, + "biz_params": payload_vars or None, + } + if self.rag_options: + payload["rag_options"] = self.rag_options partial = functools.partial( Application.call, - app_id=self.app_id, - promtp=prompt, - api_key=self.api_key, - biz_params=payload_vars or None, + **payload, ) response = await asyncio.get_event_loop().run_in_executor(None, partial) @@ -107,6 +129,14 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) output_text = response.output.get("text", "") + # RAG 引用脚标格式化 + output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) + if self.output_reference and response.output.get("doc_references", None): + ref_str = "" + for ref in response.output.get("doc_references", []): + ref_str += f"{ref['index_id']}. {ref['title']}\n" + output_text += f"\n\n回答来源:\n{ref_str}" + return LLMResponse(role="assistant", completion_text=output_text) async def forget(self, session_id): diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 37f575f21..8b5890c28 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,3 +1,5 @@ +import astrbot.core.message.components as Comp + from typing import List from .. import Provider, Personality from ..entites import LLMResponse @@ -5,8 +7,9 @@ from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_image_by_url +from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain @register_provider_adapter("dify", "Dify APP 适配器。") @@ -30,7 +33,6 @@ class ProviderDify(Provider): if not self.api_key: raise Exception("Dify API Key 不能为空。") api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_client = DifyAPIClient(self.api_key, api_base) self.api_type = provider_config.get("dify_api_type", "") if not self.api_type: raise Exception("Dify API 类型不能为空。") @@ -41,15 +43,19 @@ class ProviderDify(Provider): self.dify_query_input_key = provider_config.get( "dify_query_input_key", "astrbot_text_query" ) - self.variables: dict = provider_config.get("variables", {}) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" + if not self.workflow_output_key: + self.workflow_output_key = "astrbot_wf_output" + self.variables: dict = provider_config.get("variables", {}) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.conversation_ids = {} """记录当前 session id 的对话 ID""" + self.api_client = DifyAPIClient(self.api_key, api_base) + async def text_chat( self, prompt: str, @@ -65,26 +71,27 @@ class ProviderDify(Provider): files_payload = [] for image_url in image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - file_response = await self.api_client.file_upload( - image_path, user=session_id + image_path = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + file_response = await self.api_client.file_upload( + image_path, user=session_id + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - } - ) - else: - # TODO: 处理更多情况 - logger.warning(f"未知的图片链接:{image_url},图片将忽略。") + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) # 获得会话变量 payload_vars = self.variables.copy() @@ -96,6 +103,9 @@ class ProviderDify(Provider): try: match self.api_type: case "chat" | "agent": + if not prompt: + prompt = "请描述这张图片。" + async for chunk in self.api_client.chat_messages( inputs={ **payload_vars, @@ -148,8 +158,9 @@ class ProviderDify(Provider): ) case "workflow_finished": logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。" + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束" ) + logger.debug(f"Dify 工作流结果:{chunk}") if chunk["data"]["error"]: logger.error( f"Dify 工作流出现错误:{chunk['data']['error']}" @@ -164,9 +175,7 @@ class ProviderDify(Provider): raise Exception( f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}" ) - result = chunk["data"]["outputs"][ - self.workflow_output_key - ] + result = chunk case _: raise Exception(f"未知的 Dify API 类型:{self.api_type}") except Exception as e: @@ -176,7 +185,54 @@ class ProviderDify(Provider): if not result: logger.warning("Dify 请求结果为空,请查看 Debug 日志。") - return LLMResponse(role="assistant", completion_text=result) + chain = await self.parse_dify_result(result) + + return LLMResponse(role="assistant", result_chain=chain) + + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: + if isinstance(chunk, str): + # Chat + return MessageChain(chain=[Comp.Plain(chunk)]) + + async def parse_file(item: dict) -> Comp: + match item["type"]: + case "image": + return Comp.Image(file=item["url"], url=item["url"]) + case "audio": + # 仅支持 wav + path = f"data/temp/{item['filename']}.wav" + await download_file(item["url"], path) + return Comp.Image(file=item["url"], url=item["url"]) + case "video": + return Comp.Video(file=item["url"]) + case _: + return Comp.File(name=item["filename"], file=item["url"]) + + output = chunk["data"]["outputs"][self.workflow_output_key] + chains = [] + if isinstance(output, str): + # 纯文本输出 + chains.append(Comp.Plain(output)) + elif isinstance(output, list): + # 主要适配 Dify 的 HTTP 请求结点的多模态输出 + for item in output: + # handle Array[File] + if ( + not isinstance(item, dict) + or item.get("dify_model_identity", "") != "__dify__file__" + ): + chains.append(Comp.Plain(str(output))) + break + else: + chains.append(Comp.Plain(str(output))) + + # scan file + files = chunk["data"].get("files", []) + for item in files: + comp = await parse_file(item) + chains.append(comp) + + return MessageChain(chain=chains) async def forget(self, session_id): self.conversation_ids[session_id] = "" diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 56b42a0ca..c7887d3ea 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -57,23 +57,30 @@ class ProviderEdgeTTS(TTSProvider): # 使用ffmpeg将MP3转换为标准WAV格式 _ = await asyncio.create_subprocess_exec( - [ - "ffmpeg", - "-y", # 覆盖输出文件 - "-i", - mp3_path, # 输入文件 - "-acodec", - "pcm_s16le", # 16位PCM编码 - "-ar", - "24000", # 采样率24kHz (适合微信语音) - "-ac", - "1", # 单声道 - wav_path, # 输出文件 - ], - capture_output=True, - check=True, + "ffmpeg", + "-y", # 覆盖输出文件 + "-i", + mp3_path, # 输入文件 + "-acodec", + "pcm_s16le", # 16位PCM编码 + "-ar", + "24000", # 采样率24kHz (适合微信语音) + "-ac", + "1", # 单声道 + "-af", + "apad=pad_dur=2", # 确保输出时长准确 + "-fflags", + "+genpts", # 强制生成时间戳 + "-hide_banner", # 隐藏版本信息 + wav_path, # 输出文件 + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) - + # 等待进程完成并获取输出 + stdout, stderr = await _.communicate() + logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") + logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") + logger.info(f"[EdgeTTS] 返回值(0代表成功): {_.returncode}") os.remove(mp3_path) if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: return wav_path @@ -82,13 +89,15 @@ class ProviderEdgeTTS(TTSProvider): raise RuntimeError("生成的WAV文件不存在或为空") except subprocess.CalledProcessError as e: - logger.error(f"FFmpeg转换失败: {e.stderr.decode() if e.stderr else str(e)}") + logger.error( + f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}" + ) try: if os.path.exists(mp3_path): os.remove(mp3_path) except Exception: pass - raise RuntimeError(f"FFmpeg转换失败: {str(e)}") + raise RuntimeError(f"FFmpeg 转换失败: {str(e)}") except Exception as e: logger.error(f"音频生成失败: {str(e)}") diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index b59f2c283..f120a6a59 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -18,10 +18,14 @@ class ProviderOpenAITTSAPI(TTSProvider): self.chosen_api_key = provider_config.get("api_key", "") self.voice = provider_config.get("openai-tts-voice", "alloy") + timeout = provider_config.get("timeout", NOT_GIVEN) + if isinstance(timeout, str): + timeout = int(timeout) + self.client = AsyncOpenAI( api_key=self.chosen_api_key, base_url=provider_config.get("api_base", None), - timeout=provider_config.get("timeout", NOT_GIVEN), + timeout=timeout, ) self.set_model(provider_config.get("model", None)) diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 4e2f9d176..0b9f7ad09 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -15,7 +15,6 @@ from ..filter.regex import RegexFilter from typing import Awaitable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools -from astrbot.core import logger def get_handler_full_name(awaitable: Awaitable) -> str: @@ -359,9 +358,9 @@ def register_llm_tool(name: str = None): } ) md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) - llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler) - - logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") + llm_tools.add_func( + llm_tool_name, args, docstring.description.strip(), md.handler + ) return awaitable return decorator diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index bdf2c2473..347bc13ef 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -187,6 +187,8 @@ class PluginManager: f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" ) + await self._unbind_plugin(smd.name, smd.module_path) + star_handlers_registry.clear() star_map.clear() star_registry.clear() @@ -483,7 +485,9 @@ class PluginManager: for handler in star_handlers_registry.get_handlers_by_module_name( plugin_module_path ): - logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}") + logger.info( + f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})" + ) star_handlers_registry.remove(handler) keys_to_delete = [ k @@ -491,9 +495,10 @@ class PluginManager: if k.startswith(plugin_module_path) ] for k in keys_to_delete: - v = star_handlers_registry.star_handlers_map[k] - logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)") - del star_handlers_registry.star_handlers_map[k] + try: + del star_handlers_registry.star_handlers_map[k] + except KeyError: + pass try: del sys.modules[plugin_module_path] @@ -509,7 +514,7 @@ class PluginManager: raise Exception("该插件是 AstrBot 保留插件,无法更新。") await self.updator.update(plugin, proxy=proxy) - await self.reload() + await self.reload(plugin_name) async def turn_off_plugin(self, plugin_name: str): """ diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index 80be3fff7..badf5d62b 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -8,7 +8,7 @@ class DifyAPIClient: def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): self.api_key = api_key self.api_base = api_base - self.session = ClientSession() + self.session = ClientSession(trust_env=True) self.headers = { "Authorization": f"Bearer {self.api_key}", } diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 2e525a3ca..34ba8fd3f 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -3,6 +3,7 @@ import datetime from .route import Route, Response, RouteContext from quart import request from astrbot.core import WEBUI_SK +from astrbot import logger class AuthRoute(Route): @@ -19,9 +20,20 @@ class AuthRoute(Route): password = self.config["dashboard"]["password"] post_data = await request.json if post_data["username"] == username and post_data["password"] == password: + change_pwd_hint = False + if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9": + change_pwd_hint = True + logger.warning("为了保证安全,请尽快修改默认密码。") + return ( Response() - .ok({"token": self.generate_jwt(username), "username": username}) + .ok( + { + "token": self.generate_jwt(username), + "username": username, + "change_pwd_hint": change_pwd_hint, + } + ) .__dict__ ) else: diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 54140d92a..088c999f9 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -29,11 +29,21 @@ def validate_config( ) -> typing.Tuple[typing.List[str], typing.Dict]: errors = [] - def validate(data, metadata=schema, path=""): - for key, meta in metadata.items(): - if key not in data: + def validate(data: dict, metadata: dict = schema, path=""): + for key, value in data.items(): + if key not in metadata: + # 无 schema 的配置项,执行类型猜测 + if isinstance(value, str): + if value.isdigit(): + data[key] = int(value) + elif value.replace(".", "", 1).isdigit(): + data[key] = float(value) + elif value == "true": + data[key] = True + elif value == "false": + data[key] = False continue - value = data[key] + meta = metadata[key] # null 转换 if value is None: data[key] = DEFAULT_VALUE_MAP[meta["type"]] @@ -43,6 +53,16 @@ def validate_config( errors.append( f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}" ) + elif ( + meta["type"] == "list" + and isinstance(value, list) + and value + and "items" in meta + and isinstance(value[0], dict) + ): + # 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置 + for item in value: + validate(item, meta["items"], path=f"{path}{key}.") elif meta["type"] == "object" and isinstance(value, dict): validate(value, meta["items"], path=f"{path}{key}.") diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 42bbad318..6fc0651fa 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -120,12 +120,22 @@ class AstrBotDashboard: return f"获取进程信息失败: {str(e)}" def run(self): - try: - ip_addr = get_local_ip_addresses() - except Exception as _: - ip_addr = [] - + ip_addr = [] port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185) + host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0") + + logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}") + + if host == "0.0.0.0": + logger.info( + "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)" + ) + + if host not in ["localhost", "127.0.0.1"]: + try: + ip_addr = get_local_ip_addresses() + except Exception as _: + pass if isinstance(port, str): port = int(port) @@ -142,15 +152,21 @@ class AstrBotDashboard: raise Exception(f"端口 {port} 已被占用") - display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n" + display = f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n" display += f" ➜ 本地: http://localhost:{port}\n" for ip in ip_addr: display += f" ➜ 网络: http://{ip}:{port}\n" display += " ➜ 默认用户名和密码: astrbot\n ✨✨✨\n" + + if not ip_addr: + display += ( + "可在 data/cmd_config.json 中配置 dashboard.host 以便远程访问。\n" + ) + logger.info(display) return self.app.run_task( - host="0.0.0.0", + host=host, port=port, shutdown_trigger=self.shutdown_trigger_placeholder, ) diff --git a/changelogs/v3.4.38.md b/changelogs/v3.4.38.md new file mode 100644 index 000000000..2a3504fcf --- /dev/null +++ b/changelogs/v3.4.38.md @@ -0,0 +1,57 @@ +# What's Changed + +> Special thanks for all contributors and plugin developers and users who love AstrBot. 💖 + +## ✨ 新增的功能 + +1. 支持解析回复消息,支持 LLM 对所引用消息具有感知 #783 +2. 支持 Dify 的文件、图片、视频、音频输出 #819 +3. QQ 下支持嵌套转发(napcat) @zouyonghe +4. 配置页样式重写,更紧凑的 WebUI 配置 + +## 🎈 功能性优化 + +1. 使用系统时间而不是 UTC+8 时间作为默认时间以适应海外用户需求 @roeseth +2. 在对话隔离情况下也可以将整个群聊加入白名单 #746 +3. 在调用插件异常时更完整的报错输出 +4. gewechat 下对已知且没有业务处理的事件类型不显示详细日志 @diudiu62 +5. 优化 WebUI 悬浮文档 @IGCrystal +6. 支持自定义 WebUI、Wecom Webhook Server, QQ Official Webhook Server 的 host #821 +7. Dify 下当只有图片输入时的默认 prompt 防止一些报错 #837 + +## 🐛 修复的 Bug + +1. fishaudio 默认 baseurl 不可用 +2. gewechat 下重复登录后提示设备不存在导致无法重新登陆 @beat4ocean +3. gewechat 下用户本人发消息会触发消息回复 @beat4ocean +4. 钉钉 WebUI 文档不显示 +5. 更新插件后插件热重载不完全、函数工具重复添加 +6. OpenAI TTS API TypeError 报错 #755 +7. EdgeTTS 部分情况下无法使用 @Soulter @需要哦 +8. QQ 官方机器人平台下发送 base64 图片消息段报错 @Soulter @shuiping233 +9. QQ 官方机器人平台下命令参数报错信息无法正常发送 @shuiping233 +10. WebUI 错误地显示未知更新 +11. 部分情况下文件无法上传到 Telegram 群组 #601 +12. 插件管理的插件简介太长导致 “帮助”“操作”图标不显示 #790 +13. LLOnebot 合并消息转发错误 #842 +14. model_config 中自定义的配置项(如温度)类型自动变回 string #854 + +## 🧩 新增的插件 + +1. astrbot_plugin_image_understanding_Janus-Pro - 使用deepseek-ai/Janus-Pro系列模型为本地模型提供的图片理解补充 @xiewoc +2. astrbot_plugin_moyurenpro - 摸鱼人日历,支持自定义时间时区,自定义api,支持立即发送,工作日定时发送。 @quirrel-zh @DuBwTf +3. astrbot_plugin_wechat_manager - 微信关键字好友自动审核、关键字邀请进群。@diudiu62 +4. astrbot_plugin_qwq_filter - qwq 思考过滤工具 @beat4ocean +5. astrbot_plugin_chatsummary - 一个通过拉取历史聊天记录,调用LLM大模型接口实现消息总结功能。@laopanmemz +6. astrBot_PGR_Dialogue - 检测到部分战双角色的名称(或别称)时,有概率发送一条语音文本 @KurisuRee7 +7. astrbot_plugin_bv - 解析群内https://www.bilibili.com/video/BV号/ 的链接并获取视频数据与视频文件,以合并转发方式发送 @haliludaxuanfeng +8. astrbot_plugin_gemini_exp - 让你在AstrBot调用Gemini2.0-flash-exp来生成图片或者p图。Gemini2.0-flash-exp为原生多模态模型,其既是语言模型,也是生图模型,因此能够对图像使用简单的自然语言命令进行处理。@Elen123bot +9. astrbot_plugin_sjzb - 随机生成绝地潜兵2游戏中一组4个战备配置 @tenno1174 +10. astrbot_plugin_picture_manager - 图片管理插件,允许用户通过自定义触发指令从API或直接URL获取图片。@bigshabei +11. astrbot_plugin_bilibiliParse - 解析哔哩哔哩视频,并以图片的形式发送给用户 @7Hello12 +12. astrbot_plugin_sensoji - 这是一个模拟日本浅草寺抽签功能的插件。用户可以通过发送 /抽签 命令随机抽取一个签文,获取运势提示。签文包含吉凶结果(如“大吉”、“凶”等)以及对应的运势描述。 @Shouugou +13. astrbot_plugin_videosummary - 使用 bibigpt 实现视频总结 @kterna +14. astrbot_plugin_InitiativeDialogue - 使 bot 在用户长时间未发送消息时主动与用户对话的插件 @advent259141 +15. astrbot_plugin_emoji - 基于达莉娅综合群娱插件的表情包制作插件,仅保留了@其他群员制作表情包的部分。由桑帛云API提供表情包制作。@KurisuRee7 +16. astrbot_plugin_videos_analysis - 聚合视频分享链接解析(仅测试过napcat) @miaoxutao123 +17. astrbot_plugin_daily_news - 每日 60 秒新闻推送插件 - 自动推送每日热点新闻 @anka-afk \ No newline at end of file diff --git a/changelogs/v3.4.39.md b/changelogs/v3.4.39.md new file mode 100644 index 000000000..d80b4e86d --- /dev/null +++ b/changelogs/v3.4.39.md @@ -0,0 +1,4 @@ +# What's Changed + +1. 默认账户密码登录成功后弹出修改警告 +2. 将 WebUI 默认 host 改变回 v3.4.38 之前的版本以减少兼容性问题。 \ No newline at end of file diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 6ace849ef..2796f95da 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -1,130 +1,146 @@ - - \ No newline at end of file + removeItem(index) { + this.items.splice(index, 1); + }, + }, +}; + + + \ No newline at end of file diff --git a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue index 68e6d38ee..3d08c9cd6 100644 --- a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue +++ b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue @@ -8,6 +8,7 @@ import { useCommonStore } from '@/stores/common'; const customizer = useCustomizerStore(); let dialog = ref(false); +let accountWarning = ref(false) let updateStatusDialog = ref(false); let password = ref(''); let newPassword = ref(''); @@ -177,6 +178,14 @@ checkUpdate(); const commonStore = useCommonStore(); commonStore.createWebSocket(); commonStore.getStartTime(); + + +if (localStorage.getItem('change_pwd_hint') != null && localStorage.getItem('change_pwd_hint') == 'true') { + dialog.value = true; + accountWarning.value = true; + localStorage.removeItem('change_pwd_hint'); +} +