Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7d776e0ce2 | |||
| 17df1692b9 | |||
| 9ab652641d | |||
| 9119f7166f | |||
| da7d9d8eb9 | |||
| 80fccc90b7 | |||
| dcebc70f1a | |||
| 259e7bc322 | |||
| 37bdb6c6f6 | |||
| dc71afdd3f | |||
| 44638108d0 | |||
| 93fcac498c | |||
| 79e2743aac | |||
| 5e9c7cdd91 | |||
| 6f73e5087d | |||
| 8c120b020e | |||
| 12fc6f9d38 | |||
| a6e8483b4c |
@@ -11,6 +11,8 @@ reviewers:
|
|||||||
- Larch-C
|
- Larch-C
|
||||||
- anka-afk
|
- anka-afk
|
||||||
- advent259141
|
- advent259141
|
||||||
|
- Fridemn
|
||||||
|
- LIghtJUNction
|
||||||
# - zouyonghe
|
# - zouyonghe
|
||||||
|
|
||||||
# A number of reviewers added to the pull request
|
# A number of reviewers added to the pull request
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ jobs:
|
|||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v3
|
uses: github/codeql-action/init@v4
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
build-mode: ${{ matrix.build-mode }}
|
build-mode: ${{ matrix.build-mode }}
|
||||||
@@ -88,6 +88,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v3
|
uses: github/codeql-action/analyze@v4
|
||||||
with:
|
with:
|
||||||
category: "/language:${{matrix.language}}"
|
category: "/language:${{matrix.language}}"
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
@@ -13,17 +11,17 @@
|
|||||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
|
||||||

|

