diff --git a/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md b/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md deleted file mode 100644 index 0358a5b27..000000000 --- a/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.md +++ /dev/null @@ -1,31 +0,0 @@ ---- -name: '🥳 发布插件' -title: "[Plugin] 插件名" -about: 提交插件到插件市场 -labels: [ "plugin-publish" ] -assignees: '' - ---- - -欢迎发布插件到插件市场! - -## 插件基本信息 - -请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。 - -```json -{ - "name": "插件名", - "desc": "插件介绍", - "author": "作者名", - "repo": "插件仓库链接", - "tags": [], - "social_link": "" -} -``` - -## 检查 - -- [ ] 我的插件经过完整的测试 -- [ ] 我的插件不包含恶意代码 -- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。 diff --git a/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml b/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml new file mode 100644 index 000000000..7957178cf --- /dev/null +++ b/.github/ISSUE_TEMPLATE/PLUGIN_PUBLISH.yml @@ -0,0 +1,56 @@ +name: 🥳 发布插件 +description: 提交插件到插件市场 +title: "[Plugin] 插件名" +labels: ["plugin-publish"] +assignees: [] +body: + - type: markdown + attributes: + value: | + 欢迎发布插件到插件市场! + + - type: markdown + attributes: + value: | + ## 插件基本信息 + + 请将插件信息填写到下方的 JSON 代码块中。其中 `tags`(插件标签)和 `social_link`(社交链接)选填。 + + 不熟悉 JSON ?现在可以从 [这里](https://plugins.astrbot.app/#/submit) 获取你的 JSON 啦!获取到了记得复制粘贴过来哦! + + - type: textarea + id: plugin-info + attributes: + label: 插件信息 + description: 请在下方代码块中填写您的插件信息,确保反引号包裹了JSON + value: | + ```json + { + "name": "插件名", + "desc": "插件介绍", + "author": "作者名", + "repo": "插件仓库链接", + "tags": [], + "social_link": "" + } + ``` + validations: + required: true + + - type: markdown + attributes: + value: | + ## 检查 + + - type: checkboxes + id: checks + attributes: + label: 插件检查清单 + description: 请确认以下所有项目 + options: + - label: 我的插件经过完整的测试 + required: true + - label: 我的插件不包含恶意代码 + required: true + - label: 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。 + required: true diff --git a/README.md b/README.md index 2240acc21..ec1a8bdf3 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用 > 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本! 1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。 -2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、微信、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。 +2. **多消息平台接入**。支持接入 QQ(OneBot、QQ 官方机器人平台)、QQ 频道、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK、VoceChat。支持速率限制、白名单、关键词过滤、百度内容审核。 3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。 4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。 5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。 @@ -78,6 +78,10 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用 请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。 +#### 1Panel 部署 + +请参阅官方文档 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html) 。 + #### CasaOS 部署 社区贡献的部署方式。 @@ -125,7 +129,6 @@ uvx astrbot init | -------- | ------- | | QQ(官方机器人接口) | ✔ | | QQ(OneBot) | ✔ | -| 微信个人号 | ✔ | | Telegram | ✔ | | 企业微信 | ✔ | | 微信客服 | ✔ | @@ -246,11 +249,5 @@ _✨ WebUI ✨_ ![10k-star-banner-credit-by-kevin](https://github.com/user-attachments/assets/c97fc5fb-20b9-4bc8-9998-c20b930ab097) -## Disclaimer - -1. The project is protected under the `AGPL-v3` opensource license. -2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility. -3. Please ensure compliance with local laws and regulations when using this project. - _私は、高性能ですから!_ diff --git a/README_ja.md b/README_ja.md index 325857d9a..8e648c8e5 100644 --- a/README_ja.md +++ b/README_ja.md @@ -1,5 +1,5 @@

- + ![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)