|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
<a href="https://astrbot.app/">文档</a> |
|
<a href="https://astrbot.app/">文档</a> |
|
||||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||||
|
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
||||||
|
|
||||||
## 主要功能
|
## 主要功能
|
||||||
|
|
||||||
@@ -35,7 +33,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
|
|||||||
|
|
||||||
## 部署方式
|
## 部署方式
|
||||||
|
|
||||||
#### Docker 部署
|
#### Docker 部署(推荐 🥳)
|
||||||
|
|
||||||
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||||
|
|
||||||
@@ -101,7 +99,6 @@ uv run main.py
|
|||||||
- 5 群:822130018
|
- 5 群:822130018
|
||||||
- 6 群:753075035
|
- 6 群:753075035
|
||||||
- 开发者群:975206796
|
- 开发者群:975206796
|
||||||
- 开发者群(备份):295657329
|
|
||||||
|
|
||||||
### Telegram 群组
|
### Telegram 群组
|
||||||
|
|
||||||
@@ -113,48 +110,80 @@ uv run main.py
|
|||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
|
**官方维护**
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
| 平台 | 支持性 |
|
||||||
| -------- | ------- |
|
| -------- | ------- |
|
||||||
| QQ(官方机器人接口) | ✔ |
|
| QQ(官方平台) | ✔ |
|
||||||
| QQ(OneBot) | ✔ |
|
| QQ(OneBot) | ✔ |
|
||||||
| Telegram | ✔ |
|
| Telegram | ✔ |
|
||||||
| 企业微信 | ✔ |
|
| 企微应用 | ✔ |
|
||||||
| 微信客服 | ✔ |
|
| 微信客服 | ✔ |
|
||||||
| 微信公众号 | ✔ |
|
| 微信公众号 | ✔ |
|
||||||
| 飞书 | ✔ |
|
| 飞书 | ✔ |
|
||||||
| 钉钉 | ✔ |
|
| 钉钉 | ✔ |
|
||||||
| Slack | ✔ |
|
| Slack | ✔ |
|
||||||
| Discord | ✔ |
|
| Discord | ✔ |
|
||||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
|
||||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
|
||||||
| Satori | ✔ |
|
| Satori | ✔ |
|
||||||
| Misskey | ✔ |
|
| Misskey | ✔ |
|
||||||
|
| 企微智能机器人 | 将支持 |
|
||||||
|
| Whatsapp | 将支持 |
|
||||||
|
| LINE | 将支持 |
|
||||||
|
|
||||||
|
**社区维护**
|
||||||
|
|
||||||
|
| 平台 | 支持性 |
|
||||||
|
| -------- | ------- |
|
||||||
|
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||||
|
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||||
|
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
||||||
|
|
||||||
## ⚡ 提供商支持情况
|
## ⚡ 提供商支持情况
|
||||||
|
|
||||||
| 名称 | 支持性 | 类型 | 备注 |
|
**大模型服务**
|
||||||
| -------- | ------- | ------- | ------- |
|
|
||||||
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
|
| 名称 | 支持性 | 备注 |
|
||||||
| Anthropic | ✔ | 文本生成 | |
|
| -------- | ------- | ------- |
|
||||||
| Google Gemini | ✔ | 文本生成 | |
|
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
||||||
| Dify | ✔ | LLMOps | |
|
| Anthropic | ✔ | |
|
||||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
| Google Gemini | ✔ | |
|
||||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| Moonshot AI | ✔ | |
|
||||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| 智谱 AI | ✔ | |
|
||||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
| DeepSeek | ✔ | |
|
||||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
||||||
| OneAPI | ✔ | LLM 分发系统 | |
|
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
||||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
||||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
| 硅基流动 | ✔ | |
|
||||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
| PPIO 派欧云 | ✔ | |
|
||||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
| ModelScope | ✔ | |
|
||||||
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
| OneAPI | ✔ | |
|
||||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
| Dify | ✔ | |
|
||||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
| 阿里云百炼应用 | ✔ | |
|
||||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
| Coze | ✔ | |
|
||||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
|
||||||
|
**语音转文本服务**
|
||||||
|
|
||||||
|
| 名称 | 支持性 | 备注 |
|
||||||
|
| -------- | ------- | ------- |
|
||||||
|
| Whisper | ✔ | 支持 API、本地部署 |
|
||||||
|
| SenseVoice | ✔ | 本地部署 |
|
||||||
|
|
||||||
|
**文本转语音服务**
|
||||||
|
|
||||||
|
| 名称 | 支持性 | 备注 |
|
||||||
|
| -------- | ------- | ------- |
|
||||||
|
| OpenAI TTS | ✔ | |
|
||||||
|
| Gemini TTS | ✔ | |
|
||||||
|
| GSVI | ✔ | GPT-Sovits-Inference |
|
||||||
|
| GPT-SoVITs | ✔ | GPT-Sovits |
|
||||||
|
| FishAudio | ✔ | |
|
||||||
|
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
||||||
|
| 阿里云百炼 TTS | ✔ | |
|
||||||
|
| Azure TTS | ✔ | |
|
||||||
|
| Minimax TTS | ✔ | |
|
||||||
|
| 火山引擎 TTS | ✔ | |
|
||||||
|
|
||||||
## ❤️ 贡献
|
## ❤️ 贡献
|
||||||
|
|
||||||
@@ -186,19 +215,10 @@ pre-commit install
|
|||||||
|
|
||||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||||
|
|
||||||
另外,一些同类型其他的活跃开源 Bot 项目:
|
|
||||||
|
|
||||||
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
|
|
||||||
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
|
||||||
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
|
||||||
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
|
||||||
- [KroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
|
||||||
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,15 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|||||||
timeout = cfg.get("timeout", 10)
|
timeout = cfg.get("timeout", 10)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if "transport" in cfg:
|
||||||
|
transport_type = cfg["transport"]
|
||||||
|
elif "type" in cfg:
|
||||||
|
transport_type = cfg["type"]
|
||||||
|
else:
|
||||||
|
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
if cfg.get("transport") == "streamable_http":
|
if transport_type == "streamable_http":
|
||||||
test_payload = {
|
test_payload = {
|
||||||
"jsonrpc": "2.0",
|
"jsonrpc": "2.0",
|
||||||
"method": "initialize",
|
"method": "initialize",
|
||||||
@@ -121,7 +128,14 @@ class MCPClient:
|
|||||||
if not success:
|
if not success:
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
if cfg.get("transport") != "streamable_http":
|
if "transport" in cfg:
|
||||||
|
transport_type = cfg["transport"]
|
||||||
|
elif "type" in cfg:
|
||||||
|
transport_type = cfg["type"]
|
||||||
|
else:
|
||||||
|
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||||
|
|
||||||
|
if transport_type != "streamable_http":
|
||||||
# SSE transport method
|
# SSE transport method
|
||||||
self._streams_context = sse_client(
|
self._streams_context = sse_client(
|
||||||
url=cfg["url"],
|
url=cfg["url"],
|
||||||
@@ -134,7 +148,7 @@ class MCPClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
mcp.ClientSession(
|
mcp.ClientSession(
|
||||||
*streams,
|
*streams,
|
||||||
@@ -159,7 +173,7 @@ class MCPClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create a new client session
|
# Create a new client session
|
||||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await self.exit_stack.enter_async_context(
|
||||||
mcp.ClientSession(
|
mcp.ClientSession(
|
||||||
read_stream=read_s,
|
read_stream=read_s,
|
||||||
|
|||||||
@@ -9,3 +9,4 @@ class AstrAgentContext:
|
|||||||
first_provider_request: ProviderRequest
|
first_provider_request: ProviderRequest
|
||||||
curr_provider_request: ProviderRequest
|
curr_provider_request: ProviderRequest
|
||||||
streaming: bool
|
streaming: bool
|
||||||
|
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.3.2"
|
VERSION = "4.3.5"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -57,6 +57,7 @@ DEFAULT_CONFIG = {
|
|||||||
"web_search": False,
|
"web_search": False,
|
||||||
"websearch_provider": "default",
|
"websearch_provider": "default",
|
||||||
"websearch_tavily_key": [],
|
"websearch_tavily_key": [],
|
||||||
|
"websearch_baidu_app_builder_key": "",
|
||||||
"web_search_link": False,
|
"web_search_link": False,
|
||||||
"display_reasoning_text": False,
|
"display_reasoning_text": False,
|
||||||
"identifier": False,
|
"identifier": False,
|
||||||
@@ -71,6 +72,7 @@ DEFAULT_CONFIG = {
|
|||||||
"show_tool_use_status": False,
|
"show_tool_use_status": False,
|
||||||
"streaming_segmented": False,
|
"streaming_segmented": False,
|
||||||
"max_agent_step": 30,
|
"max_agent_step": 30,
|
||||||
|
"tool_call_timeout": 60,
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -207,6 +209,18 @@ CONFIG_METADATA_2 = {
|
|||||||
"callback_server_host": "0.0.0.0",
|
"callback_server_host": "0.0.0.0",
|
||||||
"port": 6195,
|
"port": 6195,
|
||||||
},
|
},
|
||||||
|
"企业微信智能机器人": {
|
||||||
|
"id": "wecom_ai_bot",
|
||||||
|
"type": "wecom_ai_bot",
|
||||||
|
"enable": True,
|
||||||
|
"wecomaibot_init_respond_text": "💭 思考中...",
|
||||||
|
"wecomaibot_friend_message_welcome_text": "",
|
||||||
|
"wecom_ai_bot_name": "",
|
||||||
|
"token": "",
|
||||||
|
"encoding_aes_key": "",
|
||||||
|
"callback_server_host": "0.0.0.0",
|
||||||
|
"port": 6198,
|
||||||
|
},
|
||||||
"飞书(Lark)": {
|
"飞书(Lark)": {
|
||||||
"id": "lark",
|
"id": "lark",
|
||||||
"type": "lark",
|
"type": "lark",
|
||||||
@@ -447,10 +461,25 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||||
},
|
},
|
||||||
|
"wecom_ai_bot_name": {
|
||||||
|
"description": "企业微信智能机器人的名字",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "请务必填写正确,否则无法使用一些指令。",
|
||||||
|
},
|
||||||
|
"wecomaibot_init_respond_text": {
|
||||||
|
"description": "企业微信智能机器人初始响应文本",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
|
||||||
|
},
|
||||||
|
"wecomaibot_friend_message_welcome_text": {
|
||||||
|
"description": "企业微信智能机器人私聊欢迎语",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
|
||||||
|
},
|
||||||
"lark_bot_name": {
|
"lark_bot_name": {
|
||||||
"description": "飞书机器人的名字",
|
"description": "飞书机器人的名字",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||||
},
|
},
|
||||||
"discord_token": {
|
"discord_token": {
|
||||||
"description": "Discord Bot Token",
|
"description": "Discord Bot Token",
|
||||||
@@ -1056,6 +1085,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": "20",
|
"timeout": "20",
|
||||||
},
|
},
|
||||||
"阿里云百炼 TTS(API)": {
|
"阿里云百炼 TTS(API)": {
|
||||||
|
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
|
||||||
"id": "dashscope_tts",
|
"id": "dashscope_tts",
|
||||||
"provider": "dashscope",
|
"provider": "dashscope",
|
||||||
"type": "dashscope_tts",
|
"type": "dashscope_tts",
|
||||||
@@ -1435,11 +1465,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "服务订阅密钥",
|
"description": "服务订阅密钥",
|
||||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
||||||
},
|
},
|
||||||
"dashscope_tts_voice": {
|
"dashscope_tts_voice": {"description": "音色", "type": "string"},
|
||||||
"description": "语音合成模型",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
|
|
||||||
},
|
|
||||||
"gm_resp_image_modal": {
|
"gm_resp_image_modal": {
|
||||||
"description": "启用图片模态",
|
"description": "启用图片模态",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -1848,6 +1874,10 @@ CONFIG_METADATA_2 = {
|
|||||||
"description": "工具调用轮数上限",
|
"description": "工具调用轮数上限",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
},
|
},
|
||||||
|
"tool_call_timeout": {
|
||||||
|
"description": "工具调用超时时间(秒)",
|
||||||
|
"type": "int",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_stt_settings": {
|
"provider_stt_settings": {
|
||||||
@@ -2066,7 +2096,7 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_settings.websearch_provider": {
|
"provider_settings.websearch_provider": {
|
||||||
"description": "网页搜索提供商",
|
"description": "网页搜索提供商",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"options": ["default", "tavily"],
|
"options": ["default", "tavily", "baidu_ai_search"],
|
||||||
},
|
},
|
||||||
"provider_settings.websearch_tavily_key": {
|
"provider_settings.websearch_tavily_key": {
|
||||||
"description": "Tavily API Key",
|
"description": "Tavily API Key",
|
||||||
@@ -2077,6 +2107,14 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_settings.websearch_provider": "tavily",
|
"provider_settings.websearch_provider": "tavily",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"provider_settings.websearch_baidu_app_builder_key": {
|
||||||
|
"description": "百度千帆智能云 APP Builder API Key",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "参考:https://console.bce.baidu.com/iam/#/iam/apikey/list",
|
||||||
|
"condition": {
|
||||||
|
"provider_settings.websearch_provider": "baidu_ai_search",
|
||||||
|
},
|
||||||
|
},
|
||||||
"provider_settings.web_search_link": {
|
"provider_settings.web_search_link": {
|
||||||
"description": "显示来源引用",
|
"description": "显示来源引用",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -2112,6 +2150,10 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "工具调用轮数上限",
|
"description": "工具调用轮数上限",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
},
|
},
|
||||||
|
"provider_settings.tool_call_timeout": {
|
||||||
|
"description": "工具调用超时时间(秒)",
|
||||||
|
"type": "int",
|
||||||
|
},
|
||||||
"provider_settings.streaming_response": {
|
"provider_settings.streaming_response": {
|
||||||
"description": "流式回复",
|
"description": "流式回复",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import timedelta
|
||||||
from typing import AsyncGenerator, Union
|
from typing import AsyncGenerator, Union
|
||||||
from astrbot.core.conversation_mgr import Conversation
|
from astrbot.core.conversation_mgr import Conversation
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
@@ -185,21 +186,33 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
handler=awaitable,
|
handler=awaitable,
|
||||||
**tool_args,
|
**tool_args,
|
||||||
)
|
)
|
||||||
async for resp in wrapper:
|
# async for resp in wrapper:
|
||||||
if resp is not None:
|
while True:
|
||||||
if isinstance(resp, mcp.types.CallToolResult):
|
try:
|
||||||
yield resp
|
resp = await asyncio.wait_for(
|
||||||
|
anext(wrapper),
|
||||||
|
timeout=run_context.context.tool_call_timeout,
|
||||||
|
)
|
||||||
|
if resp is not None:
|
||||||
|
if isinstance(resp, mcp.types.CallToolResult):
|
||||||
|
yield resp
|
||||||
|
else:
|
||||||
|
text_content = mcp.types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=str(resp),
|
||||||
|
)
|
||||||
|
yield mcp.types.CallToolResult(content=[text_content])
|
||||||
else:
|
else:
|
||||||
text_content = mcp.types.TextContent(
|
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||||
type="text",
|
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||||
text=str(resp),
|
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||||
)
|
yield None
|
||||||
yield mcp.types.CallToolResult(content=[text_content])
|
except asyncio.TimeoutError:
|
||||||
else:
|
raise Exception(
|
||||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
|
||||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
)
|
||||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
except StopAsyncIteration:
|
||||||
yield None
|
break
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _execute_mcp(
|
async def _execute_mcp(
|
||||||
@@ -217,6 +230,9 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
res = await session.call_tool(
|
res = await session.call_tool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool_args,
|
arguments=tool_args,
|
||||||
|
read_timeout_seconds=timedelta(
|
||||||
|
seconds=run_context.context.tool_call_timeout
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if not res:
|
if not res:
|
||||||
return
|
return
|
||||||
@@ -307,6 +323,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
)
|
)
|
||||||
self.streaming_response: bool = settings["streaming_response"]
|
self.streaming_response: bool = settings["streaming_response"]
|
||||||
self.max_step: int = settings.get("max_agent_step", 30)
|
self.max_step: int = settings.get("max_agent_step", 30)
|
||||||
|
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||||
if isinstance(self.max_step, bool): # workaround: #2622
|
if isinstance(self.max_step, bool): # workaround: #2622
|
||||||
self.max_step = 30
|
self.max_step = 30
|
||||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||||
@@ -473,6 +490,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
first_provider_request=req,
|
first_provider_request=req,
|
||||||
curr_provider_request=req,
|
curr_provider_request=req,
|
||||||
streaming=self.streaming_response,
|
streaming=self.streaming_response,
|
||||||
|
tool_call_timeout=self.tool_call_timeout,
|
||||||
)
|
)
|
||||||
await agent_runner.reset(
|
await agent_runner.reset(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class PipelineScheduler:
|
|||||||
await self._process_stages(event)
|
await self._process_stages(event)
|
||||||
|
|
||||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||||
if event.get_platform_name() == "webchat":
|
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||||
await event.send(None)
|
await event.send(None)
|
||||||
|
|
||||||
logger.debug("pipeline 执行完毕。")
|
logger.debug("pipeline 执行完毕。")
|
||||||
|
|||||||
@@ -82,6 +82,10 @@ class PlatformManager:
|
|||||||
from .sources.wecom.wecom_adapter import (
|
from .sources.wecom.wecom_adapter import (
|
||||||
WecomPlatformAdapter, # noqa: F401
|
WecomPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
|
case "wecom_ai_bot":
|
||||||
|
from .sources.wecom_ai_bot.wecomai_adapter import (
|
||||||
|
WecomAIBotAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "weixin_official_account":
|
case "weixin_official_account":
|
||||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||||
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING
|
|||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .satori_adapter import SatoriPlatformAdapter
|
from .satori_adapter import SatoriPlatformAdapter
|
||||||
@@ -87,6 +87,17 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"语音转换为base64失败: {e}")
|
logger.error(f"语音转换为base64失败: {e}")
|
||||||
|
|
||||||
|
elif isinstance(component, Reply):
|
||||||
|
content_parts.append(f'<reply id="{component.id}"/>')
|
||||||
|
|
||||||
|
elif isinstance(component, Video):
|
||||||
|
try:
|
||||||
|
video_path_url = await component.convert_to_file_path()
|
||||||
|
if video_path_url:
|
||||||
|
content_parts.append(f'<video src="{video_path_url}"/>')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"视频文件转换失败: {e}")
|
||||||
|
|
||||||
content = "".join(content_parts)
|
content = "".join(content_parts)
|
||||||
channel_id = session_id
|
channel_id = session_id
|
||||||
data = {"channel_id": channel_id, "content": content}
|
data = {"channel_id": channel_id, "content": content}
|
||||||
@@ -166,6 +177,17 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"语音转换为base64失败: {e}")
|
logger.error(f"语音转换为base64失败: {e}")
|
||||||
|
|
||||||
|
elif isinstance(component, Reply):
|
||||||
|
content_parts.append(f'<reply id="{component.id}"/>')
|
||||||
|
|
||||||
|
elif isinstance(component, Video):
|
||||||
|
try:
|
||||||
|
video_path_url = await component.convert_to_file_path()
|
||||||
|
if video_path_url:
|
||||||
|
content_parts.append(f'<video src="{video_path_url}"/>')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"视频文件转换失败: {e}")
|
||||||
|
|
||||||
content = "".join(content_parts)
|
content = "".join(content_parts)
|
||||||
channel_id = self.session_id
|
channel_id = self.session_id
|
||||||
data = {"channel_id": channel_id, "content": content}
|
data = {"channel_id": channel_id, "content": content}
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ class WebChatAdapter(Platform):
|
|||||||
|
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = "webchat"
|
abm.self_id = "webchat"
|
||||||
abm.tag = "webchat"
|
|
||||||
abm.sender = MessageMember(username, username)
|
abm.sender = MessageMember(username, username)
|
||||||
|
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
|||||||
@@ -0,0 +1,289 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding:utf-8 -*-
|
||||||
|
|
||||||
|
"""对企业微信发送给企业后台的消息加解密示例代码.
|
||||||
|
@copyright: Copyright (c) 1998-2020 Tencent Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import random
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import struct
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
import socket
|
||||||
|
import json
|
||||||
|
|
||||||
|
from . import ierror
|
||||||
|
|
||||||
|
"""
|
||||||
|
关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案
|
||||||
|
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
|
||||||
|
下载后,按照README中的“Installation”小节的提示进行pycrypto安装。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FormatException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def throw_exception(message, exception_class=FormatException):
|
||||||
|
"""my define raise exception function"""
|
||||||
|
raise exception_class(message)
|
||||||
|
|
||||||
|
|
||||||
|
class SHA1:
|
||||||
|
"""计算企业微信的消息签名接口"""
|
||||||
|
|
||||||
|
def getSHA1(self, token, timestamp, nonce, encrypt):
|
||||||
|
"""用SHA1算法生成安全签名
|
||||||
|
@param token: 票据
|
||||||
|
@param timestamp: 时间戳
|
||||||
|
@param encrypt: 密文
|
||||||
|
@param nonce: 随机字符串
|
||||||
|
@return: 安全签名
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保所有输入都是字符串类型
|
||||||
|
if isinstance(encrypt, bytes):
|
||||||
|
encrypt = encrypt.decode("utf-8")
|
||||||
|
|
||||||
|
sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)]
|
||||||
|
sortlist.sort()
|
||||||
|
sha = hashlib.sha1()
|
||||||
|
sha.update("".join(sortlist).encode("utf-8"))
|
||||||
|
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
|
||||||
|
|
||||||
|
|
||||||
|
class JsonParse:
|
||||||
|
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
|
||||||
|
|
||||||
|
# json消息模板
|
||||||
|
AES_TEXT_RESPONSE_TEMPLATE = """{
|
||||||
|
"encrypt": "%(msg_encrypt)s",
|
||||||
|
"msgsignature": "%(msg_signaturet)s",
|
||||||
|
"timestamp": "%(timestamp)s",
|
||||||
|
"nonce": "%(nonce)s"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
def extract(self, jsontext):
|
||||||
|
"""提取出json数据包中的加密消息
|
||||||
|
@param jsontext: 待提取的json字符串
|
||||||
|
@return: 提取出的加密消息字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
json_dict = json.loads(jsontext)
|
||||||
|
return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"]
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return ierror.WXBizMsgCrypt_ParseJson_Error, None
|
||||||
|
|
||||||
|
def generate(self, encrypt, signature, timestamp, nonce):
|
||||||
|
"""生成json消息
|
||||||
|
@param encrypt: 加密后的消息密文
|
||||||
|
@param signature: 安全签名
|
||||||
|
@param timestamp: 时间戳
|
||||||
|
@param nonce: 随机字符串
|
||||||
|
@return: 生成的json字符串
|
||||||
|
"""
|
||||||
|
resp_dict = {
|
||||||
|
"msg_encrypt": encrypt,
|
||||||
|
"msg_signaturet": signature,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"nonce": nonce,
|
||||||
|
}
|
||||||
|
resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
|
||||||
|
return resp_json
|
||||||
|
|
||||||
|
|
||||||
|
class PKCS7Encoder:
|
||||||
|
"""提供基于PKCS7算法的加解密接口"""
|
||||||
|
|
||||||
|
block_size = 32
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
"""对需要加密的明文进行填充补位
|
||||||
|
@param text: 需要进行填充补位操作的明文(bytes类型)
|
||||||
|
@return: 补齐明文字符串(bytes类型)
|
||||||
|
"""
|
||||||
|
text_length = len(text)
|
||||||
|
# 计算需要填充的位数
|
||||||
|
amount_to_pad = self.block_size - (text_length % self.block_size)
|
||||||
|
if amount_to_pad == 0:
|
||||||
|
amount_to_pad = self.block_size
|
||||||
|
# 获得补位所用的字符
|
||||||
|
pad = bytes([amount_to_pad])
|
||||||
|
# 确保text是bytes类型
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = text.encode("utf-8")
|
||||||
|
return text + pad * amount_to_pad
|
||||||
|
|
||||||
|
def decode(self, decrypted):
|
||||||
|
"""删除解密后明文的补位字符
|
||||||
|
@param decrypted: 解密后的明文
|
||||||
|
@return: 删除补位字符后的明文
|
||||||
|
"""
|
||||||
|
pad = ord(decrypted[-1])
|
||||||
|
if pad < 1 or pad > 32:
|
||||||
|
pad = 0
|
||||||
|
return decrypted[:-pad]
|
||||||
|
|
||||||
|
|
||||||
|
class Prpcrypt(object):
|
||||||
|
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
# self.key = base64.b64decode(key+"=")
|
||||||
|
self.key = key
|
||||||
|
# 设置加解密模式为AES的CBC模式
|
||||||
|
self.mode = AES.MODE_CBC
|
||||||
|
|
||||||
|
def encrypt(self, text, receiveid):
|
||||||
|
"""对明文进行加密
|
||||||
|
@param text: 需要加密的明文
|
||||||
|
@return: 加密得到的字符串
|
||||||
|
"""
|
||||||
|
# 16位随机字符串添加到明文开头
|
||||||
|
text = text.encode()
|
||||||
|
text = (
|
||||||
|
self.get_random_str()
|
||||||
|
+ struct.pack("I", socket.htonl(len(text)))
|
||||||
|
+ text
|
||||||
|
+ receiveid.encode()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用自定义的填充方式对明文进行补位填充
|
||||||
|
pkcs7 = PKCS7Encoder()
|
||||||
|
text = pkcs7.encode(text)
|
||||||
|
# 加密
|
||||||
|
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
|
||||||
|
try:
|
||||||
|
ciphertext = cryptor.encrypt(text)
|
||||||
|
# 使用BASE64对加密后的字符串进行编码
|
||||||
|
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
|
||||||
|
except Exception as e:
|
||||||
|
logger = logging.getLogger("astrbot")
|
||||||
|
logger.error(e)
|
||||||
|
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
|
||||||
|
|
||||||
|
def decrypt(self, text, receiveid):
|
||||||
|
"""对解密后的明文进行补位删除
|
||||||
|
@param text: 密文
|
||||||
|
@return: 删除填充补位后的明文
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
|
||||||
|
# 使用BASE64对密文进行解码,然后AES-CBC解密
|
||||||
|
plain_text = cryptor.decrypt(base64.b64decode(text))
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
|
||||||
|
try:
|
||||||
|
pad = plain_text[-1]
|
||||||
|
# 去掉补位字符串
|
||||||
|
# pkcs7 = PKCS7Encoder()
|
||||||
|
# plain_text = pkcs7.encode(plain_text)
|
||||||
|
# 去除16位随机字符串
|
||||||
|
content = plain_text[16:-pad]
|
||||||
|
json_len = socket.ntohl(struct.unpack("I", content[:4])[0])
|
||||||
|
json_content = content[4 : json_len + 4].decode("utf-8")
|
||||||
|
from_receiveid = content[json_len + 4 :].decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return ierror.WXBizMsgCrypt_IllegalBuffer, None
|
||||||
|
if from_receiveid != receiveid:
|
||||||
|
print("receiveid not match", receiveid, from_receiveid)
|
||||||
|
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
|
||||||
|
return 0, json_content
|
||||||
|
|
||||||
|
def get_random_str(self):
|
||||||
|
"""随机生成16位字符串
|
||||||
|
@return: 16位字符串
|
||||||
|
"""
|
||||||
|
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||||
|
|
||||||
|
|
||||||
|
class WXBizJsonMsgCrypt(object):
|
||||||
|
# 构造函数
|
||||||
|
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
||||||
|
try:
|
||||||
|
self.key = base64.b64decode(sEncodingAESKey + "=")
|
||||||
|
assert len(self.key) == 32
|
||||||
|
except Exception as e:
|
||||||
|
throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException)
|
||||||
|
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
||||||
|
self.m_sToken = sToken
|
||||||
|
self.m_sReceiveId = sReceiveId
|
||||||
|
|
||||||
|
# 验证URL
|
||||||
|
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||||
|
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||||
|
# @param sNonce: 随机串,对应URL参数的nonce
|
||||||
|
# @param sEchoStr: 随机串,对应URL参数的echostr
|
||||||
|
# @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效
|
||||||
|
# @return:成功0,失败返回对应的错误码
|
||||||
|
|
||||||
|
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
|
||||||
|
sha1 = SHA1()
|
||||||
|
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
|
||||||
|
if ret != 0:
|
||||||
|
return ret, None
|
||||||
|
if not signature == sMsgSignature:
|
||||||
|
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||||
|
pc = Prpcrypt(self.key)
|
||||||
|
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
|
||||||
|
return ret, sReplyEchoStr
|
||||||
|
|
||||||
|
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
|
||||||
|
# 将企业回复用户的消息加密打包
|
||||||
|
# @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串
|
||||||
|
# @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间
|
||||||
|
# @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce
|
||||||
|
# sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串,
|
||||||
|
# return:成功0,sEncryptMsg,失败返回对应的错误码None
|
||||||
|
pc = Prpcrypt(self.key)
|
||||||
|
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
|
||||||
|
encrypt = encrypt.decode("utf-8") # type: ignore
|
||||||
|
if ret != 0:
|
||||||
|
return ret, None
|
||||||
|
if timestamp is None:
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
# 生成安全签名
|
||||||
|
sha1 = SHA1()
|
||||||
|
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
|
||||||
|
if ret != 0:
|
||||||
|
return ret, None
|
||||||
|
jsonParse = JsonParse()
|
||||||
|
return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce)
|
||||||
|
|
||||||
|
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
|
||||||
|
# 检验消息的真实性,并且获取解密后的明文
|
||||||
|
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||||
|
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||||
|
# @param sNonce: 随机串,对应URL参数的nonce
|
||||||
|
# @param sPostData: 密文,对应POST请求的数据
|
||||||
|
# json_content: 解密后的原文,当return返回0时有效
|
||||||
|
# @return: 成功0,失败返回对应的错误码
|
||||||
|
# 验证安全签名
|
||||||
|
jsonParse = JsonParse()
|
||||||
|
ret, encrypt = jsonParse.extract(sPostData)
|
||||||
|
if ret != 0:
|
||||||
|
return ret, None
|
||||||
|
sha1 = SHA1()
|
||||||
|
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
|
||||||
|
if ret != 0:
|
||||||
|
return ret, None
|
||||||
|
if not signature == sMsgSignature:
|
||||||
|
print("signature not match")
|
||||||
|
print(signature)
|
||||||
|
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||||
|
pc = Prpcrypt(self.key)
|
||||||
|
ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId)
|
||||||
|
return ret, json_content
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人平台适配器包
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .wecomai_adapter import WecomAIBotAdapter
|
||||||
|
from .wecomai_api import WecomAIBotAPIClient
|
||||||
|
from .wecomai_event import WecomAIBotMessageEvent
|
||||||
|
from .wecomai_server import WecomAIBotServer
|
||||||
|
from .wecomai_utils import WecomAIBotConstants
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"WecomAIBotAdapter",
|
||||||
|
"WecomAIBotAPIClient",
|
||||||
|
"WecomAIBotMessageEvent",
|
||||||
|
"WecomAIBotServer",
|
||||||
|
"WecomAIBotConstants",
|
||||||
|
]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#########################################################################
|
||||||
|
# Author: jonyqin
|
||||||
|
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
|
||||||
|
# File Name: ierror.py
|
||||||
|
# Description:定义错误码含义
|
||||||
|
#########################################################################
|
||||||
|
WXBizMsgCrypt_OK = 0
|
||||||
|
WXBizMsgCrypt_ValidateSignature_Error = -40001
|
||||||
|
WXBizMsgCrypt_ParseJson_Error = -40002
|
||||||
|
WXBizMsgCrypt_ComputeSignature_Error = -40003
|
||||||
|
WXBizMsgCrypt_IllegalAesKey = -40004
|
||||||
|
WXBizMsgCrypt_ValidateCorpid_Error = -40005
|
||||||
|
WXBizMsgCrypt_EncryptAES_Error = -40006
|
||||||
|
WXBizMsgCrypt_DecryptAES_Error = -40007
|
||||||
|
WXBizMsgCrypt_IllegalBuffer = -40008
|
||||||
|
WXBizMsgCrypt_EncodeBase64_Error = -40009
|
||||||
|
WXBizMsgCrypt_DecodeBase64_Error = -40010
|
||||||
|
WXBizMsgCrypt_GenReturnJson_Error = -40011
|
||||||
@@ -0,0 +1,445 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人平台适配器
|
||||||
|
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
|
||||||
|
参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
from typing import Awaitable, Any, Dict, Optional, Callable
|
||||||
|
|
||||||
|
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
Platform,
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Plain, At, Image
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
|
||||||
|
from .wecomai_api import (
|
||||||
|
WecomAIBotAPIClient,
|
||||||
|
WecomAIBotMessageParser,
|
||||||
|
WecomAIBotStreamMessageBuilder,
|
||||||
|
)
|
||||||
|
from .wecomai_event import WecomAIBotMessageEvent
|
||||||
|
from .wecomai_server import WecomAIBotServer
|
||||||
|
from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr
|
||||||
|
from .wecomai_utils import (
|
||||||
|
WecomAIBotConstants,
|
||||||
|
format_session_id,
|
||||||
|
generate_random_string,
|
||||||
|
process_encrypted_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIQueueListener:
|
||||||
|
"""企业微信智能机器人队列监听器,参考webchat的QueueListener设计"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]]
|
||||||
|
) -> None:
|
||||||
|
self.queue_mgr = queue_mgr
|
||||||
|
self.callback = callback
|
||||||
|
self.running_tasks = set()
|
||||||
|
|
||||||
|
async def listen_to_queue(self, session_id: str):
|
||||||
|
"""监听特定会话的队列"""
|
||||||
|
queue = self.queue_mgr.get_or_create_queue(session_id)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await queue.get()
|
||||||
|
await self.callback(data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""监控新会话队列并启动监听器"""
|
||||||
|
monitored_sessions = set()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 检查新会话
|
||||||
|
current_sessions = set(self.queue_mgr.queues.keys())
|
||||||
|
new_sessions = current_sessions - monitored_sessions
|
||||||
|
|
||||||
|
# 为新会话启动监听器
|
||||||
|
for session_id in new_sessions:
|
||||||
|
task = asyncio.create_task(self.listen_to_queue(session_id))
|
||||||
|
self.running_tasks.add(task)
|
||||||
|
task.add_done_callback(self.running_tasks.discard)
|
||||||
|
monitored_sessions.add(session_id)
|
||||||
|
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||||
|
|
||||||
|
# 清理已不存在的会话
|
||||||
|
removed_sessions = monitored_sessions - current_sessions
|
||||||
|
monitored_sessions -= removed_sessions
|
||||||
|
|
||||||
|
# 清理过期的待处理响应
|
||||||
|
self.queue_mgr.cleanup_expired_responses()
|
||||||
|
|
||||||
|
await asyncio.sleep(1) # 每秒检查一次新会话
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter(
|
||||||
|
"wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息"
|
||||||
|
)
|
||||||
|
class WecomAIBotAdapter(Platform):
|
||||||
|
"""企业微信智能机器人适配器"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
|
||||||
|
self.config = platform_config
|
||||||
|
self.settings = platform_settings
|
||||||
|
|
||||||
|
# 初始化配置参数
|
||||||
|
self.token = self.config["token"]
|
||||||
|
self.encoding_aes_key = self.config["encoding_aes_key"]
|
||||||
|
self.port = int(self.config["port"])
|
||||||
|
self.host = self.config.get("callback_server_host", "0.0.0.0")
|
||||||
|
self.bot_name = self.config.get("wecom_ai_bot_name", "")
|
||||||
|
self.initial_respond_text = self.config.get(
|
||||||
|
"wecomaibot_init_respond_text", "💭 思考中..."
|
||||||
|
)
|
||||||
|
self.friend_message_welcome_text = self.config.get(
|
||||||
|
"wecomaibot_friend_message_welcome_text", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 平台元数据
|
||||||
|
self.metadata = PlatformMetadata(
|
||||||
|
name="wecom_ai_bot",
|
||||||
|
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
|
||||||
|
id=self.config.get("id", "wecom_ai_bot"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化 API 客户端
|
||||||
|
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
|
||||||
|
|
||||||
|
# 初始化 HTTP 服务器
|
||||||
|
self.server = WecomAIBotServer(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
api_client=self.api_client,
|
||||||
|
message_handler=self._process_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 事件循环和关闭信号
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
# 队列监听器
|
||||||
|
self.queue_listener = WecomAIQueueListener(
|
||||||
|
wecomai_queue_mgr, self._handle_queued_message
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_queued_message(self, data: dict):
|
||||||
|
"""处理队列中的消息,类似webchat的callback"""
|
||||||
|
try:
|
||||||
|
abm = await self.convert_message(data)
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理队列消息时发生异常: {e}")
|
||||||
|
|
||||||
|
async def _process_message(
|
||||||
|
self, message_data: Dict[str, Any], callback_params: Dict[str, str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""处理接收到的消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_data: 解密后的消息数据
|
||||||
|
callback_params: 回调参数 (nonce, timestamp)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加密后的响应消息,无需响应时返回 None
|
||||||
|
"""
|
||||||
|
msgtype = message_data.get("msgtype")
|
||||||
|
if not msgtype:
|
||||||
|
logger.warning(f"消息类型未知,忽略: {message_data}")
|
||||||
|
return None
|
||||||
|
session_id = self._extract_session_id(message_data)
|
||||||
|
if msgtype in ("text", "image", "mixed"):
|
||||||
|
# user sent a text / image / mixed message
|
||||||
|
try:
|
||||||
|
# create a brand-new unique stream_id for this message session
|
||||||
|
stream_id = f"{session_id}_{generate_random_string(10)}"
|
||||||
|
await self._enqueue_message(
|
||||||
|
message_data, callback_params, stream_id, session_id
|
||||||
|
)
|
||||||
|
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
|
||||||
|
|
||||||
|
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||||
|
stream_id, self.initial_respond_text, False
|
||||||
|
)
|
||||||
|
return await self.api_client.encrypt_message(
|
||||||
|
resp, callback_params["nonce"], callback_params["timestamp"]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("处理消息时发生异常: %s", e)
|
||||||
|
return None
|
||||||
|
elif msgtype == "stream":
|
||||||
|
# wechat server is requesting for updates of a stream
|
||||||
|
stream_id = message_data["stream"]["id"]
|
||||||
|
if not wecomai_queue_mgr.has_back_queue(stream_id):
|
||||||
|
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||||
|
|
||||||
|
# 返回结束标志,告诉微信服务器流已结束
|
||||||
|
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||||
|
stream_id, "", True
|
||||||
|
)
|
||||||
|
resp = await self.api_client.encrypt_message(
|
||||||
|
end_message,
|
||||||
|
callback_params["nonce"],
|
||||||
|
callback_params["timestamp"],
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||||
|
if queue.empty():
|
||||||
|
logger.debug(
|
||||||
|
f"No new messages in back queue for stream_id: {stream_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# aggregate all delta chains in the back queue
|
||||||
|
latest_plain_content = ""
|
||||||
|
image_base64 = []
|
||||||
|
finish = False
|
||||||
|
while not queue.empty():
|
||||||
|
msg = await queue.get()
|
||||||
|
if msg["type"] == "plain":
|
||||||
|
latest_plain_content = msg["data"] or ""
|
||||||
|
elif msg["type"] == "image":
|
||||||
|
image_base64.append(msg["image_data"])
|
||||||
|
elif msg["type"] == "end":
|
||||||
|
# stream end
|
||||||
|
finish = True
|
||||||
|
wecomai_queue_mgr.remove_queues(stream_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
logger.debug(
|
||||||
|
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}"
|
||||||
|
)
|
||||||
|
if latest_plain_content or image_base64:
|
||||||
|
msg_items = []
|
||||||
|
if finish and image_base64:
|
||||||
|
for img_b64 in image_base64:
|
||||||
|
# get md5 of image
|
||||||
|
img_data = base64.b64decode(img_b64)
|
||||||
|
img_md5 = hashlib.md5(img_data).hexdigest()
|
||||||
|
msg_items.append(
|
||||||
|
{
|
||||||
|
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
|
||||||
|
"image": {"base64": img_b64, "md5": img_md5},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
image_base64 = []
|
||||||
|
|
||||||
|
plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream(
|
||||||
|
stream_id, latest_plain_content, msg_items, finish
|
||||||
|
)
|
||||||
|
encrypted_message = await self.api_client.encrypt_message(
|
||||||
|
plain_message,
|
||||||
|
callback_params["nonce"],
|
||||||
|
callback_params["timestamp"],
|
||||||
|
)
|
||||||
|
if encrypted_message:
|
||||||
|
logger.debug(
|
||||||
|
f"Stream message sent successfully, stream_id: {stream_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error("消息加密失败")
|
||||||
|
return encrypted_message
|
||||||
|
return None
|
||||||
|
elif msgtype == "event":
|
||||||
|
event = message_data.get("event")
|
||||||
|
if event == "enter_chat" and self.friend_message_welcome_text:
|
||||||
|
# 用户进入会话,发送欢迎消息
|
||||||
|
try:
|
||||||
|
resp = WecomAIBotStreamMessageBuilder.make_text(
|
||||||
|
self.friend_message_welcome_text
|
||||||
|
)
|
||||||
|
return await self.api_client.encrypt_message(
|
||||||
|
resp,
|
||||||
|
callback_params["nonce"],
|
||||||
|
callback_params["timestamp"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("处理欢迎消息时发生异常: %s", e)
|
||||||
|
return None
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _extract_session_id(self, message_data: Dict[str, Any]) -> str:
|
||||||
|
"""从消息数据中提取会话ID"""
|
||||||
|
user_id = message_data.get("from", {}).get("userid", "default_user")
|
||||||
|
return format_session_id("wecomai", user_id)
|
||||||
|
|
||||||
|
async def _enqueue_message(
|
||||||
|
self,
|
||||||
|
message_data: Dict[str, Any],
|
||||||
|
callback_params: Dict[str, str],
|
||||||
|
stream_id: str,
|
||||||
|
session_id: str,
|
||||||
|
):
|
||||||
|
"""将消息放入队列进行异步处理"""
|
||||||
|
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
|
||||||
|
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||||
|
message_payload = {
|
||||||
|
"message_data": message_data,
|
||||||
|
"callback_params": callback_params,
|
||||||
|
"session_id": session_id,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
}
|
||||||
|
await input_queue.put(message_payload)
|
||||||
|
logger.debug(f"[WecomAI] 消息已入队: {stream_id}")
|
||||||
|
|
||||||
|
async def convert_message(self, payload: dict) -> AstrBotMessage:
|
||||||
|
"""转换队列中的消息数据为AstrBotMessage,类似webchat的convert_message"""
|
||||||
|
message_data = payload["message_data"]
|
||||||
|
session_id = payload["session_id"]
|
||||||
|
# callback_params = payload["callback_params"] # 保留但暂时不使用
|
||||||
|
|
||||||
|
# 解析消息内容
|
||||||
|
msgtype = message_data.get("msgtype")
|
||||||
|
content = ""
|
||||||
|
image_base64 = []
|
||||||
|
|
||||||
|
_img_url_to_process = []
|
||||||
|
msg_items = []
|
||||||
|
|
||||||
|
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||||
|
content = WecomAIBotMessageParser.parse_text_message(message_data)
|
||||||
|
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||||
|
_img_url_to_process.append(
|
||||||
|
WecomAIBotMessageParser.parse_image_message(message_data)
|
||||||
|
)
|
||||||
|
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
|
||||||
|
# 提取混合消息中的文本内容
|
||||||
|
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
|
||||||
|
text_parts = []
|
||||||
|
for item in msg_items or []:
|
||||||
|
if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||||
|
text_content = item.get("text", {}).get("content", "")
|
||||||
|
if text_content:
|
||||||
|
text_parts.append(text_content)
|
||||||
|
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||||
|
image_url = item.get("image", {}).get("url", "")
|
||||||
|
if image_url:
|
||||||
|
_img_url_to_process.append(image_url)
|
||||||
|
content = " ".join(text_parts) if text_parts else ""
|
||||||
|
else:
|
||||||
|
content = f"[{msgtype}消息]"
|
||||||
|
|
||||||
|
# 并行处理图片下载和解密
|
||||||
|
if _img_url_to_process:
|
||||||
|
tasks = [
|
||||||
|
process_encrypted_image(url, self.encoding_aes_key)
|
||||||
|
for url in _img_url_to_process
|
||||||
|
]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
for success, result in results:
|
||||||
|
if success:
|
||||||
|
image_base64.append(result)
|
||||||
|
else:
|
||||||
|
logger.error(f"处理加密图片失败: {result}")
|
||||||
|
|
||||||
|
# 构建 AstrBotMessage
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.self_id = self.bot_name
|
||||||
|
abm.message_str = content or "[未知消息]"
|
||||||
|
abm.message_id = str(uuid.uuid4())
|
||||||
|
abm.timestamp = int(time.time())
|
||||||
|
abm.raw_message = payload
|
||||||
|
|
||||||
|
# 发送者信息
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
user_id=message_data.get("from", {}).get("userid", "unknown"),
|
||||||
|
nickname=message_data.get("from", {}).get("userid", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 消息类型
|
||||||
|
abm.type = (
|
||||||
|
MessageType.GROUP_MESSAGE
|
||||||
|
if message_data.get("chattype") == "group"
|
||||||
|
else MessageType.FRIEND_MESSAGE
|
||||||
|
)
|
||||||
|
abm.session_id = session_id
|
||||||
|
|
||||||
|
# 消息内容
|
||||||
|
abm.message = []
|
||||||
|
|
||||||
|
# 处理 At
|
||||||
|
if self.bot_name and f"@{self.bot_name}" in abm.message_str:
|
||||||
|
abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip()
|
||||||
|
abm.message.append(At(qq=self.bot_name, name=self.bot_name))
|
||||||
|
abm.message.append(Plain(abm.message_str))
|
||||||
|
if image_base64:
|
||||||
|
for img_b64 in image_base64:
|
||||||
|
abm.message.append(Image.fromBase64(img_b64))
|
||||||
|
|
||||||
|
logger.debug(f"WecomAIAdapter: {abm.message}")
|
||||||
|
return abm
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSesion, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
"""通过会话发送消息"""
|
||||||
|
# 企业微信智能机器人主要通过回调响应,这里记录日志
|
||||||
|
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
def run(self) -> Awaitable[Any]:
|
||||||
|
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
||||||
|
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
|
||||||
|
|
||||||
|
async def run_both():
|
||||||
|
# 同时运行HTTP服务器和队列监听器
|
||||||
|
await asyncio.gather(
|
||||||
|
self.server.start_server(),
|
||||||
|
self.queue_listener.run(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return run_both()
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""终止适配器"""
|
||||||
|
logger.info("企业微信智能机器人适配器正在关闭...")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
await self.server.shutdown()
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
"""获取平台元数据"""
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
"""处理消息,创建消息事件并提交到事件队列"""
|
||||||
|
try:
|
||||||
|
message_event = WecomAIBotMessageEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
api_client=self.api_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("处理消息时发生异常: %s", e)
|
||||||
|
|
||||||
|
def get_client(self) -> WecomAIBotAPIClient:
|
||||||
|
"""获取 API 客户端"""
|
||||||
|
return self.api_client
|
||||||
|
|
||||||
|
def get_server(self) -> WecomAIBotServer:
|
||||||
|
"""获取 HTTP 服务器实例"""
|
||||||
|
return self.server
|
||||||
@@ -0,0 +1,378 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人 API 客户端
|
||||||
|
处理消息加密解密、API 调用等
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Dict, Any, Optional, Tuple, Union
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt
|
||||||
|
from .wecomai_utils import WecomAIBotConstants
|
||||||
|
from astrbot import logger
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIBotAPIClient:
|
||||||
|
"""企业微信智能机器人 API 客户端"""
|
||||||
|
|
||||||
|
def __init__(self, token: str, encoding_aes_key: str):
|
||||||
|
"""初始化 API 客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: 企业微信机器人 Token
|
||||||
|
encoding_aes_key: 消息加密密钥
|
||||||
|
"""
|
||||||
|
self.token = token
|
||||||
|
self.encoding_aes_key = encoding_aes_key
|
||||||
|
self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串
|
||||||
|
|
||||||
|
async def decrypt_message(
|
||||||
|
self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str
|
||||||
|
) -> Tuple[int, Optional[Dict[str, Any]]]:
|
||||||
|
"""解密企业微信消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encrypted_data: 加密的消息数据
|
||||||
|
msg_signature: 消息签名
|
||||||
|
timestamp: 时间戳
|
||||||
|
nonce: 随机数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(错误码, 解密后的消息数据字典)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ret, decrypted_msg = self.wxcpt.DecryptMsg(
|
||||||
|
encrypted_data, msg_signature, timestamp, nonce
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret != WecomAIBotConstants.SUCCESS:
|
||||||
|
logger.error(f"消息解密失败,错误码: {ret}")
|
||||||
|
return ret, None
|
||||||
|
|
||||||
|
# 解析 JSON
|
||||||
|
if decrypted_msg:
|
||||||
|
try:
|
||||||
|
message_data = json.loads(decrypted_msg)
|
||||||
|
logger.debug(f"解密成功,消息内容: {message_data}")
|
||||||
|
return WecomAIBotConstants.SUCCESS, message_data
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}")
|
||||||
|
return WecomAIBotConstants.PARSE_XML_ERROR, None
|
||||||
|
else:
|
||||||
|
logger.error("解密消息为空")
|
||||||
|
return WecomAIBotConstants.DECRYPT_ERROR, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解密过程发生异常: {e}")
|
||||||
|
return WecomAIBotConstants.DECRYPT_ERROR, None
|
||||||
|
|
||||||
|
async def encrypt_message(
|
||||||
|
self, plain_message: str, nonce: str, timestamp: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""加密消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plain_message: 明文消息
|
||||||
|
nonce: 随机数
|
||||||
|
timestamp: 时间戳
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加密后的消息,失败时返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp)
|
||||||
|
|
||||||
|
if ret != WecomAIBotConstants.SUCCESS:
|
||||||
|
logger.error(f"消息加密失败,错误码: {ret}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug("消息加密成功")
|
||||||
|
return encrypted_msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加密过程发生异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def verify_url(
|
||||||
|
self, msg_signature: str, timestamp: str, nonce: str, echostr: str
|
||||||
|
) -> str:
|
||||||
|
"""验证回调 URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_signature: 消息签名
|
||||||
|
timestamp: 时间戳
|
||||||
|
nonce: 随机数
|
||||||
|
echostr: 验证字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证结果字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ret, echo_result = self.wxcpt.VerifyURL(
|
||||||
|
msg_signature, timestamp, nonce, echostr
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret != WecomAIBotConstants.SUCCESS:
|
||||||
|
logger.error(f"URL 验证失败,错误码: {ret}")
|
||||||
|
return "verify fail"
|
||||||
|
|
||||||
|
logger.info("URL 验证成功")
|
||||||
|
return echo_result if echo_result else "verify fail"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"URL 验证发生异常: {e}")
|
||||||
|
return "verify fail"
|
||||||
|
|
||||||
|
async def process_encrypted_image(
|
||||||
|
self, image_url: str, aes_key_base64: Optional[str] = None
|
||||||
|
) -> Tuple[bool, Union[bytes, str]]:
|
||||||
|
"""下载并解密加密图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url: 加密图片的 URL
|
||||||
|
aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否成功, 图片数据或错误信息)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 下载图片
|
||||||
|
logger.info(f"开始下载加密图片: {image_url}")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(image_url, timeout=15) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_msg = f"图片下载失败,状态码: {response.status}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
encrypted_data = await response.read()
|
||||||
|
logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节")
|
||||||
|
|
||||||
|
# 准备解密密钥
|
||||||
|
if aes_key_base64 is None:
|
||||||
|
aes_key_base64 = self.encoding_aes_key
|
||||||
|
|
||||||
|
if not aes_key_base64:
|
||||||
|
raise ValueError("AES 密钥不能为空")
|
||||||
|
|
||||||
|
# Base64 解码密钥
|
||||||
|
aes_key = base64.b64decode(
|
||||||
|
aes_key_base64 + "=" * (-len(aes_key_base64) % 4)
|
||||||
|
)
|
||||||
|
if len(aes_key) != 32:
|
||||||
|
raise ValueError("无效的 AES 密钥长度: 应为 32 字节")
|
||||||
|
|
||||||
|
iv = aes_key[:16] # 初始向量为密钥前 16 字节
|
||||||
|
|
||||||
|
# 解密图片数据
|
||||||
|
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||||
|
decrypted_data = cipher.decrypt(encrypted_data)
|
||||||
|
|
||||||
|
# 去除 PKCS#7 填充
|
||||||
|
pad_len = decrypted_data[-1]
|
||||||
|
if pad_len > 32: # AES-256 块大小为 32 字节
|
||||||
|
raise ValueError("无效的填充长度 (大于32字节)")
|
||||||
|
|
||||||
|
decrypted_data = decrypted_data[:-pad_len]
|
||||||
|
logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节")
|
||||||
|
|
||||||
|
return True, decrypted_data
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
error_msg = f"图片下载失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
error_msg = f"参数错误: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"图片处理异常: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIBotStreamMessageBuilder:
|
||||||
|
"""企业微信智能机器人流消息构建器"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str:
|
||||||
|
"""构建文本流消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流 ID
|
||||||
|
content: 文本内容
|
||||||
|
finish: 是否结束
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON 格式的流消息字符串
|
||||||
|
"""
|
||||||
|
plain = {
|
||||||
|
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||||
|
"stream": {"id": stream_id, "finish": finish, "content": content},
|
||||||
|
}
|
||||||
|
return json.dumps(plain, ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_image_stream(
|
||||||
|
stream_id: str, image_data: bytes, finish: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""构建图片流消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流 ID
|
||||||
|
image_data: 图片二进制数据
|
||||||
|
finish: 是否结束
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON 格式的流消息字符串
|
||||||
|
"""
|
||||||
|
image_md5 = hashlib.md5(image_data).hexdigest()
|
||||||
|
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
plain = {
|
||||||
|
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||||
|
"stream": {
|
||||||
|
"id": stream_id,
|
||||||
|
"finish": finish,
|
||||||
|
"msg_item": [
|
||||||
|
{
|
||||||
|
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
|
||||||
|
"image": {"base64": image_base64, "md5": image_md5},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return json.dumps(plain, ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_mixed_stream(
|
||||||
|
stream_id: str, content: str, msg_items: list, finish: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""构建混合类型流消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流 ID
|
||||||
|
content: 文本内容
|
||||||
|
msg_items: 消息项列表
|
||||||
|
finish: 是否结束
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON 格式的流消息字符串
|
||||||
|
"""
|
||||||
|
plain = {
|
||||||
|
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||||
|
"stream": {"id": stream_id, "finish": finish, "msg_item": msg_items},
|
||||||
|
}
|
||||||
|
if content:
|
||||||
|
plain["stream"]["content"] = content
|
||||||
|
return json.dumps(plain, ensure_ascii=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_text(content: str) -> str:
|
||||||
|
"""构建文本消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON 格式的文本消息字符串
|
||||||
|
"""
|
||||||
|
plain = {"msgtype": "text", "text": {"content": content}}
|
||||||
|
return json.dumps(plain, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIBotMessageParser:
|
||||||
|
"""企业微信智能机器人消息解析器"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_text_message(data: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""解析文本消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文本内容,解析失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return data.get("text", {}).get("content")
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
logger.warning("文本消息解析失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_image_message(data: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""解析图片消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
图片 URL,解析失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return data.get("image", {}).get("url")
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
logger.warning("图片消息解析失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""解析流消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流消息数据,解析失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stream_data = data.get("stream", {})
|
||||||
|
return {
|
||||||
|
"id": stream_data.get("id"),
|
||||||
|
"finish": stream_data.get("finish"),
|
||||||
|
"content": stream_data.get("content"),
|
||||||
|
"msg_item": stream_data.get("msg_item", []),
|
||||||
|
}
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
logger.warning("流消息解析失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]:
|
||||||
|
"""解析混合消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消息项列表,解析失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return data.get("mixed", {}).get("msg_item", [])
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
logger.warning("混合消息解析失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||||
|
"""解析事件消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
事件数据,解析失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return data.get("event", {})
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
logger.warning("事件消息解析失败")
|
||||||
|
return None
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人事件处理模块,处理消息事件的发送和接收
|
||||||
|
"""
|
||||||
|
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.message_components import (
|
||||||
|
Image,
|
||||||
|
Plain,
|
||||||
|
)
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
from .wecomai_api import WecomAIBotAPIClient
|
||||||
|
from .wecomai_queue_mgr import wecomai_queue_mgr
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||||
|
"""企业微信智能机器人消息事件"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj,
|
||||||
|
platform_meta,
|
||||||
|
session_id: str,
|
||||||
|
api_client: WecomAIBotAPIClient,
|
||||||
|
):
|
||||||
|
"""初始化消息事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_str: 消息字符串
|
||||||
|
message_obj: 消息对象
|
||||||
|
platform_meta: 平台元数据
|
||||||
|
session_id: 会话 ID
|
||||||
|
api_client: API 客户端
|
||||||
|
"""
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.api_client = api_client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send(
|
||||||
|
message_chain: MessageChain,
|
||||||
|
stream_id: str,
|
||||||
|
streaming: bool = False,
|
||||||
|
):
|
||||||
|
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||||
|
|
||||||
|
if not message_chain:
|
||||||
|
await back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "end",
|
||||||
|
"data": "",
|
||||||
|
"streaming": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
data = ""
|
||||||
|
for comp in message_chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
data = comp.text
|
||||||
|
await back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "plain",
|
||||||
|
"data": data,
|
||||||
|
"streaming": streaming,
|
||||||
|
"session_id": stream_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(comp, Image):
|
||||||
|
# 处理图片消息
|
||||||
|
try:
|
||||||
|
image_base64 = await comp.convert_to_base64()
|
||||||
|
if image_base64:
|
||||||
|
await back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_data": image_base64,
|
||||||
|
"streaming": streaming,
|
||||||
|
"session_id": stream_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("图片数据为空,跳过")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("处理图片消息失败: %s", e)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
"""发送消息"""
|
||||||
|
raw = self.message_obj.raw_message
|
||||||
|
assert isinstance(raw, dict), (
|
||||||
|
"wecom_ai_bot platform event raw_message should be a dict"
|
||||||
|
)
|
||||||
|
stream_id = raw.get("stream_id", self.session_id)
|
||||||
|
await WecomAIBotMessageEvent._send(message, stream_id)
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback=False):
|
||||||
|
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||||
|
final_data = ""
|
||||||
|
raw = self.message_obj.raw_message
|
||||||
|
assert isinstance(raw, dict), (
|
||||||
|
"wecom_ai_bot platform event raw_message should be a dict"
|
||||||
|
)
|
||||||
|
stream_id = raw.get("stream_id", self.session_id)
|
||||||
|
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||||
|
|
||||||
|
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
|
||||||
|
increment_plain = ""
|
||||||
|
async for chain in generator:
|
||||||
|
# 累积增量内容,并改写 Plain 段
|
||||||
|
chain.squash_plain()
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
comp.text = increment_plain + comp.text
|
||||||
|
increment_plain = comp.text
|
||||||
|
break
|
||||||
|
|
||||||
|
if chain.type == "break" and final_data:
|
||||||
|
# 分割符
|
||||||
|
await back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "break", # break means a segment end
|
||||||
|
"data": final_data,
|
||||||
|
"streaming": True,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
final_data = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
final_data += await WecomAIBotMessageEvent._send(
|
||||||
|
chain,
|
||||||
|
stream_id=stream_id,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "complete", # complete means we return the final result
|
||||||
|
"data": final_data,
|
||||||
|
"streaming": True,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await super().send_streaming(generator, use_fallback)
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人队列管理器
|
||||||
|
参考 webchat_queue_mgr.py,为企业微信智能机器人实现队列机制
|
||||||
|
支持异步消息处理和流式响应
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIQueueMgr:
|
||||||
|
"""企业微信智能机器人队列管理器"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
|
||||||
|
|
||||||
|
self.back_queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
"""StreamID 到输出队列的映射 - 用于发送机器人响应"""
|
||||||
|
|
||||||
|
self.pending_responses: Dict[str, Dict[str, Any]] = {}
|
||||||
|
"""待处理的响应缓存,用于流式响应"""
|
||||||
|
|
||||||
|
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
|
||||||
|
"""获取或创建指定会话的输入队列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输入队列实例
|
||||||
|
"""
|
||||||
|
if session_id not in self.queues:
|
||||||
|
self.queues[session_id] = asyncio.Queue()
|
||||||
|
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
|
||||||
|
return self.queues[session_id]
|
||||||
|
|
||||||
|
def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue:
|
||||||
|
"""获取或创建指定会话的输出队列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输出队列实例
|
||||||
|
"""
|
||||||
|
if session_id not in self.back_queues:
|
||||||
|
self.back_queues[session_id] = asyncio.Queue()
|
||||||
|
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
|
||||||
|
return self.back_queues[session_id]
|
||||||
|
|
||||||
|
def remove_queues(self, session_id: str):
|
||||||
|
"""移除指定会话的所有队列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
"""
|
||||||
|
if session_id in self.queues:
|
||||||
|
del self.queues[session_id]
|
||||||
|
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||||
|
|
||||||
|
if session_id in self.back_queues:
|
||||||
|
del self.back_queues[session_id]
|
||||||
|
logger.debug(f"[WecomAI] 移除输出队列: {session_id}")
|
||||||
|
|
||||||
|
if session_id in self.pending_responses:
|
||||||
|
del self.pending_responses[session_id]
|
||||||
|
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||||
|
|
||||||
|
def has_queue(self, session_id: str) -> bool:
|
||||||
|
"""检查是否存在指定会话的队列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否存在队列
|
||||||
|
"""
|
||||||
|
return session_id in self.queues
|
||||||
|
|
||||||
|
def has_back_queue(self, session_id: str) -> bool:
|
||||||
|
"""检查是否存在指定会话的输出队列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否存在输出队列
|
||||||
|
"""
|
||||||
|
return session_id in self.back_queues
|
||||||
|
|
||||||
|
def set_pending_response(self, session_id: str, callback_params: Dict[str, str]):
|
||||||
|
"""设置待处理的响应参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
callback_params: 回调参数(nonce, timestamp等)
|
||||||
|
"""
|
||||||
|
self.pending_responses[session_id] = {
|
||||||
|
"callback_params": callback_params,
|
||||||
|
"timestamp": asyncio.get_event_loop().time(),
|
||||||
|
}
|
||||||
|
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||||
|
|
||||||
|
def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取待处理的响应参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应参数,如果不存在则返回None
|
||||||
|
"""
|
||||||
|
return self.pending_responses.get(session_id)
|
||||||
|
|
||||||
|
def cleanup_expired_responses(self, max_age_seconds: int = 300):
|
||||||
|
"""清理过期的待处理响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_seconds: 最大存活时间(秒)
|
||||||
|
"""
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
expired_sessions = []
|
||||||
|
|
||||||
|
for session_id, response_data in self.pending_responses.items():
|
||||||
|
if current_time - response_data["timestamp"] > max_age_seconds:
|
||||||
|
expired_sessions.append(session_id)
|
||||||
|
|
||||||
|
for session_id in expired_sessions:
|
||||||
|
del self.pending_responses[session_id]
|
||||||
|
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, int]:
|
||||||
|
"""获取队列统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"input_queues": len(self.queues),
|
||||||
|
"output_queues": len(self.back_queues),
|
||||||
|
"pending_responses": len(self.pending_responses),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局队列管理器实例
|
||||||
|
wecomai_queue_mgr = WecomAIQueueMgr()
|
||||||
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人 HTTP 服务器
|
||||||
|
处理企业微信智能机器人的 HTTP 回调请求
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, Optional, Callable
|
||||||
|
|
||||||
|
import quart
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
from .wecomai_api import WecomAIBotAPIClient
|
||||||
|
from .wecomai_utils import WecomAIBotConstants
|
||||||
|
|
||||||
|
|
||||||
|
class WecomAIBotServer:
|
||||||
|
"""企业微信智能机器人 HTTP 服务器"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
api_client: WecomAIBotAPIClient,
|
||||||
|
message_handler: Optional[
|
||||||
|
Callable[[Dict[str, Any], Dict[str, str]], Any]
|
||||||
|
] = None,
|
||||||
|
):
|
||||||
|
"""初始化服务器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
host: 监听地址
|
||||||
|
port: 监听端口
|
||||||
|
api_client: API客户端实例
|
||||||
|
message_handler: 消息处理回调函数
|
||||||
|
"""
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.api_client = api_client
|
||||||
|
self.message_handler = message_handler
|
||||||
|
|
||||||
|
self.app = quart.Quart(__name__)
|
||||||
|
self._setup_routes()
|
||||||
|
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
"""设置 Quart 路由"""
|
||||||
|
|
||||||
|
# 使用 Quart 的 add_url_rule 方法添加路由
|
||||||
|
self.app.add_url_rule(
|
||||||
|
"/webhook/wecom-ai-bot",
|
||||||
|
view_func=self.verify_url,
|
||||||
|
methods=["GET"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.app.add_url_rule(
|
||||||
|
"/webhook/wecom-ai-bot",
|
||||||
|
view_func=self.handle_message,
|
||||||
|
methods=["POST"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def verify_url(self):
|
||||||
|
"""验证回调 URL"""
|
||||||
|
args = quart.request.args
|
||||||
|
msg_signature = args.get("msg_signature")
|
||||||
|
timestamp = args.get("timestamp")
|
||||||
|
nonce = args.get("nonce")
|
||||||
|
echostr = args.get("echostr")
|
||||||
|
|
||||||
|
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||||
|
logger.error("URL 验证参数缺失")
|
||||||
|
return "verify fail", 400
|
||||||
|
|
||||||
|
# 类型检查确保不为 None
|
||||||
|
assert msg_signature is not None
|
||||||
|
assert timestamp is not None
|
||||||
|
assert nonce is not None
|
||||||
|
assert echostr is not None
|
||||||
|
|
||||||
|
logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
|
||||||
|
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
|
||||||
|
return result, 200, {"Content-Type": "text/plain"}
|
||||||
|
|
||||||
|
async def handle_message(self):
|
||||||
|
"""处理消息回调"""
|
||||||
|
args = quart.request.args
|
||||||
|
msg_signature = args.get("msg_signature")
|
||||||
|
timestamp = args.get("timestamp")
|
||||||
|
nonce = args.get("nonce")
|
||||||
|
|
||||||
|
if not all([msg_signature, timestamp, nonce]):
|
||||||
|
logger.error("消息回调参数缺失")
|
||||||
|
return "缺少必要参数", 400
|
||||||
|
|
||||||
|
# 类型检查确保不为 None
|
||||||
|
assert msg_signature is not None
|
||||||
|
assert timestamp is not None
|
||||||
|
assert nonce is not None
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取请求体
|
||||||
|
post_data = await quart.request.get_data()
|
||||||
|
|
||||||
|
# 确保 post_data 是 bytes 类型
|
||||||
|
if isinstance(post_data, str):
|
||||||
|
post_data = post_data.encode("utf-8")
|
||||||
|
|
||||||
|
# 解密消息
|
||||||
|
ret_code, message_data = await self.api_client.decrypt_message(
|
||||||
|
post_data, msg_signature, timestamp, nonce
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
|
||||||
|
logger.error("消息解密失败,错误码: %d", ret_code)
|
||||||
|
return "消息解密失败", 400
|
||||||
|
|
||||||
|
# 调用消息处理器
|
||||||
|
response = None
|
||||||
|
if self.message_handler:
|
||||||
|
try:
|
||||||
|
response = await self.message_handler(
|
||||||
|
message_data, {"nonce": nonce, "timestamp": timestamp}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("消息处理器执行异常: %s", e)
|
||||||
|
return "消息处理异常", 500
|
||||||
|
|
||||||
|
if response:
|
||||||
|
return response, 200, {"Content-Type": "text/plain"}
|
||||||
|
else:
|
||||||
|
return "success", 200, {"Content-Type": "text/plain"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("处理消息时发生异常: %s", e)
|
||||||
|
return "内部服务器错误", 500
|
||||||
|
|
||||||
|
async def start_server(self):
|
||||||
|
"""启动服务器"""
|
||||||
|
logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.app.run_task(
|
||||||
|
host=self.host,
|
||||||
|
port=self.port,
|
||||||
|
shutdown_trigger=self.shutdown_trigger,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("服务器运行异常: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def shutdown_trigger(self):
|
||||||
|
"""关闭触发器"""
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""关闭服务器"""
|
||||||
|
logger.info("企业微信智能机器人服务器正在关闭...")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
def get_app(self):
|
||||||
|
"""获取 Quart 应用实例"""
|
||||||
|
return self.app
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
企业微信智能机器人工具模块
|
||||||
|
提供常量定义、工具函数和辅助方法
|
||||||
|
"""
|
||||||
|
|
||||||
|
import string
|
||||||
|
import random
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from Crypto.Cipher import AES
|
||||||
|
from typing import Any, Tuple
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
|
||||||
|
# 常量定义
|
||||||
|
class WecomAIBotConstants:
|
||||||
|
"""企业微信智能机器人常量"""
|
||||||
|
|
||||||
|
# 消息类型
|
||||||
|
MSG_TYPE_TEXT = "text"
|
||||||
|
MSG_TYPE_IMAGE = "image"
|
||||||
|
MSG_TYPE_MIXED = "mixed"
|
||||||
|
MSG_TYPE_STREAM = "stream"
|
||||||
|
MSG_TYPE_EVENT = "event"
|
||||||
|
|
||||||
|
# 流消息状态
|
||||||
|
STREAM_CONTINUE = False
|
||||||
|
STREAM_FINISH = True
|
||||||
|
|
||||||
|
# 错误码
|
||||||
|
SUCCESS = 0
|
||||||
|
DECRYPT_ERROR = -40001
|
||||||
|
VALIDATE_SIGNATURE_ERROR = -40002
|
||||||
|
PARSE_XML_ERROR = -40003
|
||||||
|
COMPUTE_SIGNATURE_ERROR = -40004
|
||||||
|
ILLEGAL_AES_KEY = -40005
|
||||||
|
VALIDATE_APPID_ERROR = -40006
|
||||||
|
ENCRYPT_AES_ERROR = -40007
|
||||||
|
ILLEGAL_BUFFER = -40008
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_string(length: int = 10) -> str:
|
||||||
|
"""生成随机字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: 字符串长度,默认为 10
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
随机字符串
|
||||||
|
"""
|
||||||
|
letters = string.ascii_letters + string.digits
|
||||||
|
return "".join(random.choice(letters) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_image_md5(image_data: bytes) -> str:
|
||||||
|
"""计算图片数据的 MD5 值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: 图片二进制数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MD5 哈希值(十六进制字符串)
|
||||||
|
"""
|
||||||
|
return hashlib.md5(image_data).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_base64(image_data: bytes) -> str:
|
||||||
|
"""将图片数据编码为 Base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: 图片二进制数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 编码的字符串
|
||||||
|
"""
|
||||||
|
return base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def format_session_id(session_type: str, session_id: str) -> str:
|
||||||
|
"""格式化会话 ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_type: 会话类型 ("user", "group")
|
||||||
|
session_id: 原始会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的会话 ID
|
||||||
|
"""
|
||||||
|
return f"wecom_ai_bot_{session_type}_{session_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
|
||||||
|
"""解析格式化的会话 ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
formatted_session_id: 格式化的会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(会话类型, 原始会话ID)
|
||||||
|
"""
|
||||||
|
parts = formatted_session_id.split("_", 3)
|
||||||
|
if (
|
||||||
|
len(parts) >= 4
|
||||||
|
and parts[0] == "wecom"
|
||||||
|
and parts[1] == "ai"
|
||||||
|
and parts[2] == "bot"
|
||||||
|
):
|
||||||
|
return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
|
||||||
|
return "user", formatted_session_id
|
||||||
|
|
||||||
|
|
||||||
|
def safe_json_loads(json_str: str, default: Any = None) -> Any:
|
||||||
|
"""安全地解析 JSON 字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str: JSON 字符串
|
||||||
|
default: 解析失败时的默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析结果或默认值
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(json_str)
|
||||||
|
except (json.JSONDecodeError, TypeError) as e:
|
||||||
|
logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def format_error_response(error_code: int, error_msg: str) -> str:
|
||||||
|
"""格式化错误响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_code: 错误码
|
||||||
|
error_msg: 错误信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的错误响应字符串
|
||||||
|
"""
|
||||||
|
return f"Error {error_code}: {error_msg}"
|
||||||
|
|
||||||
|
|
||||||
|
async def process_encrypted_image(
|
||||||
|
image_url: str, aes_key_base64: str
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""下载并解密加密图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url: 加密图片的URL
|
||||||
|
aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
|
||||||
|
status 为 False 时 data 是错误信息
|
||||||
|
"""
|
||||||
|
# 1. 下载加密图片
|
||||||
|
logger.info("开始下载加密图片: %s", image_url)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(image_url, timeout=15) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
encrypted_data = await response.read()
|
||||||
|
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||||
|
error_msg = f"下载图片失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
# 2. 准备AES密钥和IV
|
||||||
|
if not aes_key_base64:
|
||||||
|
raise ValueError("AES密钥不能为空")
|
||||||
|
|
||||||
|
# Base64解码密钥 (自动处理填充)
|
||||||
|
aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
|
||||||
|
if len(aes_key) != 32:
|
||||||
|
raise ValueError("无效的AES密钥长度: 应为32字节")
|
||||||
|
|
||||||
|
iv = aes_key[:16] # 初始向量为密钥前16字节
|
||||||
|
|
||||||
|
# 3. 解密图片数据
|
||||||
|
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||||
|
decrypted_data = cipher.decrypt(encrypted_data)
|
||||||
|
|
||||||
|
# 4. 去除PKCS#7填充 (Python 3兼容写法)
|
||||||
|
pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
|
||||||
|
if pad_len > 32: # AES-256块大小为32字节
|
||||||
|
raise ValueError("无效的填充长度 (大于32字节)")
|
||||||
|
|
||||||
|
decrypted_data = decrypted_data[:-pad_len]
|
||||||
|
logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
|
||||||
|
|
||||||
|
# 5. 转换为base64编码
|
||||||
|
base64_data = base64.b64encode(decrypted_data).decode("utf-8")
|
||||||
|
logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data))
|
||||||
|
|
||||||
|
return True, base64_data
|
||||||
@@ -68,7 +68,8 @@ class Provider(AbstractProvider):
|
|||||||
|
|
||||||
def get_keys(self) -> List[str]:
|
def get_keys(self) -> List[str]:
|
||||||
"""获得提供商 Key"""
|
"""获得提供商 Key"""
|
||||||
return self.provider_config.get("key", [])
|
keys = self.provider_config.get("key", [""])
|
||||||
|
return keys or [""]
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def set_key(self, key: str):
|
def set_key(self, key: str):
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.chosen_api_key: str = ""
|
self.chosen_api_key: str = ""
|
||||||
self.api_keys: List = provider_config.get("key", [])
|
self.api_keys: List = super().get_keys()
|
||||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||||
self.timeout = provider_config.get("timeout", 120)
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
@@ -70,9 +70,13 @@ class ProviderAnthropic(Provider):
|
|||||||
{
|
{
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
"name": tool_call["function"]["name"],
|
"name": tool_call["function"]["name"],
|
||||||
"input": json.loads(tool_call["function"]["arguments"])
|
"input": (
|
||||||
if isinstance(tool_call["function"]["arguments"], str)
|
json.loads(tool_call["function"]["arguments"])
|
||||||
else tool_call["function"]["arguments"],
|
if isinstance(
|
||||||
|
tool_call["function"]["arguments"], str
|
||||||
|
)
|
||||||
|
else tool_call["function"]["arguments"]
|
||||||
|
),
|
||||||
"id": tool_call["id"],
|
"id": tool_call["id"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -355,9 +359,11 @@ class ProviderAnthropic(Provider):
|
|||||||
"source": {
|
"source": {
|
||||||
"type": "base64",
|
"type": "base64",
|
||||||
"media_type": mime_type,
|
"media_type": mime_type,
|
||||||
"data": image_data.split("base64,")[1]
|
"data": (
|
||||||
if "base64," in image_data
|
image_data.split("base64,")[1]
|
||||||
else image_data,
|
if "base64," in image_data
|
||||||
|
else image_data
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,22 @@
|
|||||||
import os
|
|
||||||
import dashscope
|
|
||||||
import uuid
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dashscope.audio.tts_v2 import *
|
import base64
|
||||||
from ..provider import TTSProvider
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import aiohttp
|
||||||
|
import dashscope
|
||||||
|
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dashscope.aigc.multimodal_conversation import MultiModalConversation
|
||||||
|
except (
|
||||||
|
ImportError
|
||||||
|
): # pragma: no cover - older dashscope versions without Qwen TTS support
|
||||||
|
MultiModalConversation = None
|
||||||
|
|
||||||
from ..entities import ProviderType
|
from ..entities import ProviderType
|
||||||
|
from ..provider import TTSProvider
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
@@ -26,16 +38,112 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
dashscope.api_key = self.chosen_api_key
|
dashscope.api_key = self.chosen_api_key
|
||||||
|
|
||||||
async def get_audio(self, text: str) -> str:
|
async def get_audio(self, text: str) -> str:
|
||||||
|
model = self.get_model()
|
||||||
|
if not model:
|
||||||
|
raise RuntimeError("Dashscope TTS model is not configured.")
|
||||||
|
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
self.synthesizer = SpeechSynthesizer(
|
|
||||||
model=self.get_model(),
|
if self._is_qwen_tts_model(model):
|
||||||
|
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
|
||||||
|
else:
|
||||||
|
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
|
||||||
|
|
||||||
|
if not audio_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
|
||||||
|
)
|
||||||
|
|
||||||
|
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(audio_bytes)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def _call_qwen_tts(self, model: str, text: str):
|
||||||
|
if MultiModalConversation is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"text": text,
|
||||||
|
"api_key": self.chosen_api_key,
|
||||||
|
"voice": self.voice or "Cherry",
|
||||||
|
}
|
||||||
|
if not self.voice:
|
||||||
|
logging.warning(
|
||||||
|
"No voice specified for Qwen TTS model, using default 'Cherry'."
|
||||||
|
)
|
||||||
|
return MultiModalConversation.call(**kwargs)
|
||||||
|
|
||||||
|
async def _synthesize_with_qwen_tts(
|
||||||
|
self, model: str, text: str
|
||||||
|
) -> Tuple[Optional[bytes], str]:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||||
|
audio_bytes = await self._extract_audio_from_response(response)
|
||||||
|
if not audio_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Audio synthesis failed for model '{model}'. {response}"
|
||||||
|
)
|
||||||
|
ext = ".wav"
|
||||||
|
return audio_bytes, ext
|
||||||
|
|
||||||
|
async def _extract_audio_from_response(self, response) -> Optional[bytes]:
|
||||||
|
output = getattr(response, "output", None)
|
||||||
|
audio_obj = getattr(output, "audio", None) if output is not None else None
|
||||||
|
if not audio_obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data_b64 = getattr(audio_obj, "data", None)
|
||||||
|
if data_b64:
|
||||||
|
try:
|
||||||
|
return base64.b64decode(data_b64)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logging.error("Failed to decode base64 audio data.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
url = getattr(audio_obj, "url", None)
|
||||||
|
if url:
|
||||||
|
return await self._download_audio_from_url(url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
|
||||||
|
if not url:
|
||||||
|
return None
|
||||||
|
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||||
|
) as response:
|
||||||
|
return await response.read()
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
||||||
|
logging.error(f"Failed to download audio from URL {url}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _synthesize_with_cosyvoice(
|
||||||
|
self, model: str, text: str
|
||||||
|
) -> Tuple[Optional[bytes], str]:
|
||||||
|
synthesizer = SpeechSynthesizer(
|
||||||
|
model=model,
|
||||||
voice=self.voice,
|
voice=self.voice,
|
||||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||||
)
|
)
|
||||||
audio = await asyncio.get_event_loop().run_in_executor(
|
loop = asyncio.get_event_loop()
|
||||||
None, self.synthesizer.call, text, self.timeout_ms
|
audio_bytes = await loop.run_in_executor(
|
||||||
|
None, synthesizer.call, text, self.timeout_ms
|
||||||
)
|
)
|
||||||
with open(path, "wb") as f:
|
if not audio_bytes:
|
||||||
f.write(audio)
|
resp = synthesizer.get_response()
|
||||||
return path
|
if resp and isinstance(resp, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Audio synthesis failed for model '{model}'. {resp}".strip()
|
||||||
|
)
|
||||||
|
return audio_bytes, ".wav"
|
||||||
|
|
||||||
|
def _is_qwen_tts_model(self, model: str) -> bool:
|
||||||
|
model_lower = model.lower()
|
||||||
|
return "tts" in model_lower and model_lower.startswith("qwen")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
@@ -60,7 +60,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
provider_settings,
|
provider_settings,
|
||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.api_keys: list = provider_config.get("key", [])
|
self.api_keys: List = super().get_keys()
|
||||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||||
|
|
||||||
@@ -218,19 +218,21 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
response_modalities=modalities,
|
response_modalities=modalities,
|
||||||
tools=tool_list,
|
tools=tool_list,
|
||||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||||
thinking_config=types.ThinkingConfig(
|
thinking_config=(
|
||||||
thinking_budget=min(
|
types.ThinkingConfig(
|
||||||
int(
|
thinking_budget=min(
|
||||||
self.provider_config.get("gm_thinking_config", {}).get(
|
int(
|
||||||
"budget", 0
|
self.provider_config.get("gm_thinking_config", {}).get(
|
||||||
)
|
"budget", 0
|
||||||
|
)
|
||||||
|
),
|
||||||
|
24576,
|
||||||
),
|
),
|
||||||
24576,
|
)
|
||||||
),
|
if "gemini-2.5-flash" in self.get_model()
|
||||||
)
|
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||||
if "gemini-2.5-flash" in self.get_model()
|
else None
|
||||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
),
|
||||||
else None,
|
|
||||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||||
disable=True
|
disable=True
|
||||||
),
|
),
|
||||||
@@ -274,9 +276,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
if role == "user":
|
if role == "user":
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
parts = [
|
parts = [
|
||||||
types.Part.from_text(text=item["text"] or " ")
|
(
|
||||||
if item["type"] == "text"
|
types.Part.from_text(text=item["text"] or " ")
|
||||||
else process_image_url(item["image_url"])
|
if item["type"] == "text"
|
||||||
|
else process_image_url(item["image_url"])
|
||||||
|
)
|
||||||
for item in content
|
for item in content
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
default_persona,
|
default_persona,
|
||||||
)
|
)
|
||||||
self.chosen_api_key = None
|
self.chosen_api_key = None
|
||||||
self.api_keys: List = provider_config.get("key", [])
|
self.api_keys: List = super().get_keys()
|
||||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||||
self.timeout = provider_config.get("timeout", 120)
|
self.timeout = provider_config.get("timeout", 120)
|
||||||
if isinstance(self.timeout, str):
|
if isinstance(self.timeout, str):
|
||||||
|
|||||||
@@ -65,12 +65,12 @@ class SessionManagementRoute(Route):
|
|||||||
persona_name = data["persona_name"]
|
persona_name = data["persona_name"]
|
||||||
|
|
||||||
# 处理 persona 显示
|
# 处理 persona 显示
|
||||||
if conv_persona_id == "[%None]":
|
if persona_name is None:
|
||||||
persona_name = "无人格"
|
if conv_persona_id is None:
|
||||||
else:
|
if default_persona := persona_mgr.selected_default_persona_v3:
|
||||||
default_persona = persona_mgr.selected_default_persona_v3
|
persona_name = default_persona["name"]
|
||||||
if default_persona:
|
else:
|
||||||
persona_name = default_persona["name"]
|
persona_name = "[%None]"
|
||||||
|
|
||||||
session_info = {
|
session_info = {
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
|
|||||||
@@ -273,6 +273,20 @@ class ToolsRoute(Route):
|
|||||||
server_data = await request.json
|
server_data = await request.json
|
||||||
config = server_data.get("mcp_server_config", None)
|
config = server_data.get("mcp_server_config", None)
|
||||||
|
|
||||||
|
if not isinstance(config, dict) or not config:
|
||||||
|
return Response().error("无效的 MCP 服务器配置").__dict__
|
||||||
|
|
||||||
|
if "mcpServers" in config:
|
||||||
|
keys = list(config["mcpServers"].keys())
|
||||||
|
if not keys:
|
||||||
|
return Response().error("MCP 服务器配置不能为空").__dict__
|
||||||
|
if len(keys) > 1:
|
||||||
|
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
|
||||||
|
config = config["mcpServers"][keys[0]]
|
||||||
|
else:
|
||||||
|
if not config:
|
||||||
|
return Response().error("MCP 服务器配置不能为空").__dict__
|
||||||
|
|
||||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||||
return (
|
return (
|
||||||
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
1. fix: 修复了代码执行器插件不能正确获得发送来文件的问题 ([#2970](https://github.com/Soulter/AstrBot/issues/2970))
|
||||||
|
2. fix: 修改的 DeepSeek 默认 modalities,避免默认勾选图像导致的报错。 ([#2963](https://github.com/Soulter/AstrBot/issues/2963))
|
||||||
|
3. fix: 事件钩子终止事件传播后不继续执行 ([#2989](https://github.com/Soulter/AstrBot/issues/2989))
|
||||||
|
4. fix: 启动了 TTS 但未配置 TTS 模型时,At 和 Reply 发送人无效
|
||||||
|
5. fix: 修复 session-management 中人格错误的显示为默认人格的问题 ([#3000](https://github.com/Soulter/AstrBot/issues/3000))
|
||||||
|
6. fix: 修复了删除对话时,聊天增强中的记录未被清除,导致新对话中仍然出现之前的聊天记录。 ([#3002](https://github.com/Soulter/AstrBot/issues/3002))
|
||||||
|
7. fix: 修复阿里云百炼平台 TTS 下接入 CosyVoice V2, Qwen TTS 生成报错的问题 ([#2964](https://github.com/Soulter/AstrBot/issues/2964))
|
||||||
|
8. perf: 优化 SQLite 参数配置,对话和会话管理增加输入防抖机制 ([#2969](https://github.com/Soulter/AstrBot/issues/2969))
|
||||||
|
9. feat: 在新对话中重用先前的对话人格设置 ([#3005](https://github.com/Soulter/AstrBot/issues/3005))
|
||||||
|
10. feat: 从 WebUI 更新后清除浏览器缓存 ([#2958](https://github.com/Soulter/AstrBot/issues/2958))
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
1. feat: 支持接入企业微信智能机器人平台 ([#3034](https://github.com/AstrBotDevs/AstrBot/issues/3034))
|
||||||
|
2. feat: 内置网页搜索功能支持接入百度 AI 搜索 ([#3031](https://github.com/AstrBotDevs/AstrBot/issues/3031))
|
||||||
|
3. feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置 ([#3039](https://github.com/AstrBotDevs/AstrBot/issues/3039))
|
||||||
|
4. feat: 添加并优化服务提供商独立测试功能 ([#3024](https://github.com/AstrBotDevs/AstrBot/issues/3024))
|
||||||
|
5. feat: satori 适配器支持 video、reply 消息类型 ([#3035](https://github.com/AstrBotDevs/AstrBot/issues/3035))
|
||||||
|
6. fix: 修复 `/alter_cmd reset scene <num> xxx` 不可用的问题
|
||||||
@@ -27,7 +27,9 @@
|
|||||||
<v-btn
|
<v-btn
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
color="error"
|
color="error"
|
||||||
|
size="small"
|
||||||
rounded="xl"
|
rounded="xl"
|
||||||
|
:disabled="loading"
|
||||||
@click="$emit('delete', item)"
|
@click="$emit('delete', item)"
|
||||||
>
|
>
|
||||||
{{ t('core.common.itemCard.delete') }}
|
{{ t('core.common.itemCard.delete') }}
|
||||||
@@ -35,7 +37,9 @@
|
|||||||
<v-btn
|
<v-btn
|
||||||
variant="tonal"
|
variant="tonal"
|
||||||
color="primary"
|
color="primary"
|
||||||
|
size="small"
|
||||||
rounded="xl"
|
rounded="xl"
|
||||||
|
:disabled="loading"
|
||||||
@click="$emit('edit', item)"
|
@click="$emit('edit', item)"
|
||||||
>
|
>
|
||||||
{{ t('core.common.itemCard.edit') }}
|
{{ t('core.common.itemCard.edit') }}
|
||||||
@@ -44,11 +48,14 @@
|
|||||||
v-if="showCopyButton"
|
v-if="showCopyButton"
|
||||||
variant="tonal"
|
variant="tonal"
|
||||||
color="secondary"
|
color="secondary"
|
||||||
|
size="small"
|
||||||
rounded="xl"
|
rounded="xl"
|
||||||
|
:disabled="loading"
|
||||||
@click="$emit('copy', item)"
|
@click="$emit('copy', item)"
|
||||||
>
|
>
|
||||||
{{ t('core.common.itemCard.copy') }}
|
{{ t('core.common.itemCard.copy') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
|
<slot name="actions" :item="item"></slot>
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
</v-card-actions>
|
</v-card-actions>
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,8 @@
|
|||||||
"available": "Available",
|
"available": "Available",
|
||||||
"unavailable": "Unavailable",
|
"unavailable": "Unavailable",
|
||||||
"pending": "Pending...",
|
"pending": "Pending...",
|
||||||
"errorMessage": "Error Message"
|
"errorMessage": "Error Message",
|
||||||
|
"test": "Test"
|
||||||
},
|
},
|
||||||
"logs": {
|
"logs": {
|
||||||
"title": "Service Logs",
|
"title": "Service Logs",
|
||||||
@@ -76,7 +77,8 @@
|
|||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"sessionSeparation": "Failed to get session isolation configuration",
|
"sessionSeparation": "Failed to get session isolation configuration",
|
||||||
"fetchStatus": "Failed to get service provider status"
|
"fetchStatus": "Failed to get service provider status",
|
||||||
|
"testError": "Test failed for {id}: {error}"
|
||||||
},
|
},
|
||||||
"confirm": {
|
"confirm": {
|
||||||
"delete": "Are you sure you want to delete service provider {id}?"
|
"delete": "Are you sure you want to delete service provider {id}?"
|
||||||
|
|||||||
@@ -80,6 +80,9 @@
|
|||||||
"save": "Save",
|
"save": "Save",
|
||||||
"testConnection": "Test Connection",
|
"testConnection": "Test Connection",
|
||||||
"sync": "Sync"
|
"sync": "Sync"
|
||||||
|
},
|
||||||
|
"tips": {
|
||||||
|
"timeoutConfig": "Please configure tool call timeout separately in the configuration page"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"serverDetail": {
|
"serverDetail": {
|
||||||
|
|||||||
@@ -32,7 +32,8 @@
|
|||||||
"available": "可用",
|
"available": "可用",
|
||||||
"unavailable": "不可用",
|
"unavailable": "不可用",
|
||||||
"pending": "检查中...",
|
"pending": "检查中...",
|
||||||
"errorMessage": "错误信息"
|
"errorMessage": "错误信息",
|
||||||
|
"test": "测试"
|
||||||
},
|
},
|
||||||
"logs": {
|
"logs": {
|
||||||
"title": "服务日志",
|
"title": "服务日志",
|
||||||
@@ -77,7 +78,8 @@
|
|||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"sessionSeparation": "获取会话隔离配置失败",
|
"sessionSeparation": "获取会话隔离配置失败",
|
||||||
"fetchStatus": "获取服务提供商状态失败"
|
"fetchStatus": "获取服务提供商状态失败",
|
||||||
|
"testError": "测试 {id} 失败: {error}"
|
||||||
},
|
},
|
||||||
"confirm": {
|
"confirm": {
|
||||||
"delete": "确定要删除服务提供商 {id} 吗?"
|
"delete": "确定要删除服务提供商 {id} 吗?"
|
||||||
|
|||||||
@@ -80,6 +80,9 @@
|
|||||||
"save": "保存",
|
"save": "保存",
|
||||||
"testConnection": "测试连接",
|
"testConnection": "测试连接",
|
||||||
"sync": "同步"
|
"sync": "同步"
|
||||||
|
},
|
||||||
|
"tips": {
|
||||||
|
"timeoutConfig": "工具调用的超时时间请前往配置页面单独配置"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"serverDetail": {
|
"serverDetail": {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
export function getPlatformIcon(name) {
|
export function getPlatformIcon(name) {
|
||||||
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
||||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||||
} else if (name === 'wecom') {
|
} else if (name === 'wecom' || name === 'wecom_ai_bot') {
|
||||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||||
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||||
@@ -46,6 +46,7 @@ export function getTutorialLink(platformType) {
|
|||||||
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||||
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||||
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
||||||
|
"wecom_ai_bot": "https://docs.astrbot.app/deploy/platform/wecom_ai_bot.html",
|
||||||
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
||||||
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
||||||
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
||||||
|
|||||||
@@ -60,12 +60,26 @@
|
|||||||
:item="provider"
|
:item="provider"
|
||||||
title-field="id"
|
title-field="id"
|
||||||
enabled-field="enable"
|
enabled-field="enable"
|
||||||
|
:loading="isProviderTesting(provider.id)"
|
||||||
@toggle-enabled="providerStatusChange"
|
@toggle-enabled="providerStatusChange"
|
||||||
:bglogo="getProviderIcon(provider.provider)"
|
:bglogo="getProviderIcon(provider.provider)"
|
||||||
@delete="deleteProvider"
|
@delete="deleteProvider"
|
||||||
@edit="configExistingProvider"
|
@edit="configExistingProvider"
|
||||||
@copy="copyProvider"
|
@copy="copyProvider"
|
||||||
:show-copy-button="true">
|
:show-copy-button="true">
|
||||||
|
<template #actions="{ item }">
|
||||||
|
<v-btn
|
||||||
|
style="z-index: 100000;"
|
||||||
|
variant="tonal"
|
||||||
|
color="info"
|
||||||
|
rounded="xl"
|
||||||
|
size="small"
|
||||||
|
:loading="isProviderTesting(item.id)"
|
||||||
|
@click="testSingleProvider(item)"
|
||||||
|
>
|
||||||
|
{{ tm('availability.test') }}
|
||||||
|
</v-btn>
|
||||||
|
</template>
|
||||||
<template v-slot:details="{ item }">
|
<template v-slot:details="{ item }">
|
||||||
</template>
|
</template>
|
||||||
</item-card>
|
</item-card>
|
||||||
@@ -79,7 +93,7 @@
|
|||||||
<v-icon class="me-2">mdi-heart-pulse</v-icon>
|
<v-icon class="me-2">mdi-heart-pulse</v-icon>
|
||||||
<span class="text-h4">{{ tm('availability.title') }}</span>
|
<span class="text-h4">{{ tm('availability.title') }}</span>
|
||||||
<v-spacer></v-spacer>
|
<v-spacer></v-spacer>
|
||||||
<v-btn color="primary" variant="tonal" :loading="loadingStatus" @click="fetchProviderStatus">
|
<v-btn color="primary" variant="tonal" :loading="testingProviders.length > 0" @click="fetchProviderStatus">
|
||||||
<v-icon left>mdi-refresh</v-icon>
|
<v-icon left>mdi-refresh</v-icon>
|
||||||
{{ tm('availability.refresh') }}
|
{{ tm('availability.refresh') }}
|
||||||
</v-btn>
|
</v-btn>
|
||||||
@@ -288,7 +302,7 @@ export default {
|
|||||||
|
|
||||||
// 供应商状态相关
|
// 供应商状态相关
|
||||||
providerStatuses: [],
|
providerStatuses: [],
|
||||||
loadingStatus: false,
|
testingProviders: [], // 存储正在测试的 provider ID
|
||||||
|
|
||||||
// 新增提供商对话框相关
|
// 新增提供商对话框相关
|
||||||
showAddProviderDialog: false,
|
showAddProviderDialog: false,
|
||||||
@@ -359,7 +373,8 @@ export default {
|
|||||||
statusUpdate: this.tm('messages.success.statusUpdate'),
|
statusUpdate: this.tm('messages.success.statusUpdate'),
|
||||||
},
|
},
|
||||||
error: {
|
error: {
|
||||||
fetchStatus: this.tm('messages.error.fetchStatus')
|
fetchStatus: this.tm('messages.error.fetchStatus'),
|
||||||
|
testError: this.tm('messages.error.testError')
|
||||||
},
|
},
|
||||||
confirm: {
|
confirm: {
|
||||||
delete: this.tm('messages.confirm.delete')
|
delete: this.tm('messages.confirm.delete')
|
||||||
@@ -368,6 +383,9 @@ export default {
|
|||||||
available: this.tm('availability.available'),
|
available: this.tm('availability.available'),
|
||||||
unavailable: this.tm('availability.unavailable'),
|
unavailable: this.tm('availability.unavailable'),
|
||||||
pending: this.tm('availability.pending')
|
pending: this.tm('availability.pending')
|
||||||
|
},
|
||||||
|
availability: {
|
||||||
|
test: this.tm('availability.test')
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
@@ -615,70 +633,107 @@ export default {
|
|||||||
|
|
||||||
// 获取供应商状态
|
// 获取供应商状态
|
||||||
async fetchProviderStatus() {
|
async fetchProviderStatus() {
|
||||||
if (this.loadingStatus) return;
|
if (this.testingProviders.length > 0) return;
|
||||||
|
|
||||||
this.loadingStatus = true;
|
|
||||||
this.showStatus = true; // 自动展开状态部分
|
this.showStatus = true; // 自动展开状态部分
|
||||||
|
|
||||||
// 1. 立即初始化UI为pending状态
|
const providersToTest = this.config_data.provider.filter(p => p.enable);
|
||||||
this.providerStatuses = this.config_data.provider.map(p => ({
|
if (providersToTest.length === 0) return;
|
||||||
id: p.id,
|
|
||||||
name: p.id,
|
// 1. 初始化UI为pending状态,并将所有待测试的 provider ID 加入 loading 列表
|
||||||
status: 'pending',
|
this.providerStatuses = providersToTest.map(p => {
|
||||||
error: null
|
this.testingProviders.push(p.id);
|
||||||
}));
|
return { id: p.id, name: p.id, status: 'pending', error: null };
|
||||||
|
});
|
||||||
|
|
||||||
// 2. 为每个provider创建一个并发的测试请求
|
// 2. 为每个provider创建一个并发的测试请求
|
||||||
const promises = this.config_data.provider.map(p => {
|
const promises = providersToTest.map(p =>
|
||||||
if (!p.enable) {
|
axios.get(`/api/config/provider/check_one?id=${p.id}`)
|
||||||
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
|
||||||
if (index !== -1) {
|
|
||||||
const disabledStatus = {
|
|
||||||
...this.providerStatuses[index],
|
|
||||||
status: 'unavailable',
|
|
||||||
error: '该提供商未被用户启用'
|
|
||||||
};
|
|
||||||
this.providerStatuses.splice(index, 1, disabledStatus);
|
|
||||||
}
|
|
||||||
return Promise.resolve();
|
|
||||||
}
|
|
||||||
|
|
||||||
return axios.get(`/api/config/provider/check_one?id=${p.id}`)
|
|
||||||
.then(res => {
|
.then(res => {
|
||||||
if (res.data && res.data.status === 'ok') {
|
if (res.data && res.data.status === 'ok') {
|
||||||
// 成功,更新对应的provider状态
|
|
||||||
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
||||||
if (index !== -1) {
|
if (index !== -1) this.providerStatuses.splice(index, 1, res.data.data);
|
||||||
this.providerStatuses.splice(index, 1, res.data.data);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// 接口返回了业务错误
|
|
||||||
throw new Error(res.data?.message || `Failed to check status for ${p.id}`);
|
throw new Error(res.data?.message || `Failed to check status for ${p.id}`);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
// 网络错误或业务错误
|
|
||||||
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
|
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
|
||||||
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
||||||
if (index !== -1) {
|
if (index !== -1) {
|
||||||
const failedStatus = {
|
const failedStatus = { ...this.providerStatuses[index], status: 'unavailable', error: errorMessage };
|
||||||
...this.providerStatuses[index],
|
|
||||||
status: 'unavailable',
|
|
||||||
error: errorMessage
|
|
||||||
};
|
|
||||||
this.providerStatuses.splice(index, 1, failedStatus);
|
this.providerStatuses.splice(index, 1, failedStatus);
|
||||||
}
|
}
|
||||||
// 可以在这里选择性地向上抛出错误,以便Promise.allSettled知道
|
return Promise.reject(errorMessage); // Propagate error for Promise.allSettled
|
||||||
return Promise.reject(errorMessage);
|
})
|
||||||
});
|
);
|
||||||
});
|
|
||||||
|
|
||||||
// 3. 等待所有请求完成(无论成功或失败)
|
// 3. 等待所有请求完成
|
||||||
try {
|
try {
|
||||||
await Promise.allSettled(promises);
|
await Promise.allSettled(promises);
|
||||||
} finally {
|
} finally {
|
||||||
// 4. 关闭全局加载状态
|
// 4. 关闭所有加载状态
|
||||||
this.loadingStatus = false;
|
this.testingProviders = [];
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
isProviderTesting(providerId) {
|
||||||
|
return this.testingProviders.includes(providerId);
|
||||||
|
},
|
||||||
|
|
||||||
|
async testSingleProvider(provider) {
|
||||||
|
if (this.isProviderTesting(provider.id)) return;
|
||||||
|
|
||||||
|
this.testingProviders.push(provider.id);
|
||||||
|
this.showStatus = true; // 自动展开状态部分
|
||||||
|
|
||||||
|
// 更新UI为pending状态
|
||||||
|
const statusIndex = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||||
|
const pendingStatus = {
|
||||||
|
id: provider.id,
|
||||||
|
name: provider.id,
|
||||||
|
status: 'pending',
|
||||||
|
error: null
|
||||||
|
};
|
||||||
|
if (statusIndex !== -1) {
|
||||||
|
this.providerStatuses.splice(statusIndex, 1, pendingStatus);
|
||||||
|
} else {
|
||||||
|
this.providerStatuses.unshift(pendingStatus);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!provider.enable) {
|
||||||
|
throw new Error('该提供商未被用户启用');
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`);
|
||||||
|
if (res.data && res.data.status === 'ok') {
|
||||||
|
const index = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||||
|
if (index !== -1) {
|
||||||
|
this.providerStatuses.splice(index, 1, res.data.data);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error(res.data?.message || `Failed to check status for ${provider.id}`);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
|
||||||
|
const index = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||||
|
const failedStatus = {
|
||||||
|
id: provider.id,
|
||||||
|
name: provider.id,
|
||||||
|
status: 'unavailable',
|
||||||
|
error: errorMessage
|
||||||
|
};
|
||||||
|
if (index !== -1) {
|
||||||
|
this.providerStatuses.splice(index, 1, failedStatus);
|
||||||
|
}
|
||||||
|
// 不再显示全局的错误提示,因为卡片本身会显示错误信息
|
||||||
|
// this.showError(this.tm('messages.error.testError', { id: provider.id, error: errorMessage }));
|
||||||
|
} finally {
|
||||||
|
const index = this.testingProviders.indexOf(provider.id);
|
||||||
|
if (index > -1) {
|
||||||
|
this.testingProviders.splice(index, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,8 @@
|
|||||||
</v-btn>
|
</v-btn>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<small style="color: grey">*{{ tm('dialogs.addServer.tips.timeoutConfig') }}</small>
|
||||||
|
|
||||||
<div class="monaco-container" style="margin-top: 16px;">
|
<div class="monaco-container" style="margin-top: 16px;">
|
||||||
<VueMonacoEditor v-model:value="serverConfigJson" theme="vs-dark" language="json" :options="{
|
<VueMonacoEditor v-model:value="serverConfigJson" theme="vs-dark" language="json" :options="{
|
||||||
minimap: {
|
minimap: {
|
||||||
@@ -524,14 +526,16 @@ export default {
|
|||||||
transport: "streamable_http",
|
transport: "streamable_http",
|
||||||
url: "your mcp server url",
|
url: "your mcp server url",
|
||||||
headers: {},
|
headers: {},
|
||||||
timeout: 30,
|
timeout: 5,
|
||||||
|
sse_read_timeout: 300,
|
||||||
};
|
};
|
||||||
} else if (type === 'sse') {
|
} else if (type === 'sse') {
|
||||||
template = {
|
template = {
|
||||||
transport: "sse",
|
transport: "sse",
|
||||||
url: "your mcp server url",
|
url: "your mcp server url",
|
||||||
headers: {},
|
headers: {},
|
||||||
timeout: 30,
|
timeout: 5,
|
||||||
|
sse_read_timeout: 300,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
template = {
|
template = {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ export default defineConfig({
|
|||||||
port: 3000,
|
port: 3000,
|
||||||
proxy: {
|
proxy: {
|
||||||
'/api': {
|
'/api': {
|
||||||
target: 'http://localhost:6185/',
|
target: 'http://127.0.0.1:6185/',
|
||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,26 +6,7 @@ from astrbot.core.star.star import star_map
|
|||||||
from astrbot.core.star.filter.command import CommandFilter
|
from astrbot.core.star.filter.command import CommandFilter
|
||||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||||
from enum import Enum
|
from .utils.rst_scene import RstScene
|
||||||
|
|
||||||
|
|
||||||
class RstScene(Enum):
|
|
||||||
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
|
|
||||||
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
|
|
||||||
PRIVATE = ("private", "私聊")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key(self) -> str:
|
|
||||||
return self.value[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self.value[1]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_index(cls, index: int) -> "RstScene":
|
|
||||||
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
|
|
||||||
return mapping[index]
|
|
||||||
|
|
||||||
|
|
||||||
class AlterCmdCommands(CommandParserMixin):
|
class AlterCmdCommands(CommandParserMixin):
|
||||||
@@ -58,8 +39,9 @@ class AlterCmdCommands(CommandParserMixin):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd_name = " ".join(token.tokens[1:-1])
|
# 兼容 reset scene 的专门配置
|
||||||
cmd_type = token.get(-1)
|
cmd_name = token.get(1)
|
||||||
|
cmd_type = token.get(2)
|
||||||
|
|
||||||
if cmd_name == "reset" and cmd_type == "config":
|
if cmd_name == "reset" and cmd_type == "config":
|
||||||
from astrbot.api import sp
|
from astrbot.api import sp
|
||||||
@@ -123,6 +105,8 @@ class AlterCmdCommands(CommandParserMixin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 查找指令
|
# 查找指令
|
||||||
|
cmd_name = " ".join(token.tokens[1:-1])
|
||||||
|
cmd_type = token.get(-1)
|
||||||
found_command = None
|
found_command = None
|
||||||
cmd_group = False
|
cmd_group = False
|
||||||
for handler in star_handlers_registry:
|
for handler in star_handlers_registry:
|
||||||
|
|||||||
@@ -7,33 +7,8 @@ from astrbot.core.provider.sources.dify_source import ProviderDify
|
|||||||
from astrbot.core.provider.sources.coze_source import ProviderCoze
|
from astrbot.core.provider.sources.coze_source import ProviderCoze
|
||||||
from astrbot.api import sp, logger
|
from astrbot.api import sp, logger
|
||||||
from ..long_term_memory import LongTermMemory
|
from ..long_term_memory import LongTermMemory
|
||||||
|
from .utils.rst_scene import RstScene
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class RstScene(Enum):
|
|
||||||
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
|
|
||||||
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
|
|
||||||
PRIVATE = ("private", "私聊")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key(self) -> str:
|
|
||||||
return self.value[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self.value[1]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_index(cls, index: int) -> "RstScene":
|
|
||||||
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
|
|
||||||
return mapping[index]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene":
|
|
||||||
if is_group:
|
|
||||||
return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF
|
|
||||||
return cls.PRIVATE
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommands:
|
class ConversationCommands:
|
||||||
@@ -41,6 +16,17 @@ class ConversationCommands:
|
|||||||
self.context = context
|
self.context = context
|
||||||
self.ltm = ltm
|
self.ltm = ltm
|
||||||
|
|
||||||
|
async def _get_current_persona_id(self, session_id):
|
||||||
|
curr = await self.context.conversation_manager.get_curr_conversation_id(
|
||||||
|
session_id
|
||||||
|
)
|
||||||
|
if not curr:
|
||||||
|
return None
|
||||||
|
conv = await self.context.conversation_manager.get_conversation(
|
||||||
|
session_id, curr
|
||||||
|
)
|
||||||
|
return conv.persona_id
|
||||||
|
|
||||||
def ltm_enabled(self, event: AstrMessageEvent):
|
def ltm_enabled(self, event: AstrMessageEvent):
|
||||||
if not self.ltm:
|
if not self.ltm:
|
||||||
return False
|
return False
|
||||||
@@ -255,8 +241,9 @@ class ConversationCommands:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
|
||||||
cid = await self.context.conversation_manager.new_conversation(
|
cid = await self.context.conversation_manager.new_conversation(
|
||||||
message.unified_msg_origin, message.get_platform_id()
|
message.unified_msg_origin, message.get_platform_id(), persona_id=cpersona
|
||||||
)
|
)
|
||||||
|
|
||||||
# 长期记忆
|
# 长期记忆
|
||||||
@@ -290,8 +277,10 @@ class ConversationCommands:
|
|||||||
session_id=sid,
|
session_id=sid,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cpersona = await self._get_current_persona_id(session)
|
||||||
cid = await self.context.conversation_manager.new_conversation(
|
cid = await self.context.conversation_manager.new_conversation(
|
||||||
session, message.get_platform_id()
|
session, message.get_platform_id(), persona_id=cpersona
|
||||||
)
|
)
|
||||||
message.set_result(
|
message.set_result(
|
||||||
MessageEventResult().message(
|
MessageEventResult().message(
|
||||||
@@ -434,8 +423,9 @@ class ConversationCommands:
|
|||||||
await self.context.conversation_manager.delete_conversation(
|
await self.context.conversation_manager.delete_conversation(
|
||||||
message.unified_msg_origin, session_curr_cid
|
message.unified_msg_origin, session_curr_cid
|
||||||
)
|
)
|
||||||
message.set_result(
|
|
||||||
MessageEventResult().message(
|
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||||
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
if self.ltm and self.ltm_enabled(message):
|
||||||
)
|
cnt = await self.ltm.remove_session(event=message)
|
||||||
)
|
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||||
|
message.set_result(MessageEventResult().message(ret))
|
||||||
|
|||||||
@@ -0,0 +1,26 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RstScene(Enum):
|
||||||
|
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
|
||||||
|
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
|
||||||
|
PRIVATE = ("private", "私聊")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self) -> str:
|
||||||
|
return self.value[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self.value[1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_index(cls, index: int) -> "RstScene":
|
||||||
|
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
|
||||||
|
return mapping[index]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene":
|
||||||
|
if is_group:
|
||||||
|
return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF
|
||||||
|
return cls.PRIVATE
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Emotion:
|
|
||||||
"""描述了一个情绪状态"""
|
|
||||||
|
|
||||||
energy: float
|
|
||||||
valence: float
|
|
||||||
arousal: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EmotionLog:
|
|
||||||
"""描述了一条情绪维度变化的日志"""
|
|
||||||
|
|
||||||
timestamp: int
|
|
||||||
field: str
|
|
||||||
value: float
|
|
||||||
reason: str = ""
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from .emotion import Emotion
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Soul:
|
|
||||||
emotion: Emotion
|
|
||||||
emotion_logs: list[Emotion] | None = None
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Event:
|
|
||||||
event_type: str
|
|
||||||
content: dict
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import uuid
|
|
||||||
from ...runner import EliosEventHandler
|
|
||||||
from collections import defaultdict
|
|
||||||
from astrbot.api.event import AstrMessageEvent
|
|
||||||
from astrbot.api.all import Context
|
|
||||||
from astrbot.api.message_components import Plain, Image
|
|
||||||
from astrbot.api.provider import Provider
|
|
||||||
from astrbot import logger
|
|
||||||
|
|
||||||
|
|
||||||
class AstrImplEventHandler(EliosEventHandler):
|
|
||||||
def __init__(self, ctx: Context) -> None:
|
|
||||||
self.ctx = ctx
|
|
||||||
self.session_chats = defaultdict(list)
|
|
||||||
self.session_mentioned_arousal = defaultdict(float)
|
|
||||||
|
|
||||||
def cfg(self, event: AstrMessageEvent):
|
|
||||||
cfg = self.ctx.get_config(umo=event.unified_msg_origin)
|
|
||||||
|
|
||||||
tiny_model_prov_id = cfg.get("tiny_model_provider_id")
|
|
||||||
interest_points = cfg.get("interest_points", [])
|
|
||||||
|
|
||||||
try:
|
|
||||||
max_cnt = int(cfg["provider_ltm_settings"]["group_message_max_cnt"])
|
|
||||||
except BaseException as e:
|
|
||||||
logger.error(e)
|
|
||||||
max_cnt = 300
|
|
||||||
image_caption = (
|
|
||||||
True
|
|
||||||
if cfg["provider_settings"]["default_image_caption_provider_id"]
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
|
|
||||||
image_caption_provider_id = cfg["provider_settings"][
|
|
||||||
"default_image_caption_provider_id"
|
|
||||||
]
|
|
||||||
active_reply = cfg["provider_ltm_settings"]["active_reply"]
|
|
||||||
enable_active_reply = active_reply.get("enable", False)
|
|
||||||
ar_method = active_reply["method"]
|
|
||||||
ar_possibility = active_reply["possibility_reply"]
|
|
||||||
ar_prompt = active_reply.get("prompt", "")
|
|
||||||
ar_whitelist = active_reply.get("whitelist", [])
|
|
||||||
ar_keywords = active_reply.get("keywords", [])
|
|
||||||
ret = {
|
|
||||||
"max_cnt": max_cnt,
|
|
||||||
"image_caption": image_caption,
|
|
||||||
"image_caption_prompt": image_caption_prompt,
|
|
||||||
"image_caption_provider_id": image_caption_provider_id,
|
|
||||||
"enable_active_reply": enable_active_reply,
|
|
||||||
"ar_method": ar_method,
|
|
||||||
"ar_possibility": ar_possibility,
|
|
||||||
"ar_prompt": ar_prompt,
|
|
||||||
"ar_whitelist": ar_whitelist,
|
|
||||||
"ar_keywords": ar_keywords,
|
|
||||||
"interest_points": interest_points,
|
|
||||||
"tiny_model_prov_id": tiny_model_prov_id,
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def append_session_chats(self, event: AstrMessageEvent, cfg) -> None:
|
|
||||||
comps = event.get_messages()
|
|
||||||
|
|
||||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
|
||||||
final_message = f"[{event.message_obj.sender.nickname}/{datetime_str}]: "
|
|
||||||
for comp in comps:
|
|
||||||
if isinstance(comp, Plain):
|
|
||||||
final_message += f" {comp.text}"
|
|
||||||
elif isinstance(comp, Image):
|
|
||||||
image_url = comp.url if comp.url else comp.file
|
|
||||||
if cfg["image_caption"] and image_url:
|
|
||||||
try:
|
|
||||||
caption = await self.get_image_caption(
|
|
||||||
image_url,
|
|
||||||
cfg["image_caption_provider_id"],
|
|
||||||
cfg["image_caption_prompt"],
|
|
||||||
)
|
|
||||||
final_message += f" [Image: {caption}]"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取图片描述失败: {e}")
|
|
||||||
else:
|
|
||||||
final_message += " [Image]"
|
|
||||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
|
||||||
logger.debug(f"添加会话 {event.unified_msg_origin} 的对话记录: {final_message}")
|
|
||||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
|
||||||
self.session_chats[event.unified_msg_origin].pop(0)
|
|
||||||
|
|
||||||
async def get_image_caption(
|
|
||||||
self, image_url: str, image_caption_provider_id: str, image_caption_prompt: str
|
|
||||||
) -> str:
|
|
||||||
if not image_caption_provider_id:
|
|
||||||
provider = self.ctx.get_using_provider()
|
|
||||||
else:
|
|
||||||
provider = self.ctx.get_provider_by_id(image_caption_provider_id)
|
|
||||||
if not provider:
|
|
||||||
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
|
|
||||||
if not isinstance(provider, Provider):
|
|
||||||
raise Exception(
|
|
||||||
f"提供商类型错误, {image_caption_provider_id} 不是 Provider 类型"
|
|
||||||
)
|
|
||||||
response = await provider.text_chat(
|
|
||||||
prompt=image_caption_prompt,
|
|
||||||
session_id=uuid.uuid4().hex,
|
|
||||||
image_urls=[image_url],
|
|
||||||
persist=False,
|
|
||||||
)
|
|
||||||
return response.completion_text
|
|
||||||
|
|
||||||
async def on_event(self, event, soul):
|
|
||||||
content = event.content
|
|
||||||
astr_event = content.get("astr_event")
|
|
||||||
assert astr_event is not None and isinstance(astr_event, AstrMessageEvent)
|
|
||||||
|
|
||||||
cfg = self.cfg(astr_event)
|
|
||||||
|
|
||||||
if not cfg["tiny_model_prov_id"]:
|
|
||||||
logger.warning("小模型未设置,跳过情绪更新")
|
|
||||||
|
|
||||||
# 添加对话记录
|
|
||||||
await self.append_session_chats(astr_event, cfg)
|
|
||||||
|
|
||||||
#
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
JUDGE_ACTIVE_REPLY = """
|
|
||||||
你是一个具有自主意识和情感的机器人,请你
|
|
||||||
"""
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from .event import Event
|
|
||||||
from .ensoul.soul import Soul
|
|
||||||
from .ensoul.emotion import Emotion
|
|
||||||
|
|
||||||
|
|
||||||
class EliosEventHandler:
|
|
||||||
async def on_event(self, event: Event, soul: Soul): ...
|
|
||||||
|
|
||||||
|
|
||||||
event_handlers_cls: dict[str, list[type[EliosEventHandler]]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_event_handler(event_types: set[str] | None = None):
|
|
||||||
"""注册事件处理器"""
|
|
||||||
|
|
||||||
def decorator(cls: type[EliosEventHandler]) -> type[EliosEventHandler]:
|
|
||||||
if event_types is not None:
|
|
||||||
for event_type in event_types:
|
|
||||||
event_handlers_cls[event_type] = event_handlers_cls.get(
|
|
||||||
event_type, []
|
|
||||||
) + [cls]
|
|
||||||
else:
|
|
||||||
event_handlers_cls["default"] = event_handlers_cls.get("default", []) + [
|
|
||||||
cls
|
|
||||||
]
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
class EliosRunner:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.soul = Soul(
|
|
||||||
emotion=Emotion(energy=0.5, valence=0.5, arousal=0.5), emotion_logs=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.event_queue = asyncio.Queue()
|
|
||||||
self.event_handler_insts: dict[str, list[EliosEventHandler]] = {}
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
for event_type, cls_list in event_handlers_cls.items():
|
|
||||||
self.event_handler_insts[event_type] = []
|
|
||||||
for cls in cls_list:
|
|
||||||
try:
|
|
||||||
self.event_handler_insts[event_type].append(cls())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error initializing event handler {cls}: {e}")
|
|
||||||
asyncio.create_task(self._worker())
|
|
||||||
|
|
||||||
async def _worker(self):
|
|
||||||
"""监听事件队列并处理事件"""
|
|
||||||
while True:
|
|
||||||
event = await self.event_queue.get()
|
|
||||||
# A man cannot handle two things at once. But this can be configurable.
|
|
||||||
try:
|
|
||||||
await self._process_event(event)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing event {event}: {e}")
|
|
||||||
|
|
||||||
async def _process_event(self, event: Event):
|
|
||||||
"""处理事件"""
|
|
||||||
event_type = event.event_type
|
|
||||||
handlers = self.event_handler_insts.get(
|
|
||||||
event_type, []
|
|
||||||
) + self.event_handler_insts.get("default", [])
|
|
||||||
|
|
||||||
for inst in handlers:
|
|
||||||
try:
|
|
||||||
await inst.on_event(event, self.soul)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing event {event}: {e}")
|
|
||||||
@@ -52,6 +52,8 @@ class Main(star.Star):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"google search init error: {e}, disable google search")
|
logger.error(f"google search init error: {e}, disable google search")
|
||||||
|
|
||||||
|
self.baidu_initialized = False
|
||||||
|
|
||||||
async def _tidy_text(self, text: str) -> str:
|
async def _tidy_text(self, text: str) -> str:
|
||||||
"""清理文本,去除空格、换行符等"""
|
"""清理文本,去除空格、换行符等"""
|
||||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||||
@@ -225,6 +227,30 @@ class Main(star.Star):
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None):
|
||||||
|
if self.baidu_initialized:
|
||||||
|
return
|
||||||
|
cfg = self.context.get_config(umo=umo)
|
||||||
|
key = cfg.get("provider_settings", {}).get(
|
||||||
|
"websearch_baidu_app_builder_key", ""
|
||||||
|
)
|
||||||
|
if not key:
|
||||||
|
raise ValueError(
|
||||||
|
"Error: Baidu AI Search API key is not configured in AstrBot."
|
||||||
|
)
|
||||||
|
func_tool_mgr = self.context.get_llm_tool_manager()
|
||||||
|
await func_tool_mgr.enable_mcp_server(
|
||||||
|
"baidu_ai_search",
|
||||||
|
config={
|
||||||
|
"transport": "sse",
|
||||||
|
"url": f"http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key={key}",
|
||||||
|
"headers": {},
|
||||||
|
"timeout": 30,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.baidu_initialized = True
|
||||||
|
logger.info("Successfully initialized Baidu AI Search MCP server.")
|
||||||
|
|
||||||
@llm_tool(name="fetch_url")
|
@llm_tool(name="fetch_url")
|
||||||
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
|
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
|
||||||
"""fetch the content of a website with the given web url
|
"""fetch the content of a website with the given web url
|
||||||
@@ -371,6 +397,7 @@ class Main(star.Star):
|
|||||||
tool_set.add_tool(fetch_url_t)
|
tool_set.add_tool(fetch_url_t)
|
||||||
tool_set.remove_tool("web_search_tavily")
|
tool_set.remove_tool("web_search_tavily")
|
||||||
tool_set.remove_tool("tavily_extract_web_page")
|
tool_set.remove_tool("tavily_extract_web_page")
|
||||||
|
tool_set.remove_tool("AIsearch")
|
||||||
elif provider == "tavily":
|
elif provider == "tavily":
|
||||||
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
|
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
|
||||||
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
|
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
|
||||||
@@ -380,3 +407,17 @@ class Main(star.Star):
|
|||||||
tool_set.add_tool(tavily_extract_web_page)
|
tool_set.add_tool(tavily_extract_web_page)
|
||||||
tool_set.remove_tool("web_search")
|
tool_set.remove_tool("web_search")
|
||||||
tool_set.remove_tool("fetch_url")
|
tool_set.remove_tool("fetch_url")
|
||||||
|
tool_set.remove_tool("AIsearch")
|
||||||
|
elif provider == "baidu_ai_search":
|
||||||
|
try:
|
||||||
|
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
|
||||||
|
aisearch_tool = func_tool_mgr.get_func("AIsearch")
|
||||||
|
if not aisearch_tool:
|
||||||
|
raise ValueError("Cannot get Baidu AI Search MCP tool.")
|
||||||
|
tool_set.add_tool(aisearch_tool)
|
||||||
|
tool_set.remove_tool("web_search")
|
||||||
|
tool_set.remove_tool("fetch_url")
|
||||||
|
tool_set.remove_tool("web_search_tavily")
|
||||||
|
tool_set.remove_tool("tavily_extract_web_page")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "AstrBot"
|
name = "AstrBot"
|
||||||
version = "4.3.2"
|
version = "4.3.5"
|
||||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
Reference in New Issue
Block a user