@@ -27,7 +27,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ ## ✨ 主な機能 1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。 -2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。 +2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。 3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。 4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。 5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。 @@ -35,7 +35,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ > [!TIP] > 管理パネルのオンラインデモを体験する: [https://demo.astrbot.app/](https://demo.astrbot.app/) -> +> > ユーザー名: `astrbot`, パスワード: `astrbot`。LLM が設定されていないため、チャットページで大規模モデルを使用することはできません。(デモのログインパスワードを変更しないでください 😭) ## ✨ 使用方法 @@ -136,11 +136,11 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_ ## ⭐ Star History -> [!TIP] +> [!TIP] > このプロジェクトがあなたの生活や仕事に役立った場合、またはこのプロジェクトの将来の発展に関心がある場合は、プロジェクトに Star を付けてください。これはこのオープンソースプロジェクトを維持するためのモチベーションです <3
- + [![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
@@ -152,8 +152,7 @@ _✨ 内蔵 Web Chat、オンラインでボットと対話 ✨_ ## 免責事項 1. このプロジェクトは `AGPL-v3` オープンソースライセンスの下で保護されています。 -2. WeChat(個人アカウント)のデプロイメントには [Gewechat](https://github.com/Devo919/Gewechat) サービスを利用しています。AstrBot は Gewechat との接続を保証するだけであり、アカウントのリスク管理に関しては、このプロジェクトの著者は一切の責任を負いません。 -3. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。 +2. このプロジェクトを使用する際は、現地の法律および規制を遵守してください。 - _私は、高性能ですから!_ - diff --git a/astrbot/cli/__init__.py b/astrbot/cli/__init__.py index 25f7f33c3..8d1eee0b1 100644 --- a/astrbot/cli/__init__.py +++ b/astrbot/cli/__init__.py @@ -1 +1 @@ -__version__ = "3.5.8" +__version__ = "3.5.23" diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index fca150ba8..5b9f14d32 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -117,6 +117,9 @@ def build_plug_list(plugins_dir: Path) -> list: # 从 metadata.yaml 加载元数据 metadata = load_yaml_metadata(plugin_dir) + if "desc" not in metadata and "description" in metadata: + metadata["desc"] = metadata["description"] + # 如果成功加载元数据,添加到结果列表 if metadata and all( k in metadata for k in ["name", "desc", "version", "author", "repo"] diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 40a602d3d..41225e0ee 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -157,15 +157,6 @@ CONFIG_METADATA_2 = { "ws_reverse_port": 6199, "ws_reverse_token": "", }, - "微信个人号(Gewechat)": { - "id": "gwchat", - "type": "gewechat", - "enable": False, - "base_url": "http://localhost:2531", - "nickname": "soulter", - "host": "这里填写你的局域网IP或者公网服务器IP", - "port": 11451, - }, "微信个人号(WeChatPadPro)": { "id": "wechatpadpro", "type": "wechatpadpro", @@ -823,6 +814,19 @@ CONFIG_METADATA_2 = { "variables": {}, "timeout": 60, }, + "ModelScope": { + "id": "modelscope", + "provider": "modelscope", + "type": "openai_chat_completion", + "provider_type": "chat_completion", + "enable": True, + "key": [], + "timeout": 120, + "api_base": "https://api-inference.modelscope.cn/v1", + "model_config": { + "model": "Qwen/Qwen3-32B", + }, + }, "FastGPT": { "id": "fastgpt", "provider": "fastgpt", @@ -1021,7 +1025,7 @@ CONFIG_METADATA_2 = { "embedding_api_key": "", "embedding_api_base": "", "embedding_model": "", - "embedding_dimensions": 1536, + "embedding_dimensions": 1024, "timeout": 20, }, "Gemini Embedding": { diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 9b78eaec6..3a1c50371 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -96,8 +96,6 @@ class LogBroker: Queue: 订阅者的队列, 可用于接收日志消息 """ q = Queue(maxsize=CACHED_SIZE + 10) - for log in self.log_cache: - q.put_nowait(log) self.subscribers.append(q) return q diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 1344d4dec..ae44cf36d 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -173,7 +173,9 @@ class LLMRequestSubStage(Stage): event=event, pipeline_ctx=self.ctx, ) - logger.debug(f"handle provider[id: {provider.provider_config['id']}] request: {req}") + logger.debug( + f"handle provider[id: {provider.provider_config['id']}] request: {req}" + ) await tool_loop_agent.reset(req=req, streaming=self.streaming_response) async def requesting(): @@ -229,7 +231,7 @@ class LLMRequestSubStage(Stage): logger.error(traceback.format_exc()) event.set_result( MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n" ) ) return diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 50b436043..54ad1e63b 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map from astrbot.core.utils.path_util import path_Mapping +from astrbot.core.utils.session_lock import session_lock_manager @register_stage @@ -177,25 +178,26 @@ class RespondStage(Stage): result.chain.remove(comp) break - for rcomp in record_comps: - i = await self._calc_comp_interval(rcomp) - await asyncio.sleep(i) - try: - await event.send(MessageChain([rcomp])) - except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") - break - - # 分段回复 - for comp in non_record_comps: - i = await self._calc_comp_interval(comp) - await asyncio.sleep(i) - try: - await event.send(MessageChain([*decorated_comps, comp])) - decorated_comps = [] # 清空已发送的装饰组件 - except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") - break + # leverage lock to guarentee the order of message sending among different events + async with session_lock_manager.acquire_lock(event.unified_msg_origin): + for rcomp in record_comps: + i = await self._calc_comp_interval(rcomp) + await asyncio.sleep(i) + try: + await event.send(MessageChain([rcomp])) + except Exception as e: + logger.error(f"发送消息失败: {e} chain: {result.chain}") + break + # 分段回复 + for comp in non_record_comps: + i = await self._calc_comp_interval(comp) + await asyncio.sleep(i) + try: + await event.send(MessageChain([*decorated_comps, comp])) + decorated_comps = [] # 清空已发送的装饰组件 + except Exception as e: + logger.error(f"发送消息失败: {e} chain: {result.chain}") + break else: for rcomp in record_comps: try: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index e565c13ce..8be20be73 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -228,7 +228,7 @@ class AstrMessageEvent(abc.ABC): ): """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram,qq official 私聊。 - Fallback仅支持 aiocqhttp, gewechat。 + Fallback仅支持 aiocqhttp。 """ asyncio.create_task( Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) @@ -420,7 +420,6 @@ class AstrMessageEvent(abc.ABC): 适配情况: - - gewechat - aiocqhttp(OneBotv11) """ ... diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 63ec27982..23109ca53 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -58,10 +58,6 @@ class PlatformManager: from .sources.qqofficial_webhook.qo_webhook_adapter import ( QQOfficialWebhookPlatformAdapter, # noqa: F401 ) - case "gewechat": - from .sources.gewechat.gewechat_platform_adapter import ( - GewechatPlatformAdapter, # noqa: F401 - ) case "wechatpadpro": from .sources.wechatpadpro.wechatpadpro_adapter import ( WeChatPadProAdapter, # noqa: F401 diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 1991bf393..7329ab603 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -272,8 +272,14 @@ class AiocqhttpAdapter(Platform): ) # 添加必要的 post_type 字段,防止 Event.from_payload 报错 reply_event_data["post_type"] = "message" + new_event = Event.from_payload(reply_event_data) + if not new_event: + logger.error( + f"无法从回复消息数据构造 Event 对象: {reply_event_data}" + ) + continue abm_reply = await self._convert_handle_message_event( - Event.from_payload(reply_event_data), get_reply=False + new_event, get_reply=False ) reply_seg = Reply( diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py deleted file mode 100644 index 5f97a6778..000000000 --- a/astrbot/core/platform/sources/gewechat/client.py +++ /dev/null @@ -1,812 +0,0 @@ -import asyncio -import base64 -import datetime -import os -import re -import uuid -import threading - -import aiohttp -import anyio -import quart - -from astrbot.api import logger, sp -from astrbot.api.message_components import Plain, Image, At, Record, Video -from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType -from astrbot.core.utils.io import download_image_by_url -from .downloader import GeweDownloader -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - -try: - from .xml_data_parser import GeweDataParser -except (ImportError, ModuleNotFoundError) as e: - logger.warning( - f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}" - ) - - -class SimpleGewechatClient: - """针对 Gewechat 的简单实现。 - - @author: Soulter - @website: https://github.com/Soulter - """ - - def __init__( - self, - base_url: str, - nickname: str, - host: str, - port: int, - event_queue: asyncio.Queue, - ): - self.base_url = base_url - if self.base_url.endswith("/"): - self.base_url = self.base_url[:-1] - - self.download_base_url = self.base_url.split(":")[:-1] # 去掉端口 - self.download_base_url = ":".join(self.download_base_url) + ":2532/download/" - - self.base_url += "/v2/api" - - logger.info(f"Gewechat API: {self.base_url}") - logger.info(f"Gewechat 下载 API: {self.download_base_url}") - - if isinstance(port, str): - port = int(port) - - self.token = None - self.headers = {} - self.nickname = nickname - self.appid = sp.get(f"gewechat-appid-{nickname}", "") - - self.server = quart.Quart(__name__) - self.server.add_url_rule( - "/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"] - ) - self.server.add_url_rule( - "/astrbot-gewechat/file/", - view_func=self._handle_file, - methods=["GET"], - ) - - self.host = host - self.port = port - self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback" - self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file" - - self.event_queue = event_queue - - self.multimedia_downloader = None - - self.userrealnames = {} - - self.shutdown_event = asyncio.Event() - - self.staged_files = {} - """存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。""" - - self.lock = asyncio.Lock() - - async def get_token_id(self): - """获取 Gewechat Token。""" - async with aiohttp.ClientSession() as session: - async with session.post(f"{self.base_url}/tools/getTokenId") as resp: - json_blob = await resp.json() - self.token = json_blob["data"] - logger.info(f"获取到 Gewechat Token: {self.token}") - self.headers = {"X-GEWE-TOKEN": self.token} - - async def _convert(self, data: dict) -> AstrBotMessage: - if "TypeName" in data: - type_name = data["TypeName"] - elif "type_name" in data: - type_name = data["type_name"] - else: - raise Exception("无法识别的消息类型") - - # 以下没有业务处理,只是避免控制台打印太多的日志 - if type_name == "ModContacts": - logger.info("gewechat下发:ModContacts消息通知。") - return - if type_name == "DelContacts": - logger.info("gewechat下发:DelContacts消息通知。") - return - - if type_name == "Offline": - logger.critical("收到 gewechat 下线通知。") - return - - d = None - if "Data" in data: - d = data["Data"] - elif "data" in data: - d = data["data"] - - if not d: - logger.warning(f"消息不含 data 字段: {data}") - return - - if "CreateTime" in d: - # 得到系统 UTF+8 的 ts - tz_offset = datetime.timedelta(hours=8) - tz = datetime.timezone(tz_offset) - ts = datetime.datetime.now(tz).timestamp() - create_time = d["CreateTime"] - if create_time < ts - 30: - logger.warning(f"消息时间戳过旧: {create_time},当前时间戳: {ts}") - return - - abm = AstrBotMessage() - - from_user_name = d["FromUserName"]["string"] # 消息来源 - d["to_wxid"] = from_user_name # 用于发信息 - - abm.message_id = str(d.get("MsgId")) - abm.session_id = from_user_name - abm.self_id = data["Wxid"] # 机器人的 wxid - - user_id = "" # 发送人 wxid - content = d["Content"]["string"] # 消息内容 - - at_me = False - at_wxids = [] - if "@chatroom" in from_user_name: - abm.type = MessageType.GROUP_MESSAGE - _t = content.split(":\n") - user_id = _t[0] - content = _t[1] - # at - msg_source = d["MsgSource"] - if "\u2005" in content: - # at - # content = content.split('\u2005')[1] - content = re.sub(r"@[^\u2005]*\u2005", "", content) - at_wxids = re.findall( - r")", - msg_source, - ) - - abm.group_id = from_user_name - - if ( - f"" in msg_source - or f"" in msg_source - ): - at_me = True - if "在群聊中@了你" in d.get("PushContent", ""): - at_me = True - else: - abm.type = MessageType.FRIEND_MESSAGE - user_id = from_user_name - - # 检查消息是否由自己发送,若是则忽略 - # 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。 - # if user_id == abm.self_id: - # logger.info("忽略自己发送的消息") - # return None - - abm.message = [] - - # 解析用户真实名字 - user_real_name = "unknown" - if abm.group_id: - if ( - abm.group_id not in self.userrealnames - or user_id not in self.userrealnames[abm.group_id] - ): - # 获取群成员列表,并且缓存 - if abm.group_id not in self.userrealnames: - self.userrealnames[abm.group_id] = {} - member_list = await self.get_chatroom_member_list(abm.group_id) - logger.debug(f"获取到 {abm.group_id} 的群成员列表。") - if member_list and "memberList" in member_list: - for member in member_list["memberList"]: - self.userrealnames[abm.group_id][member["wxid"]] = member[ - "nickName" - ] - if user_id in self.userrealnames[abm.group_id]: - user_real_name = self.userrealnames[abm.group_id][user_id] - else: - user_real_name = self.userrealnames[abm.group_id][user_id] - else: - try: - info = (await self.get_user_or_group_info(user_id))["data"][0] - user_real_name = info["nickName"] - except Exception as e: - logger.debug(f"获取用户 {user_id} 昵称失败: {e}") - user_real_name = user_id - - if at_me: - abm.message.insert(0, At(qq=abm.self_id, name=self.nickname)) - for wxid in at_wxids: - # 群聊里 At 其他人的列表 - _username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid) - abm.message.append(At(qq=wxid, name=_username)) - - abm.sender = MessageMember(user_id, user_real_name) - abm.raw_message = d - abm.message_str = "" - - if user_id == "weixin": - # 忽略微信团队消息 - return - - # 不同消息类型 - match d["MsgType"]: - case 1: - # 文本消息 - abm.message.append(Plain(content)) - abm.message_str = content - case 3: - # 图片消息 - file_url = await self.multimedia_downloader.download_image( - self.appid, content - ) - logger.debug(f"下载图片: {file_url}") - file_path = await download_image_by_url(file_url) - abm.message.append(Image(file=file_path, url=file_path)) - - case 34: - # 语音消息 - if "ImgBuf" in d and "buffer" in d["ImgBuf"]: - voice_data = base64.b64decode(d["ImgBuf"]["buffer"]) - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - file_path = os.path.join( - temp_dir, f"gewe_voice_{abm.message_id}.silk" - ) - - async with await anyio.open_file(file_path, "wb") as f: - await f.write(voice_data) - abm.message.append(Record(file=file_path, url=file_path)) - - # 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志 - case 37: # 好友申请 - logger.info("消息类型(37):好友申请") - case 42: # 名片 - logger.info("消息类型(42):名片") - case 43: # 视频 - video = Video(file="", cover=content) - abm.message.append(video) - case 47: # emoji - data_parser = GeweDataParser(content, abm.group_id == "") - emoji = data_parser.parse_emoji() - abm.message.append(emoji) - case 48: # 地理位置 - logger.info("消息类型(48):地理位置") - case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请 - data_parser = GeweDataParser(content, abm.group_id == "") - segments = data_parser.parse_mutil_49() - if segments: - abm.message.extend(segments) - for seg in segments: - if isinstance(seg, Plain): - abm.message_str += seg.text - case 51: # 帐号消息同步? - logger.info("消息类型(51):帐号消息同步?") - case 10000: # 被踢出群聊/更换群主/修改群名称 - logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称") - case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办 - logger.info( - "消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办" - ) - - case _: - logger.info(f"未实现的消息类型: {d['MsgType']}") - abm.raw_message = d - - logger.debug(f"abm: {abm}") - return abm - - async def _callback(self): - data = await quart.request.json - logger.debug(f"收到 gewechat 回调: {data}") - - if data.get("testMsg", None): - return quart.jsonify({"r": "AstrBot ACK"}) - - abm = None - try: - abm = await self._convert(data) - except BaseException as e: - logger.warning( - f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}。" - ) - - if abm: - coro = getattr(self, "on_event_received") - if coro: - await coro(abm) - - return quart.jsonify({"r": "AstrBot ACK"}) - - async def _register_file(self, file_path: str) -> str: - """向 AstrBot 回调服务器 注册一个允许外部访问的文件。 - - Args: - file_path (str): 文件路径。 - Returns: - str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。 - """ - async with self.lock: - if not os.path.exists(file_path): - raise Exception(f"文件不存在: {file_path}") - - file_token = str(uuid.uuid4()) - self.staged_files[file_token] = file_path - return file_token - - async def _handle_file(self, file_token): - async with self.lock: - if file_token not in self.staged_files: - logger.warning(f"请求的文件 {file_token} 不存在。") - return quart.abort(404) - if not os.path.exists(self.staged_files[file_token]): - logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。") - return quart.abort(404) - file_path = self.staged_files[file_token] - self.staged_files.pop(file_token, None) - return await quart.send_file(file_path) - - async def _set_callback_url(self): - logger.info("设置回调,请等待...") - await asyncio.sleep(3) - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/tools/setCallback", - headers=self.headers, - json={"token": self.token, "callbackUrl": self.callback_url}, - ) as resp: - json_blob = await resp.json() - logger.info(f"设置回调结果: {json_blob}") - if json_blob["ret"] != 200: - raise Exception(f"设置回调失败: {json_blob}") - logger.info( - f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。如果仍没收到请到管理面板聊天页输入 /gewe_logout 重新登录。" - ) - - async def start_polling(self): - threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start() - await self.server.run_task( - host="0.0.0.0", - port=self.port, - shutdown_trigger=self.shutdown_trigger, - ) - - async def shutdown_trigger(self): - await self.shutdown_event.wait() - - async def check_online(self, appid: str): - """检查 APPID 对应的设备是否在线。""" - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/login/checkOnline", - headers=self.headers, - json={"appId": appid}, - ) as resp: - json_blob = await resp.json() - return json_blob["data"] - - async def logout(self): - """登出 gewechat。""" - if self.appid: - online = await self.check_online(self.appid) - if online: - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/login/logout", - headers=self.headers, - json={"appId": self.appid}, - ) as resp: - json_blob = await resp.json() - logger.info(f"登出结果: {json_blob}") - - async def login(self): - """登录 gewechat。一般来说插件用不到这个方法。""" - if self.token is None: - await self.get_token_id() - - self.multimedia_downloader = GeweDownloader( - self.base_url, self.download_base_url, self.token - ) - - if self.appid: - try: - online = await self.check_online(self.appid) - if online: - logger.info(f"APPID: {self.appid} 已在线") - return - except Exception as e: - logger.error(f"检查在线状态失败: {e}") - sp.put(f"gewechat-appid-{self.nickname}", "") - self.appid = None - - payload = {"appId": self.appid} - - if self.appid: - logger.info(f"使用 APPID: {self.appid}, {self.nickname}") - - try: - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/login/getLoginQrCode", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - if json_blob["ret"] != 200: - error_msg = json_blob.get("data", {}).get("msg", "") - if "设备不存在" in error_msg: - logger.error( - f"检测到无效的appid: {self.appid},将清除并重新登录。" - ) - sp.put(f"gewechat-appid-{self.nickname}", "") - self.appid = None - return await self.login() - else: - raise Exception(f"获取二维码失败: {json_blob}") - qr_data = json_blob["data"]["qrData"] - qr_uuid = json_blob["data"]["uuid"] - appid = json_blob["data"]["appId"] - logger.info(f"APPID: {appid}") - logger.warning( - f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}" - ) - except Exception as e: - raise e - - # 执行登录 - retry_cnt = 64 - payload.update({"uuid": qr_uuid, "appId": appid}) - while retry_cnt > 0: - retry_cnt -= 1 - - # 需要验证码 - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - code_file_path = os.path.join(temp_dir, "gewe_code") - if os.path.exists(code_file_path): - with open(code_file_path, "r") as f: - code = f.read().strip() - if not code: - logger.warning( - "未找到验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456" - ) - await asyncio.sleep(5) - continue - payload["captchCode"] = code - logger.info(f"使用验证码: {code}") - try: - os.remove(code_file_path) - except Exception: - logger.warning(f"删除验证码文件 {code_file_path} 失败。") - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/login/checkLogin", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.info(f"检查登录状态: {json_blob}") - - ret = json_blob["ret"] - msg = "" - if json_blob["data"] and "msg" in json_blob["data"]: - msg = json_blob["data"]["msg"] - if ret == 500 and "安全验证码" in msg: - logger.warning( - "此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456" - ) - else: - if "status" in json_blob["data"]: - status = json_blob["data"]["status"] - nickname = json_blob["data"].get("nickName", "") - if status == 1: - logger.info(f"等待确认...{nickname}") - elif status == 2: - logger.info(f"绿泡泡平台登录成功: {nickname}") - break - elif status == 0: - logger.info("等待扫码...") - else: - logger.warning(f"未知状态: {status}") - await asyncio.sleep(5) - - if appid: - sp.put(f"gewechat-appid-{self.nickname}", appid) - self.appid = appid - logger.info(f"已保存 APPID: {appid}") - - """API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1 - """ - - async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict: - """获取群成员列表。 - - Args: - chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。 - - Returns: - dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。 - """ - payload = {"appId": self.appid, "chatroomId": chatroom_wxid} - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/group/getChatroomMemberList", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - return json_blob["data"] - - async def post_text(self, to_wxid, content: str, ats: str = ""): - """发送纯文本消息""" - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "content": content, - } - if ats: - payload["ats"] = ats - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/postText", headers=self.headers, json=payload - ) as resp: - json_blob = await resp.json() - logger.debug(f"发送消息结果: {json_blob}") - - async def post_image(self, to_wxid, image_url: str): - """发送图片消息""" - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "imgUrl": image_url, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/postImage", headers=self.headers, json=payload - ) as resp: - json_blob = await resp.json() - logger.debug(f"发送图片结果: {json_blob}") - - async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""): - """发送emoji消息""" - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "emojiMd5": emoji_md5, - "emojiSize": emoji_size, - } - - # 优先表情包,若拿不到表情包的md5,就用当作图片发 - try: - if emoji_md5 != "" and emoji_size != "": - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/postEmoji", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.info( - f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}" - ) - else: - await self.post_image(to_wxid, cdnurl) - - except Exception as e: - logger.error(e) - - async def post_video( - self, to_wxid, video_url: str, thumb_url: str, video_duration: int - ): - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "videoUrl": video_url, - "thumbUrl": thumb_url, - "videoDuration": video_duration, - } - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/postVideo", headers=self.headers, json=payload - ) as resp: - json_blob = await resp.json() - logger.debug(f"发送视频结果: {json_blob}") - - async def forward_video(self, to_wxid, cnd_xml: str): - """转发视频 - - Args: - to_wxid (str): 发送给谁 - cnd_xml (str): 视频消息的cdn信息 - """ - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "xml": cnd_xml, - } - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/forwardVideo", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"转发视频结果: {json_blob}") - - async def post_voice(self, to_wxid, voice_url: str, voice_duration: int): - """发送语音信息 - - Args: - voice_url (str): 语音文件的网络链接 - voice_duration (int): 语音时长,毫秒 - """ - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "voiceUrl": voice_url, - "voiceDuration": voice_duration, - } - - logger.debug(f"发送语音: {payload}") - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/postVoice", headers=self.headers, json=payload - ) as resp: - json_blob = await resp.json() - logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}") - - async def post_file(self, to_wxid, file_url: str, file_name: str): - """发送文件 - - Args: - to_wxid (string): 微信ID - file_url (str): 文件的网络链接 - file_name (str): 文件名 - """ - payload = { - "appId": self.appid, - "toWxid": to_wxid, - "fileUrl": file_url, - "fileName": file_name, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/message/postFile", headers=self.headers, json=payload - ) as resp: - json_blob = await resp.json() - logger.debug(f"发送文件结果: {json_blob}") - - async def add_friend(self, v3: str, v4: str, content: str): - """申请添加好友""" - payload = { - "appId": self.appid, - "scene": 3, - "content": content, - "v4": v4, - "v3": v3, - "option": 2, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/contacts/addContacts", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"申请添加好友结果: {json_blob}") - return json_blob - - async def get_group(self, group_id: str): - payload = { - "appId": self.appid, - "chatroomId": group_id, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/group/getChatroomInfo", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"获取群信息结果: {json_blob}") - return json_blob - - async def get_group_member(self, group_id: str): - payload = { - "appId": self.appid, - "chatroomId": group_id, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/group/getChatroomMemberList", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"获取群信息结果: {json_blob}") - return json_blob - - async def accept_group_invite(self, url: str): - """同意进群""" - payload = {"appId": self.appid, "url": url} - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/group/agreeJoinRoom", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"获取群信息结果: {json_blob}") - return json_blob - - async def add_group_member_to_friend( - self, group_id: str, to_wxid: str, content: str - ): - payload = { - "appId": self.appid, - "chatroomId": group_id, - "content": content, - "memberWxid": to_wxid, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/group/addGroupMemberAsFriend", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"获取群信息结果: {json_blob}") - return json_blob - - async def get_user_or_group_info(self, *ids): - """ - 获取用户或群组信息。 - - :param ids: 可变数量的 wxid 参数 - """ - - wxids_str = list(ids) - - payload = { - "appId": self.appid, - "wxids": wxids_str, # 使用逗号分隔的字符串 - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/contacts/getDetailInfo", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"获取群信息结果: {json_blob}") - return json_blob - - async def get_contacts_list(self): - """ - 获取通讯录列表 - 见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504 - """ - payload = {"appId": self.appid} - - async with aiohttp.ClientSession() as session: - async with session.post( - f"{self.base_url}/contacts/fetchContactsList", - headers=self.headers, - json=payload, - ) as resp: - json_blob = await resp.json() - logger.debug(f"获取通讯录列表结果: {json_blob}") - return json_blob diff --git a/astrbot/core/platform/sources/gewechat/downloader.py b/astrbot/core/platform/sources/gewechat/downloader.py deleted file mode 100644 index 01c89fd28..000000000 --- a/astrbot/core/platform/sources/gewechat/downloader.py +++ /dev/null @@ -1,55 +0,0 @@ -from astrbot import logger -import aiohttp -import json - - -class GeweDownloader: - def __init__(self, base_url: str, download_base_url: str, token: str): - self.base_url = base_url - self.download_base_url = download_base_url - self.headers = {"Content-Type": "application/json", "X-GEWE-TOKEN": token} - - async def _post_json(self, baseurl: str, route: str, payload: dict): - async with aiohttp.ClientSession() as session: - async with session.post( - f"{baseurl}{route}", headers=self.headers, json=payload - ) as resp: - return await resp.read() - - async def download_voice(self, appid: str, xml: str, msg_id: str): - payload = {"appId": appid, "xml": xml, "msgId": msg_id} - return await self._post_json(self.base_url, "/message/downloadVoice", payload) - - async def download_image(self, appid: str, xml: str) -> str: - """返回一个可下载的 URL""" - choices = [2, 3] # 2:常规图片 3:缩略图 - - for choice in choices: - try: - payload = {"appId": appid, "xml": xml, "type": choice} - data = await self._post_json( - self.base_url, "/message/downloadImage", payload - ) - json_blob = json.loads(data) - if "fileUrl" in json_blob["data"]: - return self.download_base_url + json_blob["data"]["fileUrl"] - - except BaseException as e: - logger.error(f"gewe download image: {e}") - continue - - raise Exception("无法下载图片") - - async def download_emoji_md5(self, app_id, emoji_md5): - """下载emoji""" - try: - payload = {"appId": app_id, "emojiMd5": emoji_md5} - - # gewe 计划中的接口,暂时没有实现。返回代码404 - data = await self._post_json( - self.base_url, "/message/downloadEmojiMd5", payload - ) - json_blob = json.loads(data) - return json_blob - except BaseException as e: - logger.error(f"gewe download emoji: {e}") diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py deleted file mode 100644 index f549d9ece..000000000 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ /dev/null @@ -1,264 +0,0 @@ -import asyncio -import re -import wave -import uuid -import traceback -import os - -from typing import AsyncGenerator -from astrbot.core.utils.io import download_file -from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk -from astrbot.api import logger -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember -from astrbot.api.message_components import ( - Plain, - Image, - Record, - At, - File, - Video, - WechatEmoji as Emoji, -) -from .client import SimpleGewechatClient -from astrbot.core.utils.astrbot_path import get_astrbot_data_path - - -def get_wav_duration(file_path): - with wave.open(file_path, "rb") as wav_file: - file_size = os.path.getsize(file_path) - n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4] - if n_frames == 2147483647: - duration = (file_size - 44) / (n_channels * sampwidth * framerate) - elif n_frames == 0: - duration = (file_size - 44) / (n_channels * sampwidth * framerate) - else: - duration = n_frames / float(framerate) - return duration - - -class GewechatPlatformEvent(AstrMessageEvent): - def __init__( - self, - message_str: str, - message_obj: AstrBotMessage, - platform_meta: PlatformMetadata, - session_id: str, - client: SimpleGewechatClient, - ): - super().__init__(message_str, message_obj, platform_meta, session_id) - self.client = client - - @staticmethod - async def send_with_client( - message: MessageChain, to_wxid: str, client: SimpleGewechatClient - ): - if not to_wxid: - logger.error("无法获取到 to_wxid。") - return - - # 检查@ - ats = [] - ats_names = [] - for comp in message.chain: - if isinstance(comp, At): - ats.append(comp.qq) - ats_names.append(comp.name) - has_at = False - - for comp in message.chain: - if isinstance(comp, Plain): - text = comp.text - payload = { - "to_wxid": to_wxid, - "content": text, - } - if not has_at and ats: - ats = f"{','.join(ats)}" - ats_names = f"@{' @'.join(ats_names)}" - text = f"{ats_names} {text}" - payload["content"] = text - payload["ats"] = ats - has_at = True - await client.post_text(**payload) - - elif isinstance(comp, Image): - img_path = await comp.convert_to_file_path() - # 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token - token = await client._register_file(img_path) - img_url = f"{client.file_server_url}/{token}" - logger.debug(f"gewe callback img url: {img_url}") - await client.post_image(to_wxid, img_url) - elif isinstance(comp, Video): - if comp.cover != "": - await client.forward_video(to_wxid, comp.cover) - else: - try: - from pyffmpeg import FFmpeg - except (ImportError, ModuleNotFoundError): - logger.error( - "需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg" - ) - raise ModuleNotFoundError( - "需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg" - ) - - video_url = comp.file - # 根据 url 下载视频 - if video_url.startswith("http"): - video_filename = f"{uuid.uuid4()}.mp4" - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - video_path = os.path.join(temp_dir, video_filename) - await download_file(video_url, video_path) - else: - video_path = video_url - - video_token = await client._register_file(video_path) - video_callback_url = f"{client.file_server_url}/{video_token}" - - # 获取视频第一帧 - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - thumb_path = os.path.join( - temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg" - ) - - video_path = video_path.replace(" ", "\\ ") - try: - ff = FFmpeg() - command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}" - ff.options(command) - thumb_token = await client._register_file(thumb_path) - thumb_url = f"{client.file_server_url}/{thumb_token}" - except Exception as e: - logger.error(f"获取视频第一帧失败: {e}") - - # 获取视频时长 - try: - from pyffmpeg import FFprobe - - # 创建 FFprobe 实例 - ffprobe = FFprobe(video_url) - # 获取时长字符串 - duration_str = ffprobe.duration - # 处理时长字符串 - video_duration = float(duration_str.replace(":", "")) - except Exception as e: - logger.error(f"获取时长失败: {e}") - video_duration = 10 - - # 发送视频 - await client.post_video( - to_wxid, video_callback_url, thumb_url, video_duration - ) - - # 删除临时缩略图文件 - if os.path.exists(thumb_path): - os.remove(thumb_path) - elif isinstance(comp, Record): - # 默认已经存在 data/temp 中 - record_url = comp.file - record_path = await comp.convert_to_file_path() - - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk") - try: - duration = await wav_to_tencent_silk(record_path, silk_path) - except Exception as e: - logger.error(traceback.format_exc()) - await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}") - logger.info("Silk 语音文件格式转换至: " + record_path) - if duration == 0: - duration = get_wav_duration(record_path) - token = await client._register_file(silk_path) - record_url = f"{client.file_server_url}/{token}" - logger.debug(f"gewe callback record url: {record_url}") - await client.post_voice(to_wxid, record_url, duration * 1000) - elif isinstance(comp, File): - file_path = comp.file - file_name = comp.name - if file_path.startswith("file:///"): - file_path = file_path[8:] - elif file_path.startswith("http"): - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - temp_file_path = os.path.join(temp_dir, file_name) - await download_file(file_path, temp_file_path) - file_path = temp_file_path - else: - file_path = file_path - - token = await client._register_file(file_path) - file_url = f"{client.file_server_url}/{token}" - logger.debug(f"gewe callback file url: {file_url}") - await client.post_file(to_wxid, file_url, file_name) - elif isinstance(comp, Emoji): - await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl) - elif isinstance(comp, At): - pass - else: - logger.debug(f"gewechat 忽略: {comp.type}") - - async def send(self, message: MessageChain): - to_wxid = self.message_obj.raw_message.get("to_wxid", None) - await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client) - await super().send(message) - - async def get_group(self, group_id=None, **kwargs): - # 确定有效的 group_id - if group_id is None: - group_id = self.get_group_id() - - if not group_id: - return None - - res = await self.client.get_group(group_id) - data: dict = res["data"] - - if not data["chatroomId"]: - return None - - members = [ - MessageMember(user_id=member["wxid"], nickname=member["nickName"]) - for member in data.get("memberList", []) - ] - - return Group( - group_id=data["chatroomId"], - group_name=data.get("nickName"), - group_avatar=data.get("smallHeadImgUrl"), - group_owner=data.get("chatRoomOwner"), - members=members, - ) - - async def send_streaming( - self, generator: AsyncGenerator, use_fallback: bool = False - ): - if not use_fallback: - buffer = None - async for chain in generator: - if not buffer: - buffer = chain - else: - buffer.chain.extend(chain.chain) - if not buffer: - return - buffer.squash_plain() - await self.send(buffer) - return await super().send_streaming(generator, use_fallback) - - buffer = "" - pattern = re.compile(r"[^。?!~…]+[。?!~…]+") - - async for chain in generator: - if isinstance(chain, MessageChain): - for comp in chain.chain: - if isinstance(comp, Plain): - buffer += comp.text - if any(p in buffer for p in "。?!~…"): - buffer = await self.process_buffer(buffer, pattern) - else: - await self.send(MessageChain(chain=[comp])) - await asyncio.sleep(1.5) # 限速 - - if buffer.strip(): - await self.send(MessageChain([Plain(buffer)])) - return await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py deleted file mode 100644 index 7d8dddfca..000000000 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ /dev/null @@ -1,103 +0,0 @@ -import sys -import asyncio -import os - -from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata -from astrbot.api.event import MessageChain -from astrbot.core.platform.astr_message_event import MessageSesion -from ...register import register_platform_adapter -from .gewechat_event import GewechatPlatformEvent -from .client import SimpleGewechatClient -from astrbot import logger - -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - - -@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器") -class GewechatPlatformAdapter(Platform): - def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue - ) -> None: - super().__init__(event_queue) - self.config = platform_config - self.settingss = platform_settings - self.test_mode = os.environ.get("TEST_MODE", "off") == "on" - self.client = None - - self.client = SimpleGewechatClient( - self.config["base_url"], - self.config["nickname"], - self.config["host"], - self.config["port"], - self._event_queue, - ) - - async def on_event_received(abm: AstrBotMessage): - await self.handle_msg(abm) - - self.client.on_event_received = on_event_received - - @override - async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain - ): - session_id = session.session_id - if "#" in session_id: - # unique session - to_wxid = session_id.split("#")[1] - else: - to_wxid = session_id - - await GewechatPlatformEvent.send_with_client( - message_chain, to_wxid, self.client - ) - - await super().send_by_session(session, message_chain) - - @override - def meta(self) -> PlatformMetadata: - return PlatformMetadata( - name="gewechat", - description="基于 gewechat 的 Wechat 适配器", - id=self.config.get("id"), - ) - - async def terminate(self): - self.client.shutdown_event.set() - try: - await self.client.server.shutdown() - except Exception as _: - pass - logger.info("Gewechat 适配器已被优雅地关闭。") - - async def logout(self): - await self.client.logout() - - @override - def run(self): - return self._run() - - async def _run(self): - await self.client.login() - await self.client.start_polling() - - async def handle_msg(self, message: AstrBotMessage): - if message.type == MessageType.GROUP_MESSAGE: - if self.settingss["unique_session"]: - message.session_id = message.sender.user_id + "#" + message.group_id - - message_event = GewechatPlatformEvent( - message_str=message.message_str, - message_obj=message, - platform_meta=self.meta(), - session_id=message.session_id, - client=self.client, - ) - - self.commit_event(message_event) - - def get_client(self) -> SimpleGewechatClient: - return self.client diff --git a/astrbot/core/platform/sources/gewechat/xml_data_parser.py b/astrbot/core/platform/sources/gewechat/xml_data_parser.py deleted file mode 100644 index 1af4a051a..000000000 --- a/astrbot/core/platform/sources/gewechat/xml_data_parser.py +++ /dev/null @@ -1,110 +0,0 @@ -from defusedxml import ElementTree as eT -from astrbot.api import logger -from astrbot.api.message_components import ( - WechatEmoji as Emoji, - Reply, - Plain, - BaseMessageComponent, -) - - -class GeweDataParser: - def __init__(self, data, is_private_chat): - self.data = data - self.is_private_chat = is_private_chat - - def _format_to_xml(self): - return eT.fromstring(self.data) - - def parse_mutil_49(self) -> list[BaseMessageComponent] | None: - appmsg_type = self._format_to_xml().find(".//appmsg/type") - if appmsg_type is None: - return - - match appmsg_type.text: - case "57": - return self.parse_reply() - - def parse_emoji(self) -> Emoji | None: - try: - emoji_element = self._format_to_xml().find(".//emoji") - # 提取 md5 和 len 属性 - if emoji_element is not None: - md5_value = emoji_element.get("md5") - emoji_size = emoji_element.get("len") - cdnurl = emoji_element.get("cdnurl") - - return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl) - - except Exception as e: - logger.error(f"gewechat: parse_emoji failed, {e}") - - def parse_reply(self) -> list[Reply, Plain] | None: - """解析引用消息 - - Returns: - list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。 - """ - try: - replied_id = -1 - replied_uid = 0 - replied_nickname = "" - replied_content = "" # 被引用者说的内容 - content = "" # 引用者说的内容 - - root = self._format_to_xml() - refermsg = root.find(".//refermsg") - if refermsg is not None: - # 被引用的信息 - svrid = refermsg.find("svrid") - fromusr = refermsg.find("fromusr") - displayname = refermsg.find("displayname") - refermsg_content = refermsg.find("content") - if svrid is not None: - replied_id = svrid.text - if fromusr is not None: - replied_uid = fromusr.text - if displayname is not None: - replied_nickname = displayname.text - if refermsg_content is not None: - # 处理引用嵌套,包括嵌套公众号消息 - if refermsg_content.text.startswith( - "" - ) or refermsg_content.text.startswith(" dict: + """准备配置,处理嵌套格式""" + if "mcpServers" in config and config["mcpServers"]: + first_key = next(iter(config["mcpServers"])) + config = config["mcpServers"][first_key] + config.pop("active", None) + return config + + +async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: + """快速测试 MCP 服务器可达性""" + import aiohttp + + cfg = _prepare_config(config.copy()) + + url = cfg["url"] + headers = cfg.get("headers", {}) + timeout = cfg.get("timeout", 10) + + try: + async with aiohttp.ClientSession() as session: + if cfg.get("transport") == "streamable_http": + test_payload = { + "jsonrpc": "2.0", + "method": "initialize", + "id": 0, + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.2.3"}, + }, + } + async with session.post( + url, + headers={ + **headers, + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + json=test_payload, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 200: + return True, "" + else: + return False, f"HTTP {response.status}: {response.reason}" + else: + async with session.get( + url, + headers={ + **headers, + "Accept": "application/json, text/event-stream", + }, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 200: + return True, "" + else: + return False, f"HTTP {response.status}: {response.reason}" + + except asyncio.TimeoutError: + return False, f"连接超时: {timeout}秒" + except Exception as e: + return False, f"{e!s}" + + @dataclass class FuncTool: """ @@ -80,12 +146,10 @@ class FuncTool: if not self.mcp_client or not self.mcp_client.session: raise Exception(f"MCP client for {self.name} is not available") # 使用name属性而不是额外的mcp_tool_name - if ":" in self.name: - # 如果名字是格式为 mcp:server:tool_name,提取实际的工具名 - actual_tool_name = self.name.split(":")[-1] - return await self.mcp_client.session.call_tool(actual_tool_name, args) - else: - return await self.mcp_client.session.call_tool(self.name, args) + actual_tool_name = ( + self.name.split(":")[-1] if ":" in self.name else self.name + ) + return await self.mcp_client.session.call_tool(actual_tool_name, args) else: raise Exception(f"Unknown function origin: {self.origin}") @@ -100,6 +164,7 @@ class MCPClient: self.active: bool = True self.tools: List[mcp.Tool] = [] self.server_errlogs: List[str] = [] + self.running_event = asyncio.Event() async def connect_to_server(self, mcp_server_config: dict, name: str): """连接到 MCP 服务器 @@ -112,17 +177,19 @@ class MCPClient: Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ - cfg = mcp_server_config.copy() - if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0: - key_0 = list(cfg["mcpServers"].keys())[0] - cfg = cfg["mcpServers"][key_0] - cfg.pop("active", None) # Remove active flag from config + cfg = _prepare_config(mcp_server_config.copy()) + + def logging_callback(msg: str): + # 处理 MCP 服务的错误日志 + print(f"MCP Server {name} Error: {msg}") + self.server_errlogs.append(msg) if "url" in cfg: - is_sse = True - if cfg.get("transport") == "streamable_http": - is_sse = False - if is_sse: + success, error_msg = await _quick_test_mcp_connection(cfg) + if not success: + raise Exception(error_msg) + + if cfg.get("transport") != "streamable_http": # SSE transport method self._streams_context = sse_client( url=cfg["url"], @@ -130,11 +197,18 @@ class MCPClient: timeout=cfg.get("timeout", 5), sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), ) - streams = await self._streams_context.__aenter__() + streams = await self.exit_stack.enter_async_context( + self._streams_context + ) # Create a new client session + read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(*streams) + mcp.ClientSession( + *streams, + read_timeout_seconds=read_timeout, + logging_callback=logging_callback, # type: ignore + ) ) else: timeout = timedelta(seconds=cfg.get("timeout", 30)) @@ -148,11 +222,19 @@ class MCPClient: sse_read_timeout=sse_read_timeout, terminate_on_close=cfg.get("terminate_on_close", True), ) - read_s, write_s, _ = await self._streams_context.__aenter__() + read_s, write_s, _ = await self.exit_stack.enter_async_context( + self._streams_context + ) # Create a new client session + read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(read_stream=read_s, write_stream=write_s) + mcp.ClientSession( + read_stream=read_s, + write_stream=write_s, + read_timeout_seconds=read_timeout, + logging_callback=logging_callback, # type: ignore + ) ) else: @@ -172,7 +254,7 @@ class MCPClient: logger=logger, identifier=f"MCPServer-{name}", callback=callback, - ), + ), # type: ignore ), ) @@ -180,19 +262,18 @@ class MCPClient: self.session = await self.exit_stack.enter_async_context( mcp.ClientSession(*stdio_transport) ) - await self.session.initialize() async def list_tools_and_save(self) -> mcp.ListToolsResult: """List all tools from the server and save them to self.tools""" response = await self.session.list_tools() - logger.debug(f"MCP server {self.name} list tools response: {response}") self.tools = response.tools return response async def cleanup(self): """Clean up resources""" await self.exit_stack.aclose() + self.running_event.set() # Set the running event to indicate cleanup is done class FuncCall: @@ -201,8 +282,6 @@ class FuncCall: """内部加载的 func tools""" self.mcp_client_dict: Dict[str, MCPClient] = {} """MCP 服务列表""" - self.mcp_service_queue = asyncio.Queue() - """用于外部控制 MCP 服务的启停""" self.mcp_client_event: Dict[str, asyncio.Event] = {} def empty(self) -> bool: @@ -258,7 +337,7 @@ class FuncCall: return f return None - async def _init_mcp_clients(self) -> None: + async def init_mcp_clients(self) -> None: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: ``` { @@ -300,113 +379,64 @@ class FuncCall: ) self.mcp_client_event[name] = event - async def mcp_service_selector(self): - """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制 - - 使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下: - - {"type": "init"} 初始化所有MCP客户端 - - {"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端 - - {"type": "terminate"} 终止所有MCP客户端 - - {"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端 - """ - while True: - data = await self.mcp_service_queue.get() - if data["type"] == "init": - if "name" in data: - event = asyncio.Event() - asyncio.create_task( - self._init_mcp_client_task_wrapper( - data["name"], data["cfg"], event - ) - ) - self.mcp_client_event[data["name"]] = event - else: - await self._init_mcp_clients() - elif data["type"] == "terminate": - if "name" in data: - # await self._terminate_mcp_client(data["name"]) - if data["name"] in self.mcp_client_event: - self.mcp_client_event[data["name"]].set() - self.mcp_client_event.pop(data["name"], None) - self.func_list = [ - f - for f in self.func_list - if not ( - f.origin == "mcp" and f.mcp_server_name == data["name"] - ) - ] - else: - for name in self.mcp_client_dict.keys(): - # await self._terminate_mcp_client(name) - # self.mcp_client_event[name].set() - if name in self.mcp_client_event: - self.mcp_client_event[name].set() - self.mcp_client_event.pop(name, None) - self.func_list = [f for f in self.func_list if f.origin != "mcp"] - async def _init_mcp_client_task_wrapper( - self, name: str, cfg: dict, event: asyncio.Event + self, + name: str, + cfg: dict, + event: asyncio.Event, + ready_future: asyncio.Future = None, ) -> None: """初始化 MCP 客户端的包装函数,用于捕获异常""" try: await self._init_mcp_client(name, cfg) + tools = await self.mcp_client_dict[name].list_tools_and_save() + if ready_future and not ready_future.done(): + # tell the caller we are ready + ready_future.set_result(tools) await event.wait() logger.info(f"收到 MCP 客户端 {name} 终止信号") - await self._terminate_mcp_client(name) except Exception as e: - import traceback - - traceback.print_exc() - logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") + logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) + if ready_future and not ready_future.done(): + ready_future.set_exception(e) + finally: + # 无论如何都能清理 + await self._terminate_mcp_client(name) async def _init_mcp_client(self, name: str, config: dict) -> None: """初始化单个MCP客户端""" - try: - # 先清理之前的客户端,如果存在 - if name in self.mcp_client_dict: - await self._terminate_mcp_client(name) + # 先清理之前的客户端,如果存在 + if name in self.mcp_client_dict: + await self._terminate_mcp_client(name) - mcp_client = MCPClient() - mcp_client.name = name - self.mcp_client_dict[name] = mcp_client - await mcp_client.connect_to_server(config, name) - tools_res = await mcp_client.list_tools_and_save() - tool_names = [tool.name for tool in tools_res.tools] + mcp_client = MCPClient() + mcp_client.name = name + self.mcp_client_dict[name] = mcp_client + await mcp_client.connect_to_server(config, name) + tools_res = await mcp_client.list_tools_and_save() + logger.debug(f"MCP server {name} list tools response: {tools_res}") + tool_names = [tool.name for tool in tools_res.tools] - # 移除该MCP服务之前的工具(如有) - self.func_list = [ - f - for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == name) - ] + # 移除该MCP服务之前的工具(如有) + self.func_list = [ + f + for f in self.func_list + if not (f.origin == "mcp" and f.mcp_server_name == name) + ] - # 将 MCP 工具转换为 FuncTool 并添加到 func_list - for tool in mcp_client.tools: - func_tool = FuncTool( - name=tool.name, - parameters=tool.inputSchema, - description=tool.description, - origin="mcp", - mcp_server_name=name, - mcp_client=mcp_client, - ) - self.func_list.append(func_tool) + # 将 MCP 工具转换为 FuncTool 并添加到 func_list + for tool in mcp_client.tools: + func_tool = FuncTool( + name=tool.name, + parameters=tool.inputSchema, + description=tool.description, + origin="mcp", + mcp_server_name=name, + mcp_client=mcp_client, + ) + self.func_list.append(func_tool) - logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") - return - except Exception as e: - import traceback - - logger.error(traceback.format_exc()) - logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") - # 发生错误时确保客户端被清理 - if name in self.mcp_client_dict: - await self._terminate_mcp_client(name) - return + logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") async def _terminate_mcp_client(self, name: str) -> None: """关闭并清理MCP客户端""" @@ -414,9 +444,9 @@ class FuncCall: try: # 关闭MCP连接 await self.mcp_client_dict[name].cleanup() - del self.mcp_client_dict[name] + self.mcp_client_dict.pop(name) except Exception as e: - logger.info(f"清空 MCP 客户端资源 {name}: {e}。") + logger.error(f"清空 MCP 客户端资源 {name}: {e}。") # 移除关联的FuncTool self.func_list = [ f @@ -425,6 +455,103 @@ class FuncCall: ] logger.info(f"已关闭 MCP 服务 {name}") + @staticmethod + async def test_mcp_server_connection(config: dict) -> list[str]: + if "url" in config: + success, error_msg = await _quick_test_mcp_connection(config) + if not success: + raise Exception(error_msg) + + mcp_client = MCPClient() + try: + logger.debug(f"testing MCP server connection with config: {config}") + await mcp_client.connect_to_server(config, "test") + tools_res = await mcp_client.list_tools_and_save() + tool_names = [tool.name for tool in tools_res.tools] + finally: + logger.debug("Cleaning up MCP client after testing connection.") + await mcp_client.cleanup() + return tool_names + + async def enable_mcp_server( + self, + name: str, + config: dict, + event: asyncio.Event | None = None, + ready_future: asyncio.Future | None = None, + timeout: int = 30, + ) -> None: + """Enable_mcp_server a new MCP server to the manager and initialize it. + + Args: + name (str): The name of the MCP server. + config (dict): Configuration for the MCP server. + event (asyncio.Event): Event to signal when the MCP client is ready. + ready_future (asyncio.Future): Future to signal when the MCP client is ready. + timeout (int): Timeout for the initialization. + Raises: + TimeoutError: If the initialization does not complete within the specified timeout. + Exception: If there is an error during initialization. + """ + if not event: + event = asyncio.Event() + if not ready_future: + ready_future = asyncio.Future() + if name in self.mcp_client_dict: + return + asyncio.create_task( + self._init_mcp_client_task_wrapper(name, config, event, ready_future) + ) + try: + await asyncio.wait_for(ready_future, timeout=timeout) + finally: + self.mcp_client_event[name] = event + + if ready_future.done() and ready_future.exception(): + exc = ready_future.exception() + if exc is not None: + raise exc + + async def disable_mcp_server( + self, name: str | None = None, timeout: float = 10 + ) -> None: + """Disable an MCP server by its name. + + Args: + name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. + timeout (int): Timeout. + """ + if name: + if name not in self.mcp_client_event: + return + client = self.mcp_client_dict.get(name) + self.mcp_client_event[name].set() + if not client: + return + client_running_event = client.running_event + try: + await asyncio.wait_for(client_running_event.wait(), timeout=timeout) + finally: + self.mcp_client_event.pop(name, None) + self.func_list = [ + f + for f in self.func_list + if f.origin != "mcp" or f.mcp_server_name != name + ] + else: + running_events = [ + client.running_event.wait() for client in self.mcp_client_dict.values() + ] + for key, event in self.mcp_client_event.items(): + event.set() + # waiting for all clients to finish + try: + await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout) + finally: + self.mcp_client_event.clear() + self.mcp_client_dict.clear() + self.func_list = [f for f in self.func_list if f.origin != "mcp"] + def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 @@ -629,8 +756,3 @@ class FuncCall: def __repr__(self): return str(self.func_list) - - async def terminate(self): - for name in self.mcp_client_dict.keys(): - await self._terminate_mcp_client(name) - logger.debug(f"清理 MCP 客户端 {name} 资源") diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index df21e6a12..370c5322b 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -169,10 +169,7 @@ class ProviderManager: self.curr_tts_provider_inst = self.tts_provider_insts[0] # 初始化 MCP Client 连接 - asyncio.create_task( - self.llm_tools.mcp_service_selector(), name="mcp-service-handler" - ) - self.llm_tools.mcp_service_queue.put_nowait({"type": "init"}) + asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") async def load_provider(self, provider_config: dict): if not provider_config["enable"]: @@ -422,7 +419,7 @@ class ProviderManager: self.curr_tts_provider_inst = None if getattr(self.inst_map[provider_id], "terminate", None): - await self.inst_map[provider_id].terminate() # type: ignore + await self.inst_map[provider_id].terminate() # type: ignore logger.info( f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" @@ -432,6 +429,8 @@ class ProviderManager: async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): - await provider_inst.terminate() # type: ignore - # 清理 MCP Client 连接 - await self.llm_tools.mcp_service_queue.put({"type": "terminate"}) + await provider_inst.terminate() # type: ignore + try: + await self.llm_tools.disable_mcp_server() + except Exception: + logger.error("Error while disabling MCP servers", exc_info=True) diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index f43152473..79b2e83b2 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -22,7 +22,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): timeout=int(provider_config.get("timeout", 20)), ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") - self.dimension = provider_config.get("embedding_dimensions", 1536) + self.dimension = provider_config.get("embedding_dimensions", 1024) async def get_embedding(self, text: str) -> list[float]: """ diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index e7ce0d54a..14c2da2de 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -176,7 +176,7 @@ class ProviderOpenAIOfficial(Provider): raise Exception("API 返回的 completion 为空。") choice = completion.choices[0] - if choice.message.content: + if choice.message.content is not None: # text completion completion_text = str(choice.message.content).strip() llm_response.result_chain = MessageChain().message(completion_text) @@ -210,7 +210,7 @@ class ProviderOpenAIOfficial(Provider): "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。" ) - if not llm_response.completion_text and not llm_response.tools_call_args: + if llm_response.completion_text is None and not llm_response.tools_call_args: logger.error(f"API 返回的 completion 无法解析:{completion}。") raise Exception(f"API 返回的 completion 无法解析:{completion}。") diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 25c291659..86318f8b7 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -10,7 +10,7 @@ from astrbot.core.star.star_tools import StarTools class Star(CommandParserMixin): """所有插件(Star)的父类,所有插件都应该继承于这个类""" - def __init__(self, context: Context): + def __init__(self, context: Context, config: dict | None = None): StarTools.initialize(context) self.context = context @@ -41,9 +41,17 @@ class Star(CommandParserMixin): tmpl, data, return_url=return_url, options=options ) + async def initialize(self): + """当插件被激活时会调用这个方法""" + pass + async def terminate(self): """当插件被禁用、重载插件时会调用这个方法""" pass + def __del__(self): + """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" + pass + __all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"] diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index fffaf8553..6c2d38572 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -8,7 +8,6 @@ from typing import Union class PlatformAdapterType(enum.Flag): AIOCQHTTP = enum.auto() QQOFFICIAL = enum.auto() - GEWECHAT = enum.auto() TELEGRAM = enum.auto() WECOM = enum.auto() LARK = enum.auto() @@ -22,7 +21,6 @@ class PlatformAdapterType(enum.Flag): ALL = ( AIOCQHTTP | QQOFFICIAL - | GEWECHAT | TELEGRAM | WECOM | LARK @@ -39,7 +37,6 @@ class PlatformAdapterType(enum.Flag): ADAPTER_NAME_2_TYPE = { "aiocqhttp": PlatformAdapterType.AIOCQHTTP, "qq_official": PlatformAdapterType.QQOFFICIAL, - "gewechat": PlatformAdapterType.GEWECHAT, "telegram": PlatformAdapterType.TELEGRAM, "wecom": PlatformAdapterType.WECOM, "lark": PlatformAdapterType.LARK, diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index d44388238..2fe9dd7f3 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass, field from types import ModuleType +from typing import TYPE_CHECKING from astrbot.core.config import AstrBotConfig @@ -9,6 +10,9 @@ star_registry: list[StarMetadata] = [] star_map: dict[str, StarMetadata] = {} """key 是模块路径,__module__""" +if TYPE_CHECKING: + from . import Star + @dataclass class StarMetadata: @@ -29,12 +33,12 @@ class StarMetadata: repo: str | None = None """插件仓库地址""" - star_cls_type: type | None = None + star_cls_type: type[Star] | None = None """插件的类对象的类型""" module_path: str | None = None """插件的模块路径""" - star_cls: object | None = None + star_cls: Star | None = None """插件的类对象""" module: ModuleType | None = None """插件的模块对象""" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 4a6d4d902..b64b4aa85 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -163,7 +163,7 @@ class PluginManager: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str = None): + async def _check_plugin_dept_update(self, target_plugin: str | None = None): """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -187,7 +187,7 @@ class PluginManager: logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") @staticmethod - def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata: + def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None: """先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。 Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。 @@ -253,8 +253,8 @@ class PluginManager: def _purge_modules( self, - module_patterns: list[str] = None, - root_dir_name: str = None, + module_patterns: list[str] | None = None, + root_dir_name: str | None = None, is_reserved: bool = False, ): """从 sys.modules 中移除指定的模块 @@ -314,8 +314,8 @@ class PluginManager: logger.warning( f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" ) - - await self._unbind_plugin(smd.name, smd.module_path) + if smd.name and smd.module_path: + await self._unbind_plugin(smd.name, smd.module_path) star_handlers_registry.clear() star_map.clear() @@ -331,8 +331,8 @@ class PluginManager: logger.warning( f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" ) - - await self._unbind_plugin(smd.name, specified_module_path) + if smd.name: + await self._unbind_plugin(smd.name, specified_module_path) result = await self.load(specified_module_path) @@ -460,8 +460,7 @@ class PluginManager: metadata.config = plugin_config if path not in inactivated_plugins: # 只有没有禁用插件时才实例化插件类 - if plugin_config: - # metadata.config = plugin_config + if plugin_config and metadata.star_cls_type: try: metadata.star_cls = metadata.star_cls_type( context=self.context, config=plugin_config @@ -470,7 +469,7 @@ class PluginManager: metadata.star_cls = metadata.star_cls_type( context=self.context ) - else: + elif metadata.star_cls_type: metadata.star_cls = metadata.star_cls_type( context=self.context ) @@ -487,6 +486,10 @@ class PluginManager: ) metadata.update_platform_compatibility(plugin_enable_config) + assert metadata.module_path is not None, ( + f"插件 {metadata.name} 的模块路径为空。" + ) + # 绑定 handler related_handlers = ( star_handlers_registry.get_handlers_by_module_name( @@ -495,7 +498,8 @@ class PluginManager: ) for handler in related_handlers: handler.handler = functools.partial( - handler.handler, metadata.star_cls + handler.handler, + metadata.star_cls, # type: ignore ) # 绑定 llm_tool handler for func_tool in llm_tools.func_list: @@ -505,7 +509,8 @@ class PluginManager: ): func_tool.handler_module_path = metadata.module_path func_tool.handler = functools.partial( - func_tool.handler, metadata.star_cls + func_tool.handler, + metadata.star_cls, # type: ignore ) if func_tool.name in inactivated_llm_tools: func_tool.active = False @@ -532,13 +537,12 @@ class PluginManager: obj = getattr(module, classes[0])( context=self.context ) # 实例化插件类 - else: - logger.info(f"插件 {metadata.name} 已被禁用。") - metadata = None metadata = self._load_plugin_metadata( plugin_path=plugin_dir_path, plugin_obj=obj ) + if not metadata: + raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。") metadata.star_cls = obj metadata.config = plugin_config metadata.module = module @@ -553,6 +557,10 @@ class PluginManager: if metadata.module_path in inactivated_plugins: metadata.activated = False + assert metadata.module_path is not None, ( + f"插件 {metadata.name} 的模块路径为空。" + ) + full_names = [] for handler in star_handlers_registry.get_handlers_by_module_name( metadata.module_path @@ -592,7 +600,7 @@ class PluginManager: metadata.star_handler_full_names = full_names # 执行 initialize() 方法 - if hasattr(metadata.star_cls, "initialize"): + if hasattr(metadata.star_cls, "initialize") and metadata.star_cls: await metadata.star_cls.initialize() except BaseException as e: @@ -734,6 +742,9 @@ class PluginManager: ]: del star_handlers_registry.star_handlers_map[k] + if plugin is None: + return + self._purge_modules( root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved ) @@ -795,6 +806,9 @@ class PluginManager: logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。") return + if star_metadata.star_cls is None: + return + if hasattr(star_metadata.star_cls, "__del__"): asyncio.get_event_loop().run_in_executor( None, star_metadata.star_cls.__del__ diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 2cd8fd9c2..2b34c2a14 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -30,7 +30,7 @@ def on_error(func, path, exc_info): raise exc_info[1] -def remove_dir(file_path) -> bool: +def remove_dir(file_path: str) -> bool: if not os.path.exists(file_path): return True shutil.rmtree(file_path, onerror=on_error) diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py new file mode 100644 index 000000000..912d91e53 --- /dev/null +++ b/astrbot/core/utils/session_lock.py @@ -0,0 +1,29 @@ +import asyncio +from collections import defaultdict +from contextlib import asynccontextmanager + + +class SessionLockManager: + def __init__(self): + self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._lock_count: dict[str, int] = defaultdict(int) + self._access_lock = asyncio.Lock() + + @asynccontextmanager + async def acquire_lock(self, session_id: str): + async with self._access_lock: + lock = self._locks[session_id] + self._lock_count[session_id] += 1 + + try: + async with lock: + yield + finally: + async with self._access_lock: + self._lock_count[session_id] -= 1 + if self._lock_count[session_id] == 0: + self._locks.pop(session_id, None) + self._lock_count.pop(session_id, None) + + +session_lock_manager = SessionLockManager() diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 7a503583b..42018d19e 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,7 +1,9 @@ import json import os +from typing import TypeVar from .astrbot_path import get_astrbot_data_path +_VT = TypeVar("_VT") class SharedPreferences: def __init__(self, path=None): @@ -24,7 +26,7 @@ class SharedPreferences: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() - def get(self, key, default=None): + def get(self, key, default: _VT = None) -> _VT: return self._data.get(key, default) def put(self, key, value): diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index a8cf34c95..5ae98cc21 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -2,7 +2,7 @@ import asyncio import json from quart import make_response from astrbot.core import logger, LogBroker -from .route import Route, RouteContext +from .route import Route, RouteContext, Response class LogRoute(Route): @@ -10,6 +10,7 @@ class LogRoute(Route): super().__init__(context) self.log_broker = log_broker self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"]) + self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"]) async def log(self): async def stream(): @@ -23,7 +24,6 @@ class LogRoute(Route): **message, # see astrbot/core/log.py } yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - await asyncio.sleep(0.07) # 控制发送频率,避免过快 except asyncio.CancelledError: pass except BaseException as e: @@ -43,3 +43,14 @@ class LogRoute(Route): ) response.timeout = None return response + + async def log_history(self): + """获取日志历史""" + try: + logs = list(self.log_broker.log_cache) + return Response().ok(data={ + "logs": logs, + }).__dict__ + except BaseException as e: + logger.error(f"获取日志历史失败: {e}") + return Response().error(f"获取日志历史失败: {e}").__dict__ diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 6f37609b2..179b45428 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,6 +1,8 @@ import traceback import aiohttp import os +import json +from datetime import datetime import ssl import certifi @@ -75,15 +77,33 @@ class PluginRoute(Route): async def get_online_plugins(self): custom = request.args.get("custom_registry") + force_refresh = request.args.get("force_refresh", "false").lower() == "true" + + cache_file = "data/plugins.json" if custom: urls = [custom] else: - urls = ["https://api.soulter.top/astrbot/plugins"] + urls = [ + "https://api.soulter.top/astrbot/plugins", + "https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json", + ] - # 新增:创建 SSL 上下文,使用 certifi 提供的根证书 + # 如果不是强制刷新,先检查缓存是否有效 + cached_data = None + if not force_refresh: + # 先检查MD5是否匹配,如果匹配则使用缓存 + if await self._is_cache_valid(cache_file): + cached_data = self._load_plugin_cache(cache_file) + if cached_data: + logger.debug("缓存MD5匹配,使用缓存的插件市场数据") + return Response().ok(cached_data).__dict__ + + # 尝试获取远程数据 + remote_data = None ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) + for url in urls: try: async with aiohttp.ClientSession( @@ -91,14 +111,123 @@ class PluginRoute(Route): ) as session: async with session.get(url) as response: if response.status == 200: - result = await response.json() - return Response().ok(result).__dict__ + remote_data = await response.json() + + # 检查远程数据是否为空 + if not remote_data or ( + isinstance(remote_data, dict) and len(remote_data) == 0 + ): + logger.warning(f"远程插件市场数据为空: {url}") + continue # 继续尝试其他URL或使用缓存 + + logger.info("成功获取远程插件市场数据") + # 获取最新的MD5并保存到缓存 + current_md5 = await self._get_remote_md5() + self._save_plugin_cache( + cache_file, remote_data, current_md5 + ) + return Response().ok(remote_data).__dict__ else: logger.error(f"请求 {url} 失败,状态码:{response.status}") except Exception as e: logger.error(f"请求 {url} 失败,错误:{e}") - return Response().error("获取插件列表失败").__dict__ + # 如果远程获取失败,尝试使用缓存数据 + if not cached_data: + cached_data = self._load_plugin_cache(cache_file) + + if cached_data: + logger.warning("远程插件市场数据获取失败,使用缓存数据") + return Response().ok(cached_data, "使用缓存数据,可能不是最新版本").__dict__ + + return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__ + + async def _is_cache_valid(self, cache_file: str) -> bool: + """检查缓存是否有效(基于MD5)""" + try: + if not os.path.exists(cache_file): + return False + + # 加载缓存文件 + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + + cached_md5 = cache_data.get("md5") + if not cached_md5: + logger.debug("缓存文件中没有MD5信息") + return False + + # 获取远程MD5 + remote_md5 = await self._get_remote_md5() + if not remote_md5: + logger.warning("无法获取远程MD5,将使用缓存") + return True # 如果无法获取远程MD5,认为缓存有效 + + is_valid = cached_md5 == remote_md5 + logger.debug( + f"插件数据MD5: 本地={cached_md5}, 远程={remote_md5}, 有效={is_valid}" + ) + return is_valid + + except Exception as e: + logger.warning(f"检查缓存有效性失败: {e}") + return False + + async def _get_remote_md5(self) -> str: + """获取远程插件数据的MD5""" + try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + + async with aiohttp.ClientSession( + trust_env=True, connector=connector + ) as session: + async with session.get( + "https://api.soulter.top/astrbot/plugins-md5" + ) as response: + if response.status == 200: + data = await response.json() + return data.get("md5", "") + else: + logger.error(f"获取MD5失败,状态码:{response.status}") + return "" + except Exception as e: + logger.error(f"获取远程MD5失败: {e}") + return "" + + def _load_plugin_cache(self, cache_file: str): + """加载本地缓存的插件市场数据""" + try: + if os.path.exists(cache_file): + with open(cache_file, "r", encoding="utf-8") as f: + cache_data = json.load(f) + # 检查缓存是否有效 + if "data" in cache_data and "timestamp" in cache_data: + logger.debug( + f"加载缓存文件: {cache_file}, 缓存时间: {cache_data['timestamp']}" + ) + return cache_data["data"] + except Exception as e: + logger.warning(f"加载插件市场缓存失败: {e}") + return None + + def _save_plugin_cache(self, cache_file: str, data, md5: str = None): + """保存插件市场数据到本地缓存""" + try: + # 确保目录存在 + os.makedirs(os.path.dirname(cache_file), exist_ok=True) + + cache_data = { + "timestamp": datetime.now().isoformat(), + "data": data, + "md5": md5 or "", + } + + with open(cache_file, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + logger.debug(f"插件市场数据已缓存到: {cache_file}, MD5: {md5}") + except Exception as e: + logger.warning(f"保存插件市场缓存失败: {e}") async def get_plugins(self): _plugin_resp = [] diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 5bc401a0d..1890002c7 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -2,6 +2,7 @@ import traceback import psutil import time import threading +import aiohttp from .route import Route, Response, RouteContext from astrbot.core import logger from quart import request @@ -25,6 +26,7 @@ class StatRoute(Route): "/stat/version": ("GET", self.get_version), "/stat/start-time": ("GET", self.get_start_time), "/stat/restart-core": ("POST", self.restart_core), + "/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection), } self.db_helper = db_helper self.register_routes() @@ -45,11 +47,7 @@ class StatRoute(Route): """将总秒数转换为时分秒组件""" minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) - return { - "hours": hours, - "minutes": minutes, - "seconds": seconds - } + return {"hours": hours, "minutes": minutes, "seconds": seconds} def is_default_cred(self): username = self.config["dashboard"]["username"] @@ -144,3 +142,40 @@ class StatRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(e.__str__()).__dict__ + + async def test_ghproxy_connection(self): + """ + 测试 GitHub 代理连接是否可用。 + """ + try: + data = await request.get_json() + proxy_url: str = data.get("proxy_url") + + if not proxy_url: + return Response().error("proxy_url is required").__dict__ + + proxy_url = proxy_url.rstrip("/") + + test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version" + start_time = time.time() + + async with aiohttp.ClientSession() as session: + async with session.get( + test_url, timeout=aiohttp.ClientTimeout(total=10) + ) as response: + if response.status == 200: + end_time = time.time() + _ = await response.text() + ret = { + "latency": round((end_time - start_time) * 1000, 2), + } + return Response().ok(data=ret).__dict__ + else: + return ( + Response() + .error(f"Failed. Status code: {response.status}") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"Error: {str(e)}").__dict__ diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index d38014c71..5dad2576b 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -26,6 +26,7 @@ class ToolsRoute(Route): "/tools/mcp/update": ("POST", self.update_mcp_server), "/tools/mcp/delete": ("POST", self.delete_mcp_server), "/tools/mcp/market": ("GET", self.get_mcp_markets), + "/tools/mcp/test": ("POST", self.test_mcp_connection), } self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools @@ -132,12 +133,19 @@ class ToolsRoute(Route): config["mcpServers"][name] = server_config if self.save_mcp_config(config): - # 动态初始化新MCP客户端 - await self.tool_mgr.mcp_service_queue.put({ - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - }) + try: + await self.tool_mgr.enable_mcp_server( + name, server_config, timeout=30 + ) + except TimeoutError: + return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"启用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ else: return Response().error("保存配置失败").__dict__ @@ -193,31 +201,55 @@ class ToolsRoute(Route): if self.save_mcp_config(config): # 处理MCP客户端状态变化 if active: - # 如果要激活服务器或者配置已更改 if name in self.tool_mgr.mcp_client_dict or not only_update_active: - await self.tool_mgr.mcp_service_queue.put({ - "type": "terminate", - "name": name, - }) - await self.tool_mgr.mcp_service_queue.put({ - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - }) - else: - # 客户端不存在,初始化 - await self.tool_mgr.mcp_service_queue.put({ - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - }) + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError as e: + return ( + Response() + .error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}") + .__dict__ + ) + try: + await self.tool_mgr.enable_mcp_server( + name, config["mcpServers"][name], timeout=30 + ) + except TimeoutError: + return ( + Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"启用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) else: # 如果要停用服务器 if name in self.tool_mgr.mcp_client_dict: - self.tool_mgr.mcp_service_queue.put_nowait({ - "type": "terminate", - "name": name, - }) + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError: + return ( + Response() + .error(f"停用 MCP 服务器 {name} 超时。") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"停用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__ else: @@ -239,17 +271,23 @@ class ToolsRoute(Route): if name not in config["mcpServers"]: return Response().error(f"服务器 {name} 不存在").__dict__ - # 删除服务器配置 del config["mcpServers"][name] if self.save_mcp_config(config): - # 关闭并删除MCP客户端 if name in self.tool_mgr.mcp_client_dict: - self.tool_mgr.mcp_service_queue.put_nowait({ - "type": "terminate", - "name": name, - }) - + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError: + return ( + Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"停用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__ else: return Response().error("保存配置失败").__dict__ @@ -281,3 +319,20 @@ class ToolsRoute(Route): except Exception as _: logger.error(traceback.format_exc()) return Response().error("获取市场数据失败").__dict__ + + async def test_mcp_connection(self): + """ + 测试 MCP 服务器连接 + """ + try: + server_data = await request.json + config = server_data.get("mcp_server_config", None) + + tools_name = await self.tool_mgr.test_mcp_server_connection(config) + return ( + Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__ + ) + + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__ diff --git a/changelogs/v3.5.23.md b/changelogs/v3.5.23.md new file mode 100644 index 000000000..0aff7def8 --- /dev/null +++ b/changelogs/v3.5.23.md @@ -0,0 +1,18 @@ +1. 改进: WebUI提供商徽标显示 +2. 修复:在LLMRequestSubStage中添加对提供商请求处理的调试日志记录 +3. 修复: 为嵌入模型提供商添加状态检查 +4. 新增: 支持在WebUI上管理会话 +5. 新增: 为ProviderMetadata添加provider_type字段并优化提供商可用性测试 +6. 改进: WebUI聊天页面Markdown代码块 +7. 修复: 讯飞模型工具使用错误 +8. 修复: 修复mcp导致的持续占用100% CPU +9. 重构: mcp服务器重载机制 +10. 新增: 为WebChat页面添加文件上传按钮 +11. 优化: 工具使用页面用户界面 +12. 新增: 添加测试GitHub加速地址的组件 +13. 新增: 使用会话锁保证分段回复时的消息发送顺序 +14. 新增: 实现日志历史记录检索并改进日志流处理 +15. 杂务: 修改openai的嵌入模型默认维度为1024 +16. 修复:更新axios版本范围 +17. chore: remove adapters of WeChat personal account(gewechat) +18. 新增: 为AstrBotConfig中的嵌套对象添加展开状态管理 \ No newline at end of file diff --git a/dashboard/package.json b/dashboard/package.json index 4f7ca6753..56d6540e5 100644 --- a/dashboard/package.json +++ b/dashboard/package.json @@ -17,7 +17,7 @@ "@tiptap/starter-kit": "2.1.7", "@tiptap/vue-3": "2.1.7", "apexcharts": "3.42.0", - "axios": "^1.6.2", + "axios": ">=1.6.2 <1.10.0 || >1.10.0 <2.0.0", "axios-mock-adapter": "^1.22.0", "chance": "1.1.11", "d3": "^7.9.0", diff --git a/dashboard/src/components/shared/ItemCard.vue b/dashboard/src/components/shared/ItemCard.vue index ff790cb7b..6152c531f 100644 --- a/dashboard/src/components/shared/ItemCard.vue +++ b/dashboard/src/components/shared/ItemCard.vue @@ -9,6 +9,8 @@ hide-details density="compact" :model-value="getItemEnabled()" + :loading="loading" + :disabled="loading" v-bind="props" @update:model-value="toggleEnabled" > @@ -77,6 +79,10 @@ export default { bglogo: { type: String, default: null + }, + loading: { + type: Boolean, + default: false } }, emits: ['toggle-enabled', 'delete', 'edit'], diff --git a/dashboard/src/components/shared/ProxySelector.vue b/dashboard/src/components/shared/ProxySelector.vue new file mode 100644 index 000000000..d45a0f520 --- /dev/null +++ b/dashboard/src/components/shared/ProxySelector.vue @@ -0,0 +1,152 @@ + + + + + + \ No newline at end of file diff --git a/dashboard/src/i18n/locales/en-US/features/extension.json b/dashboard/src/i18n/locales/en-US/features/extension.json index a9d0d1b2f..a586da59d 100644 --- a/dashboard/src/i18n/locales/en-US/features/extension.json +++ b/dashboard/src/i18n/locales/en-US/features/extension.json @@ -32,7 +32,8 @@ "cancel": "Cancel", "actions": "Actions", "back": "Back", - "selectFile": "Select File" + "selectFile": "Select File", + "refresh": "Refresh" }, "status": { "enabled": "Enabled", diff --git a/dashboard/src/i18n/locales/en-US/features/platform.json b/dashboard/src/i18n/locales/en-US/features/platform.json index 182627bc5..31ce1d2c4 100644 --- a/dashboard/src/i18n/locales/en-US/features/platform.json +++ b/dashboard/src/i18n/locales/en-US/features/platform.json @@ -24,6 +24,7 @@ "addPlatform": "Add Platform Adapter", "connectTitle": "Connect {name}", "viewTutorial": "View Tutorial", + "noTemplates": "No platform templates available", "idConflict": { "title": "ID Conflict Warning", "message": "Detected duplicate ID \"{id}\". Please use a new ID.", diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json index fad67a0d5..96c4760e8 100644 --- a/dashboard/src/i18n/locales/en-US/features/tool-use.json +++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json @@ -15,7 +15,9 @@ "buttons": { "refresh": "Refresh", "add": "Add Server", - "useTemplate": "Use Template" + "useTemplateStdio": "Stdio Template", + "useTemplateStreamableHttp": "Streamable HTTP Template", + "useTemplateSse": "SSE Template" }, "empty": "No MCP servers available, click Add Server to add one", "status": { @@ -28,8 +30,7 @@ "functionTools": { "title": "Function Tools", "buttons": { - "expand": "Expand", - "collapse": "Collapse" + "view": "View Tools" }, "search": "Search function tools", "empty": "No function tools available", @@ -68,10 +69,6 @@ "enable": "Enable Server", "config": "Server Configuration" }, - "configNotes": { - "note1": "1. Some MCP servers may require filling in `API_KEY` or `TOKEN` information in env according to their requirements, please check if filled.", - "note2": "2. When url parameter is specified in configuration: if `transport` parameter is also specified as `streamable_http`, Streamable HTTP is used, otherwise SSE connection is used." - }, "errors": { "configEmpty": "Configuration cannot be empty", "jsonFormat": "JSON format error: {error}", @@ -79,7 +76,8 @@ }, "buttons": { "cancel": "Cancel", - "save": "Save" + "save": "Save", + "testConnection": "Test Connection" } }, "serverDetail": { diff --git a/dashboard/src/i18n/locales/zh-CN/features/extension.json b/dashboard/src/i18n/locales/zh-CN/features/extension.json index a8e4559eb..61b30183e 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/extension.json +++ b/dashboard/src/i18n/locales/zh-CN/features/extension.json @@ -32,7 +32,8 @@ "cancel": "取消", "actions": "操作", "back": "返回", - "selectFile": "选择文件" + "selectFile": "选择文件", + "refresh": "刷新" }, "status": { "enabled": "启用", diff --git a/dashboard/src/i18n/locales/zh-CN/features/platform.json b/dashboard/src/i18n/locales/zh-CN/features/platform.json index 26e672cb0..f5432ff14 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/platform.json +++ b/dashboard/src/i18n/locales/zh-CN/features/platform.json @@ -24,6 +24,7 @@ "addPlatform": "添加平台适配器", "connectTitle": "接入 {name}", "viewTutorial": "查看接入教程", + "noTemplates": "暂无平台模板", "idConflict": { "title": "ID 冲突警告", "message": "检测到 ID \"{id}\" 重复。请使用一个新的 ID。", diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json index f44a16d59..61b8691bc 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json +++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json @@ -15,7 +15,9 @@ "buttons": { "refresh": "刷新", "add": "新增服务器", - "useTemplate": "使用模板" + "useTemplateStdio": "Stdio 模板", + "useTemplateStreamableHttp": "Streamable HTTP 模板", + "useTemplateSse": "SSE 模板" }, "empty": "暂无 MCP 服务器,点击 新增服务器 添加", "status": { @@ -28,8 +30,7 @@ "functionTools": { "title": "函数工具", "buttons": { - "expand": "展开", - "collapse": "收起" + "view": "查看工具" }, "search": "搜索函数工具", "empty": "没有可用的函数工具", @@ -68,10 +69,6 @@ "enable": "启用服务器", "config": "服务器配置" }, - "configNotes": { - "note1": "1. 某些 MCP 服务器可能需要按照其要求在 env 中填充 `API_KEY` 或 `TOKEN` 等信息,请注意检查是否填写。", - "note2": "2. 当配置中指定 url 参数时:如果还同时指定 `transport` 参数的值为 `streamable_http`,则使用 Steamable HTTP,否则使用 SSE 连接。" - }, "errors": { "configEmpty": "配置不能为空", "jsonFormat": "JSON 格式错误: {error}", @@ -79,7 +76,8 @@ }, "buttons": { "cancel": "取消", - "save": "保存" + "save": "保存", + "testConnection": "测试连接" } }, "serverDetail": { diff --git a/dashboard/src/stores/common.js b/dashboard/src/stores/common.js index 8fc489760..fb951fc0a 100644 --- a/dashboard/src/stores/common.js +++ b/dashboard/src/stores/common.js @@ -15,7 +15,22 @@ export const useCommonStore = defineStore({ pluginMarketData: [], }), actions: { - createEventSource() { + async createEventSource() { + + const fetchLogHistory = async () => { + try { + const res = await axios.get('/api/log-history'); + if (res.data.data.logs) { + this.log_cache.push(...res.data.data.logs); + } else { + this.log_cache = []; + } + } catch (err) { + console.error('Failed to fetch log history:', err); + } + }; + await fetchLogHistory(); + if (this.eventSource) { return } @@ -40,7 +55,24 @@ export const useCommonStore = defineStore({ const reader = response.body.getReader(); const decoder = new TextDecoder(); + let incompleteLine = ""; // 用于存储不完整的行 + + const handleIncompleteLine = (line) => { + incompleteLine += line; + // if can parse as JSON, return it + try { + const data_json = JSON.parse(incompleteLine); + incompleteLine = ""; // 清空不完整行 + return data_json; + } catch (e) { + return null; + } + } + const processStream = ({ done, value }) => { + // get bytes length + const bytesLength = value ? value.byteLength : 0; + console.log(`Received ${bytesLength} bytes from live log`); if (done) { console.log('SSE stream closed'); setTimeout(() => { @@ -53,6 +85,9 @@ export const useCommonStore = defineStore({ const text = decoder.decode(value); const lines = text.split('\n\n'); lines.forEach(line => { + if (!line.trim()) { + return; + } if (line.startsWith('data:')) { const data = line.substring(5).trim(); // {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"} @@ -60,21 +95,29 @@ export const useCommonStore = defineStore({ try { data_json = JSON.parse(data); } catch (e) { - console.error('Invalid JSON:', data); - data_json = { - type: 'log', - data: data, - level: 'INFO', - time: new Date().toISOString(), + console.warn('Invalid JSON:', data); + // 尝试处理不完整的行 + const parsedData = handleIncompleteLine(data); + if (parsedData) { + data_json = parsedData; + } else { + return; // 如果无法解析,跳过当前行 } } if (data_json.type === 'log') { - // let log = data_json.data this.log_cache.push(data_json); if (this.log_cache.length > this.log_cache_max_len) { this.log_cache.shift(); } } + } else { + const parsedData = handleIncompleteLine(line); + if (parsedData && parsedData.type === 'log') { + this.log_cache.push(parsedData); + if (this.log_cache.length > this.log_cache_max_len) { + this.log_cache.shift(); + } + } } }); return reader.read().then(processStream); @@ -116,7 +159,11 @@ export const useCommonStore = defineStore({ if (!force && this.pluginMarketData.length > 0) { return Promise.resolve(this.pluginMarketData); } - return axios.get('/api/plugin/market_list') + + // 如果是强制刷新,添加 force_refresh 参数 + const url = force ? '/api/plugin/market_list?force_refresh=true' : '/api/plugin/market_list'; + + return axios.get(url) .then((res) => { let data = [] for (let key in res.data.data) { diff --git a/dashboard/src/theme/DarkTheme.ts b/dashboard/src/theme/DarkTheme.ts index 9899fcfff..177bee39c 100644 --- a/dashboard/src/theme/DarkTheme.ts +++ b/dashboard/src/theme/DarkTheme.ts @@ -36,12 +36,13 @@ const PurpleThemeDark: ThemeTypes = { gray100: '#cccccccc', primary200: '#90caf9', secondary200: '#b39ddb', - background: '#111111', + background: '#1d1d1d', overlay: '#111111aa', codeBg: '#282833', preBg: 'rgb(23, 23, 23)', code: '#ffffffdd', chatMessageBubble: '#2d2e30', + mcpCardBg: '#2a2a2a', } }; diff --git a/dashboard/src/theme/LightTheme.ts b/dashboard/src/theme/LightTheme.ts index 03630523f..b8fdec259 100644 --- a/dashboard/src/theme/LightTheme.ts +++ b/dashboard/src/theme/LightTheme.ts @@ -27,7 +27,7 @@ const PurpleTheme: ThemeTypes = { borderLight: '#d0d0d0', border: '#d0d0d0', inputBorder: '#787878', - containerBg: '#f7f1f6', + containerBg: '#f9fafcf4', surface: '#fff', 'on-surface-variant': '#fff', facebook: '#4267b2', @@ -36,12 +36,13 @@ const PurpleTheme: ThemeTypes = { gray100: '#fafafacc', primary200: '#90caf9', secondary200: '#b39ddb', - background: '#f9fafcf4', + background: '#ffffff', overlay: '#ffffffaa', codeBg: '#ececec', preBg: 'rgb(249, 249, 249)', code: 'rgb(13, 13, 13)', chatMessageBubble: '#e7ebf4', + mcpCardBg: '#f7f2f9', } }; diff --git a/dashboard/src/types/themeTypes/ThemeType.ts b/dashboard/src/types/themeTypes/ThemeType.ts index b18ee3dc5..8d2760044 100644 --- a/dashboard/src/types/themeTypes/ThemeType.ts +++ b/dashboard/src/types/themeTypes/ThemeType.ts @@ -37,5 +37,6 @@ export type ThemeTypes = { preBg?: string; code?: string; chatMessageBubble?: string; + mcpCardBg?: string; }; }; diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 208a751ab..a7e5c3eb3 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -226,6 +226,9 @@
+ + @@ -668,34 +671,44 @@ export default { }; }, + async processAndUploadImage(file) { + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await axios.post('/api/chat/post_image', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + const img = response.data.data.filename; + this.stagedImagesName.push(img); // Store just the filename + this.stagedImagesUrl.push(URL.createObjectURL(file)); // Create a blob URL for immediate display + + } catch (err) { + console.error('Error uploading image:', err); + } + }, + async handlePaste(event) { console.log('Pasting image...'); const items = event.clipboardData.items; for (let i = 0; i < items.length; i++) { if (items[i].type.indexOf('image') !== -1) { const file = items[i].getAsFile(); - const formData = new FormData(); - formData.append('file', file); - - try { - const response = await axios.post('/api/chat/post_image', formData, { - headers: { - 'Content-Type': 'multipart/form-data' - } - }); - - const img = response.data.data.filename; - this.stagedImagesName.push(img); // Store just the filename - this.stagedImagesUrl.push(URL.createObjectURL(file)); // Create a blob URL for immediate display - - } catch (err) { - console.error('Error uploading image:', err); - } + this.processAndUploadImage(file); } } }, removeImage(index) { + // Revoke the blob URL to prevent memory leaks + const urlToRevoke = this.stagedImagesUrl[index]; + if (urlToRevoke && urlToRevoke.startsWith('blob:')) { + URL.revokeObjectURL(urlToRevoke); + } + this.stagedImagesName.splice(index, 1); this.stagedImagesUrl.splice(index, 1); }, @@ -703,6 +716,21 @@ export default { clearMessage() { this.prompt = ''; }, + + triggerImageInput() { + this.$refs.imageInput.click(); + }, + + handleFileSelect(event) { + const files = event.target.files; + if (files) { + for (const file of files) { + this.processAndUploadImage(file); + } + } + // Reset the input value to allow selecting the same file again + event.target.value = ''; + }, getConversations() { axios.get('/api/chat/conversations').then(response => { this.conversations = response.data.data; @@ -848,33 +876,42 @@ export default { // URL is already updated in newConversation method } + // 保存当前要发送的数据到临时变量 + const promptToSend = this.prompt.trim(); + const imageNamesToSend = [...this.stagedImagesName]; + const audioNameToSend = this.stagedAudioUrl; + + // 立即清空输入和附件预览 + this.prompt = ''; + this.stagedImagesName = []; + this.stagedImagesUrl = []; + this.stagedAudioUrl = ""; + // Create a message object with actual URLs for display const userMessage = { type: 'user', - message: this.prompt.trim(), // 使用 trim() 去除前后空格 + message: promptToSend, image_url: [], audio_url: null }; // Convert image filenames to blob URLs for display - if (this.stagedImagesName.length > 0) { - for (let i = 0; i < this.stagedImagesName.length; i++) { - // If it's just a filename, get the blob URL - if (!this.stagedImagesName[i].startsWith('blob:')) { - const imgUrl = await this.getMediaFile(this.stagedImagesName[i]); - userMessage.image_url.push(imgUrl); - } else { - userMessage.image_url.push(this.stagedImagesName[i]); + if (imageNamesToSend.length > 0) { + const imagePromises = imageNamesToSend.map(name => { + if (!name.startsWith('blob:')) { + return this.getMediaFile(name); } - } + return Promise.resolve(name); + }); + userMessage.image_url = await Promise.all(imagePromises); } // Convert audio filename to blob URL for display - if (this.stagedAudioUrl) { - if (!this.stagedAudioUrl.startsWith('blob:')) { - userMessage.audio_url = await this.getMediaFile(this.stagedAudioUrl); + if (audioNameToSend) { + if (!audioNameToSend.startsWith('blob:')) { + userMessage.audio_url = await this.getMediaFile(audioNameToSend); } else { - userMessage.audio_url = this.stagedAudioUrl; + userMessage.audio_url = audioNameToSend; } } @@ -889,8 +926,6 @@ export default { const selection = this.$refs.providerModelSelector?.getCurrentSelection(); const selectedProviderId = selection?.providerId || ''; const selectedModelName = selection?.modelName || ''; - let prompt = this.prompt.trim(); - this.prompt = ''; // 清空输入框 try { const response = await fetch('/api/chat/send', { @@ -900,10 +935,10 @@ export default { 'Authorization': 'Bearer ' + localStorage.getItem('token') }, body: JSON.stringify({ - message: prompt, + message: promptToSend, conversation_id: this.currCid, - image_url: this.stagedImagesName, - audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [], + image_url: imageNamesToSend, + audio_url: audioNameToSend ? [audioNameToSend] : [], selected_provider: selectedProviderId, selected_model: selectedModelName }) @@ -1013,11 +1048,7 @@ export default { } } - // Clear input after successful send - this.prompt = ''; - this.stagedImagesName = []; - this.stagedImagesUrl = []; - this.stagedAudioUrl = ""; + // Input and attachments are already cleared this.loadingChat = false; // get the latest conversations diff --git a/dashboard/src/views/ConversationPage.vue b/dashboard/src/views/ConversationPage.vue index 15ec44ca2..3f35d91d2 100644 --- a/dashboard/src/views/ConversationPage.vue +++ b/dashboard/src/views/ConversationPage.vue @@ -364,7 +364,6 @@ export default { 'telegram': 'blue-lighten-1', 'qq_official': 'purple-lighten-1', 'qq_official_webhook': 'purple-lighten-2', - 'gewechat': 'green-lighten-1', 'aiocqhttp': 'deep-purple-lighten-1', 'lark': 'cyan-darken-1', 'wecom': 'green-darken-1', diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue index 892034add..9d159990b 100644 --- a/dashboard/src/views/ExtensionPage.vue +++ b/dashboard/src/views/ExtensionPage.vue @@ -3,6 +3,7 @@ import ExtensionCard from '@/components/shared/ExtensionCard.vue'; import AstrBotConfig from '@/components/shared/AstrBotConfig.vue'; import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue'; import ReadmeDialog from '@/components/shared/ReadmeDialog.vue'; +import ProxySelector from '@/components/shared/ProxySelector.vue'; import axios from 'axios'; import { useCommonStore } from '@/stores/common'; import { useI18n, useModuleI18n } from '@/i18n/composables'; @@ -29,12 +30,12 @@ const extension_config = reactive({ config: {} }); const pluginMarketData = ref([]); - const loadingDialog = reactive({ - show: false, - title: "", - statusCode: 0, // 0: loading, 1: success, 2: error, - result: "" - }); +const loadingDialog = reactive({ + show: false, + title: "", + statusCode: 0, // 0: loading, 1: success, 2: error, + result: "" +}); const showPluginInfoDialog = ref(false); const selectedPlugin = ref({}); const curr_namespace = ref(""); @@ -70,6 +71,7 @@ const uploadTab = ref('file'); const showPluginFullName = ref(false); const marketSearch = ref(""); const filterKeys = ['name', 'desc', 'author']; +const refreshingMarket = ref(false); const plugin_handler_info_headers = computed(() => [ { title: tm('table.headers.eventType'), key: 'event_type_h' }, @@ -184,8 +186,8 @@ const checkUpdate = () => { if (matchedPlugin) { extension.online_version = matchedPlugin.version; - extension.has_update = extension.version !== matchedPlugin.version && - matchedPlugin.version !== tm('status.unknown'); + extension.has_update = extension.version !== matchedPlugin.version && + matchedPlugin.version !== tm('status.unknown'); } else { extension.has_update = false; } @@ -559,6 +561,25 @@ const newExtension = async () => { } }; +// 刷新插件市场数据 +const refreshPluginMarket = async () => { + refreshingMarket.value = true; + try { + // 强制刷新插件市场数据 + const data = await commonStore.getPluginCollections(true); + pluginMarketData.value = data; + trimExtensionName(); + checkAlreadyInstalled(); + checkUpdate(); + + toast(tm('messages.refreshSuccess'), "success"); + } catch (err) { + toast(tm('messages.refreshFailed') + " " + err, "error"); + } finally { + refreshingMarket.value = false; + } +}; + // 生命周期 onMounted(async () => { await getExtensions(); @@ -622,27 +643,12 @@ onMounted(async () => { - + - + @@ -678,33 +684,32 @@ onMounted(async () => { mdi-plus {{ tm('buttons.install') }} - - - - - - + + + + + + @@ -726,7 +731,8 @@ onMounted(async () => {
{{ item.name }}
- {{ tm('status.system') }} + {{ tm('status.system') + }}
@@ -847,8 +853,8 @@ onMounted(async () => { - +
@@ -865,8 +871,20 @@ onMounted(async () => {

{{ tm('market.allPlugins') }}

- + + mdi-refresh + {{ tm('buttons.refresh') }} + + +
@@ -904,7 +922,8 @@ onMounted(async () => { \ No newline at end of file diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue index cc70c415b..6060088a4 100644 --- a/dashboard/src/views/ToolUsePage.vue +++ b/dashboard/src/views/ToolUsePage.vue @@ -20,9 +20,16 @@

- - {{ tm('mcpServers.buttons.add') }} - +
+ + {{ tm('functionTools.buttons.view') }}({{ tools.length }}) + + + {{ tm('mcpServers.buttons.add') }} + +
@@ -44,169 +51,79 @@ - - - mdi-server - {{ tm('mcpServers.title') }} - - - {{ tm('mcpServers.buttons.refresh') }} - - - {{ tm('mcpServers.buttons.add') }} - - - +
+ mdi-server-off +

{{ tm('mcpServers.empty') }}

+
- -
- mdi-server-off -

{{ tm('mcpServers.empty') }}

-
+ + + + + + + - - - -

- mdi-information - {{ tm('functionTools.description') }} -

-

{{ tool.function.description }}

- - -
- mdi-code-brackets -

{{ tm('functionTools.noParameters') }}

-
-
-
-
- - - - -
- -
@@ -216,9 +133,9 @@ mdi-store {{ tm('marketplace.title') }} - + {{ tm('marketplace.buttons.refresh') }} @@ -256,7 +173,8 @@
mdi-tools - {{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 }) }} + {{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 }) + }}
@@ -310,31 +228,25 @@ - - - +
{{ tm('dialogs.addServer.fields.config') }} - - -
- {{ tm('tooltip.serverConfig') }} -
-
- - {{ tm('mcpServers.buttons.useTemplate') }} + + {{ tm('mcpServers.buttons.useTemplateStdio') }} + + + {{ tm('mcpServers.buttons.useTemplateStreamableHttp') }} + + + {{ tm('mcpServers.buttons.useTemplateSse') }}
- {{ tm('dialogs.addServer.configNotes.note1') }} -
- {{ tm('dialogs.addServer.configNotes.note2') }} -
+
+ {{ addServerDialogMessage }} +
- + {{ tm('dialogs.addServer.buttons.cancel') }} + + {{ tm('dialogs.addServer.buttons.testConnection') }} + {{ tm('dialogs.addServer.buttons.save') }} @@ -469,6 +386,106 @@ + + + + + {{ tm('functionTools.title') }} + {{ tools.length }} + + + +
+
+ mdi-api-off +

{{ tm('functionTools.empty') }}

+
+ +
+ + + + + + + +
+ + {{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }} + + + {{ formatToolName(tool.function.name) }} + +
+
+ + {{ tool.function.description }} + +
+
+ + + + +

+ mdi-information + {{ tm('functionTools.description') }} +

+

{{ tool.function.description }}

+ + +
+ mdi-code-brackets +

{{ tm('functionTools.noParameters') }}

+
+
+
+
+
+
+
+
+
+
+ + + + + {{ tm('dialogs.serverDetail.buttons.close') }} + + +
+
+ @@ -504,8 +521,12 @@ export default { tools: [], showMcpServerDialog: false, showServerDetailDialog: false, + addServerDialogMessage: "", + showToolsDialog: false, showTools: true, loading: false, + loadingGettingServers: false, + mcpServerUpdateLoaders: {}, // record loading state for each server update isEditMode: false, serverConfigJson: '', jsonError: null, @@ -575,10 +596,10 @@ export default { if (!this.marketplaceSearch.trim()) { return this.marketplaceServers; } - + const searchTerm = this.marketplaceSearch.toLowerCase(); - return this.marketplaceServers.filter(server => - server.name.toLowerCase().includes(searchTerm) || + return this.marketplaceServers.filter(server => + server.name.toLowerCase().includes(searchTerm) || (server.name_h && server.name_h.toLowerCase().includes(searchTerm)) || (server.description && server.description.toLowerCase().includes(searchTerm)) ); @@ -618,17 +639,21 @@ export default { }, getServers() { - this.loading = true + this.loadingGettingServers = true; axios.get('/api/tools/mcp/servers') .then(response => { this.mcpServers = response.data.data || []; + this.mcpServers.forEach(server => { + // Ensure each server has a loader state + if (!this.mcpServerUpdateLoaders[server.name]) { + this.mcpServerUpdateLoaders[server.name] = false; + } + }); }) .catch(error => { this.showError(this.tm('messages.getServersError', { error: error.message })); }).finally(() => { - setTimeout(() => { - this.loading = false; - }, 500); + this.loadingGettingServers = false; }); }, @@ -658,14 +683,28 @@ export default { } }, - setConfigTemplate() { - // 设置一个基本的配置模板 - const template = { - command: "python", - args: ["-m", "your_module"], - // 可以添加其他 MCP 支持的配置项 - }; - + setConfigTemplate(type = 'stdio') { + let template = {}; + if (type === 'streamable_http') { + template = { + transport: "streamable_http", + url: "your mcp server url", + headers: {}, + timeout: 30, + }; + } else if (type === 'sse') { + template = { + transport: "sse", + url: "your mcp server url", + headers: {}, + timeout: 30, + }; + } else { + template = { + command: "python", + args: ["-m", "your_module"], + }; + } this.serverConfigJson = JSON.stringify(template, null, 2); }, @@ -693,6 +732,7 @@ export default { .then(response => { this.loading = false; this.showMcpServerDialog = false; + this.addServerDialogMessage = ""; this.getServers(); this.getTools(); this.showSuccess(response.data.message || this.tm('messages.saveSuccess')); @@ -753,6 +793,7 @@ export default { updateServerStatus(server) { // 切换服务器状态 + this.mcpServerUpdateLoaders[server.name] = true; server.active = !server.active; axios.post('/api/tools/mcp/update', server) .then(response => { @@ -761,16 +802,48 @@ export default { }) .catch(error => { this.showError(this.tm('messages.updateError', { error: error.response?.data?.message || error.message })); - // 回滚状态 server.active = !server.active; + }) + .finally(() => { + this.mcpServerUpdateLoaders[server.name] = false; }); }, closeServerDialog() { this.showMcpServerDialog = false; + this.addServerDialogMessage = ''; this.resetForm(); }, + testServerConnection() { + if (!this.validateJson()) { + return; + } + + this.loading = true; + + let configObj; + try { + configObj = JSON.parse(this.serverConfigJson); + } catch (e) { + this.loading = false; + this.showError(this.tm('dialogs.addServer.errors.jsonParse', { error: e.message })); + return; + } + + axios.post('/api/tools/mcp/test', { + "mcp_server_config": configObj, + }) + .then(response => { + this.loading = false; + this.addServerDialogMessage = `${response.data.message} (tools: ${response.data.data})`; + }) + .catch(error => { + this.loading = false; + this.showError(this.tm('messages.testError', { error: error.response?.data?.message || error.message })); + }); + }, + resetForm() { this.currentServer = { name: '', @@ -939,7 +1012,7 @@ export default { .monaco-container { border: 1px solid rgba(0, 0, 0, 0.1); - border-radius: 4px; + border-radius: 8px; height: 300px; margin-top: 4px; overflow: hidden; diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index fcb34e250..1dd2cbe2f 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1,4 +1,3 @@ -import os import aiohttp import datetime import builtins @@ -16,7 +15,6 @@ from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.sources.dify_source import ProviderDify from astrbot.core.utils.io import download_dashboard, get_dashboard_version -from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata from astrbot.core.star.star import star_map from astrbot.core.star.star_manager import PluginManager @@ -1153,24 +1151,6 @@ UID: {user_id} 此 ID 可用于设置管理员。 sp.put("session_variables", session_vars) yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。") - @filter.command("gewe_logout") - async def gewe_logout(self, event: AstrMessageEvent): - platforms = self.context.platform_manager.platform_insts - for platform in platforms: - if platform.meta().name == "gewechat": - yield event.plain_result("正在登出 gewechat") - await platform.logout() - yield event.plain_result("已登出 gewechat,请重启 AstrBot") - return - - @filter.command("gewe_code") - async def gewe_code(self, event: AstrMessageEvent, code: str): - """保存 gewechat 验证码""" - code_path = os.path.join(get_astrbot_data_path(), "temp", "gewe_code") - with open(code_path, "w", encoding="utf-8") as f: - f.write(code) - yield event.plain_result("验证码已保存。") - @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) async def on_message(self, event: AstrMessageEvent): """群聊记忆增强""" @@ -1242,6 +1222,10 @@ UID: {user_id} 此 ID 可用于设置管理员。 logger.error(traceback.format_exc()) logger.error(f"主动回复失败: {e}") + @filter.on_decorating_result() + async def decorate_result(self, event: AstrMessageEvent): + logger.debug("Decorating result for event: %s", event) + @filter.on_llm_request() async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" diff --git a/pyproject.toml b/pyproject.toml index 12639b5c9..2c271592f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "AstrBot" -version = "3.5.22" +version = "3.5.23" description = "易上手的多平台 LLM 聊天机器人及开发框架" readme = "README.md" requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 96ecfeda6..bd8f0eca0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,7 +37,6 @@ watchfiles websockets faiss-cpu aiosqlite -nh3 py-cord>=2.6.1 slack-sdk pydub \ No newline at end of file