Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7d776e0ce2 | |||
| 17df1692b9 | |||
| 9ab652641d | |||
| 9119f7166f | |||
| da7d9d8eb9 | |||
| 80fccc90b7 | |||
| dcebc70f1a | |||
| 259e7bc322 | |||
| 37bdb6c6f6 | |||
| dc71afdd3f | |||
| 44638108d0 | |||
| 93fcac498c | |||
| 79e2743aac | |||
| 5e9c7cdd91 | |||
| 6f73e5087d | |||
| 8c120b020e | |||
| 12fc6f9d38 | |||
| a6e8483b4c | |||
| 7191d28ada | |||
| e6b5e3d282 | |||
| 1413d6b5fe | |||
| dcd8a1094c | |||
| e64b31b9ba | |||
| 080f347511 | |||
| eaaff4298d | |||
| dd5a02e8ef | |||
| 3211ec57ee | |||
| 6796afdaee | |||
| cc6fe57773 | |||
| 1dfc831938 | |||
| cafeda4abf | |||
| d951b99718 | |||
| 0ad87209e5 | |||
| 1b50c5404d | |||
| 3007f67cab | |||
| ee08659f01 | |||
| baf5ad0fab | |||
| 8bdd748aec | |||
| cef0c22f52 | |||
| 13d3fc5cfe | |||
| b91141e2be | |||
| f8a4b54165 | |||
| afe007ca0b | |||
| 8a9a044f95 | |||
| 5eaf03e227 | |||
| a8437d9331 | |||
| e0392fa98b | |||
| 68ff8951de | |||
| 9c6b31e71c | |||
| 50f74f5ba2 |
@@ -11,6 +11,8 @@ reviewers:
|
||||
- Larch-C
|
||||
- anka-afk
|
||||
- advent259141
|
||||
- Fridemn
|
||||
- LIghtJUNction
|
||||
# - zouyonghe
|
||||
|
||||
# A number of reviewers added to the pull request
|
||||
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
uses: github/codeql-action/init@v4
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
build-mode: ${{ matrix.build-mode }}
|
||||
@@ -88,6 +88,6 @@ jobs:
|
||||
exit 1
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
uses: github/codeql-action/analyze@v4
|
||||
with:
|
||||
category: "/language:${{matrix.language}}"
|
||||
|
||||
+3
-1
@@ -30,4 +30,6 @@ packages/python_interpreter/workplace
|
||||
.conda/
|
||||
.idea
|
||||
pytest.ini
|
||||
.astrbot
|
||||
.astrbot
|
||||
|
||||
uv.lock
|
||||
+8
-13
@@ -4,8 +4,6 @@ WORKDIR /AstrBot
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
nodejs \
|
||||
npm \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
@@ -13,23 +11,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libssl-dev \
|
||||
ca-certificates \
|
||||
bash \
|
||||
ffmpeg \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN apt-get update && apt-get install -y curl gnupg && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
# 释出 ffmpeg
|
||||
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||
|
||||
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
<img width="430" height="31" alt="image" src="https://github.com/user-attachments/assets/474c822c-fab7-41be-8c23-6dae252823ed" /><p align="center">
|
||||
|
||||

|
||||
|
||||
</p>
|
||||
@@ -13,17 +11,17 @@
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
<a href="https://astrbot.app/">文档</a> |
|
||||
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||
<a href="https://astrbot.featurebase.app/roadmap">路线图</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台及开发框架。
|
||||
|
||||
## 主要功能
|
||||
|
||||
@@ -35,7 +33,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
|
||||
|
||||
## 部署方式
|
||||
|
||||
#### Docker 部署
|
||||
#### Docker 部署(推荐 🥳)
|
||||
|
||||
推荐使用 Docker / Docker Compose 方式部署 AstrBot。
|
||||
|
||||
@@ -101,7 +99,6 @@ uv run main.py
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 开发者群:975206796
|
||||
- 开发者群(备份):295657329
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
@@ -113,48 +110,80 @@ uv run main.py
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
**官方维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| QQ(官方机器人接口) | ✔ |
|
||||
| QQ(官方平台) | ✔ |
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企业微信 | ✔ |
|
||||
| 企微应用 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
| 钉钉 | ✔ |
|
||||
| Slack | ✔ |
|
||||
| Discord | ✔ |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| Satori | ✔ |
|
||||
| Misskey | ✔ |
|
||||
| 企微智能机器人 | 将支持 |
|
||||
| Whatsapp | 将支持 |
|
||||
| LINE | 将支持 |
|
||||
|
||||
**社区维护**
|
||||
|
||||
| 平台 | 支持性 |
|
||||
| -------- | ------- |
|
||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||
| [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) | ✔ |
|
||||
|
||||
## ⚡ 提供商支持情况
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | 文本生成 | |
|
||||
| Google Gemini | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||
| PPIO 派欧云 | ✔ | 模型 API 服务平台 | |
|
||||
| OneAPI | ✔ | LLM 分发系统 | |
|
||||
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
|
||||
| OpenAI TTS API | ✔ | 文本转语音 | |
|
||||
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | 文本转语音 | GPT-Sovits-Inference |
|
||||
| FishAudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
|
||||
| Edge TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||
**大模型服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI | ✔ | 支持任何兼容 OpenAI API 的服务 |
|
||||
| Anthropic | ✔ | |
|
||||
| Google Gemini | ✔ | |
|
||||
| Moonshot AI | ✔ | |
|
||||
| 智谱 AI | ✔ | |
|
||||
| DeepSeek | ✔ | |
|
||||
| Ollama | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| LM Studio | ✔ | 本地部署 DeepSeek 等开源语言模型 |
|
||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | |
|
||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | |
|
||||
| [小马算力](https://www.tokenpony.cn/3YPyf) | ✔ | |
|
||||
| 硅基流动 | ✔ | |
|
||||
| PPIO 派欧云 | ✔ | |
|
||||
| ModelScope | ✔ | |
|
||||
| OneAPI | ✔ | |
|
||||
| Dify | ✔ | |
|
||||
| 阿里云百炼应用 | ✔ | |
|
||||
| Coze | ✔ | |
|
||||
|
||||
**语音转文本服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| Whisper | ✔ | 支持 API、本地部署 |
|
||||
| SenseVoice | ✔ | 本地部署 |
|
||||
|
||||
**文本转语音服务**
|
||||
|
||||
| 名称 | 支持性 | 备注 |
|
||||
| -------- | ------- | ------- |
|
||||
| OpenAI TTS | ✔ | |
|
||||
| Gemini TTS | ✔ | |
|
||||
| GSVI | ✔ | GPT-Sovits-Inference |
|
||||
| GPT-SoVITs | ✔ | GPT-Sovits |
|
||||
| FishAudio | ✔ | |
|
||||
| Edge TTS | ✔ | Edge 浏览器的免费 TTS |
|
||||
| 阿里云百炼 TTS | ✔ | |
|
||||
| Azure TTS | ✔ | |
|
||||
| Minimax TTS | ✔ | |
|
||||
| 火山引擎 TTS | ✔ | |
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
@@ -186,19 +215,10 @@ pre-commit install
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||
|
||||
另外,一些同类型其他的活跃开源 Bot 项目:
|
||||
|
||||
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
|
||||
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
||||
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
||||
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
||||
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
||||
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我们维护这个开源项目的动力 <3
|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
@@ -40,8 +40,15 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
if transport_type == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
@@ -121,7 +128,14 @@ class MCPClient:
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
if "transport" in cfg:
|
||||
transport_type = cfg["transport"]
|
||||
elif "type" in cfg:
|
||||
transport_type = cfg["type"]
|
||||
else:
|
||||
raise Exception("MCP 连接配置缺少 transport 或 type 字段")
|
||||
|
||||
if transport_type != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
@@ -134,7 +148,7 @@ class MCPClient:
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
@@ -159,7 +173,7 @@ class MCPClient:
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
|
||||
@@ -198,6 +198,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
if not func_tool:
|
||||
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: 未找到工具 {func_tool_name}",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
@@ -210,9 +221,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
)
|
||||
async for resp in executor:
|
||||
|
||||
_final_resp: CallToolResult | None = None
|
||||
async for resp in executor: # type: ignore
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
_final_resp = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -279,13 +293,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool, func_tool_args, None
|
||||
self.run_context, func_tool, func_tool_args, _final_resp
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||
|
||||
@@ -9,3 +9,4 @@ class AstrAgentContext:
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
||||
|
||||
+114
-18
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.2.1"
|
||||
VERSION = "4.3.5"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -57,6 +57,7 @@ DEFAULT_CONFIG = {
|
||||
"web_search": False,
|
||||
"websearch_provider": "default",
|
||||
"websearch_tavily_key": [],
|
||||
"websearch_baidu_app_builder_key": "",
|
||||
"web_search_link": False,
|
||||
"display_reasoning_text": False,
|
||||
"identifier": False,
|
||||
@@ -64,13 +65,14 @@ DEFAULT_CONFIG = {
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "",
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"max_agent_step": 30,
|
||||
"tool_call_timeout": 60,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -116,6 +118,15 @@ DEFAULT_CONFIG = {
|
||||
"port": 6185,
|
||||
},
|
||||
"platform": [],
|
||||
"platform_specific": {
|
||||
# 平台特异配置:按平台分类,平台下按功能分组
|
||||
"lark": {
|
||||
"pre_ack_emoji": {"enable": False, "emojis": ["Typing"]},
|
||||
},
|
||||
"telegram": {
|
||||
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
|
||||
},
|
||||
},
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"pip_install_arg": "",
|
||||
@@ -198,6 +209,18 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
},
|
||||
"企业微信智能机器人": {
|
||||
"id": "wecom_ai_bot",
|
||||
"type": "wecom_ai_bot",
|
||||
"enable": True,
|
||||
"wecomaibot_init_respond_text": "💭 思考中...",
|
||||
"wecomaibot_friend_message_welcome_text": "",
|
||||
"wecom_ai_bot_name": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6198,
|
||||
},
|
||||
"飞书(Lark)": {
|
||||
"id": "lark",
|
||||
"type": "lark",
|
||||
@@ -438,10 +461,25 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
},
|
||||
"wecom_ai_bot_name": {
|
||||
"description": "企业微信智能机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填写正确,否则无法使用一些指令。",
|
||||
},
|
||||
"wecomaibot_init_respond_text": {
|
||||
"description": "企业微信智能机器人初始响应文本",
|
||||
"type": "string",
|
||||
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。",
|
||||
},
|
||||
"wecomaibot_friend_message_welcome_text": {
|
||||
"description": "企业微信智能机器人私聊欢迎语",
|
||||
"type": "string",
|
||||
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
},
|
||||
"discord_token": {
|
||||
"description": "Discord Bot Token",
|
||||
@@ -766,7 +804,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -812,6 +850,21 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
@@ -1032,6 +1085,7 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
|
||||
"id": "dashscope_tts",
|
||||
"provider": "dashscope",
|
||||
"type": "dashscope_tts",
|
||||
@@ -1411,11 +1465,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "服务订阅密钥",
|
||||
"hint": "Azure_TTS 服务的订阅密钥(注意不是令牌)",
|
||||
},
|
||||
"dashscope_tts_voice": {
|
||||
"description": "语音合成模型",
|
||||
"type": "string",
|
||||
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
|
||||
},
|
||||
"dashscope_tts_voice": {"description": "音色", "type": "string"},
|
||||
"gm_resp_image_modal": {
|
||||
"description": "启用图片模态",
|
||||
"type": "bool",
|
||||
@@ -1824,6 +1874,10 @@ CONFIG_METADATA_2 = {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
"tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -1976,26 +2030,28 @@ CONFIG_METADATA_3 = {
|
||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||
},
|
||||
"provider_stt_settings.enable": {
|
||||
"description": "默认启用语音转文本",
|
||||
"description": "启用语音转文本",
|
||||
"type": "bool",
|
||||
"hint": "STT 总开关。",
|
||||
},
|
||||
"provider_stt_settings.provider_id": {
|
||||
"description": "语音转文本模型",
|
||||
"description": "默认语音转文本模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"hint": "用户也可使用 /provider 指令单独选择会话的 STT 模型。",
|
||||
"_special": "select_provider_stt",
|
||||
"condition": {
|
||||
"provider_stt_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.enable": {
|
||||
"description": "默认启用文本转语音",
|
||||
"description": "启用文本转语音",
|
||||
"type": "bool",
|
||||
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
|
||||
},
|
||||
"provider_tts_settings.provider_id": {
|
||||
"description": "文本转语音模型",
|
||||
"description": "默认文本转语音模型",
|
||||
"type": "string",
|
||||
"hint": "留空代表不使用。",
|
||||
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
|
||||
"_special": "select_provider_tts",
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
@@ -2040,7 +2096,7 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.websearch_provider": {
|
||||
"description": "网页搜索提供商",
|
||||
"type": "string",
|
||||
"options": ["default", "tavily"],
|
||||
"options": ["default", "tavily", "baidu_ai_search"],
|
||||
},
|
||||
"provider_settings.websearch_tavily_key": {
|
||||
"description": "Tavily API Key",
|
||||
@@ -2051,6 +2107,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.websearch_provider": "tavily",
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_baidu_app_builder_key": {
|
||||
"description": "百度千帆智能云 APP Builder API Key",
|
||||
"type": "string",
|
||||
"hint": "参考:https://console.bce.baidu.com/iam/#/iam/apikey/list",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "baidu_ai_search",
|
||||
},
|
||||
},
|
||||
"provider_settings.web_search_link": {
|
||||
"description": "显示来源引用",
|
||||
"type": "bool",
|
||||
@@ -2086,6 +2150,10 @@ CONFIG_METADATA_3 = {
|
||||
"description": "工具调用轮数上限",
|
||||
"type": "int",
|
||||
},
|
||||
"provider_settings.tool_call_timeout": {
|
||||
"description": "工具调用超时时间(秒)",
|
||||
"type": "int",
|
||||
},
|
||||
"provider_settings.streaming_response": {
|
||||
"description": "流式回复",
|
||||
"type": "bool",
|
||||
@@ -2107,12 +2175,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
},
|
||||
"provider_settings.prompt_prefix": {
|
||||
"description": "额外前缀提示词",
|
||||
"description": "用户提示词",
|
||||
"type": "string",
|
||||
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
|
||||
},
|
||||
"provider_settings.dual_output": {
|
||||
"provider_tts_settings.dual_output": {
|
||||
"description": "开启 TTS 时同时输出语音和文字内容",
|
||||
"type": "bool",
|
||||
},
|
||||
@@ -2293,6 +2363,32 @@ CONFIG_METADATA_3 = {
|
||||
"description": "用户权限不足时是否回复",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.lark.pre_ack_emoji.enable": {
|
||||
"description": "[飞书] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.lark.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(飞书表情枚举名)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "表情枚举名参考:https://open.feishu.cn/document/server-docs/im-v1/message-reaction/emojis-introduce",
|
||||
"condition": {
|
||||
"platform_specific.lark.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": {
|
||||
"description": "[Telegram] 启用预回应表情",
|
||||
"type": "bool",
|
||||
},
|
||||
"platform_specific.telegram.pre_ack_emoji.emojis": {
|
||||
"description": "表情列表(Unicode)",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "Telegram 仅支持固定反应集合,参考:https://gist.github.com/Soulter/3f22c8e5f9c7e152e967e8bc28c97fc9",
|
||||
"condition": {
|
||||
"platform_specific.telegram.pre_ack_emoji.enable": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -164,7 +164,7 @@ class BaseDatabase(abc.ABC):
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
content: dict,
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -287,3 +287,14 @@ class BaseDatabase(abc.ABC):
|
||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||
# """Get all LLM messages for a specific conversation."""
|
||||
# ...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
search_query: str | None = None,
|
||||
platform: str | None = None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details, support search and platform filter."""
|
||||
...
|
||||
|
||||
@@ -75,7 +75,9 @@ class Persona(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "personas"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
@@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
@@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
|
||||
+151
-34
@@ -15,10 +15,8 @@ from astrbot.core.db.po import (
|
||||
SQLModel,
|
||||
)
|
||||
|
||||
from sqlalchemy import select, update, delete, text
|
||||
from sqlmodel import select, update, delete, text, func, or_, desc, col
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import or_
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
@@ -34,6 +32,12 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Initialize the database by creating tables if they do not exist."""
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
# ====
|
||||
@@ -42,10 +46,10 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime = None,
|
||||
platform_id,
|
||||
platform_type,
|
||||
count=1,
|
||||
timestamp=None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
async with self.get_db() as session:
|
||||
@@ -76,7 +80,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(
|
||||
select(func.count(PlatformStat.platform_id)).select_from(PlatformStat)
|
||||
select(func.count(col(PlatformStat.platform_id))).select_from(
|
||||
PlatformStat
|
||||
)
|
||||
)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
@@ -96,7 +102,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""),
|
||||
{"start_time": start_time},
|
||||
)
|
||||
return result.scalars().all()
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ====
|
||||
# Conversation Management
|
||||
@@ -112,7 +118,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
if platform_id:
|
||||
query = query.where(ConversationV2.platform_id == platform_id)
|
||||
# order by
|
||||
query = query.order_by(ConversationV2.created_at.desc())
|
||||
query = query.order_by(desc(ConversationV2.created_at))
|
||||
result = await session.execute(query)
|
||||
|
||||
return result.scalars().all()
|
||||
@@ -130,7 +136,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
offset = (page - 1) * page_size
|
||||
result = await session.execute(
|
||||
select(ConversationV2)
|
||||
.order_by(ConversationV2.created_at.desc())
|
||||
.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
@@ -151,25 +157,26 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
if platform_ids:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.platform_id.in_(platform_ids)
|
||||
col(ConversationV2.platform_id).in_(platform_ids)
|
||||
)
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
ConversationV2.title.ilike(f"%{search_query}%"),
|
||||
ConversationV2.content.ilike(f"%{search_query}%"),
|
||||
ConversationV2.user_id.ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
||||
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
||||
)
|
||||
)
|
||||
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||
for msg_type in kwargs["message_types"]:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
|
||||
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
||||
)
|
||||
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||
base_query = base_query.where(
|
||||
ConversationV2.platform_id.in_(kwargs["platforms"])
|
||||
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
||||
)
|
||||
|
||||
# Get total count matching the filters
|
||||
@@ -180,7 +187,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
# Get paginated results
|
||||
offset = (page - 1) * page_size
|
||||
result_query = (
|
||||
base_query.order_by(ConversationV2.created_at.desc())
|
||||
base_query.order_by(desc(ConversationV2.created_at))
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
@@ -226,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(ConversationV2).where(
|
||||
ConversationV2.conversation_id == cid
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
values = {}
|
||||
if title is not None:
|
||||
@@ -246,7 +253,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.conversation_id) == cid
|
||||
)
|
||||
)
|
||||
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
@@ -254,9 +263,116 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(ConversationV2.user_id == user_id)
|
||||
delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
|
||||
)
|
||||
|
||||
async def get_session_conversations(
|
||||
self,
|
||||
page=1,
|
||||
page_size=20,
|
||||
search_query=None,
|
||||
platform=None,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Get paginated session conversations with joined conversation and persona details."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
base_query = (
|
||||
select(
|
||||
col(Preference.scope_id).label("session_id"),
|
||||
func.json_extract(Preference.value, "$.val").label(
|
||||
"conversation_id"
|
||||
), # type: ignore
|
||||
col(ConversationV2.persona_id).label("persona_id"),
|
||||
col(ConversationV2.title).label("title"),
|
||||
col(Persona.persona_id).label("persona_name"),
|
||||
)
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 搜索筛选
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
# 平台筛选
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
base_query = base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
)
|
||||
|
||||
# 排序
|
||||
base_query = base_query.order_by(Preference.scope_id)
|
||||
|
||||
# 分页结果
|
||||
result_query = base_query.offset(offset).limit(page_size)
|
||||
result = await session.execute(result_query)
|
||||
rows = result.fetchall()
|
||||
|
||||
# 查询总数(应用相同的筛选条件)
|
||||
count_base_query = (
|
||||
select(func.count(col(Preference.scope_id)))
|
||||
.select_from(Preference)
|
||||
.outerjoin(
|
||||
ConversationV2,
|
||||
func.json_extract(Preference.value, "$.val")
|
||||
== ConversationV2.conversation_id,
|
||||
)
|
||||
.outerjoin(
|
||||
Persona, col(ConversationV2.persona_id) == Persona.persona_id
|
||||
)
|
||||
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
||||
)
|
||||
|
||||
# 应用相同的搜索和平台筛选条件到计数查询
|
||||
if search_query:
|
||||
search_pattern = f"%{search_query}%"
|
||||
count_base_query = count_base_query.where(
|
||||
or_(
|
||||
col(Preference.scope_id).ilike(search_pattern),
|
||||
col(ConversationV2.title).ilike(search_pattern),
|
||||
col(Persona.persona_id).ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
if platform:
|
||||
platform_pattern = f"{platform}:%"
|
||||
count_base_query = count_base_query.where(
|
||||
col(Preference.scope_id).like(platform_pattern)
|
||||
)
|
||||
|
||||
total_result = await session.execute(count_base_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
sessions_data = [
|
||||
{
|
||||
"session_id": row.session_id,
|
||||
"conversation_id": row.conversation_id,
|
||||
"persona_id": row.persona_id,
|
||||
"title": row.title,
|
||||
"persona_name": row.persona_name,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
return sessions_data, total
|
||||
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id,
|
||||
@@ -290,9 +406,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
cutoff_time = now - timedelta(seconds=offset_sec)
|
||||
await session.execute(
|
||||
delete(PlatformMessageHistory).where(
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
PlatformMessageHistory.created_at < cutoff_time,
|
||||
col(PlatformMessageHistory.platform_id) == platform_id,
|
||||
col(PlatformMessageHistory.user_id) == user_id,
|
||||
col(PlatformMessageHistory.created_at) < cutoff_time,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -309,7 +425,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
PlatformMessageHistory.platform_id == platform_id,
|
||||
PlatformMessageHistory.user_id == user_id,
|
||||
)
|
||||
.order_by(PlatformMessageHistory.created_at.desc())
|
||||
.order_by(desc(PlatformMessageHistory.created_at))
|
||||
)
|
||||
result = await session.execute(query.offset(offset).limit(page_size))
|
||||
return result.scalars().all()
|
||||
@@ -331,7 +447,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
"""Get an attachment by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(Attachment.id == attachment_id)
|
||||
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -374,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(Persona).where(Persona.persona_id == persona_id)
|
||||
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
values = {}
|
||||
if system_prompt is not None:
|
||||
values["system_prompt"] = system_prompt
|
||||
@@ -394,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Persona).where(Persona.persona_id == persona_id)
|
||||
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
||||
@@ -449,9 +565,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope,
|
||||
Preference.scope_id == scope_id,
|
||||
Preference.key == key,
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
col(Preference.key) == key,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
@@ -463,7 +579,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Preference).where(
|
||||
Preference.scope == scope, Preference.scope_id == scope_id
|
||||
col(Preference.scope) == scope,
|
||||
col(Preference.scope_id) == scope_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
@@ -490,7 +607,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
DeprecatedPlatformStat(
|
||||
name=data.platform_id,
|
||||
count=data.count,
|
||||
timestamp=data.timestamp.timestamp(),
|
||||
timestamp=int(data.timestamp.timestamp()),
|
||||
)
|
||||
)
|
||||
return deprecated_stats
|
||||
@@ -548,7 +665,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
DeprecatedPlatformStat(
|
||||
name=platform_id,
|
||||
count=count,
|
||||
timestamp=start_time.timestamp(),
|
||||
timestamp=int(start_time.timestamp()),
|
||||
)
|
||||
)
|
||||
return deprecated_stats
|
||||
|
||||
@@ -97,5 +97,6 @@ async def call_event_hook(
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return True
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
import random
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
@@ -22,6 +23,26 @@ class PreProcessStage(Stage):
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""在处理事件之前的预处理"""
|
||||
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
|
||||
supported = {"telegram", "lark"}
|
||||
platform = event.get_platform_name()
|
||||
cfg = (
|
||||
self.config.get("platform_specific", {})
|
||||
.get(platform, {})
|
||||
.get("pre_ack_emoji", {})
|
||||
) or {}
|
||||
emojis = cfg.get("emojis") or []
|
||||
if (
|
||||
cfg.get("enable", False)
|
||||
and platform in supported
|
||||
and emojis
|
||||
and event.is_at_or_wake_command
|
||||
):
|
||||
try:
|
||||
await event.react(random.choice(emojis))
|
||||
except Exception as e:
|
||||
logger.warning(f"{platform} 预回应表情发送失败: {e}")
|
||||
|
||||
# 路径映射
|
||||
if mappings := self.platform_settings.get("path_mapping", []):
|
||||
# 支持 Record,Image 消息段的路径映射。
|
||||
@@ -46,6 +67,9 @@ class PreProcessStage(Stage):
|
||||
ctx = self.plugin_manager.context
|
||||
stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin)
|
||||
if not stt_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置语音转文本模型。"
|
||||
)
|
||||
return
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
|
||||
@@ -6,6 +6,7 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core import logger
|
||||
@@ -185,21 +186,33 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
handler=awaitable,
|
||||
**tool_args,
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
# async for resp in wrapper:
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
anext(wrapper),
|
||||
timeout=run_context.context.tool_call_timeout,
|
||||
)
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(
|
||||
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
@@ -217,6 +230,9 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
res = await session.call_tool(
|
||||
name=tool.name,
|
||||
arguments=tool_args,
|
||||
read_timeout_seconds=timedelta(
|
||||
seconds=run_context.context.tool_call_timeout
|
||||
),
|
||||
)
|
||||
if not res:
|
||||
return
|
||||
@@ -307,6 +323,7 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.max_step: int = settings.get("max_agent_step", 30)
|
||||
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
||||
if isinstance(self.max_step, bool): # workaround: #2622
|
||||
self.max_step = 30
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
@@ -473,6 +490,7 @@ class LLMRequestSubStage(Stage):
|
||||
first_provider_request=req,
|
||||
curr_provider_request=req,
|
||||
streaming=self.streaming_response,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
|
||||
@@ -190,6 +190,16 @@ class RespondStage(Stage):
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
# 将 Plain 为空的消息段移除
|
||||
result.chain = [
|
||||
comp
|
||||
for comp in result.chain
|
||||
if not (
|
||||
isinstance(comp, Comp.Plain)
|
||||
and (not comp.text or not comp.text.strip())
|
||||
)
|
||||
]
|
||||
|
||||
# 发送消息链
|
||||
# Record 需要强制单独发送
|
||||
need_separately = {ComponentType.Record}
|
||||
|
||||
@@ -183,56 +183,60 @@ class ResultDecorateStage(Stage):
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
and tts_provider
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。"
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}"
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
)
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
@@ -275,7 +279,6 @@ class ResultDecorateStage(Stage):
|
||||
result.chain = [Image.fromFileSystem(url)]
|
||||
|
||||
# 触发转发消息
|
||||
has_forwarded = False
|
||||
if event.get_platform_name() == "aiocqhttp":
|
||||
word_cnt = 0
|
||||
for comp in result.chain:
|
||||
@@ -286,9 +289,9 @@ class ResultDecorateStage(Stage):
|
||||
uin=event.get_self_id(), name="AstrBot", content=[*result.chain]
|
||||
)
|
||||
result.chain = [node]
|
||||
has_forwarded = True
|
||||
|
||||
if not has_forwarded:
|
||||
has_plain = any(isinstance(item, Plain) for item in result.chain)
|
||||
if has_plain:
|
||||
# at 回复
|
||||
if (
|
||||
self.reply_with_mention
|
||||
|
||||
@@ -74,7 +74,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if event.get_platform_name() == "webchat":
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -416,6 +416,16 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def react(self, emoji: str):
|
||||
"""
|
||||
对消息添加表情回应。
|
||||
|
||||
默认实现为发送一条包含该表情的消息。
|
||||
注意:此实现并不一定符合所有平台的原生“表情回应”行为。
|
||||
如需支持平台原生的消息反应功能,请在对应平台的子类中重写本方法。
|
||||
"""
|
||||
await self.send(MessageChain([Plain(emoji)]))
|
||||
|
||||
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
|
||||
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
|
||||
|
||||
|
||||
@@ -82,6 +82,10 @@ class PlatformManager:
|
||||
from .sources.wecom.wecom_adapter import (
|
||||
WecomPlatformAdapter, # noqa: F401
|
||||
)
|
||||
case "wecom_ai_bot":
|
||||
from .sources.wecom_ai_bot.wecomai_adapter import (
|
||||
WecomAIBotAdapter, # noqa: F401
|
||||
)
|
||||
case "weixin_official_account":
|
||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
||||
|
||||
@@ -14,3 +14,5 @@ class PlatformMetadata:
|
||||
"""平台的默认配置模板"""
|
||||
adapter_display_name: str = None
|
||||
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
|
||||
logo_path: str = None
|
||||
"""平台适配器的 logo 文件路径(相对于插件目录)"""
|
||||
|
||||
@@ -13,10 +13,12 @@ def register_platform_adapter(
|
||||
desc: str,
|
||||
default_config_tmpl: dict = None,
|
||||
adapter_display_name: str = None,
|
||||
logo_path: str = None,
|
||||
):
|
||||
"""用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
@@ -39,6 +41,7 @@ def register_platform_adapter(
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -107,6 +107,22 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji).build())
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||
return None
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, At, File, Record
|
||||
from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .satori_adapter import SatoriPlatformAdapter
|
||||
@@ -87,6 +87,17 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
content_parts.append(f'<reply id="{component.id}"/>')
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
content_parts.append(f'<video src="{video_path_url}"/>')
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
|
||||
content = "".join(content_parts)
|
||||
channel_id = session_id
|
||||
data = {"channel_id": channel_id, "content": content}
|
||||
@@ -166,6 +177,17 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
content_parts.append(f'<reply id="{component.id}"/>')
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
content_parts.append(f'<video src="{video_path_url}"/>')
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
|
||||
content = "".join(content_parts)
|
||||
channel_id = self.session_id
|
||||
data = {"channel_id": channel_id, "content": content}
|
||||
|
||||
@@ -95,9 +95,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||
)
|
||||
id_ = self.config.get("id") or "telegram"
|
||||
return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
@@ -117,6 +116,10 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
if not self.application.updater:
|
||||
logger.error("Telegram Updater is not initialized. Cannot start polling.")
|
||||
return
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
await queue
|
||||
@@ -194,6 +197,11 @@ class TelegramPlatformAdapter(Platform):
|
||||
return cmd_name, description
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if not update.effective_chat:
|
||||
logger.warning(
|
||||
"Received a start command without an effective chat, skipping /start reply."
|
||||
)
|
||||
return
|
||||
await context.bot.send_message(
|
||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||
)
|
||||
@@ -206,15 +214,20 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
async def convert_message(
|
||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
|
||||
) -> AstrBotMessage:
|
||||
) -> AstrBotMessage | None:
|
||||
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
|
||||
|
||||
@param update: Telegram 的 Update 对象。
|
||||
@param context: Telegram 的 Context 对象。
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
if not update.message:
|
||||
logger.warning("Received an update without a message.")
|
||||
return None
|
||||
|
||||
message = AstrBotMessage()
|
||||
message.session_id = str(update.message.chat.id)
|
||||
|
||||
# 获得是群聊还是私聊
|
||||
if update.message.chat.type == ChatType.PRIVATE:
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -225,10 +238,13 @@ class TelegramPlatformAdapter(Platform):
|
||||
# Topic Group
|
||||
message.group_id += "#" + str(update.message.message_thread_id)
|
||||
message.session_id = message.group_id
|
||||
|
||||
message.message_id = str(update.message.message_id)
|
||||
_from_user = update.message.from_user
|
||||
if not _from_user:
|
||||
logger.warning("[Telegram] Received a message without a from_user.")
|
||||
return None
|
||||
message.sender = MessageMember(
|
||||
str(update.message.from_user.id), update.message.from_user.username
|
||||
str(_from_user.id), _from_user.username or "Unknown"
|
||||
)
|
||||
message.self_id = str(context.bot.username)
|
||||
message.raw_message = update
|
||||
@@ -247,22 +263,32 @@ class TelegramPlatformAdapter(Platform):
|
||||
)
|
||||
reply_abm = await self.convert_message(reply_update, context, False)
|
||||
|
||||
message.message.append(
|
||||
Comp.Reply(
|
||||
id=reply_abm.message_id,
|
||||
chain=reply_abm.message,
|
||||
sender_id=reply_abm.sender.user_id,
|
||||
sender_nickname=reply_abm.sender.nickname,
|
||||
time=reply_abm.timestamp,
|
||||
message_str=reply_abm.message_str,
|
||||
text=reply_abm.message_str,
|
||||
qq=reply_abm.sender.user_id,
|
||||
if reply_abm:
|
||||
message.message.append(
|
||||
Comp.Reply(
|
||||
id=reply_abm.message_id,
|
||||
chain=reply_abm.message,
|
||||
sender_id=reply_abm.sender.user_id,
|
||||
sender_nickname=reply_abm.sender.nickname,
|
||||
time=reply_abm.timestamp,
|
||||
message_str=reply_abm.message_str,
|
||||
text=reply_abm.message_str,
|
||||
qq=reply_abm.sender.user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if update.message.text:
|
||||
# 处理文本消息
|
||||
plain_text = update.message.text
|
||||
if (
|
||||
message.type == MessageType.GROUP_MESSAGE
|
||||
and update.message
|
||||
and update.message.reply_to_message
|
||||
and update.message.reply_to_message.from_user
|
||||
and update.message.reply_to_message.from_user.id == context.bot.id
|
||||
):
|
||||
plain_text2 = f"/@{context.bot.username} " + plain_text
|
||||
plain_text = plain_text2
|
||||
|
||||
# 群聊场景命令特殊处理
|
||||
if plain_text.startswith("/"):
|
||||
@@ -328,15 +354,25 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
elif update.message.document:
|
||||
file = await update.message.document.get_file()
|
||||
message.message = [
|
||||
Comp.File(file=file.file_path, name=update.message.document.file_name),
|
||||
]
|
||||
file_name = update.message.document.file_name or uuid.uuid4().hex
|
||||
file_path = file.file_path
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"Telegram document file_path is None, cannot save the file {file_name}."
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.File(file=file_path, name=file_name))
|
||||
|
||||
elif update.message.video:
|
||||
file = await update.message.video.get_file()
|
||||
message.message = [
|
||||
Comp.Video(file=file.file_path, path=file.file_path),
|
||||
]
|
||||
file_name = update.message.video.file_name or uuid.uuid4().hex
|
||||
file_path = file.file_path
|
||||
if file_path is None:
|
||||
logger.warning(
|
||||
f"Telegram video file_path is None, cannot save the file {file_name}."
|
||||
)
|
||||
else:
|
||||
message.message.append(Comp.Video(file=file_path, path=file.file_path))
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -135,6 +136,39 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str | None, big: bool = False):
|
||||
"""
|
||||
给原消息添加 Telegram 反应:
|
||||
- 普通 emoji:传入 '👍'、'😂' 等
|
||||
- 自定义表情:传入其 custom_emoji_id(纯数字字符串)
|
||||
- 取消本机器人的反应:传入 None 或空字符串
|
||||
"""
|
||||
try:
|
||||
# 解析 chat_id(去掉超级群的 "#<thread_id>" 片段)
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
chat_id = (self.message_obj.group_id or "").split("#")[0]
|
||||
else:
|
||||
chat_id = self.get_sender_id()
|
||||
|
||||
message_id = int(self.message_obj.message_id)
|
||||
|
||||
# 组装 reaction 参数(必须是 ReactionType 的列表)
|
||||
if not emoji: # 清空本 bot 的反应
|
||||
reaction_param = [] # 空列表表示移除本 bot 的反应
|
||||
elif emoji.isdigit(): # 自定义表情:传 custom_emoji_id
|
||||
reaction_param = [ReactionTypeCustomEmoji(emoji)]
|
||||
else: # 普通 emoji
|
||||
reaction_param = [ReactionTypeEmoji(emoji)]
|
||||
|
||||
await self.client.set_message_reaction(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
reaction=reaction_param, # 注意是列表
|
||||
is_big=big, # 可选:大动画
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] 添加反应失败: {e}")
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
message_thread_id = None
|
||||
|
||||
|
||||
@@ -91,7 +91,6 @@ class WebChatAdapter(Platform):
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = "webchat"
|
||||
abm.tag = "webchat"
|
||||
abm.sender = MessageMember(username, username)
|
||||
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding:utf-8 -*-
|
||||
|
||||
"""对企业微信发送给企业后台的消息加解密示例代码.
|
||||
@copyright: Copyright (c) 1998-2020 Tencent Inc.
|
||||
|
||||
"""
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import random
|
||||
import hashlib
|
||||
import time
|
||||
import struct
|
||||
from Crypto.Cipher import AES
|
||||
import socket
|
||||
import json
|
||||
|
||||
from . import ierror
|
||||
|
||||
"""
|
||||
关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案
|
||||
请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。
|
||||
下载后,按照README中的“Installation”小节的提示进行pycrypto安装。
|
||||
"""
|
||||
|
||||
|
||||
class FormatException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def throw_exception(message, exception_class=FormatException):
|
||||
"""my define raise exception function"""
|
||||
raise exception_class(message)
|
||||
|
||||
|
||||
class SHA1:
|
||||
"""计算企业微信的消息签名接口"""
|
||||
|
||||
def getSHA1(self, token, timestamp, nonce, encrypt):
|
||||
"""用SHA1算法生成安全签名
|
||||
@param token: 票据
|
||||
@param timestamp: 时间戳
|
||||
@param encrypt: 密文
|
||||
@param nonce: 随机字符串
|
||||
@return: 安全签名
|
||||
"""
|
||||
try:
|
||||
# 确保所有输入都是字符串类型
|
||||
if isinstance(encrypt, bytes):
|
||||
encrypt = encrypt.decode("utf-8")
|
||||
|
||||
sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)]
|
||||
sortlist.sort()
|
||||
sha = hashlib.sha1()
|
||||
sha.update("".join(sortlist).encode("utf-8"))
|
||||
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
|
||||
|
||||
|
||||
class JsonParse:
|
||||
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
|
||||
|
||||
# json消息模板
|
||||
AES_TEXT_RESPONSE_TEMPLATE = """{
|
||||
"encrypt": "%(msg_encrypt)s",
|
||||
"msgsignature": "%(msg_signaturet)s",
|
||||
"timestamp": "%(timestamp)s",
|
||||
"nonce": "%(nonce)s"
|
||||
}"""
|
||||
|
||||
def extract(self, jsontext):
|
||||
"""提取出json数据包中的加密消息
|
||||
@param jsontext: 待提取的json字符串
|
||||
@return: 提取出的加密消息字符串
|
||||
"""
|
||||
try:
|
||||
json_dict = json.loads(jsontext)
|
||||
return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_ParseJson_Error, None
|
||||
|
||||
def generate(self, encrypt, signature, timestamp, nonce):
|
||||
"""生成json消息
|
||||
@param encrypt: 加密后的消息密文
|
||||
@param signature: 安全签名
|
||||
@param timestamp: 时间戳
|
||||
@param nonce: 随机字符串
|
||||
@return: 生成的json字符串
|
||||
"""
|
||||
resp_dict = {
|
||||
"msg_encrypt": encrypt,
|
||||
"msg_signaturet": signature,
|
||||
"timestamp": timestamp,
|
||||
"nonce": nonce,
|
||||
}
|
||||
resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
|
||||
return resp_json
|
||||
|
||||
|
||||
class PKCS7Encoder:
|
||||
"""提供基于PKCS7算法的加解密接口"""
|
||||
|
||||
block_size = 32
|
||||
|
||||
def encode(self, text):
|
||||
"""对需要加密的明文进行填充补位
|
||||
@param text: 需要进行填充补位操作的明文(bytes类型)
|
||||
@return: 补齐明文字符串(bytes类型)
|
||||
"""
|
||||
text_length = len(text)
|
||||
# 计算需要填充的位数
|
||||
amount_to_pad = self.block_size - (text_length % self.block_size)
|
||||
if amount_to_pad == 0:
|
||||
amount_to_pad = self.block_size
|
||||
# 获得补位所用的字符
|
||||
pad = bytes([amount_to_pad])
|
||||
# 确保text是bytes类型
|
||||
if isinstance(text, str):
|
||||
text = text.encode("utf-8")
|
||||
return text + pad * amount_to_pad
|
||||
|
||||
def decode(self, decrypted):
|
||||
"""删除解密后明文的补位字符
|
||||
@param decrypted: 解密后的明文
|
||||
@return: 删除补位字符后的明文
|
||||
"""
|
||||
pad = ord(decrypted[-1])
|
||||
if pad < 1 or pad > 32:
|
||||
pad = 0
|
||||
return decrypted[:-pad]
|
||||
|
||||
|
||||
class Prpcrypt(object):
|
||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||
|
||||
def __init__(self, key):
|
||||
# self.key = base64.b64decode(key+"=")
|
||||
self.key = key
|
||||
# 设置加解密模式为AES的CBC模式
|
||||
self.mode = AES.MODE_CBC
|
||||
|
||||
def encrypt(self, text, receiveid):
|
||||
"""对明文进行加密
|
||||
@param text: 需要加密的明文
|
||||
@return: 加密得到的字符串
|
||||
"""
|
||||
# 16位随机字符串添加到明文开头
|
||||
text = text.encode()
|
||||
text = (
|
||||
self.get_random_str()
|
||||
+ struct.pack("I", socket.htonl(len(text)))
|
||||
+ text
|
||||
+ receiveid.encode()
|
||||
)
|
||||
|
||||
# 使用自定义的填充方式对明文进行补位填充
|
||||
pkcs7 = PKCS7Encoder()
|
||||
text = pkcs7.encode(text)
|
||||
# 加密
|
||||
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
|
||||
try:
|
||||
ciphertext = cryptor.encrypt(text)
|
||||
# 使用BASE64对加密后的字符串进行编码
|
||||
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
|
||||
except Exception as e:
|
||||
logger = logging.getLogger("astrbot")
|
||||
logger.error(e)
|
||||
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
|
||||
|
||||
def decrypt(self, text, receiveid):
|
||||
"""对解密后的明文进行补位删除
|
||||
@param text: 密文
|
||||
@return: 删除填充补位后的明文
|
||||
"""
|
||||
try:
|
||||
cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore
|
||||
# 使用BASE64对密文进行解码,然后AES-CBC解密
|
||||
plain_text = cryptor.decrypt(base64.b64decode(text))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
|
||||
try:
|
||||
pad = plain_text[-1]
|
||||
# 去掉补位字符串
|
||||
# pkcs7 = PKCS7Encoder()
|
||||
# plain_text = pkcs7.encode(plain_text)
|
||||
# 去除16位随机字符串
|
||||
content = plain_text[16:-pad]
|
||||
json_len = socket.ntohl(struct.unpack("I", content[:4])[0])
|
||||
json_content = content[4 : json_len + 4].decode("utf-8")
|
||||
from_receiveid = content[json_len + 4 :].decode("utf-8")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return ierror.WXBizMsgCrypt_IllegalBuffer, None
|
||||
if from_receiveid != receiveid:
|
||||
print("receiveid not match", receiveid, from_receiveid)
|
||||
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
|
||||
return 0, json_content
|
||||
|
||||
def get_random_str(self):
|
||||
"""随机生成16位字符串
|
||||
@return: 16位字符串
|
||||
"""
|
||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||
|
||||
|
||||
class WXBizJsonMsgCrypt(object):
|
||||
# 构造函数
|
||||
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
||||
try:
|
||||
self.key = base64.b64decode(sEncodingAESKey + "=")
|
||||
assert len(self.key) == 32
|
||||
except Exception as e:
|
||||
throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException)
|
||||
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
||||
self.m_sToken = sToken
|
||||
self.m_sReceiveId = sReceiveId
|
||||
|
||||
# 验证URL
|
||||
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||
# @param sNonce: 随机串,对应URL参数的nonce
|
||||
# @param sEchoStr: 随机串,对应URL参数的echostr
|
||||
# @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效
|
||||
# @return:成功0,失败返回对应的错误码
|
||||
|
||||
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if not signature == sMsgSignature:
|
||||
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
|
||||
return ret, sReplyEchoStr
|
||||
|
||||
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
|
||||
# 将企业回复用户的消息加密打包
|
||||
# @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串
|
||||
# @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间
|
||||
# @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce
|
||||
# sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串,
|
||||
# return:成功0,sEncryptMsg,失败返回对应的错误码None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
|
||||
encrypt = encrypt.decode("utf-8") # type: ignore
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if timestamp is None:
|
||||
timestamp = str(int(time.time()))
|
||||
# 生成安全签名
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
jsonParse = JsonParse()
|
||||
return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce)
|
||||
|
||||
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
|
||||
# 检验消息的真实性,并且获取解密后的明文
|
||||
# @param sMsgSignature: 签名串,对应URL参数的msg_signature
|
||||
# @param sTimeStamp: 时间戳,对应URL参数的timestamp
|
||||
# @param sNonce: 随机串,对应URL参数的nonce
|
||||
# @param sPostData: 密文,对应POST请求的数据
|
||||
# json_content: 解密后的原文,当return返回0时有效
|
||||
# @return: 成功0,失败返回对应的错误码
|
||||
# 验证安全签名
|
||||
jsonParse = JsonParse()
|
||||
ret, encrypt = jsonParse.extract(sPostData)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
sha1 = SHA1()
|
||||
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
|
||||
if ret != 0:
|
||||
return ret, None
|
||||
if not signature == sMsgSignature:
|
||||
print("signature not match")
|
||||
print(signature)
|
||||
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
|
||||
pc = Prpcrypt(self.key)
|
||||
ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId)
|
||||
return ret, json_content
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
企业微信智能机器人平台适配器包
|
||||
"""
|
||||
|
||||
from .wecomai_adapter import WecomAIBotAdapter
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_event import WecomAIBotMessageEvent
|
||||
from .wecomai_server import WecomAIBotServer
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
|
||||
__all__ = [
|
||||
"WecomAIBotAdapter",
|
||||
"WecomAIBotAPIClient",
|
||||
"WecomAIBotMessageEvent",
|
||||
"WecomAIBotServer",
|
||||
"WecomAIBotConstants",
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#########################################################################
|
||||
# Author: jonyqin
|
||||
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
|
||||
# File Name: ierror.py
|
||||
# Description:定义错误码含义
|
||||
#########################################################################
|
||||
WXBizMsgCrypt_OK = 0
|
||||
WXBizMsgCrypt_ValidateSignature_Error = -40001
|
||||
WXBizMsgCrypt_ParseJson_Error = -40002
|
||||
WXBizMsgCrypt_ComputeSignature_Error = -40003
|
||||
WXBizMsgCrypt_IllegalAesKey = -40004
|
||||
WXBizMsgCrypt_ValidateCorpid_Error = -40005
|
||||
WXBizMsgCrypt_EncryptAES_Error = -40006
|
||||
WXBizMsgCrypt_DecryptAES_Error = -40007
|
||||
WXBizMsgCrypt_IllegalBuffer = -40008
|
||||
WXBizMsgCrypt_EncodeBase64_Error = -40009
|
||||
WXBizMsgCrypt_DecodeBase64_Error = -40010
|
||||
WXBizMsgCrypt_GenReturnJson_Error = -40011
|
||||
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
企业微信智能机器人平台适配器
|
||||
基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调
|
||||
参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Awaitable, Any, Dict, Optional, Callable
|
||||
|
||||
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, At, Image
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
from .wecomai_api import (
|
||||
WecomAIBotAPIClient,
|
||||
WecomAIBotMessageParser,
|
||||
WecomAIBotStreamMessageBuilder,
|
||||
)
|
||||
from .wecomai_event import WecomAIBotMessageEvent
|
||||
from .wecomai_server import WecomAIBotServer
|
||||
from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr
|
||||
from .wecomai_utils import (
|
||||
WecomAIBotConstants,
|
||||
format_session_id,
|
||||
generate_random_string,
|
||||
process_encrypted_image,
|
||||
)
|
||||
|
||||
|
||||
class WecomAIQueueListener:
|
||||
"""企业微信智能机器人队列监听器,参考webchat的QueueListener设计"""
|
||||
|
||||
def __init__(
|
||||
self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]]
|
||||
) -> None:
|
||||
self.queue_mgr = queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, session_id: str):
|
||||
"""监听特定会话的队列"""
|
||||
queue = self.queue_mgr.get_or_create_queue(session_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""监控新会话队列并启动监听器"""
|
||||
monitored_sessions = set()
|
||||
|
||||
while True:
|
||||
# 检查新会话
|
||||
current_sessions = set(self.queue_mgr.queues.keys())
|
||||
new_sessions = current_sessions - monitored_sessions
|
||||
|
||||
# 为新会话启动监听器
|
||||
for session_id in new_sessions:
|
||||
task = asyncio.create_task(self.listen_to_queue(session_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_sessions.add(session_id)
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
# 清理已不存在的会话
|
||||
removed_sessions = monitored_sessions - current_sessions
|
||||
monitored_sessions -= removed_sessions
|
||||
|
||||
# 清理过期的待处理响应
|
||||
self.queue_mgr.cleanup_expired_responses()
|
||||
|
||||
await asyncio.sleep(1) # 每秒检查一次新会话
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
"wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息"
|
||||
)
|
||||
class WecomAIBotAdapter(Platform):
|
||||
"""企业微信智能机器人适配器"""
|
||||
|
||||
def __init__(
|
||||
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||
) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
|
||||
# 初始化配置参数
|
||||
self.token = self.config["token"]
|
||||
self.encoding_aes_key = self.config["encoding_aes_key"]
|
||||
self.port = int(self.config["port"])
|
||||
self.host = self.config.get("callback_server_host", "0.0.0.0")
|
||||
self.bot_name = self.config.get("wecom_ai_bot_name", "")
|
||||
self.initial_respond_text = self.config.get(
|
||||
"wecomaibot_init_respond_text", "💭 思考中..."
|
||||
)
|
||||
self.friend_message_welcome_text = self.config.get(
|
||||
"wecomaibot_friend_message_welcome_text", ""
|
||||
)
|
||||
|
||||
# 平台元数据
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wecom_ai_bot",
|
||||
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
|
||||
id=self.config.get("id", "wecom_ai_bot"),
|
||||
)
|
||||
|
||||
# 初始化 API 客户端
|
||||
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
|
||||
|
||||
# 初始化 HTTP 服务器
|
||||
self.server = WecomAIBotServer(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
api_client=self.api_client,
|
||||
message_handler=self._process_message,
|
||||
)
|
||||
|
||||
# 事件循环和关闭信号
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
# 队列监听器
|
||||
self.queue_listener = WecomAIQueueListener(
|
||||
wecomai_queue_mgr, self._handle_queued_message
|
||||
)
|
||||
|
||||
async def _handle_queued_message(self, data: dict):
|
||||
"""处理队列中的消息,类似webchat的callback"""
|
||||
try:
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
except Exception as e:
|
||||
logger.error(f"处理队列消息时发生异常: {e}")
|
||||
|
||||
async def _process_message(
|
||||
self, message_data: Dict[str, Any], callback_params: Dict[str, str]
|
||||
) -> Optional[str]:
|
||||
"""处理接收到的消息
|
||||
|
||||
Args:
|
||||
message_data: 解密后的消息数据
|
||||
callback_params: 回调参数 (nonce, timestamp)
|
||||
|
||||
Returns:
|
||||
加密后的响应消息,无需响应时返回 None
|
||||
"""
|
||||
msgtype = message_data.get("msgtype")
|
||||
if not msgtype:
|
||||
logger.warning(f"消息类型未知,忽略: {message_data}")
|
||||
return None
|
||||
session_id = self._extract_session_id(message_data)
|
||||
if msgtype in ("text", "image", "mixed"):
|
||||
# user sent a text / image / mixed message
|
||||
try:
|
||||
# create a brand-new unique stream_id for this message session
|
||||
stream_id = f"{session_id}_{generate_random_string(10)}"
|
||||
await self._enqueue_message(
|
||||
message_data, callback_params, stream_id, session_id
|
||||
)
|
||||
wecomai_queue_mgr.set_pending_response(stream_id, callback_params)
|
||||
|
||||
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
stream_id, self.initial_respond_text, False
|
||||
)
|
||||
return await self.api_client.encrypt_message(
|
||||
resp, callback_params["nonce"], callback_params["timestamp"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("处理消息时发生异常: %s", e)
|
||||
return None
|
||||
elif msgtype == "stream":
|
||||
# wechat server is requesting for updates of a stream
|
||||
stream_id = message_data["stream"]["id"]
|
||||
if not wecomai_queue_mgr.has_back_queue(stream_id):
|
||||
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
|
||||
# 返回结束标志,告诉微信服务器流已结束
|
||||
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
stream_id, "", True
|
||||
)
|
||||
resp = await self.api_client.encrypt_message(
|
||||
end_message,
|
||||
callback_params["nonce"],
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
return resp
|
||||
queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
if queue.empty():
|
||||
logger.debug(
|
||||
f"No new messages in back queue for stream_id: {stream_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# aggregate all delta chains in the back queue
|
||||
latest_plain_content = ""
|
||||
image_base64 = []
|
||||
finish = False
|
||||
while not queue.empty():
|
||||
msg = await queue.get()
|
||||
if msg["type"] == "plain":
|
||||
latest_plain_content = msg["data"] or ""
|
||||
elif msg["type"] == "image":
|
||||
image_base64.append(msg["image_data"])
|
||||
elif msg["type"] == "end":
|
||||
# stream end
|
||||
finish = True
|
||||
wecomai_queue_mgr.remove_queues(stream_id)
|
||||
break
|
||||
else:
|
||||
pass
|
||||
logger.debug(
|
||||
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}"
|
||||
)
|
||||
if latest_plain_content or image_base64:
|
||||
msg_items = []
|
||||
if finish and image_base64:
|
||||
for img_b64 in image_base64:
|
||||
# get md5 of image
|
||||
img_data = base64.b64decode(img_b64)
|
||||
img_md5 = hashlib.md5(img_data).hexdigest()
|
||||
msg_items.append(
|
||||
{
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
|
||||
"image": {"base64": img_b64, "md5": img_md5},
|
||||
}
|
||||
)
|
||||
image_base64 = []
|
||||
|
||||
plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream(
|
||||
stream_id, latest_plain_content, msg_items, finish
|
||||
)
|
||||
encrypted_message = await self.api_client.encrypt_message(
|
||||
plain_message,
|
||||
callback_params["nonce"],
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
if encrypted_message:
|
||||
logger.debug(
|
||||
f"Stream message sent successfully, stream_id: {stream_id}"
|
||||
)
|
||||
else:
|
||||
logger.error("消息加密失败")
|
||||
return encrypted_message
|
||||
return None
|
||||
elif msgtype == "event":
|
||||
event = message_data.get("event")
|
||||
if event == "enter_chat" and self.friend_message_welcome_text:
|
||||
# 用户进入会话,发送欢迎消息
|
||||
try:
|
||||
resp = WecomAIBotStreamMessageBuilder.make_text(
|
||||
self.friend_message_welcome_text
|
||||
)
|
||||
return await self.api_client.encrypt_message(
|
||||
resp,
|
||||
callback_params["nonce"],
|
||||
callback_params["timestamp"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("处理欢迎消息时发生异常: %s", e)
|
||||
return None
|
||||
pass
|
||||
|
||||
def _extract_session_id(self, message_data: Dict[str, Any]) -> str:
|
||||
"""从消息数据中提取会话ID"""
|
||||
user_id = message_data.get("from", {}).get("userid", "default_user")
|
||||
return format_session_id("wecomai", user_id)
|
||||
|
||||
async def _enqueue_message(
|
||||
self,
|
||||
message_data: Dict[str, Any],
|
||||
callback_params: Dict[str, str],
|
||||
stream_id: str,
|
||||
session_id: str,
|
||||
):
|
||||
"""将消息放入队列进行异步处理"""
|
||||
input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id)
|
||||
_ = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
message_payload = {
|
||||
"message_data": message_data,
|
||||
"callback_params": callback_params,
|
||||
"session_id": session_id,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
await input_queue.put(message_payload)
|
||||
logger.debug(f"[WecomAI] 消息已入队: {stream_id}")
|
||||
|
||||
async def convert_message(self, payload: dict) -> AstrBotMessage:
|
||||
"""转换队列中的消息数据为AstrBotMessage,类似webchat的convert_message"""
|
||||
message_data = payload["message_data"]
|
||||
session_id = payload["session_id"]
|
||||
# callback_params = payload["callback_params"] # 保留但暂时不使用
|
||||
|
||||
# 解析消息内容
|
||||
msgtype = message_data.get("msgtype")
|
||||
content = ""
|
||||
image_base64 = []
|
||||
|
||||
_img_url_to_process = []
|
||||
msg_items = []
|
||||
|
||||
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||
content = WecomAIBotMessageParser.parse_text_message(message_data)
|
||||
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||
_img_url_to_process.append(
|
||||
WecomAIBotMessageParser.parse_image_message(message_data)
|
||||
)
|
||||
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
|
||||
# 提取混合消息中的文本内容
|
||||
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
|
||||
text_parts = []
|
||||
for item in msg_items or []:
|
||||
if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT:
|
||||
text_content = item.get("text", {}).get("content", "")
|
||||
if text_content:
|
||||
text_parts.append(text_content)
|
||||
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
|
||||
image_url = item.get("image", {}).get("url", "")
|
||||
if image_url:
|
||||
_img_url_to_process.append(image_url)
|
||||
content = " ".join(text_parts) if text_parts else ""
|
||||
else:
|
||||
content = f"[{msgtype}消息]"
|
||||
|
||||
# 并行处理图片下载和解密
|
||||
if _img_url_to_process:
|
||||
tasks = [
|
||||
process_encrypted_image(url, self.encoding_aes_key)
|
||||
for url in _img_url_to_process
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
for success, result in results:
|
||||
if success:
|
||||
image_base64.append(result)
|
||||
else:
|
||||
logger.error(f"处理加密图片失败: {result}")
|
||||
|
||||
# 构建 AstrBotMessage
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_name
|
||||
abm.message_str = content or "[未知消息]"
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.timestamp = int(time.time())
|
||||
abm.raw_message = payload
|
||||
|
||||
# 发送者信息
|
||||
abm.sender = MessageMember(
|
||||
user_id=message_data.get("from", {}).get("userid", "unknown"),
|
||||
nickname=message_data.get("from", {}).get("userid", "unknown"),
|
||||
)
|
||||
|
||||
# 消息类型
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if message_data.get("chattype") == "group"
|
||||
else MessageType.FRIEND_MESSAGE
|
||||
)
|
||||
abm.session_id = session_id
|
||||
|
||||
# 消息内容
|
||||
abm.message = []
|
||||
|
||||
# 处理 At
|
||||
if self.bot_name and f"@{self.bot_name}" in abm.message_str:
|
||||
abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip()
|
||||
abm.message.append(At(qq=self.bot_name, name=self.bot_name))
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
if image_base64:
|
||||
for img_b64 in image_base64:
|
||||
abm.message.append(Image.fromBase64(img_b64))
|
||||
|
||||
logger.debug(f"WecomAIAdapter: {abm.message}")
|
||||
return abm
|
||||
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
# 企业微信智能机器人主要通过回调响应,这里记录日志
|
||||
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
"""运行适配器,同时启动HTTP服务器和队列监听器"""
|
||||
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
|
||||
|
||||
async def run_both():
|
||||
# 同时运行HTTP服务器和队列监听器
|
||||
await asyncio.gather(
|
||||
self.server.start_server(),
|
||||
self.queue_listener.run(),
|
||||
)
|
||||
|
||||
return run_both()
|
||||
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("企业微信智能机器人适配器正在关闭...")
|
||||
self.shutdown_event.set()
|
||||
await self.server.shutdown()
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
"""获取平台元数据"""
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
"""处理消息,创建消息事件并提交到事件队列"""
|
||||
try:
|
||||
message_event = WecomAIBotMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
api_client=self.api_client,
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("处理消息时发生异常: %s", e)
|
||||
|
||||
def get_client(self) -> WecomAIBotAPIClient:
|
||||
"""获取 API 客户端"""
|
||||
return self.api_client
|
||||
|
||||
def get_server(self) -> WecomAIBotServer:
|
||||
"""获取 HTTP 服务器实例"""
|
||||
return self.server
|
||||
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
企业微信智能机器人 API 客户端
|
||||
处理消息加密解密、API 调用等
|
||||
"""
|
||||
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import Dict, Any, Optional, Tuple, Union
|
||||
from Crypto.Cipher import AES
|
||||
import aiohttp
|
||||
|
||||
from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class WecomAIBotAPIClient:
|
||||
"""企业微信智能机器人 API 客户端"""
|
||||
|
||||
def __init__(self, token: str, encoding_aes_key: str):
|
||||
"""初始化 API 客户端
|
||||
|
||||
Args:
|
||||
token: 企业微信机器人 Token
|
||||
encoding_aes_key: 消息加密密钥
|
||||
"""
|
||||
self.token = token
|
||||
self.encoding_aes_key = encoding_aes_key
|
||||
self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串
|
||||
|
||||
async def decrypt_message(
|
||||
self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str
|
||||
) -> Tuple[int, Optional[Dict[str, Any]]]:
|
||||
"""解密企业微信消息
|
||||
|
||||
Args:
|
||||
encrypted_data: 加密的消息数据
|
||||
msg_signature: 消息签名
|
||||
timestamp: 时间戳
|
||||
nonce: 随机数
|
||||
|
||||
Returns:
|
||||
(错误码, 解密后的消息数据字典)
|
||||
"""
|
||||
try:
|
||||
ret, decrypted_msg = self.wxcpt.DecryptMsg(
|
||||
encrypted_data, msg_signature, timestamp, nonce
|
||||
)
|
||||
|
||||
if ret != WecomAIBotConstants.SUCCESS:
|
||||
logger.error(f"消息解密失败,错误码: {ret}")
|
||||
return ret, None
|
||||
|
||||
# 解析 JSON
|
||||
if decrypted_msg:
|
||||
try:
|
||||
message_data = json.loads(decrypted_msg)
|
||||
logger.debug(f"解密成功,消息内容: {message_data}")
|
||||
return WecomAIBotConstants.SUCCESS, message_data
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}")
|
||||
return WecomAIBotConstants.PARSE_XML_ERROR, None
|
||||
else:
|
||||
logger.error("解密消息为空")
|
||||
return WecomAIBotConstants.DECRYPT_ERROR, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密过程发生异常: {e}")
|
||||
return WecomAIBotConstants.DECRYPT_ERROR, None
|
||||
|
||||
async def encrypt_message(
|
||||
self, plain_message: str, nonce: str, timestamp: str
|
||||
) -> Optional[str]:
|
||||
"""加密消息
|
||||
|
||||
Args:
|
||||
plain_message: 明文消息
|
||||
nonce: 随机数
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
加密后的消息,失败时返回 None
|
||||
"""
|
||||
try:
|
||||
ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp)
|
||||
|
||||
if ret != WecomAIBotConstants.SUCCESS:
|
||||
logger.error(f"消息加密失败,错误码: {ret}")
|
||||
return None
|
||||
|
||||
logger.debug("消息加密成功")
|
||||
return encrypted_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密过程发生异常: {e}")
|
||||
return None
|
||||
|
||||
def verify_url(
|
||||
self, msg_signature: str, timestamp: str, nonce: str, echostr: str
|
||||
) -> str:
|
||||
"""验证回调 URL
|
||||
|
||||
Args:
|
||||
msg_signature: 消息签名
|
||||
timestamp: 时间戳
|
||||
nonce: 随机数
|
||||
echostr: 验证字符串
|
||||
|
||||
Returns:
|
||||
验证结果字符串
|
||||
"""
|
||||
try:
|
||||
ret, echo_result = self.wxcpt.VerifyURL(
|
||||
msg_signature, timestamp, nonce, echostr
|
||||
)
|
||||
|
||||
if ret != WecomAIBotConstants.SUCCESS:
|
||||
logger.error(f"URL 验证失败,错误码: {ret}")
|
||||
return "verify fail"
|
||||
|
||||
logger.info("URL 验证成功")
|
||||
return echo_result if echo_result else "verify fail"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"URL 验证发生异常: {e}")
|
||||
return "verify fail"
|
||||
|
||||
async def process_encrypted_image(
|
||||
self, image_url: str, aes_key_base64: Optional[str] = None
|
||||
) -> Tuple[bool, Union[bytes, str]]:
|
||||
"""下载并解密加密图片
|
||||
|
||||
Args:
|
||||
image_url: 加密图片的 URL
|
||||
aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥
|
||||
|
||||
Returns:
|
||||
(是否成功, 图片数据或错误信息)
|
||||
"""
|
||||
try:
|
||||
# 下载图片
|
||||
logger.info(f"开始下载加密图片: {image_url}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=15) as response:
|
||||
if response.status != 200:
|
||||
error_msg = f"图片下载失败,状态码: {response.status}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
encrypted_data = await response.read()
|
||||
logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节")
|
||||
|
||||
# 准备解密密钥
|
||||
if aes_key_base64 is None:
|
||||
aes_key_base64 = self.encoding_aes_key
|
||||
|
||||
if not aes_key_base64:
|
||||
raise ValueError("AES 密钥不能为空")
|
||||
|
||||
# Base64 解码密钥
|
||||
aes_key = base64.b64decode(
|
||||
aes_key_base64 + "=" * (-len(aes_key_base64) % 4)
|
||||
)
|
||||
if len(aes_key) != 32:
|
||||
raise ValueError("无效的 AES 密钥长度: 应为 32 字节")
|
||||
|
||||
iv = aes_key[:16] # 初始向量为密钥前 16 字节
|
||||
|
||||
# 解密图片数据
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted_data = cipher.decrypt(encrypted_data)
|
||||
|
||||
# 去除 PKCS#7 填充
|
||||
pad_len = decrypted_data[-1]
|
||||
if pad_len > 32: # AES-256 块大小为 32 字节
|
||||
raise ValueError("无效的填充长度 (大于32字节)")
|
||||
|
||||
decrypted_data = decrypted_data[:-pad_len]
|
||||
logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节")
|
||||
|
||||
return True, decrypted_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
error_msg = f"图片下载失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"参数错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"图片处理异常: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
class WecomAIBotStreamMessageBuilder:
|
||||
"""企业微信智能机器人流消息构建器"""
|
||||
|
||||
@staticmethod
|
||||
def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str:
|
||||
"""构建文本流消息
|
||||
|
||||
Args:
|
||||
stream_id: 流 ID
|
||||
content: 文本内容
|
||||
finish: 是否结束
|
||||
|
||||
Returns:
|
||||
JSON 格式的流消息字符串
|
||||
"""
|
||||
plain = {
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||
"stream": {"id": stream_id, "finish": finish, "content": content},
|
||||
}
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def make_image_stream(
|
||||
stream_id: str, image_data: bytes, finish: bool = False
|
||||
) -> str:
|
||||
"""构建图片流消息
|
||||
|
||||
Args:
|
||||
stream_id: 流 ID
|
||||
image_data: 图片二进制数据
|
||||
finish: 是否结束
|
||||
|
||||
Returns:
|
||||
JSON 格式的流消息字符串
|
||||
"""
|
||||
image_md5 = hashlib.md5(image_data).hexdigest()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
plain = {
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||
"stream": {
|
||||
"id": stream_id,
|
||||
"finish": finish,
|
||||
"msg_item": [
|
||||
{
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE,
|
||||
"image": {"base64": image_base64, "md5": image_md5},
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def make_mixed_stream(
|
||||
stream_id: str, content: str, msg_items: list, finish: bool = False
|
||||
) -> str:
|
||||
"""构建混合类型流消息
|
||||
|
||||
Args:
|
||||
stream_id: 流 ID
|
||||
content: 文本内容
|
||||
msg_items: 消息项列表
|
||||
finish: 是否结束
|
||||
|
||||
Returns:
|
||||
JSON 格式的流消息字符串
|
||||
"""
|
||||
plain = {
|
||||
"msgtype": WecomAIBotConstants.MSG_TYPE_STREAM,
|
||||
"stream": {"id": stream_id, "finish": finish, "msg_item": msg_items},
|
||||
}
|
||||
if content:
|
||||
plain["stream"]["content"] = content
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def make_text(content: str) -> str:
|
||||
"""构建文本消息
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
|
||||
Returns:
|
||||
JSON 格式的文本消息字符串
|
||||
"""
|
||||
plain = {"msgtype": "text", "text": {"content": content}}
|
||||
return json.dumps(plain, ensure_ascii=False)
|
||||
|
||||
|
||||
class WecomAIBotMessageParser:
|
||||
"""企业微信智能机器人消息解析器"""
|
||||
|
||||
@staticmethod
|
||||
def parse_text_message(data: Dict[str, Any]) -> Optional[str]:
|
||||
"""解析文本消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
文本内容,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("text", {}).get("content")
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("文本消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_image_message(data: Dict[str, Any]) -> Optional[str]:
|
||||
"""解析图片消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
图片 URL,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("image", {}).get("url")
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("图片消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""解析流消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
流消息数据,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
stream_data = data.get("stream", {})
|
||||
return {
|
||||
"id": stream_data.get("id"),
|
||||
"finish": stream_data.get("finish"),
|
||||
"content": stream_data.get("content"),
|
||||
"msg_item": stream_data.get("msg_item", []),
|
||||
}
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("流消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]:
|
||||
"""解析混合消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
消息项列表,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("mixed", {}).get("msg_item", [])
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("混合消息解析失败")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""解析事件消息
|
||||
|
||||
Args:
|
||||
data: 消息数据
|
||||
|
||||
Returns:
|
||||
事件数据,解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return data.get("event", {})
|
||||
except (KeyError, TypeError):
|
||||
logger.warning("事件消息解析失败")
|
||||
return None
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
企业微信智能机器人事件处理模块,处理消息事件的发送和接收
|
||||
"""
|
||||
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import (
|
||||
Image,
|
||||
Plain,
|
||||
)
|
||||
from astrbot.api import logger
|
||||
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_queue_mgr import wecomai_queue_mgr
|
||||
|
||||
|
||||
class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
"""企业微信智能机器人消息事件"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj,
|
||||
platform_meta,
|
||||
session_id: str,
|
||||
api_client: WecomAIBotAPIClient,
|
||||
):
|
||||
"""初始化消息事件
|
||||
|
||||
Args:
|
||||
message_str: 消息字符串
|
||||
message_obj: 消息对象
|
||||
platform_meta: 平台元数据
|
||||
session_id: 会话 ID
|
||||
api_client: API 客户端
|
||||
"""
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.api_client = api_client
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
message_chain: MessageChain,
|
||||
stream_id: str,
|
||||
streaming: bool = False,
|
||||
):
|
||||
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
|
||||
if not message_chain:
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
}
|
||||
)
|
||||
return ""
|
||||
|
||||
data = ""
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
data = comp.text
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"session_id": stream_id,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# 处理图片消息
|
||||
try:
|
||||
image_base64 = await comp.convert_to_base64()
|
||||
if image_base64:
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"image_data": image_base64,
|
||||
"streaming": streaming,
|
||||
"session_id": stream_id,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning("图片数据为空,跳过")
|
||||
except Exception as e:
|
||||
logger.error("处理图片消息失败: %s", e)
|
||||
else:
|
||||
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息"""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
"wecom_ai_bot platform event raw_message should be a dict"
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback=False):
|
||||
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||
final_data = ""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
"wecom_ai_bot platform event raw_message should be a dict"
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id)
|
||||
|
||||
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
|
||||
increment_plain = ""
|
||||
async for chain in generator:
|
||||
# 累积增量内容,并改写 Plain 段
|
||||
chain.squash_plain()
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
comp.text = increment_plain + comp.text
|
||||
increment_plain = comp.text
|
||||
break
|
||||
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"session_id": self.session_id,
|
||||
}
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
|
||||
final_data += await WecomAIBotMessageEvent._send(
|
||||
chain,
|
||||
stream_id=stream_id,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
await back_queue.put(
|
||||
{
|
||||
"type": "complete", # complete means we return the final result
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"session_id": self.session_id,
|
||||
}
|
||||
)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
企业微信智能机器人队列管理器
|
||||
参考 webchat_queue_mgr.py,为企业微信智能机器人实现队列机制
|
||||
支持异步消息处理和流式响应
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
class WecomAIQueueMgr:
|
||||
"""企业微信智能机器人队列管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.queues: Dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
|
||||
|
||||
self.back_queues: Dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输出队列的映射 - 用于发送机器人响应"""
|
||||
|
||||
self.pending_responses: Dict[str, Dict[str, Any]] = {}
|
||||
"""待处理的响应缓存,用于流式响应"""
|
||||
|
||||
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输入队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
输入队列实例
|
||||
"""
|
||||
if session_id not in self.queues:
|
||||
self.queues[session_id] = asyncio.Queue()
|
||||
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
|
||||
return self.queues[session_id]
|
||||
|
||||
def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输出队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
输出队列实例
|
||||
"""
|
||||
if session_id not in self.back_queues:
|
||||
self.back_queues[session_id] = asyncio.Queue()
|
||||
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
|
||||
return self.back_queues[session_id]
|
||||
|
||||
def remove_queues(self, session_id: str):
|
||||
"""移除指定会话的所有队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
|
||||
if session_id in self.back_queues:
|
||||
del self.back_queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输出队列: {session_id}")
|
||||
|
||||
if session_id in self.pending_responses:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||
|
||||
def has_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
是否存在队列
|
||||
"""
|
||||
return session_id in self.queues
|
||||
|
||||
def has_back_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的输出队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
是否存在输出队列
|
||||
"""
|
||||
return session_id in self.back_queues
|
||||
|
||||
def set_pending_response(self, session_id: str, callback_params: Dict[str, str]):
|
||||
"""设置待处理的响应参数
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
callback_params: 回调参数(nonce, timestamp等)
|
||||
"""
|
||||
self.pending_responses[session_id] = {
|
||||
"callback_params": callback_params,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
|
||||
|
||||
def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取待处理的响应参数
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
响应参数,如果不存在则返回None
|
||||
"""
|
||||
return self.pending_responses.get(session_id)
|
||||
|
||||
def cleanup_expired_responses(self, max_age_seconds: int = 300):
|
||||
"""清理过期的待处理响应
|
||||
|
||||
Args:
|
||||
max_age_seconds: 最大存活时间(秒)
|
||||
"""
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, response_data in self.pending_responses.items():
|
||||
if current_time - response_data["timestamp"] > max_age_seconds:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取队列统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return {
|
||||
"input_queues": len(self.queues),
|
||||
"output_queues": len(self.back_queues),
|
||||
"pending_responses": len(self.pending_responses),
|
||||
}
|
||||
|
||||
|
||||
# 全局队列管理器实例
|
||||
wecomai_queue_mgr = WecomAIQueueMgr()
|
||||
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
企业微信智能机器人 HTTP 服务器
|
||||
处理企业微信智能机器人的 HTTP 回调请求
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Callable
|
||||
|
||||
import quart
|
||||
from astrbot.api import logger
|
||||
|
||||
from .wecomai_api import WecomAIBotAPIClient
|
||||
from .wecomai_utils import WecomAIBotConstants
|
||||
|
||||
|
||||
class WecomAIBotServer:
|
||||
"""企业微信智能机器人 HTTP 服务器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
api_client: WecomAIBotAPIClient,
|
||||
message_handler: Optional[
|
||||
Callable[[Dict[str, Any], Dict[str, str]], Any]
|
||||
] = None,
|
||||
):
|
||||
"""初始化服务器
|
||||
|
||||
Args:
|
||||
host: 监听地址
|
||||
port: 监听端口
|
||||
api_client: API客户端实例
|
||||
message_handler: 消息处理回调函数
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.api_client = api_client
|
||||
self.message_handler = message_handler
|
||||
|
||||
self.app = quart.Quart(__name__)
|
||||
self._setup_routes()
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置 Quart 路由"""
|
||||
|
||||
# 使用 Quart 的 add_url_rule 方法添加路由
|
||||
self.app.add_url_rule(
|
||||
"/webhook/wecom-ai-bot",
|
||||
view_func=self.verify_url,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/webhook/wecom-ai-bot",
|
||||
view_func=self.handle_message,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def verify_url(self):
|
||||
"""验证回调 URL"""
|
||||
args = quart.request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
echostr = args.get("echostr")
|
||||
|
||||
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||
logger.error("URL 验证参数缺失")
|
||||
return "verify fail", 400
|
||||
|
||||
# 类型检查确保不为 None
|
||||
assert msg_signature is not None
|
||||
assert timestamp is not None
|
||||
assert nonce is not None
|
||||
assert echostr is not None
|
||||
|
||||
logger.info("收到企业微信智能机器人 WebHook URL 验证请求。")
|
||||
result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr)
|
||||
return result, 200, {"Content-Type": "text/plain"}
|
||||
|
||||
async def handle_message(self):
|
||||
"""处理消息回调"""
|
||||
args = quart.request.args
|
||||
msg_signature = args.get("msg_signature")
|
||||
timestamp = args.get("timestamp")
|
||||
nonce = args.get("nonce")
|
||||
|
||||
if not all([msg_signature, timestamp, nonce]):
|
||||
logger.error("消息回调参数缺失")
|
||||
return "缺少必要参数", 400
|
||||
|
||||
# 类型检查确保不为 None
|
||||
assert msg_signature is not None
|
||||
assert timestamp is not None
|
||||
assert nonce is not None
|
||||
|
||||
logger.debug(
|
||||
f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取请求体
|
||||
post_data = await quart.request.get_data()
|
||||
|
||||
# 确保 post_data 是 bytes 类型
|
||||
if isinstance(post_data, str):
|
||||
post_data = post_data.encode("utf-8")
|
||||
|
||||
# 解密消息
|
||||
ret_code, message_data = await self.api_client.decrypt_message(
|
||||
post_data, msg_signature, timestamp, nonce
|
||||
)
|
||||
|
||||
if ret_code != WecomAIBotConstants.SUCCESS or not message_data:
|
||||
logger.error("消息解密失败,错误码: %d", ret_code)
|
||||
return "消息解密失败", 400
|
||||
|
||||
# 调用消息处理器
|
||||
response = None
|
||||
if self.message_handler:
|
||||
try:
|
||||
response = await self.message_handler(
|
||||
message_data, {"nonce": nonce, "timestamp": timestamp}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("消息处理器执行异常: %s", e)
|
||||
return "消息处理异常", 500
|
||||
|
||||
if response:
|
||||
return response, 200, {"Content-Type": "text/plain"}
|
||||
else:
|
||||
return "success", 200, {"Content-Type": "text/plain"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("处理消息时发生异常: %s", e)
|
||||
return "内部服务器错误", 500
|
||||
|
||||
async def start_server(self):
|
||||
"""启动服务器"""
|
||||
logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port)
|
||||
|
||||
try:
|
||||
await self.app.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("服务器运行异常: %s", e)
|
||||
raise
|
||||
|
||||
async def shutdown_trigger(self):
|
||||
"""关闭触发器"""
|
||||
await self.shutdown_event.wait()
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
logger.info("企业微信智能机器人服务器正在关闭...")
|
||||
self.shutdown_event.set()
|
||||
|
||||
def get_app(self):
|
||||
"""获取 Quart 应用实例"""
|
||||
return self.app
|
||||
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
企业微信智能机器人工具模块
|
||||
提供常量定义、工具函数和辅助方法
|
||||
"""
|
||||
|
||||
import string
|
||||
import random
|
||||
import hashlib
|
||||
import base64
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from Crypto.Cipher import AES
|
||||
from typing import Any, Tuple
|
||||
from astrbot.api import logger
|
||||
|
||||
|
||||
# 常量定义
|
||||
class WecomAIBotConstants:
|
||||
"""企业微信智能机器人常量"""
|
||||
|
||||
# 消息类型
|
||||
MSG_TYPE_TEXT = "text"
|
||||
MSG_TYPE_IMAGE = "image"
|
||||
MSG_TYPE_MIXED = "mixed"
|
||||
MSG_TYPE_STREAM = "stream"
|
||||
MSG_TYPE_EVENT = "event"
|
||||
|
||||
# 流消息状态
|
||||
STREAM_CONTINUE = False
|
||||
STREAM_FINISH = True
|
||||
|
||||
# 错误码
|
||||
SUCCESS = 0
|
||||
DECRYPT_ERROR = -40001
|
||||
VALIDATE_SIGNATURE_ERROR = -40002
|
||||
PARSE_XML_ERROR = -40003
|
||||
COMPUTE_SIGNATURE_ERROR = -40004
|
||||
ILLEGAL_AES_KEY = -40005
|
||||
VALIDATE_APPID_ERROR = -40006
|
||||
ENCRYPT_AES_ERROR = -40007
|
||||
ILLEGAL_BUFFER = -40008
|
||||
|
||||
|
||||
def generate_random_string(length: int = 10) -> str:
|
||||
"""生成随机字符串
|
||||
|
||||
Args:
|
||||
length: 字符串长度,默认为 10
|
||||
|
||||
Returns:
|
||||
随机字符串
|
||||
"""
|
||||
letters = string.ascii_letters + string.digits
|
||||
return "".join(random.choice(letters) for _ in range(length))
|
||||
|
||||
|
||||
def calculate_image_md5(image_data: bytes) -> str:
|
||||
"""计算图片数据的 MD5 值
|
||||
|
||||
Args:
|
||||
image_data: 图片二进制数据
|
||||
|
||||
Returns:
|
||||
MD5 哈希值(十六进制字符串)
|
||||
"""
|
||||
return hashlib.md5(image_data).hexdigest()
|
||||
|
||||
|
||||
def encode_image_base64(image_data: bytes) -> str:
|
||||
"""将图片数据编码为 Base64
|
||||
|
||||
Args:
|
||||
image_data: 图片二进制数据
|
||||
|
||||
Returns:
|
||||
Base64 编码的字符串
|
||||
"""
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
|
||||
def format_session_id(session_type: str, session_id: str) -> str:
|
||||
"""格式化会话 ID
|
||||
|
||||
Args:
|
||||
session_type: 会话类型 ("user", "group")
|
||||
session_id: 原始会话 ID
|
||||
|
||||
Returns:
|
||||
格式化后的会话 ID
|
||||
"""
|
||||
return f"wecom_ai_bot_{session_type}_{session_id}"
|
||||
|
||||
|
||||
def parse_session_id(formatted_session_id: str) -> Tuple[str, str]:
|
||||
"""解析格式化的会话 ID
|
||||
|
||||
Args:
|
||||
formatted_session_id: 格式化的会话 ID
|
||||
|
||||
Returns:
|
||||
(会话类型, 原始会话ID)
|
||||
"""
|
||||
parts = formatted_session_id.split("_", 3)
|
||||
if (
|
||||
len(parts) >= 4
|
||||
and parts[0] == "wecom"
|
||||
and parts[1] == "ai"
|
||||
and parts[2] == "bot"
|
||||
):
|
||||
return parts[3], "_".join(parts[4:]) if len(parts) > 4 else ""
|
||||
return "user", formatted_session_id
|
||||
|
||||
|
||||
def safe_json_loads(json_str: str, default: Any = None) -> Any:
|
||||
"""安全地解析 JSON 字符串
|
||||
|
||||
Args:
|
||||
json_str: JSON 字符串
|
||||
default: 解析失败时的默认值
|
||||
|
||||
Returns:
|
||||
解析结果或默认值
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}")
|
||||
return default
|
||||
|
||||
|
||||
def format_error_response(error_code: int, error_msg: str) -> str:
|
||||
"""格式化错误响应
|
||||
|
||||
Args:
|
||||
error_code: 错误码
|
||||
error_msg: 错误信息
|
||||
|
||||
Returns:
|
||||
格式化的错误响应字符串
|
||||
"""
|
||||
return f"Error {error_code}: {error_msg}"
|
||||
|
||||
|
||||
async def process_encrypted_image(
|
||||
image_url: str, aes_key_base64: str
|
||||
) -> Tuple[bool, str]:
|
||||
"""下载并解密加密图片
|
||||
|
||||
Args:
|
||||
image_url: 加密图片的URL
|
||||
aes_key_base64: Base64编码的AES密钥(与回调加解密相同)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码,
|
||||
status 为 False 时 data 是错误信息
|
||||
"""
|
||||
# 1. 下载加密图片
|
||||
logger.info("开始下载加密图片: %s", image_url)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=15) as response:
|
||||
response.raise_for_status()
|
||||
encrypted_data = await response.read()
|
||||
logger.info("图片下载成功,大小: %d 字节", len(encrypted_data))
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
error_msg = f"下载图片失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 2. 准备AES密钥和IV
|
||||
if not aes_key_base64:
|
||||
raise ValueError("AES密钥不能为空")
|
||||
|
||||
# Base64解码密钥 (自动处理填充)
|
||||
aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4))
|
||||
if len(aes_key) != 32:
|
||||
raise ValueError("无效的AES密钥长度: 应为32字节")
|
||||
|
||||
iv = aes_key[:16] # 初始向量为密钥前16字节
|
||||
|
||||
# 3. 解密图片数据
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted_data = cipher.decrypt(encrypted_data)
|
||||
|
||||
# 4. 去除PKCS#7填充 (Python 3兼容写法)
|
||||
pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值
|
||||
if pad_len > 32: # AES-256块大小为32字节
|
||||
raise ValueError("无效的填充长度 (大于32字节)")
|
||||
|
||||
decrypted_data = decrypted_data[:-pad_len]
|
||||
logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data))
|
||||
|
||||
# 5. 转换为base64编码
|
||||
base64_data = base64.b64encode(decrypted_data).decode("utf-8")
|
||||
logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data))
|
||||
|
||||
return True, base64_data
|
||||
@@ -68,14 +68,15 @@ class Provider(AbstractProvider):
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""获得提供商 Key"""
|
||||
return self.provider_config.get("key", [])
|
||||
keys = self.provider_config.get("key", [""])
|
||||
return keys or [""]
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_key(self, key: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_models(self) -> List[str]:
|
||||
async def get_models(self) -> List[str]:
|
||||
"""获得支持的模型列表"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class ProviderAnthropic(Provider):
|
||||
)
|
||||
|
||||
self.chosen_api_key: str = ""
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.api_keys: List = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
@@ -70,9 +70,13 @@ class ProviderAnthropic(Provider):
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], str)
|
||||
else tool_call["function"]["arguments"],
|
||||
"input": (
|
||||
json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(
|
||||
tool_call["function"]["arguments"], str
|
||||
)
|
||||
else tool_call["function"]["arguments"]
|
||||
),
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
)
|
||||
@@ -355,9 +359,11 @@ class ProviderAnthropic(Provider):
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
import os
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
import aiohttp
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer
|
||||
|
||||
try:
|
||||
from dashscope.aigc.multimodal_conversation import MultiModalConversation
|
||||
except (
|
||||
ImportError
|
||||
): # pragma: no cover - older dashscope versions without Qwen TTS support
|
||||
MultiModalConversation = None
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
@@ -26,16 +38,112 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
model = self.get_model()
|
||||
if not model:
|
||||
raise RuntimeError("Dashscope TTS model is not configured.")
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav")
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
if self._is_qwen_tts_model(model):
|
||||
audio_bytes, ext = await self._synthesize_with_qwen_tts(model, text)
|
||||
else:
|
||||
audio_bytes, ext = await self._synthesize_with_cosyvoice(model, text)
|
||||
|
||||
if not audio_bytes:
|
||||
raise RuntimeError(
|
||||
"Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable."
|
||||
)
|
||||
|
||||
path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}")
|
||||
with open(path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
return path
|
||||
|
||||
def _call_qwen_tts(self, model: str, text: str):
|
||||
if MultiModalConversation is None:
|
||||
raise RuntimeError(
|
||||
"dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"api_key": self.chosen_api_key,
|
||||
"voice": self.voice or "Cherry",
|
||||
}
|
||||
if not self.voice:
|
||||
logging.warning(
|
||||
"No voice specified for Qwen TTS model, using default 'Cherry'."
|
||||
)
|
||||
return MultiModalConversation.call(**kwargs)
|
||||
|
||||
async def _synthesize_with_qwen_tts(
|
||||
self, model: str, text: str
|
||||
) -> Tuple[Optional[bytes], str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
|
||||
audio_bytes = await self._extract_audio_from_response(response)
|
||||
if not audio_bytes:
|
||||
raise RuntimeError(
|
||||
f"Audio synthesis failed for model '{model}'. {response}"
|
||||
)
|
||||
ext = ".wav"
|
||||
return audio_bytes, ext
|
||||
|
||||
async def _extract_audio_from_response(self, response) -> Optional[bytes]:
|
||||
output = getattr(response, "output", None)
|
||||
audio_obj = getattr(output, "audio", None) if output is not None else None
|
||||
if not audio_obj:
|
||||
return None
|
||||
|
||||
data_b64 = getattr(audio_obj, "data", None)
|
||||
if data_b64:
|
||||
try:
|
||||
return base64.b64decode(data_b64)
|
||||
except (ValueError, TypeError):
|
||||
logging.error("Failed to decode base64 audio data.")
|
||||
return None
|
||||
|
||||
url = getattr(audio_obj, "url", None)
|
||||
if url:
|
||||
return await self._download_audio_from_url(url)
|
||||
return None
|
||||
|
||||
async def _download_audio_from_url(self, url: str) -> Optional[bytes]:
|
||||
if not url:
|
||||
return None
|
||||
timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as response:
|
||||
return await response.read()
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
||||
logging.error(f"Failed to download audio from URL {url}: {e}")
|
||||
return None
|
||||
|
||||
async def _synthesize_with_cosyvoice(
|
||||
self, model: str, text: str
|
||||
) -> Tuple[Optional[bytes], str]:
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=model,
|
||||
voice=self.voice,
|
||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||
)
|
||||
audio = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.synthesizer.call, text, self.timeout_ms
|
||||
loop = asyncio.get_event_loop()
|
||||
audio_bytes = await loop.run_in_executor(
|
||||
None, synthesizer.call, text, self.timeout_ms
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
f.write(audio)
|
||||
return path
|
||||
if not audio_bytes:
|
||||
resp = synthesizer.get_response()
|
||||
if resp and isinstance(resp, dict):
|
||||
raise RuntimeError(
|
||||
f"Audio synthesis failed for model '{model}'. {resp}".strip()
|
||||
)
|
||||
return audio_bytes, ".wav"
|
||||
|
||||
def _is_qwen_tts_model(self, model: str) -> bool:
|
||||
model_lower = model.lower()
|
||||
return "tts" in model_lower and model_lower.startswith("qwen")
|
||||
|
||||
@@ -3,7 +3,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
@@ -60,7 +60,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = provider_config.get("key", [])
|
||||
self.api_keys: List = super().get_keys()
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||
|
||||
@@ -218,19 +218,21 @@ class ProviderGoogleGenAI(Provider):
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
thinking_config=(
|
||||
types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
),
|
||||
24576,
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None,
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None
|
||||
),
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
@@ -274,9 +276,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
if role == "user":
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
types.Part.from_text(text=item["text"] or " ")
|
||||
if item["type"] == "text"
|
||||
else process_image_url(item["image_url"])
|
||||
(
|
||||
types.Part.from_text(text=item["text"] or " ")
|
||||
if item["type"] == "text"
|
||||
else process_image_url(item["image_url"])
|
||||
)
|
||||
for item in content
|
||||
]
|
||||
else:
|
||||
|
||||
@@ -38,7 +38,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.api_keys: List = super().get_keys()
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
# This file was originally created to adapt to glm-4v-flash, which only supports one image in the context.
|
||||
# It is no longer specifically adapted to Zhipu's models. To ensure compatibility, this
|
||||
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
|
||||
@register_provider_adapter("zhipu_chat_completion", "智谱 Chat Completion 提供商适配器")
|
||||
class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -19,63 +19,3 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
model = model or self.get_model()
|
||||
# glm-4v-flash 只支持一张图片
|
||||
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
|
||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||
logger.debug(context_query)
|
||||
new_context_query_ = []
|
||||
for i in range(0, len(context_query) - 1, 2):
|
||||
if isinstance(context_query[i].get("content", ""), list):
|
||||
continue
|
||||
new_context_query_.append(context_query[i])
|
||||
new_context_query_.append(context_query[i + 1])
|
||||
new_context_query_.append(context_query[-1]) # 保留最后一条记录
|
||||
context_query = new_context_query_
|
||||
logger.debug(context_query)
|
||||
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {"messages": context_query, **model_cfgs}
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(
|
||||
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。"
|
||||
)
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import re
|
||||
import inspect
|
||||
import types
|
||||
import typing
|
||||
from typing import List, Any, Type, Dict
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -14,6 +16,18 @@ class GreedyStr(str):
|
||||
pass
|
||||
|
||||
|
||||
def unwrap_optional(annotation) -> tuple:
|
||||
"""去掉 Optional[T] / Union[T, None] / T|None,返回 T"""
|
||||
args = typing.get_args(annotation)
|
||||
non_none_args = [a for a in args if a is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
return (non_none_args[0],)
|
||||
elif len(non_none_args) > 1:
|
||||
return tuple(non_none_args)
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter):
|
||||
"""标准指令过滤器"""
|
||||
@@ -40,6 +54,8 @@ class CommandFilter(HandlerFilter):
|
||||
for k, v in self.handler_params.items():
|
||||
if isinstance(v, type):
|
||||
result += f"{k}({v.__name__}),"
|
||||
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
|
||||
result += f"{k}({v}),"
|
||||
else:
|
||||
result += f"{k}({type(v).__name__})={v},"
|
||||
result = result.rstrip(",")
|
||||
@@ -95,7 +111,8 @@ class CommandFilter(HandlerFilter):
|
||||
# 没有 GreedyStr 的情况
|
||||
if i >= len(params):
|
||||
if (
|
||||
isinstance(param_type_or_default_val, Type)
|
||||
isinstance(param_type_or_default_val, (Type, types.UnionType))
|
||||
or typing.get_origin(param_type_or_default_val) is typing.Union
|
||||
or param_type_or_default_val is inspect.Parameter.empty
|
||||
):
|
||||
# 是类型
|
||||
@@ -132,7 +149,20 @@ class CommandFilter(HandlerFilter):
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
origin = typing.get_origin(param_type_or_default_val)
|
||||
if origin in (typing.Union, types.UnionType):
|
||||
# 注解是联合类型
|
||||
# NOTE: 目前没有处理联合类型嵌套相关的注解写法
|
||||
nn_types = unwrap_optional(param_type_or_default_val)
|
||||
if len(nn_types) == 1:
|
||||
# 只有一个非 NoneType 类型
|
||||
result[param_name] = nn_types[0](params[i])
|
||||
else:
|
||||
# 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。
|
||||
# NOTE: 目前还没有做类型校验
|
||||
result[param_name] = params[i]
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"
|
||||
|
||||
@@ -205,7 +205,6 @@ def register_command_group(
|
||||
new_group = CommandGroupFilter(command_group_name, alias)
|
||||
|
||||
def decorator(obj):
|
||||
# 根指令组
|
||||
if new_group:
|
||||
handler_md = get_handler_or_create(
|
||||
obj, EventType.AdapterMessageEvent, **kwargs
|
||||
@@ -213,6 +212,7 @@ def register_command_group(
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
raise ValueError("注册指令组失败。")
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -220,9 +220,11 @@ def register_command_group(
|
||||
class RegisteringCommandable:
|
||||
"""用于指令组级联注册"""
|
||||
|
||||
group: CommandGroupFilter = register_command_group
|
||||
command: CommandFilter = register_command
|
||||
custom_filter = register_custom_filter
|
||||
group: Callable[..., Callable[..., "RegisteringCommandable"]] = (
|
||||
register_command_group
|
||||
)
|
||||
command: Callable[..., Callable[..., None]] = register_command
|
||||
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
|
||||
@@ -6,7 +6,7 @@ class CommandTokens:
|
||||
self.tokens = []
|
||||
self.len = 0
|
||||
|
||||
def get(self, idx: int):
|
||||
def get(self, idx: int) -> str | None:
|
||||
if idx >= self.len:
|
||||
return None
|
||||
return self.tokens[idx].strip()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import typing
|
||||
import traceback
|
||||
import os
|
||||
import inspect
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
@@ -13,10 +14,10 @@ from astrbot.core.config.default import (
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.platform.register import platform_registry, platform_cls_map
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, file_token_service
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
import asyncio
|
||||
@@ -149,6 +150,7 @@ class ConfigRoute(Route):
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
||||
self._logo_token_cache = {} # 缓存logo token,避免重复注册
|
||||
self.acm = core_lifecycle.astrbot_config_mgr
|
||||
self.routes = {
|
||||
"/config/abconf/new": ("POST", self.create_abconf),
|
||||
@@ -655,6 +657,78 @@ class ConfigRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效~").__dict__
|
||||
|
||||
async def get_llm_tools(self):
|
||||
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools = tool_mgr.get_func_desc_openai_style()
|
||||
return Response().ok(tools).__dict__
|
||||
|
||||
async def _register_platform_logo(self, platform, platform_default_tmpl):
|
||||
"""注册平台logo文件并生成访问令牌"""
|
||||
if not platform.logo_path:
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查缓存
|
||||
cache_key = f"{platform.name}:{platform.logo_path}"
|
||||
if cache_key in self._logo_token_cache:
|
||||
cached_token = self._logo_token_cache[cache_key]
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl:
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
elif not isinstance(platform_default_tmpl[platform.name], dict):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
platform_default_tmpl[platform.name]["logo_token"] = cached_token
|
||||
logger.debug(f"Using cached logo token for platform {platform.name}")
|
||||
return
|
||||
|
||||
# 获取平台适配器类
|
||||
platform_cls = platform_cls_map.get(platform.name)
|
||||
if not platform_cls:
|
||||
logger.warning(f"Platform class not found for {platform.name}")
|
||||
return
|
||||
|
||||
# 获取插件目录路径
|
||||
module_file = inspect.getfile(platform_cls)
|
||||
plugin_dir = os.path.dirname(module_file)
|
||||
|
||||
# 解析logo文件路径
|
||||
logo_file_path = os.path.join(plugin_dir, platform.logo_path)
|
||||
|
||||
# 检查文件是否存在并注册令牌
|
||||
if os.path.exists(logo_file_path):
|
||||
logo_token = await file_token_service.register_file(
|
||||
logo_file_path, timeout=3600
|
||||
)
|
||||
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl:
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
elif not isinstance(platform_default_tmpl[platform.name], dict):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
|
||||
platform_default_tmpl[platform.name]["logo_token"] = logo_token
|
||||
|
||||
# 缓存token
|
||||
self._logo_token_cache[cache_key] = logo_token
|
||||
|
||||
logger.debug(f"Logo token registered for platform {platform.name}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Platform {platform.name} logo file not found: {logo_file_path}"
|
||||
)
|
||||
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning(
|
||||
f"Failed to import required modules for platform {platform.name}: {e}"
|
||||
)
|
||||
except (OSError, IOError) as e:
|
||||
logger.warning(f"File system error for platform {platform.name} logo: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Unexpected error registering logo for platform {platform.name}: {e}"
|
||||
)
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
@@ -662,9 +736,21 @@ class ConfigRoute(Route):
|
||||
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
|
||||
"platform"
|
||||
]["config_template"]
|
||||
|
||||
# 收集需要注册logo的平台
|
||||
logo_registration_tasks = []
|
||||
for platform in platform_registry:
|
||||
if platform.default_config_tmpl:
|
||||
platform_default_tmpl[platform.name] = platform.default_config_tmpl
|
||||
# 收集logo注册任务
|
||||
if platform.logo_path:
|
||||
logo_registration_tasks.append(
|
||||
self._register_platform_logo(platform, platform_default_tmpl)
|
||||
)
|
||||
|
||||
# 并行执行logo注册
|
||||
if logo_registration_tasks:
|
||||
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)
|
||||
|
||||
# 服务提供商的默认配置模板注入
|
||||
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
|
||||
|
||||
@@ -20,6 +20,7 @@ class SessionManagementRoute(Route):
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.db_helper = db_helper
|
||||
self.routes = {
|
||||
"/session/list": ("GET", self.list_sessions),
|
||||
"/session/update_persona": ("POST", self.update_session_persona),
|
||||
@@ -39,22 +40,42 @@ class SessionManagementRoute(Route):
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
|
||||
session_conversations = {}
|
||||
for pref in preferences:
|
||||
session_conversations[pref.scope_id] = pref.value["val"]
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = int(request.args.get("page_size", 20))
|
||||
search_query = request.args.get("search", "")
|
||||
platform = request.args.get("platform", "")
|
||||
|
||||
# 获取活跃的会话数据(处于对话内的会话)
|
||||
sessions_data, total = await self.db_helper.get_session_conversations(
|
||||
page, page_size, search_query, platform
|
||||
)
|
||||
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
persona_mgr = self.core_lifecycle.persona_mgr
|
||||
personas = persona_mgr.personas_v3
|
||||
|
||||
sessions = []
|
||||
|
||||
# 构建会话信息
|
||||
for session_id, conversation_id in session_conversations.items():
|
||||
# 循环补充非数据库信息,如 provider 和 session 状态
|
||||
for data in sessions_data:
|
||||
session_id = data["session_id"]
|
||||
conversation_id = data["conversation_id"]
|
||||
conv_persona_id = data["persona_id"]
|
||||
title = data["title"]
|
||||
persona_name = data["persona_name"]
|
||||
|
||||
# 处理 persona 显示
|
||||
if persona_name is None:
|
||||
if conv_persona_id is None:
|
||||
if default_persona := persona_mgr.selected_default_persona_v3:
|
||||
persona_name = default_persona["name"]
|
||||
else:
|
||||
persona_name = "[%None]"
|
||||
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_id": persona_name,
|
||||
"chat_provider_id": None,
|
||||
"stt_provider_id": None,
|
||||
"tts_provider_id": None,
|
||||
@@ -79,31 +100,10 @@ class SessionManagementRoute(Route):
|
||||
"session_raw_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
"title": title,
|
||||
}
|
||||
|
||||
# 获取对话信息
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=session_id, conversation_id=conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_id"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_id"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = persona_mgr.selected_default_persona_v3
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
chat_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION, umo=session_id
|
||||
)
|
||||
@@ -172,6 +172,14 @@ class SessionManagementRoute(Route):
|
||||
"available_chat_providers": available_chat_providers,
|
||||
"available_stt_providers": available_stt_providers,
|
||||
"available_tts_providers": available_tts_providers,
|
||||
"pagination": {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"total_pages": (total + page_size - 1) // page_size
|
||||
if page_size > 0
|
||||
else 0,
|
||||
},
|
||||
}
|
||||
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
@@ -273,6 +273,20 @@ class ToolsRoute(Route):
|
||||
server_data = await request.json
|
||||
config = server_data.get("mcp_server_config", None)
|
||||
|
||||
if not isinstance(config, dict) or not config:
|
||||
return Response().error("无效的 MCP 服务器配置").__dict__
|
||||
|
||||
if "mcpServers" in config:
|
||||
keys = list(config["mcpServers"].keys())
|
||||
if not keys:
|
||||
return Response().error("MCP 服务器配置不能为空").__dict__
|
||||
if len(keys) > 1:
|
||||
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
|
||||
config = config["mcpServers"][keys[0]]
|
||||
else:
|
||||
if not config:
|
||||
return Response().error("MCP 服务器配置不能为空").__dict__
|
||||
|
||||
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
|
||||
return (
|
||||
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
|
||||
|
||||
@@ -9,6 +9,8 @@ from astrbot.core.config.default import VERSION
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot.core.db.migration.helper import do_migration_v4, check_migration_needed_v4
|
||||
|
||||
CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'}
|
||||
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(
|
||||
@@ -113,17 +115,19 @@ class UpdateRoute(Route):
|
||||
|
||||
if reboot:
|
||||
await self.core_lifecycle.restart()
|
||||
return (
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
else:
|
||||
return (
|
||||
ret = (
|
||||
Response()
|
||||
.ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。")
|
||||
.__dict__
|
||||
)
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -135,9 +139,8 @@ class UpdateRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
return Response().error(f"下载管理面板文件失败: {e}").__dict__
|
||||
return (
|
||||
Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
)
|
||||
ret = Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
return ret, 200, CLEAR_SITE_DATA_HEADERS
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 修复"开启 TTS 时同时输出语音和文字内容"功能不可用的问题 ([#2900](https://github.com/AstrBotDevs/AstrBot/issues/2900))
|
||||
2. feat: 优化了会话管理页的数据查询逻辑,添加分页和搜索功能,大幅度提高响应速度 ([#2906](https://github.com/AstrBotDevs/AstrBot/issues/2906))
|
||||
3. fix: 用 mi-googlesearch-python 库代替失效的 googlesearch-python 库 ([#2909](https://github.com/AstrBotDevs/AstrBot/issues/2909))
|
||||
4. feat: 支持在 Telegram 和飞书下请求 LLM 前预表态功能 ([#2737](https://github.com/AstrBotDevs/AstrBot/issues/2737))
|
||||
5. perf: 对于 Telegram 群聊,将回复机器人的消息视为唤醒机器人 ([#2926](https://github.com/AstrBotDevs/AstrBot/issues/2926))
|
||||
6. feat: 提示词前缀配置项升级为“用户提示词”,支持 `{{prompt}}` 作为用户输入的占位符。
|
||||
7. fix: 增加知识库插件的启用检查,避免部分情况下导致知识库页面白屏的问题。
|
||||
8. fix: 修复接入智谱提供商后,工具调用无限循环的问题,并停止支持 glm-4v-flash ([#2931](https://github.com/AstrBotDevs/AstrBot/issues/2931))
|
||||
9. fix: 修复注册指令组指令时的 Pyright 类型检查提示 ([#2923](https://github.com/AstrBotDevs/AstrBot/issues/2923))
|
||||
10. refactor: 优化 packages/astrbot 内置插件的代码结构以提高可维护性和可读性 ([#2924](https://github.com/AstrBotDevs/AstrBot/issues/2924))
|
||||
11. fix: 修复插件指令注解为联合类型时处理异常的问题 ([#2925](https://github.com/AstrBotDevs/AstrBot/issues/2925))
|
||||
12. feat: 支持注册消息平台适配器的 logo ([#2109](https://github.com/AstrBotDevs/AstrBot/issues/2109))
|
||||
@@ -0,0 +1 @@
|
||||
# What's Changed
|
||||
@@ -0,0 +1,7 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 修复 /reset 指令没有清除群聊上下文感知数据的问题 ([#2954](https://github.com/AstrBotDevs/AstrBot/issues/2954))
|
||||
2. fix: 修复自带的 WebSearch 插件可能在部分场景下无法使用的问题
|
||||
3. fix: 发送阶段强行将 Plain 为空的消息段移除
|
||||
4. fix: on_tool_end无法获得工具返回的结果 ([#2956](https://github.com/AstrBotDevs/AstrBot/issues/2956))
|
||||
5. feat: 为插件市场的搜索增加拼音与首字母搜索功能 ([#2936](https://github.com/AstrBotDevs/AstrBot/issues/2936))
|
||||
@@ -0,0 +1,12 @@
|
||||
# What's Changed
|
||||
|
||||
1. fix: 修复了代码执行器插件不能正确获得发送来文件的问题 ([#2970](https://github.com/Soulter/AstrBot/issues/2970))
|
||||
2. fix: 修改的 DeepSeek 默认 modalities,避免默认勾选图像导致的报错。 ([#2963](https://github.com/Soulter/AstrBot/issues/2963))
|
||||
3. fix: 事件钩子终止事件传播后不继续执行 ([#2989](https://github.com/Soulter/AstrBot/issues/2989))
|
||||
4. fix: 启动了 TTS 但未配置 TTS 模型时,At 和 Reply 发送人无效
|
||||
5. fix: 修复 session-management 中人格错误的显示为默认人格的问题 ([#3000](https://github.com/Soulter/AstrBot/issues/3000))
|
||||
6. fix: 修复了删除对话时,聊天增强中的记录未被清除,导致新对话中仍然出现之前的聊天记录。 ([#3002](https://github.com/Soulter/AstrBot/issues/3002))
|
||||
7. fix: 修复阿里云百炼平台 TTS 下接入 CosyVoice V2, Qwen TTS 生成报错的问题 ([#2964](https://github.com/Soulter/AstrBot/issues/2964))
|
||||
8. perf: 优化 SQLite 参数配置,对话和会话管理增加输入防抖机制 ([#2969](https://github.com/Soulter/AstrBot/issues/2969))
|
||||
9. feat: 在新对话中重用先前的对话人格设置 ([#3005](https://github.com/Soulter/AstrBot/issues/3005))
|
||||
10. feat: 从 WebUI 更新后清除浏览器缓存 ([#2958](https://github.com/Soulter/AstrBot/issues/2958))
|
||||
@@ -0,0 +1,8 @@
|
||||
# What's Changed
|
||||
|
||||
1. feat: 支持接入企业微信智能机器人平台 ([#3034](https://github.com/AstrBotDevs/AstrBot/issues/3034))
|
||||
2. feat: 内置网页搜索功能支持接入百度 AI 搜索 ([#3031](https://github.com/AstrBotDevs/AstrBot/issues/3031))
|
||||
3. feat: 支持配置工具调用超时时间并适配 ModelScope 的 MCP Server 配置 ([#3039](https://github.com/AstrBotDevs/AstrBot/issues/3039))
|
||||
4. feat: 添加并优化服务提供商独立测试功能 ([#3024](https://github.com/AstrBotDevs/AstrBot/issues/3024))
|
||||
5. feat: satori 适配器支持 video、reply 消息类型 ([#3035](https://github.com/AstrBotDevs/AstrBot/issues/3035))
|
||||
6. fix: 修复 `/alter_cmd reset scene <num> xxx` 不可用的问题
|
||||
@@ -27,6 +27,7 @@
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.7",
|
||||
"markdown-it": "^14.1.0",
|
||||
"pinyin-pro": "^3.26.0",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
|
||||
@@ -27,7 +27,9 @@
|
||||
<v-btn
|
||||
variant="outlined"
|
||||
color="error"
|
||||
size="small"
|
||||
rounded="xl"
|
||||
:disabled="loading"
|
||||
@click="$emit('delete', item)"
|
||||
>
|
||||
{{ t('core.common.itemCard.delete') }}
|
||||
@@ -35,7 +37,9 @@
|
||||
<v-btn
|
||||
variant="tonal"
|
||||
color="primary"
|
||||
size="small"
|
||||
rounded="xl"
|
||||
:disabled="loading"
|
||||
@click="$emit('edit', item)"
|
||||
>
|
||||
{{ t('core.common.itemCard.edit') }}
|
||||
@@ -44,11 +48,14 @@
|
||||
v-if="showCopyButton"
|
||||
variant="tonal"
|
||||
color="secondary"
|
||||
size="small"
|
||||
rounded="xl"
|
||||
:disabled="loading"
|
||||
@click="$emit('copy', item)"
|
||||
>
|
||||
{{ t('core.common.itemCard.copy') }}
|
||||
</v-btn>
|
||||
<slot name="actions" :item="item"></slot>
|
||||
<v-spacer></v-spacer>
|
||||
</v-card-actions>
|
||||
|
||||
|
||||
@@ -101,6 +101,7 @@
|
||||
},
|
||||
"messages": {
|
||||
"pluginNotAvailable": "Plugin not installed or unavailable",
|
||||
"pluginNotActivated": "astrbot_plugin_knowledge_base plugin not activated, please activate it in the plugin management page and restart AstrBot",
|
||||
"checkPluginFailed": "Failed to check plugin",
|
||||
"installFailed": "Installation failed",
|
||||
"installPluginFailed": "Failed to install plugin",
|
||||
|
||||
@@ -31,7 +31,8 @@
|
||||
"available": "Available",
|
||||
"unavailable": "Unavailable",
|
||||
"pending": "Pending...",
|
||||
"errorMessage": "Error Message"
|
||||
"errorMessage": "Error Message",
|
||||
"test": "Test"
|
||||
},
|
||||
"logs": {
|
||||
"title": "Service Logs",
|
||||
@@ -76,7 +77,8 @@
|
||||
},
|
||||
"error": {
|
||||
"sessionSeparation": "Failed to get session isolation configuration",
|
||||
"fetchStatus": "Failed to get service provider status"
|
||||
"fetchStatus": "Failed to get service provider status",
|
||||
"testError": "Test failed for {id}: {error}"
|
||||
},
|
||||
"confirm": {
|
||||
"delete": "Are you sure you want to delete service provider {id}?"
|
||||
|
||||
@@ -80,6 +80,9 @@
|
||||
"save": "Save",
|
||||
"testConnection": "Test Connection",
|
||||
"sync": "Sync"
|
||||
},
|
||||
"tips": {
|
||||
"timeoutConfig": "Please configure tool call timeout separately in the configuration page"
|
||||
}
|
||||
},
|
||||
"serverDetail": {
|
||||
|
||||
@@ -101,6 +101,7 @@
|
||||
},
|
||||
"messages": {
|
||||
"pluginNotAvailable": "插件未安装或不可用",
|
||||
"pluginNotActivated": "astrbot_plugin_knowledge_base 插件未启用,请前往插件管理页面启用,然后重启 AstrBot。",
|
||||
"checkPluginFailed": "检查插件失败",
|
||||
"installFailed": "安装失败",
|
||||
"installPluginFailed": "安装插件失败",
|
||||
|
||||
@@ -32,7 +32,8 @@
|
||||
"available": "可用",
|
||||
"unavailable": "不可用",
|
||||
"pending": "检查中...",
|
||||
"errorMessage": "错误信息"
|
||||
"errorMessage": "错误信息",
|
||||
"test": "测试"
|
||||
},
|
||||
"logs": {
|
||||
"title": "服务日志",
|
||||
@@ -77,7 +78,8 @@
|
||||
},
|
||||
"error": {
|
||||
"sessionSeparation": "获取会话隔离配置失败",
|
||||
"fetchStatus": "获取服务提供商状态失败"
|
||||
"fetchStatus": "获取服务提供商状态失败",
|
||||
"testError": "测试 {id} 失败: {error}"
|
||||
},
|
||||
"confirm": {
|
||||
"delete": "确定要删除服务提供商 {id} 吗?"
|
||||
|
||||
@@ -80,6 +80,9 @@
|
||||
"save": "保存",
|
||||
"testConnection": "测试连接",
|
||||
"sync": "同步"
|
||||
},
|
||||
"tips": {
|
||||
"timeoutConfig": "工具调用的超时时间请前往配置页面单独配置"
|
||||
}
|
||||
},
|
||||
"serverDetail": {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
export function getPlatformIcon(name) {
|
||||
if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') {
|
||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||
} else if (name === 'wecom') {
|
||||
} else if (name === 'wecom' || name === 'wecom_ai_bot') {
|
||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||
} else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') {
|
||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||
@@ -46,6 +46,7 @@ export function getTutorialLink(platformType) {
|
||||
"qq_official": "https://docs.astrbot.app/deploy/platform/qqofficial/websockets.html",
|
||||
"aiocqhttp": "https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html",
|
||||
"wecom": "https://docs.astrbot.app/deploy/platform/wecom.html",
|
||||
"wecom_ai_bot": "https://docs.astrbot.app/deploy/platform/wecom_ai_bot.html",
|
||||
"lark": "https://docs.astrbot.app/deploy/platform/lark.html",
|
||||
"telegram": "https://docs.astrbot.app/deploy/platform/telegram.html",
|
||||
"dingtalk": "https://docs.astrbot.app/deploy/platform/dingtalk.html",
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
<v-col cols="12" sm="6" md="4">
|
||||
<v-combobox v-model="platformFilter" :label="tm('filters.platform')"
|
||||
:items="availablePlatforms" chips multiple clearable variant="solo-filled" flat
|
||||
density="compact" hide-details :disabled="loading">
|
||||
density="compact" hide-details>
|
||||
<template v-slot:selection="{ item }">
|
||||
<v-chip size="small" label>
|
||||
{{ item.title }}
|
||||
@@ -21,8 +21,7 @@
|
||||
|
||||
<v-col cols="12" sm="6" md="4">
|
||||
<v-select v-model="messageTypeFilter" :label="tm('filters.type')" :items="messageTypeItems"
|
||||
chips multiple clearable variant="solo-filled" density="compact" hide-details flat
|
||||
:disabled="loading">
|
||||
chips multiple clearable variant="solo-filled" density="compact" hide-details flat>
|
||||
<template v-slot:selection="{ item }">
|
||||
<v-chip size="small" variant="solo-filled" label>
|
||||
{{ item.title }}
|
||||
@@ -34,7 +33,7 @@
|
||||
<v-col cols="12" sm="12" md="4">
|
||||
<v-text-field v-model="search" prepend-inner-icon="mdi-magnify"
|
||||
:label="tm('filters.search')" hide-details density="compact" variant="solo-filled" flat
|
||||
clearable :disabled="loading"></v-text-field>
|
||||
clearable></v-text-field>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchConversations"
|
||||
@@ -79,6 +78,10 @@
|
||||
</v-chip>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.cid="{ item }">
|
||||
<span class="text-truncate">{{ item.cid || tm('status.unknown') }}</span>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.sessionId="{ item }">
|
||||
<span>{{ item.sessionInfo.sessionId || tm('status.unknown') }}</span>
|
||||
</template>
|
||||
@@ -313,6 +316,7 @@
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { debounce } from 'lodash';
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor';
|
||||
import MarkdownIt from 'markdown-it';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
@@ -417,8 +421,7 @@ export default {
|
||||
},
|
||||
|
||||
created() {
|
||||
// 创建一个防抖函数,避免频繁请求
|
||||
this.debouncedApplyFilters = this.debounce(() => {
|
||||
this.debouncedApplyFilters = debounce(() => {
|
||||
// 重置到第一页
|
||||
this.pagination.page = 1;
|
||||
this.fetchConversations();
|
||||
@@ -430,13 +433,14 @@ export default {
|
||||
tableHeaders() {
|
||||
return [
|
||||
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
|
||||
{ title: '会话 ID', key: 'cid', sortable: true, width: '100px' },
|
||||
{
|
||||
title: this.tm('table.headers.sessionId'),
|
||||
align: 'center',
|
||||
children: [
|
||||
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
|
||||
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
|
||||
{ title: '会话 ID', key: 'sessionId', sortable: true, width: '100px' },
|
||||
{ title: '用户 ID', key: 'sessionId', sortable: true, width: '100px' },
|
||||
],
|
||||
},
|
||||
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
|
||||
@@ -526,19 +530,6 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
// 添加防抖函数
|
||||
debounce(func, wait) {
|
||||
let timeout;
|
||||
return function () {
|
||||
const context = this;
|
||||
const args = arguments;
|
||||
clearTimeout(timeout);
|
||||
timeout = setTimeout(() => {
|
||||
func.apply(context, args);
|
||||
}, wait);
|
||||
};
|
||||
},
|
||||
|
||||
// 处理表格选项变更(页面大小等)
|
||||
handleTableOptions(options) {
|
||||
// 处理页面大小变更
|
||||
@@ -579,83 +570,93 @@ export default {
|
||||
},
|
||||
|
||||
// 获取对话列表
|
||||
async fetchConversations() {
|
||||
this.loading = true;
|
||||
try {
|
||||
// 准备请求参数,包含分页和筛选条件
|
||||
const params = {
|
||||
page: this.pagination.page,
|
||||
page_size: this.pagination.page_size
|
||||
};
|
||||
fetchConversations: (() => {
|
||||
let controller = new AbortController();
|
||||
|
||||
// 添加筛选条件 - 处理combobox的混合数据格式
|
||||
if (this.platformFilter.length > 0) {
|
||||
const platforms = this.platformFilter.map(item =>
|
||||
typeof item === 'object' ? item.value : item
|
||||
);
|
||||
params.platforms = platforms.join(',');
|
||||
}
|
||||
return async function () {
|
||||
// 新请求前停止之前的请求
|
||||
controller?.abort()
|
||||
controller = new AbortController();
|
||||
|
||||
if (this.messageTypeFilter.length > 0) {
|
||||
params.message_types = this.messageTypeFilter.join(',');
|
||||
}
|
||||
this.loading = true;
|
||||
try {
|
||||
// 准备请求参数,包含分页和筛选条件
|
||||
const params = {
|
||||
page: this.pagination.page,
|
||||
page_size: this.pagination.page_size
|
||||
};
|
||||
|
||||
if (this.search) {
|
||||
params.search = this.search.trim();
|
||||
}
|
||||
|
||||
// 添加排除条件
|
||||
params.exclude_ids = 'astrbot';
|
||||
params.exclude_platforms = 'webchat';
|
||||
|
||||
const response = await axios.get('/api/conversation/list', { params });
|
||||
|
||||
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
||||
|
||||
if (response.data.status === "ok") {
|
||||
const data = response.data.data;
|
||||
|
||||
if (!data || !data.conversations) {
|
||||
console.error('API 返回数据格式不符合预期:', data);
|
||||
this.showErrorMessage(this.tm('messages.fetchError'));
|
||||
return;
|
||||
// 添加筛选条件 - 处理combobox的混合数据格式
|
||||
if (this.platformFilter.length > 0) {
|
||||
const platforms = this.platformFilter.map(item =>
|
||||
typeof item === 'object' ? item.value : item
|
||||
);
|
||||
params.platforms = platforms.join(',');
|
||||
}
|
||||
|
||||
// 处理会话数据,解析sessionId
|
||||
this.conversations = (data.conversations || []).map(conv => {
|
||||
// 为每个会话添加会话信息
|
||||
conv.sessionInfo = this.parseSessionId(conv.user_id);
|
||||
return conv;
|
||||
if (this.messageTypeFilter.length > 0) {
|
||||
params.message_types = this.messageTypeFilter.join(',');
|
||||
}
|
||||
|
||||
if (this.search) {
|
||||
params.search = this.search.trim();
|
||||
}
|
||||
|
||||
// 添加排除条件
|
||||
params.exclude_ids = 'astrbot';
|
||||
params.exclude_platforms = 'webchat';
|
||||
|
||||
const response = await axios.get('/api/conversation/list', {
|
||||
signal: controller.signal,
|
||||
params
|
||||
});
|
||||
|
||||
// 更新分页信息
|
||||
if (data.pagination) {
|
||||
this.pagination = {
|
||||
page: data.pagination.page || 1,
|
||||
page_size: data.pagination.page_size || 20,
|
||||
total: data.pagination.total || 0,
|
||||
total_pages: data.pagination.total_pages || 1
|
||||
};
|
||||
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
||||
|
||||
if (response.data.status === "ok") {
|
||||
const data = response.data.data;
|
||||
|
||||
if (!data || !data.conversations) {
|
||||
console.error('API 返回数据格式不符合预期:', data);
|
||||
this.showErrorMessage(this.tm('messages.fetchError'));
|
||||
return;
|
||||
}
|
||||
|
||||
// 处理会话数据,解析sessionId
|
||||
this.conversations = (data.conversations || []).map(conv => {
|
||||
// 为每个会话添加会话信息
|
||||
conv.sessionInfo = this.parseSessionId(conv.user_id);
|
||||
return conv;
|
||||
});
|
||||
|
||||
// 更新分页信息
|
||||
if (data.pagination) {
|
||||
this.pagination = {
|
||||
page: data.pagination.page || 1,
|
||||
page_size: data.pagination.page_size || 20,
|
||||
total: data.pagination.total || 0,
|
||||
total_pages: data.pagination.total_pages || 1
|
||||
};
|
||||
} else {
|
||||
console.warn('API 响应中没有分页信息');
|
||||
}
|
||||
} else {
|
||||
console.warn('API 响应中没有分页信息');
|
||||
this.showErrorMessage(response.data.message || this.tm('messages.fetchError'));
|
||||
}
|
||||
} else {
|
||||
this.showErrorMessage(response.data.message || this.tm('messages.fetchError'));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取对话列表出错:', error);
|
||||
if (error.response) {
|
||||
console.error('错误响应数据:', error.response.data);
|
||||
console.error('错误状态码:', error.response.status);
|
||||
}
|
||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.fetchError'));
|
||||
} finally {
|
||||
// this.loading = false;
|
||||
setTimeout(() => {
|
||||
} catch (error) {
|
||||
if (axios.isCancel(error)) return;
|
||||
|
||||
console.error('获取对话列表出错:', error);
|
||||
if (error.response) {
|
||||
console.error('错误响应数据:', error.response.data);
|
||||
console.error('错误状态码:', error.response.status);
|
||||
}
|
||||
this.showErrorMessage(error.response?.data?.message || error.message || this.tm('messages.fetchError'));
|
||||
} finally {
|
||||
this.loading = false;
|
||||
}, 200);
|
||||
}
|
||||
}
|
||||
},
|
||||
})(),
|
||||
|
||||
// 查看对话详情
|
||||
async viewConversation(item) {
|
||||
@@ -993,6 +994,14 @@ export default {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.text-truncate {
|
||||
display: inline-block;
|
||||
max-width: 100px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
/* 动画 */
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
|
||||
@@ -5,6 +5,7 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
|
||||
import ProxySelector from '@/components/shared/ProxySelector.vue';
|
||||
import axios from 'axios';
|
||||
import { pinyin } from 'pinyin-pro';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
@@ -65,6 +66,32 @@ const marketSearch = ref("");
|
||||
const filterKeys = ['name', 'desc', 'author'];
|
||||
const refreshingMarket = ref(false);
|
||||
|
||||
// 插件市场拼音搜索
|
||||
const normalizeStr = (s) => (s ?? '').toString().toLowerCase().trim();
|
||||
const toPinyinText = (s) => pinyin(s ?? '', { toneType: 'none' }).toLowerCase().replace(/\s+/g, '');
|
||||
const toInitials = (s) => pinyin(s ?? '', { pattern: 'first', toneType: 'none' }).toLowerCase().replace(/\s+/g, '');
|
||||
const marketCustomFilter = (value, query, item) => {
|
||||
const q = normalizeStr(query);
|
||||
if (!q) return true;
|
||||
|
||||
const candidates = new Set();
|
||||
if (value != null) candidates.add(String(value));
|
||||
if (item?.name) candidates.add(String(item.name));
|
||||
if (item?.trimmedName) candidates.add(String(item.trimmedName));
|
||||
if (item?.desc) candidates.add(String(item.desc));
|
||||
if (item?.author) candidates.add(String(item.author));
|
||||
|
||||
for (const v of candidates) {
|
||||
const nv = normalizeStr(v);
|
||||
if (nv.includes(q)) return true;
|
||||
const pv = toPinyinText(v);
|
||||
if (pv.includes(q)) return true;
|
||||
const iv = toInitials(v);
|
||||
if (iv.includes(q)) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const plugin_handler_info_headers = computed(() => [
|
||||
{ title: tm('table.headers.eventType'), key: 'event_type_h' },
|
||||
{ title: tm('table.headers.description'), key: 'desc', maxWidth: '250px' },
|
||||
@@ -772,7 +799,7 @@ onMounted(async () => {
|
||||
|
||||
<v-col cols="12" md="12" style="padding: 0px;">
|
||||
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
|
||||
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys">
|
||||
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys" :custom-filter="marketCustomFilter">
|
||||
<template v-slot:item.name="{ item }">
|
||||
<div class="d-flex align-center"
|
||||
style="overflow-x: auto; scrollbar-width: thin; scrollbar-track-color: transparent;">
|
||||
|
||||
@@ -114,7 +114,7 @@
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
|
||||
}}</v-btn>
|
||||
}}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -241,7 +241,15 @@ export default {
|
||||
|
||||
methods: {
|
||||
// 从工具函数导入
|
||||
getPlatformIcon,
|
||||
getPlatformIcon(platform_id) {
|
||||
// 首先检查是否有来自插件的 logo_token
|
||||
const template = this.metadata['platform_group']?.metadata?.platform?.config_template?.[platform_id];
|
||||
if (template && template.logo_token) {
|
||||
// 通过文件服务访问插件提供的 logo
|
||||
return `/api/file/${template.logo_token}`;
|
||||
}
|
||||
return getPlatformIcon(platform_id);
|
||||
},
|
||||
|
||||
openTutorial() {
|
||||
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
|
||||
|
||||
@@ -60,12 +60,26 @@
|
||||
:item="provider"
|
||||
title-field="id"
|
||||
enabled-field="enable"
|
||||
:loading="isProviderTesting(provider.id)"
|
||||
@toggle-enabled="providerStatusChange"
|
||||
:bglogo="getProviderIcon(provider.provider)"
|
||||
@delete="deleteProvider"
|
||||
@edit="configExistingProvider"
|
||||
@copy="copyProvider"
|
||||
:show-copy-button="true">
|
||||
<template #actions="{ item }">
|
||||
<v-btn
|
||||
style="z-index: 100000;"
|
||||
variant="tonal"
|
||||
color="info"
|
||||
rounded="xl"
|
||||
size="small"
|
||||
:loading="isProviderTesting(item.id)"
|
||||
@click="testSingleProvider(item)"
|
||||
>
|
||||
{{ tm('availability.test') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-slot:details="{ item }">
|
||||
</template>
|
||||
</item-card>
|
||||
@@ -79,7 +93,7 @@
|
||||
<v-icon class="me-2">mdi-heart-pulse</v-icon>
|
||||
<span class="text-h4">{{ tm('availability.title') }}</span>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" variant="tonal" :loading="loadingStatus" @click="fetchProviderStatus">
|
||||
<v-btn color="primary" variant="tonal" :loading="testingProviders.length > 0" @click="fetchProviderStatus">
|
||||
<v-icon left>mdi-refresh</v-icon>
|
||||
{{ tm('availability.refresh') }}
|
||||
</v-btn>
|
||||
@@ -288,7 +302,7 @@ export default {
|
||||
|
||||
// 供应商状态相关
|
||||
providerStatuses: [],
|
||||
loadingStatus: false,
|
||||
testingProviders: [], // 存储正在测试的 provider ID
|
||||
|
||||
// 新增提供商对话框相关
|
||||
showAddProviderDialog: false,
|
||||
@@ -359,7 +373,8 @@ export default {
|
||||
statusUpdate: this.tm('messages.success.statusUpdate'),
|
||||
},
|
||||
error: {
|
||||
fetchStatus: this.tm('messages.error.fetchStatus')
|
||||
fetchStatus: this.tm('messages.error.fetchStatus'),
|
||||
testError: this.tm('messages.error.testError')
|
||||
},
|
||||
confirm: {
|
||||
delete: this.tm('messages.confirm.delete')
|
||||
@@ -368,6 +383,9 @@ export default {
|
||||
available: this.tm('availability.available'),
|
||||
unavailable: this.tm('availability.unavailable'),
|
||||
pending: this.tm('availability.pending')
|
||||
},
|
||||
availability: {
|
||||
test: this.tm('availability.test')
|
||||
}
|
||||
};
|
||||
},
|
||||
@@ -615,70 +633,107 @@ export default {
|
||||
|
||||
// 获取供应商状态
|
||||
async fetchProviderStatus() {
|
||||
if (this.loadingStatus) return;
|
||||
if (this.testingProviders.length > 0) return;
|
||||
|
||||
this.loadingStatus = true;
|
||||
this.showStatus = true; // 自动展开状态部分
|
||||
|
||||
// 1. 立即初始化UI为pending状态
|
||||
this.providerStatuses = this.config_data.provider.map(p => ({
|
||||
id: p.id,
|
||||
name: p.id,
|
||||
status: 'pending',
|
||||
error: null
|
||||
}));
|
||||
const providersToTest = this.config_data.provider.filter(p => p.enable);
|
||||
if (providersToTest.length === 0) return;
|
||||
|
||||
// 1. 初始化UI为pending状态,并将所有待测试的 provider ID 加入 loading 列表
|
||||
this.providerStatuses = providersToTest.map(p => {
|
||||
this.testingProviders.push(p.id);
|
||||
return { id: p.id, name: p.id, status: 'pending', error: null };
|
||||
});
|
||||
|
||||
// 2. 为每个provider创建一个并发的测试请求
|
||||
const promises = this.config_data.provider.map(p => {
|
||||
if (!p.enable) {
|
||||
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
||||
if (index !== -1) {
|
||||
const disabledStatus = {
|
||||
...this.providerStatuses[index],
|
||||
status: 'unavailable',
|
||||
error: '该提供商未被用户启用'
|
||||
};
|
||||
this.providerStatuses.splice(index, 1, disabledStatus);
|
||||
}
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
return axios.get(`/api/config/provider/check_one?id=${p.id}`)
|
||||
const promises = providersToTest.map(p =>
|
||||
axios.get(`/api/config/provider/check_one?id=${p.id}`)
|
||||
.then(res => {
|
||||
if (res.data && res.data.status === 'ok') {
|
||||
// 成功,更新对应的provider状态
|
||||
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
||||
if (index !== -1) {
|
||||
this.providerStatuses.splice(index, 1, res.data.data);
|
||||
}
|
||||
if (index !== -1) this.providerStatuses.splice(index, 1, res.data.data);
|
||||
} else {
|
||||
// 接口返回了业务错误
|
||||
throw new Error(res.data?.message || `Failed to check status for ${p.id}`);
|
||||
}
|
||||
})
|
||||
.catch(err => {
|
||||
// 网络错误或业务错误
|
||||
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
|
||||
const index = this.providerStatuses.findIndex(s => s.id === p.id);
|
||||
if (index !== -1) {
|
||||
const failedStatus = {
|
||||
...this.providerStatuses[index],
|
||||
status: 'unavailable',
|
||||
error: errorMessage
|
||||
};
|
||||
const failedStatus = { ...this.providerStatuses[index], status: 'unavailable', error: errorMessage };
|
||||
this.providerStatuses.splice(index, 1, failedStatus);
|
||||
}
|
||||
// 可以在这里选择性地向上抛出错误,以便Promise.allSettled知道
|
||||
return Promise.reject(errorMessage);
|
||||
});
|
||||
});
|
||||
return Promise.reject(errorMessage); // Propagate error for Promise.allSettled
|
||||
})
|
||||
);
|
||||
|
||||
// 3. 等待所有请求完成(无论成功或失败)
|
||||
// 3. 等待所有请求完成
|
||||
try {
|
||||
await Promise.allSettled(promises);
|
||||
} finally {
|
||||
// 4. 关闭全局加载状态
|
||||
this.loadingStatus = false;
|
||||
// 4. 关闭所有加载状态
|
||||
this.testingProviders = [];
|
||||
}
|
||||
},
|
||||
|
||||
isProviderTesting(providerId) {
|
||||
return this.testingProviders.includes(providerId);
|
||||
},
|
||||
|
||||
async testSingleProvider(provider) {
|
||||
if (this.isProviderTesting(provider.id)) return;
|
||||
|
||||
this.testingProviders.push(provider.id);
|
||||
this.showStatus = true; // 自动展开状态部分
|
||||
|
||||
// 更新UI为pending状态
|
||||
const statusIndex = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||
const pendingStatus = {
|
||||
id: provider.id,
|
||||
name: provider.id,
|
||||
status: 'pending',
|
||||
error: null
|
||||
};
|
||||
if (statusIndex !== -1) {
|
||||
this.providerStatuses.splice(statusIndex, 1, pendingStatus);
|
||||
} else {
|
||||
this.providerStatuses.unshift(pendingStatus);
|
||||
}
|
||||
|
||||
try {
|
||||
if (!provider.enable) {
|
||||
throw new Error('该提供商未被用户启用');
|
||||
}
|
||||
|
||||
const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`);
|
||||
if (res.data && res.data.status === 'ok') {
|
||||
const index = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||
if (index !== -1) {
|
||||
this.providerStatuses.splice(index, 1, res.data.data);
|
||||
}
|
||||
} else {
|
||||
throw new Error(res.data?.message || `Failed to check status for ${provider.id}`);
|
||||
}
|
||||
} catch (err) {
|
||||
const errorMessage = err.response?.data?.message || err.message || 'Unknown error';
|
||||
const index = this.providerStatuses.findIndex(s => s.id === provider.id);
|
||||
const failedStatus = {
|
||||
id: provider.id,
|
||||
name: provider.id,
|
||||
status: 'unavailable',
|
||||
error: errorMessage
|
||||
};
|
||||
if (index !== -1) {
|
||||
this.providerStatuses.splice(index, 1, failedStatus);
|
||||
}
|
||||
// 不再显示全局的错误提示,因为卡片本身会显示错误信息
|
||||
// this.showError(this.tm('messages.error.testError', { id: provider.id, error: errorMessage }));
|
||||
} finally {
|
||||
const index = this.testingProviders.indexOf(provider.id);
|
||||
if (index > -1) {
|
||||
this.testingProviders.splice(index, 1);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
<v-card flat>
|
||||
<v-card-title class="d-flex align-center py-3 px-4">
|
||||
<span class="text-h4">{{ tm('sessions.activeSessions') }}</span>
|
||||
<v-chip size="small" class="ml-2">{{ sessions.length }} {{ tm('sessions.sessionCount') }}</v-chip>
|
||||
<v-chip size="small" class="ml-2">{{ totalItems }} {{ tm('sessions.sessionCount') }}</v-chip>
|
||||
<v-row class="me-4 ms-4" dense>
|
||||
<v-text-field v-model="searchQuery" prepend-inner-icon="mdi-magnify" :label="tm('search.placeholder')"
|
||||
hide-details clearable variant="solo-filled" flat class="me-4" density="compact"></v-text-field>
|
||||
hide-details clearable variant="solo-filled" flat class="me-4" density="compact" @update:model-value="handleSearchChange"></v-text-field>
|
||||
<v-select v-model="filterPlatform" :items="platformOptions" :label="tm('search.platformFilter')"
|
||||
hide-details clearable variant="solo-filled" flat class="me-4" style="max-width: 150px;"
|
||||
density="compact"></v-select>
|
||||
density="compact" @update:model-value="handlePlatformChange"></v-select>
|
||||
</v-row>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="refreshSessions" :loading="loading"
|
||||
size="small">
|
||||
@@ -22,8 +22,17 @@
|
||||
|
||||
<v-card-text class="pa-0">
|
||||
<!-- 会话列表 -->
|
||||
<v-data-table :headers="headers" :items="filteredSessions" :loading="loading" :items-per-page="itemsPerPage" density="compact"
|
||||
class="elevation-0" style="font-size: 11px;">
|
||||
<v-data-table-server
|
||||
:headers="headers"
|
||||
:items="sessions"
|
||||
:loading="loading"
|
||||
:items-per-page="itemsPerPage"
|
||||
:page="currentPage"
|
||||
:items-length="totalItems"
|
||||
@update:options="handlePaginationUpdate"
|
||||
density="compact"
|
||||
class="elevation-0"
|
||||
style="font-size: 11px;">
|
||||
|
||||
<!-- 会话启停 -->
|
||||
<template v-slot:item.session_enabled="{ item }">
|
||||
@@ -160,7 +169,7 @@
|
||||
<div class="text-body-2 text-grey-500">{{ tm('sessions.noActiveSessionsDesc') }}</div>
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-data-table-server>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
@@ -336,6 +345,7 @@
|
||||
|
||||
<script>
|
||||
import axios from 'axios'
|
||||
import { debounce } from 'lodash'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
|
||||
export default {
|
||||
@@ -357,7 +367,10 @@ export default {
|
||||
filterPlatform: null,
|
||||
|
||||
// 分页相关
|
||||
currentPage: 1,
|
||||
itemsPerPage: 10,
|
||||
totalItems: 0,
|
||||
totalPages: 0,
|
||||
|
||||
// 可用选项
|
||||
availablePersonas: [],
|
||||
@@ -424,30 +437,6 @@ export default {
|
||||
]
|
||||
},
|
||||
|
||||
// 懒加载过滤会话 - 使用客户端分页
|
||||
filteredSessions() {
|
||||
let filtered = this.sessions;
|
||||
|
||||
// 搜索筛选
|
||||
if (this.searchQuery) {
|
||||
const query = this.searchQuery.toLowerCase().trim();
|
||||
filtered = filtered.filter(session =>
|
||||
session.session_name.toLowerCase().includes(query) ||
|
||||
session.platform.toLowerCase().includes(query) ||
|
||||
session.persona_name?.toLowerCase().includes(query) ||
|
||||
session.chat_provider_name?.toLowerCase().includes(query) ||
|
||||
session.session_id.toLowerCase().includes(query)
|
||||
);
|
||||
}
|
||||
|
||||
// 平台筛选
|
||||
if (this.filterPlatform) {
|
||||
filtered = filtered.filter(session => session.platform === this.filterPlatform);
|
||||
}
|
||||
|
||||
return filtered;
|
||||
},
|
||||
|
||||
platformOptions() {
|
||||
const platforms = [...new Set(this.sessions.map(s => s.platform))];
|
||||
return platforms.map(p => ({ title: p, value: p }));
|
||||
@@ -494,7 +483,20 @@ export default {
|
||||
async loadSessions() {
|
||||
this.loading = true;
|
||||
try {
|
||||
const response = await axios.get('/api/session/list');
|
||||
const params = {
|
||||
page: this.currentPage,
|
||||
page_size: this.itemsPerPage
|
||||
};
|
||||
|
||||
// 添加搜索和平台筛选参数
|
||||
if (this.searchQuery) {
|
||||
params.search = this.searchQuery;
|
||||
}
|
||||
if (this.filterPlatform) {
|
||||
params.platform = this.filterPlatform;
|
||||
}
|
||||
|
||||
const response = await axios.get('/api/session/list', { params });
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data;
|
||||
this.sessions = data.sessions.map(session => ({
|
||||
@@ -507,6 +509,13 @@ export default {
|
||||
this.availableChatProviders = data.available_chat_providers;
|
||||
this.availableSttProviders = data.available_stt_providers;
|
||||
this.availableTtsProviders = data.available_tts_providers;
|
||||
|
||||
// 处理分页信息
|
||||
if (data.pagination) {
|
||||
this.totalItems = data.pagination.total;
|
||||
this.totalPages = data.pagination.total_pages;
|
||||
this.currentPage = data.pagination.page;
|
||||
}
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadSessionsError'));
|
||||
}
|
||||
@@ -679,7 +688,7 @@ export default {
|
||||
let totalErrorCount = 0;
|
||||
let allErrorSessions = [];
|
||||
|
||||
const sessions = this.filteredSessions;
|
||||
const sessions = this.sessions;
|
||||
|
||||
try {
|
||||
// 定义批量操作任务
|
||||
@@ -936,6 +945,25 @@ export default {
|
||||
|
||||
session.deleting = false;
|
||||
},
|
||||
|
||||
// 处理分页更新事件
|
||||
handlePaginationUpdate(options) {
|
||||
this.currentPage = options.page;
|
||||
this.itemsPerPage = options.itemsPerPage;
|
||||
this.loadSessions();
|
||||
},
|
||||
|
||||
// 处理搜索变化
|
||||
handleSearchChange: debounce(function() {
|
||||
this.currentPage = 1; // 重置到第一页
|
||||
this.loadSessions();
|
||||
}, 300),
|
||||
|
||||
// 处理平台筛选变化
|
||||
handlePlatformChange() {
|
||||
this.currentPage = 1; // 重置到第一页
|
||||
this.loadSessions();
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -141,6 +141,8 @@
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<small style="color: grey">*{{ tm('dialogs.addServer.tips.timeoutConfig') }}</small>
|
||||
|
||||
<div class="monaco-container" style="margin-top: 16px;">
|
||||
<VueMonacoEditor v-model:value="serverConfigJson" theme="vs-dark" language="json" :options="{
|
||||
minimap: {
|
||||
@@ -524,14 +526,16 @@ export default {
|
||||
transport: "streamable_http",
|
||||
url: "your mcp server url",
|
||||
headers: {},
|
||||
timeout: 30,
|
||||
timeout: 5,
|
||||
sse_read_timeout: 300,
|
||||
};
|
||||
} else if (type === 'sse') {
|
||||
template = {
|
||||
transport: "sse",
|
||||
url: "your mcp server url",
|
||||
headers: {},
|
||||
timeout: 30,
|
||||
timeout: 5,
|
||||
sse_read_timeout: 300,
|
||||
};
|
||||
} else {
|
||||
template = {
|
||||
|
||||
@@ -601,8 +601,13 @@ export default {
|
||||
checkPlugin() {
|
||||
axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base')
|
||||
.then(response => {
|
||||
if (response.data.status !== 'ok') {
|
||||
if (response.data.status !== 'ok' || response.data.data.length === 0) {
|
||||
this.showSnackbar(this.tm('messages.pluginNotAvailable'), 'error');
|
||||
return
|
||||
}
|
||||
if (!response.data.data[0].activated) {
|
||||
this.showSnackbar(this.tm('messages.pluginNotActivated'), 'error');
|
||||
return
|
||||
}
|
||||
if (response.data.data.length > 0) {
|
||||
this.installed = true;
|
||||
@@ -708,6 +713,10 @@ export default {
|
||||
getKBCollections() {
|
||||
axios.get('/api/plug/alkaid/kb/collections')
|
||||
.then(response => {
|
||||
if (response.data.status !== 'ok') {
|
||||
this.showSnackbar(response.data.message || this.tm('messages.getKnowledgeBaseListFailed'), 'error');
|
||||
return;
|
||||
}
|
||||
this.kbCollections = response.data.data;
|
||||
})
|
||||
.catch(error => {
|
||||
|
||||
@@ -39,7 +39,7 @@ export default defineConfig({
|
||||
port: 3000,
|
||||
proxy: {
|
||||
'/api': {
|
||||
target: 'http://localhost:6185/',
|
||||
target: 'http://127.0.0.1:6185/',
|
||||
changeOrigin: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Commands module
|
||||
|
||||
from .help import HelpCommand
|
||||
from .llm import LLMCommands
|
||||
from .tool import ToolCommands
|
||||
from .plugin import PluginCommands
|
||||
from .admin import AdminCommands
|
||||
from .conversation import ConversationCommands
|
||||
from .provider import ProviderCommands
|
||||
from .persona import PersonaCommands
|
||||
from .alter_cmd import AlterCmdCommands
|
||||
from .setunset import SetUnsetCommands
|
||||
from .t2i import T2ICommand
|
||||
from .tts import TTSCommand
|
||||
from .sid import SIDCommand
|
||||
|
||||
__all__ = [
|
||||
"HelpCommand",
|
||||
"LLMCommands",
|
||||
"ToolCommands",
|
||||
"PluginCommands",
|
||||
"AdminCommands",
|
||||
"ConversationCommands",
|
||||
"ProviderCommands",
|
||||
"PersonaCommands",
|
||||
"AlterCmdCommands",
|
||||
"SetUnsetCommands",
|
||||
"T2ICommand",
|
||||
"TTSCommand",
|
||||
"SIDCommand",
|
||||
]
|
||||
@@ -0,0 +1,76 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, MessageChain
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
|
||||
class AdminCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
"""授权管理员。op <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /op <id> 授权管理员;/deop <id> 取消管理员。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
self.context.get_config()["admins_id"].append(str(admin_id))
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /deop <id> 取消管理员。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
try:
|
||||
self.context.get_config()["admins_id"].remove(str(admin_id))
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("取消授权成功。"))
|
||||
except ValueError:
|
||||
event.set_result(
|
||||
MessageEventResult().message("此用户 ID 不在管理员名单内。")
|
||||
)
|
||||
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
"""添加白名单。wl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /wl <id> 添加白名单;/dwl <id> 删除白名单。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
cfg["platform_settings"]["id_whitelist"].append(str(sid))
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
"""删除白名单。dwl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"使用方法: /dwl <id> 删除白名单。可通过 /sid 获取 ID。"
|
||||
)
|
||||
)
|
||||
return
|
||||
try:
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
cfg["platform_settings"]["id_whitelist"].remove(str(sid))
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("删除白名单成功。"))
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
@@ -0,0 +1,172 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from .utils.rst_scene import RstScene
|
||||
|
||||
|
||||
class AlterCmdCommands(CommandParserMixin):
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
"""更新reset命令在特定场景下的权限设置"""
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_cfg = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_cfg.get("reset", {})
|
||||
reset_cfg[scene_key] = perm_type
|
||||
plugin_cfg["reset"] = reset_cfg
|
||||
alter_cmd_cfg["astrbot"] = plugin_cfg
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
token = self.parse_commands(event.message_str)
|
||||
if token.len < 3:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
"该指令用于设置指令或指令组的权限。\n"
|
||||
"格式: /alter_cmd <cmd_name> <admin/member>\n"
|
||||
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
|
||||
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
|
||||
"/alter_cmd reset config 打开 reset 权限配置"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 兼容 reset scene 的专门配置
|
||||
cmd_name = token.get(1)
|
||||
cmd_type = token.get(2)
|
||||
|
||||
if cmd_name == "reset" and cmd_type == "config":
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_.get("reset", {})
|
||||
|
||||
group_unique_on = reset_cfg.get("group_unique_on", "admin")
|
||||
group_unique_off = reset_cfg.get("group_unique_off", "admin")
|
||||
private = reset_cfg.get("private", "member")
|
||||
|
||||
config_menu = f"""reset命令权限细粒度配置
|
||||
当前配置:
|
||||
1. 群聊+会话隔离开: {group_unique_on}
|
||||
2. 群聊+会话隔离关: {group_unique_off}
|
||||
3. 私聊: {private}
|
||||
修改指令格式:
|
||||
/alter_cmd reset scene <场景编号> <admin/member>
|
||||
例如: /alter_cmd reset scene 2 member"""
|
||||
await event.send(MessageChain().message(config_menu))
|
||||
return
|
||||
|
||||
if cmd_name == "reset" and cmd_type == "scene" and token.len >= 4:
|
||||
scene_num = token.get(3)
|
||||
perm_type = token.get(4)
|
||||
|
||||
if scene_num is None or perm_type is None:
|
||||
await event.send(MessageChain().message("场景编号和权限类型不能为空"))
|
||||
return
|
||||
|
||||
if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3:
|
||||
await event.send(
|
||||
MessageChain().message("场景编号必须是 1-3 之间的数字")
|
||||
)
|
||||
return
|
||||
|
||||
if perm_type not in ["admin", "member"]:
|
||||
await event.send(
|
||||
MessageChain().message("权限类型错误,只能是 admin 或 member")
|
||||
)
|
||||
return
|
||||
|
||||
scene_num = int(scene_num)
|
||||
scene = RstScene.from_index(scene_num)
|
||||
scene_key = scene.key
|
||||
|
||||
await self.update_reset_permission(scene_key, perm_type)
|
||||
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if cmd_type not in ["admin", "member"]:
|
||||
await event.send(
|
||||
MessageChain().message("指令类型错误,可选类型有 admin, member")
|
||||
)
|
||||
return
|
||||
|
||||
# 查找指令
|
||||
cmd_name = " ".join(token.tokens[1:-1])
|
||||
cmd_type = token.get(-1)
|
||||
found_command = None
|
||||
cmd_group = False
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
cmd_group = True
|
||||
break
|
||||
|
||||
if not found_command:
|
||||
await event.send(MessageChain().message("未找到该指令"))
|
||||
return
|
||||
|
||||
found_plugin = star_map[found_command.handler_module_path]
|
||||
|
||||
from astrbot.api import sp
|
||||
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
|
||||
cfg = plugin_.get(found_command.handler_name, {})
|
||||
cfg["permission"] = cmd_type
|
||||
plugin_[found_command.handler_name] = cfg
|
||||
alter_cmd_cfg[found_plugin.name] = plugin_
|
||||
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
# 注入权限过滤器
|
||||
found_permission_filter = False
|
||||
for filter_ in found_command.event_filters:
|
||||
if isinstance(filter_, PermissionTypeFilter):
|
||||
if cmd_type == "admin":
|
||||
import astrbot.api.event.filter as filter
|
||||
|
||||
filter_.permission_type = filter.PermissionType.ADMIN
|
||||
else:
|
||||
import astrbot.api.event.filter as filter
|
||||
|
||||
filter_.permission_type = filter.PermissionType.MEMBER
|
||||
found_permission_filter = True
|
||||
break
|
||||
if not found_permission_filter:
|
||||
import astrbot.api.event.filter as filter
|
||||
|
||||
found_command.event_filters.insert(
|
||||
0,
|
||||
PermissionTypeFilter(
|
||||
filter.PermissionType.ADMIN
|
||||
if cmd_type == "admin"
|
||||
else filter.PermissionType.MEMBER
|
||||
),
|
||||
)
|
||||
cmd_group_str = "指令组" if cmd_group else "指令"
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。"
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,431 @@
|
||||
import datetime
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.sources.dify_source import ProviderDify
|
||||
from astrbot.core.provider.sources.coze_source import ProviderCoze
|
||||
from astrbot.api import sp, logger
|
||||
from ..long_term_memory import LongTermMemory
|
||||
from .utils.rst_scene import RstScene
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ConversationCommands:
|
||||
def __init__(self, context: star.Context, ltm: LongTermMemory | None = None):
|
||||
self.context = context
|
||||
self.ltm = ltm
|
||||
|
||||
async def _get_current_persona_id(self, session_id):
|
||||
curr = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
if not curr:
|
||||
return None
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
session_id, curr
|
||||
)
|
||||
return conv.persona_id
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
if not self.ltm:
|
||||
return False
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
]
|
||||
return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"]
|
||||
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
"""重置 LLM 会话"""
|
||||
|
||||
is_unique_session = self.context.get_config()["platform_settings"][
|
||||
"unique_session"
|
||||
]
|
||||
is_group = bool(message.get_group_id())
|
||||
|
||||
scene = RstScene.get_scene(is_group, is_unique_session)
|
||||
|
||||
alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {})
|
||||
plugin_config = alter_cmd_cfg.get("astrbot", {})
|
||||
reset_cfg = plugin_config.get("reset", {})
|
||||
|
||||
required_perm = reset_cfg.get(
|
||||
scene.key, "admin" if is_group and not is_unique_session else "member"
|
||||
)
|
||||
|
||||
if required_perm == "admin" and message.role != "admin":
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"在{scene.name}场景下,reset命令需要管理员权限,"
|
||||
f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前未处于对话状态,请 /switch 切换或者 /new 创建。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation(
|
||||
message.unified_msg_origin, cid, []
|
||||
)
|
||||
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
"""查看对话记录"""
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
|
||||
size_per_page = 6
|
||||
|
||||
conv_mgr = self.context.conversation_manager
|
||||
umo = message.unified_msg_origin
|
||||
session_curr_cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
|
||||
if not session_curr_cid:
|
||||
session_curr_cid = await conv_mgr.new_conversation(
|
||||
umo, message.get_platform_id()
|
||||
)
|
||||
|
||||
contexts, total_pages = await conv_mgr.get_human_readable_context(
|
||||
umo, session_curr_cid, page, size_per_page
|
||||
)
|
||||
|
||||
history = ""
|
||||
for context in contexts:
|
||||
if len(context) > 150:
|
||||
context = context[:150] + "..."
|
||||
history += f"{context}\n"
|
||||
|
||||
ret = (
|
||||
f"当前对话历史记录:"
|
||||
f"{history or '无历史记录'}\n\n"
|
||||
f"第 {page} 页 | 共 {total_pages} 页\n"
|
||||
f"*输入 /history 2 跳转到第 2 页"
|
||||
)
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
"""查看对话列表"""
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
"""原有的Dify处理逻辑保持不变"""
|
||||
ret = "Dify 对话列表:\n"
|
||||
assert isinstance(provider, ProviderDify)
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
idx = 1
|
||||
for conv in data["data"]:
|
||||
ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime(
|
||||
"%m-%d %H:%M"
|
||||
)
|
||||
ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
|
||||
idx += 1
|
||||
if idx == 1:
|
||||
ret += "没有找到任何对话。"
|
||||
dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
ret += f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
return
|
||||
|
||||
size_per_page = 6
|
||||
"""获取所有对话列表"""
|
||||
conversations_all = await self.context.conversation_manager.get_conversations(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
"""计算总页数"""
|
||||
total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page
|
||||
"""确保页码有效"""
|
||||
page = max(1, min(page, total_pages))
|
||||
"""分页处理"""
|
||||
start_idx = (page - 1) * size_per_page
|
||||
end_idx = start_idx + size_per_page
|
||||
conversations_paged = conversations_all[start_idx:end_idx]
|
||||
|
||||
ret = "对话列表:\n---\n"
|
||||
"""全局序号从当前页的第一个开始"""
|
||||
global_index = start_idx + 1
|
||||
|
||||
"""生成所有对话的标题字典"""
|
||||
_titles = {}
|
||||
for conv in conversations_all:
|
||||
title = conv.title if conv.title else "新对话"
|
||||
_titles[conv.cid] = title
|
||||
|
||||
"""遍历分页后的对话生成列表显示"""
|
||||
for conv in conversations_paged:
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id or persona_id == "[%None]":
|
||||
persona = await self.context.persona_manager.get_default_persona_v3(
|
||||
umo=message.unified_msg_origin
|
||||
)
|
||||
persona_id = persona["name"]
|
||||
title = _titles.get(conv.cid, "新对话")
|
||||
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
global_index += 1
|
||||
|
||||
ret += "---\n"
|
||||
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
if curr_cid:
|
||||
"""从所有对话的标题字典中获取标题"""
|
||||
title = _titles.get(curr_cid, "新对话")
|
||||
ret += f"\n当前对话: {title}({curr_cid[:4]})"
|
||||
else:
|
||||
ret += "\n当前对话: 无"
|
||||
|
||||
unique_session = self.context.get_config()["platform_settings"][
|
||||
"unique_session"
|
||||
]
|
||||
if unique_session:
|
||||
ret += "\n会话隔离粒度: 个人"
|
||||
else:
|
||||
ret += "\n会话隔离粒度: 群聊"
|
||||
|
||||
ret += f"\n第 {page} 页 | 共 {total_pages} 页"
|
||||
ret += "\n*输入 /ls 2 跳转到第 2 页"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
return
|
||||
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
"""
|
||||
创建新对话
|
||||
"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||
)
|
||||
return
|
||||
|
||||
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
|
||||
cid = await self.context.conversation_manager.new_conversation(
|
||||
message.unified_msg_origin, message.get_platform_id(), persona_id=cpersona
|
||||
)
|
||||
|
||||
# 长期记忆
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
try:
|
||||
await self.ltm.remove_session(event=message)
|
||||
except Exception as e:
|
||||
logger.error(f"清理聊天增强记录失败: {e}")
|
||||
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。")
|
||||
)
|
||||
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
|
||||
"""创建新群聊对话"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
assert isinstance(provider, (ProviderDify, ProviderCoze)), (
|
||||
"provider type is not dify or coze"
|
||||
)
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||
)
|
||||
return
|
||||
if sid:
|
||||
session = str(
|
||||
MessageSesion(
|
||||
platform_name=message.platform_meta.id,
|
||||
message_type=MessageType("GroupMessage"),
|
||||
session_id=sid,
|
||||
)
|
||||
)
|
||||
|
||||
cpersona = await self._get_current_persona_id(session)
|
||||
cid = await self.context.conversation_manager.new_conversation(
|
||||
session, message.get_platform_id(), persona_id=cpersona
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。")
|
||||
)
|
||||
|
||||
async def switch_conv(
|
||||
self, message: AstrMessageEvent, index: Union[int, None] = None
|
||||
):
|
||||
"""通过 /ls 前面的序号切换对话"""
|
||||
|
||||
if not isinstance(index, int):
|
||||
message.set_result(
|
||||
MessageEventResult().message("类型错误,请输入数字对话序号。")
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify), "provider type is not dify"
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
if not data["data"]:
|
||||
message.set_result(MessageEventResult().message("未找到任何对话。"))
|
||||
return
|
||||
selected_conv = None
|
||||
if index is not None:
|
||||
try:
|
||||
selected_conv = data["data"][index - 1]
|
||||
except IndexError:
|
||||
message.set_result(
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看")
|
||||
)
|
||||
return
|
||||
else:
|
||||
selected_conv = data["data"][0]
|
||||
ret = (
|
||||
f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。"
|
||||
)
|
||||
provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"]
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
return
|
||||
|
||||
if index is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话"
|
||||
)
|
||||
)
|
||||
return
|
||||
conversations = await self.context.conversation_manager.get_conversations(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
if index > len(conversations) or index < 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message("对话序号错误,请使用 /ls 查看")
|
||||
)
|
||||
else:
|
||||
conversation = conversations[index - 1]
|
||||
title = conversation.title if conversation.title else "新对话"
|
||||
await self.context.conversation_manager.switch_conversation(
|
||||
message.unified_msg_origin, conversation.cid
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"切换到对话: {title}({conversation.cid[:4]})。"
|
||||
)
|
||||
)
|
||||
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""):
|
||||
"""重命名对话"""
|
||||
if not new_name:
|
||||
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
if not cid:
|
||||
message.set_result(MessageEventResult().message("未找到当前对话。"))
|
||||
return
|
||||
await provider.api_client.rename(cid, new_name, message.unified_msg_origin)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.update_conversation_title(
|
||||
message.unified_msg_origin, new_name
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
"""删除当前对话"""
|
||||
is_unique_session = self.context.get_config()["platform_settings"][
|
||||
"unique_session"
|
||||
]
|
||||
if message.get_group_id() and not is_unique_session and message.role != "admin":
|
||||
# 群聊,没开独立会话,发送人不是管理员
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
|
||||
if dify_cid:
|
||||
await provider.api_client.delete_chat_conv(
|
||||
message.unified_msg_origin, dify_cid
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
session_curr_cid = (
|
||||
await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
)
|
||||
)
|
||||
|
||||
if not session_curr_cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前未处于对话状态,请 /switch 序号 切换或 /new 创建。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
await self.context.conversation_manager.delete_conversation(
|
||||
message.unified_msg_origin, session_curr_cid
|
||||
)
|
||||
|
||||
ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
if self.ltm and self.ltm_enabled(message):
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。"
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
@@ -0,0 +1,61 @@
|
||||
import aiohttp
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://astrbot.app/notice.json", timeout=2
|
||||
) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
return ""
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
|
||||
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
|
||||
内置指令:
|
||||
[System]
|
||||
/plugin: 查看插件、插件帮助
|
||||
/t2i: 开关文本转图片
|
||||
/tts: 开关文本转语音
|
||||
/sid: 获取会话 ID
|
||||
/op: 管理员
|
||||
/wl: 白名单
|
||||
/dashboard_update: 更新管理面板(op)
|
||||
/alter_cmd: 设置指令权限(op)
|
||||
|
||||
[大模型]
|
||||
/llm: 开启/关闭 LLM
|
||||
/provider: 大模型提供商
|
||||
/model: 模型列表
|
||||
/ls: 对话列表
|
||||
/new: 创建新对话
|
||||
/groupnew 群号: 为群聊创建新对话(op)
|
||||
/switch 序号: 切换对话
|
||||
/rename 新名字: 重命名当前对话
|
||||
/del: 删除当前会话对话(op)
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 当前对话的对话记录
|
||||
/persona: 人格情景(op)
|
||||
/key: API Key(op)
|
||||
/websearch: 网页搜索
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
@@ -0,0 +1,20 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
|
||||
|
||||
class LLMCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def llm(self, event: AstrMessageEvent):
|
||||
"""开启/关闭 LLM"""
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
enable = cfg["provider_settings"].get("enable", True)
|
||||
if enable:
|
||||
cfg["provider_settings"]["enable"] = False
|
||||
status = "关闭"
|
||||
else:
|
||||
cfg["provider_settings"]["enable"] = True
|
||||
status = "开启"
|
||||
cfg.save_config()
|
||||
await event.send(MessageChain().message(f"{status} LLM 聊天功能。"))
|
||||
@@ -0,0 +1,122 @@
|
||||
import builtins
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class PersonaCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
l = message.message_str.split(" ") # noqa: E741
|
||||
umo = message.unified_msg_origin
|
||||
|
||||
curr_persona_name = "无"
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
|
||||
default_persona = await self.context.persona_manager.get_default_persona_v3(
|
||||
umo=umo
|
||||
)
|
||||
curr_cid_title = "无"
|
||||
if cid:
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
unified_msg_origin=umo,
|
||||
conversation_id=cid,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
if conv is None:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前对话不存在,请先使用 /new 新建一个对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
if not conv.persona_id and conv.persona_id != "[%None]":
|
||||
curr_persona_name = default_persona["name"]
|
||||
else:
|
||||
curr_persona_name = conv.persona_id
|
||||
|
||||
curr_cid_title = conv.title if conv.title else "新对话"
|
||||
curr_cid_title += f"({cid[:4]})"
|
||||
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult()
|
||||
.message(
|
||||
f"""[Persona]
|
||||
|
||||
- 人格情景列表: `/persona list`
|
||||
- 设置人格情景: `/persona 人格`
|
||||
- 人格情景详细信息: `/persona view 人格`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
默认人格情景: {default_persona["name"]}
|
||||
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
"""
|
||||
)
|
||||
.use_t2i(False)
|
||||
)
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for persona in self.context.provider_manager.personas:
|
||||
msg += f"- {persona['name']}\n"
|
||||
msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "view":
|
||||
if len(l) == 2:
|
||||
message.set_result(MessageEventResult().message("请输入人格情景名"))
|
||||
return
|
||||
ps = l[2].strip()
|
||||
if persona := next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == ps,
|
||||
self.context.provider_manager.personas,
|
||||
),
|
||||
None,
|
||||
):
|
||||
msg = f"人格{ps}的详细信息:\n"
|
||||
msg += f"{persona['prompt']}\n"
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "unset":
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message("当前没有对话,无法取消人格。")
|
||||
)
|
||||
return
|
||||
await self.context.conversation_manager.update_conversation_persona_id(
|
||||
message.unified_msg_origin, "[%None]"
|
||||
)
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前没有对话,请先开始对话或使用 /new 创建一个对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
if persona := next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == ps,
|
||||
self.context.provider_manager.personas,
|
||||
),
|
||||
None,
|
||||
):
|
||||
await self.context.conversation_manager.update_conversation_persona_id(
|
||||
message.unified_msg_origin, ps
|
||||
)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"不存在该人格情景。使用 /persona list 查看所有。"
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,117 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core import DEMO_MODE, logger
|
||||
|
||||
|
||||
class PluginCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
"""获取已经安装的插件列表。"""
|
||||
plugin_list_info = "已加载的插件:\n"
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
|
||||
if not plugin.activated:
|
||||
plugin_list_info += " (未启用)"
|
||||
plugin_list_info += "\n"
|
||||
if plugin_list_info.strip() == "":
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
|
||||
plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False)
|
||||
)
|
||||
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
"""禁用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
|
||||
return
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin off <插件名> 禁用插件。")
|
||||
)
|
||||
return
|
||||
await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
|
||||
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
"""启用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
|
||||
return
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin on <插件名> 启用插件。")
|
||||
)
|
||||
return
|
||||
await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
|
||||
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
|
||||
"""安装插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
|
||||
return
|
||||
if not plugin_repo:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin get <插件仓库地址> 安装插件")
|
||||
)
|
||||
return
|
||||
logger.info(f"准备从 {plugin_repo} 安装插件。")
|
||||
if self.context._star_manager:
|
||||
star_mgr: PluginManager = self.context._star_manager
|
||||
try:
|
||||
await star_mgr.install_plugin(plugin_repo) # type: ignore
|
||||
event.set_result(MessageEventResult().message("安装插件成功。"))
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}")
|
||||
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
|
||||
return
|
||||
|
||||
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
"""获取插件帮助"""
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
MessageEventResult().message("/plugin help <插件名> 查看插件信息。")
|
||||
)
|
||||
return
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if plugin is None:
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
return
|
||||
help_msg = ""
|
||||
help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}"
|
||||
command_handlers = []
|
||||
command_names = []
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
if handler.handler_module_path != plugin.module_path:
|
||||
continue
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
command_handlers.append(handler)
|
||||
command_names.append(filter_.command_name)
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
command_handlers.append(handler)
|
||||
command_names.append(filter_.group_name)
|
||||
|
||||
if len(command_handlers) > 0:
|
||||
help_msg += "\n\n🔧 指令列表:\n"
|
||||
for i in range(len(command_handlers)):
|
||||
help_msg += f"- {command_names[i]}"
|
||||
if command_handlers[i].desc:
|
||||
help_msg += f": {command_handlers[i].desc}"
|
||||
help_msg += "\n"
|
||||
|
||||
help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。"
|
||||
|
||||
ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg
|
||||
ret += "更多帮助信息请查看插件仓库 README。"
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
@@ -0,0 +1,201 @@
|
||||
import re
|
||||
from typing import Union
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
|
||||
|
||||
class ProviderCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def provider(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
idx: Union[str, int, None] = None,
|
||||
idx2: Union[int, None] = None,
|
||||
):
|
||||
"""查看或者切换 LLM Provider"""
|
||||
umo = event.unified_msg_origin
|
||||
|
||||
if idx is None:
|
||||
ret = "## 载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
id_ = llm.meta().id
|
||||
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
|
||||
provider_using = self.context.get_using_provider(umo=umo)
|
||||
if provider_using and provider_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
tts_providers = self.context.get_all_tts_providers()
|
||||
if tts_providers:
|
||||
ret += "\n## 载入的 TTS 提供商\n"
|
||||
for idx, tts in enumerate(tts_providers):
|
||||
id_ = tts.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
tts_using = self.context.get_using_tts_provider(umo=umo)
|
||||
if tts_using and tts_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
stt_providers = self.context.get_all_stt_providers()
|
||||
if stt_providers:
|
||||
ret += "\n## 载入的 STT 提供商\n"
|
||||
for idx, stt in enumerate(stt_providers):
|
||||
id_ = stt.meta().id
|
||||
ret += f"{idx + 1}. {id_}"
|
||||
stt_using = self.context.get_using_stt_provider(umo=umo)
|
||||
if stt_using and stt_using.meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
ret += "\n使用 /provider <序号> 切换 LLM 提供商。"
|
||||
|
||||
if tts_providers:
|
||||
ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。"
|
||||
if stt_providers:
|
||||
ret += "\n使用 /provider stt <切换> STT 提供商。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
elif idx == "tts":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif idx == "stt":
|
||||
if idx2 is None:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
else:
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
provider_id=id_,
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("无效的参数。"))
|
||||
|
||||
async def model_ls(
|
||||
self, message: AstrMessageEvent, idx_or_name: Union[int, str, None] = None
|
||||
):
|
||||
"""查看或者切换模型"""
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
# 定义正则表达式匹配 API 密钥
|
||||
api_key_pattern = re.compile(r"key=[^&'\" ]+")
|
||||
|
||||
if idx_or_name is None:
|
||||
models = []
|
||||
try:
|
||||
models = await prov.get_models()
|
||||
except BaseException as e:
|
||||
err_msg = api_key_pattern.sub("key=***", str(e))
|
||||
message.set_result(
|
||||
MessageEventResult()
|
||||
.message("获取模型列表失败: " + err_msg)
|
||||
.use_t2i(False)
|
||||
)
|
||||
return
|
||||
i = 1
|
||||
ret = "下面列出了此服务提供商可用模型:"
|
||||
for model in models:
|
||||
ret += f"\n{i}. {model}"
|
||||
i += 1
|
||||
|
||||
curr_model = prov.get_model() or "无"
|
||||
ret += f"\n当前模型: [{curr_model}]"
|
||||
|
||||
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
if isinstance(idx_or_name, int):
|
||||
models = []
|
||||
try:
|
||||
models = await prov.get_models()
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message("获取模型列表失败: " + str(e))
|
||||
)
|
||||
return
|
||||
if idx_or_name > len(models) or idx_or_name < 1:
|
||||
message.set_result(MessageEventResult().message("模型序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_model = models[idx_or_name - 1]
|
||||
prov.set_model(new_model)
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message("切换模型未知错误: " + str(e))
|
||||
)
|
||||
message.set_result(MessageEventResult().message("切换模型成功。"))
|
||||
else:
|
||||
prov.set_model(idx_or_name)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换模型到 {prov.get_model()}。")
|
||||
)
|
||||
|
||||
async def key(self, message: AstrMessageEvent, index: Union[int, None] = None):
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
|
||||
if index is None:
|
||||
keys_data = prov.get_keys()
|
||||
curr_key = prov.get_current_key()
|
||||
ret = "Key:"
|
||||
for i, k in enumerate(keys_data):
|
||||
ret += f"\n{i + 1}. {k[:8]}"
|
||||
|
||||
ret += f"\n当前 Key: {curr_key[:8]}"
|
||||
ret += "\n当前模型: " + prov.get_model()
|
||||
ret += "\n使用 /key <idx> 切换 Key。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
keys_data = prov.get_keys()
|
||||
if index > len(keys_data) or index < 1:
|
||||
message.set_result(MessageEventResult().message("Key 序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_key = keys_data[index - 1]
|
||||
prov.set_key(new_key)
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换 Key 未知错误: {str(e)}")
|
||||
)
|
||||
message.set_result(MessageEventResult().message("切换 Key 成功。"))
|
||||
@@ -0,0 +1,37 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
|
||||
|
||||
class SetUnsetCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
"""设置会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
session_var[key] = value
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。"
|
||||
)
|
||||
)
|
||||
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
"""移除会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
|
||||
if key not in session_var:
|
||||
event.set_result(
|
||||
MessageEventResult().message("没有那个变量名。格式 /unset 变量名。")
|
||||
)
|
||||
else:
|
||||
del session_var[key]
|
||||
await sp.session_put(uid, "session_variables", session_var)
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。")
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
"""会话ID命令"""
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class SIDCommand:
|
||||
"""会话ID命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
"""获取会话 ID 和 管理员 ID"""
|
||||
sid = event.unified_msg_origin
|
||||
user_id = str(event.get_sender_id())
|
||||
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。
|
||||
/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
|
||||
|
||||
UID: {user_id} 此 ID 可用于设置管理员。
|
||||
/op <UID> 授权管理员, /deop <UID> 取消管理员。"""
|
||||
|
||||
if (
|
||||
self.context.get_config()["platform_settings"]["unique_session"]
|
||||
and event.get_group_id()
|
||||
):
|
||||
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
@@ -0,0 +1,23 @@
|
||||
"""文本转图片命令"""
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class T2ICommand:
|
||||
"""文本转图片命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
"""开关文本转图片"""
|
||||
config = self.context.get_config(umo=event.unified_msg_origin)
|
||||
if config["t2i"]:
|
||||
config["t2i"] = False
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
|
||||
return
|
||||
config["t2i"] = True
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
|
||||
@@ -0,0 +1,31 @@
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class ToolCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
"""查看函数工具列表"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
|
||||
async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""):
|
||||
"""启用一个函数工具"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
|
||||
async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""):
|
||||
"""停用一个函数工具"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
|
||||
async def tool_all_off(self, event: AstrMessageEvent):
|
||||
"""停用所有函数工具"""
|
||||
event.set_result(
|
||||
MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。")
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""文本转语音命令"""
|
||||
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
|
||||
class TTSCommand:
|
||||
"""文本转语音命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
tts_enable = cfg["provider_tts_settings"]["enable"]
|
||||
|
||||
# 切换状态
|
||||
new_status = not ses_tts
|
||||
SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
|
||||
if new_status and not tts_enable:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。"
|
||||
)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{status_text}当前会话的文本转语音。")
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RstScene(Enum):
|
||||
GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启")
|
||||
GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭")
|
||||
PRIVATE = ("private", "私聊")
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
return self.value[0]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.value[1]
|
||||
|
||||
@classmethod
|
||||
def from_index(cls, index: int) -> "RstScene":
|
||||
mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE}
|
||||
return mapping[index]
|
||||
|
||||
@classmethod
|
||||
def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene":
|
||||
if is_group:
|
||||
return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF
|
||||
return cls.PRIVATE
|
||||
+87
-1215
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,195 @@
|
||||
import astrbot.api.star as star
|
||||
import builtins
|
||||
import datetime
|
||||
import zoneinfo
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
|
||||
|
||||
class ProcessLLMRequest:
|
||||
def __init__(self, context: star.Context):
|
||||
self.ctx = context
|
||||
cfg = context.get_config()
|
||||
self.timezone = cfg.get("timezone")
|
||||
if not self.timezone:
|
||||
# 系统默认时区
|
||||
self.timezone = None
|
||||
else:
|
||||
logger.info(f"Timezone set to: {self.timezone}")
|
||||
|
||||
def _ensure_persona(self, req: ProviderRequest, cfg: dict):
|
||||
"""确保用户人格已加载"""
|
||||
if not req.conversation:
|
||||
return
|
||||
# persona inject
|
||||
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
||||
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
|
||||
default_persona = self.ctx.persona_manager.selected_default_persona_v3
|
||||
if default_persona:
|
||||
persona_id = default_persona["name"]
|
||||
persona = next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == persona_id,
|
||||
self.ctx.persona_manager.personas_v3,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if persona:
|
||||
if prompt := persona["prompt"]:
|
||||
req.system_prompt += prompt
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# tools select
|
||||
tmgr = self.ctx.get_llm_tool_manager()
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
# select all
|
||||
toolset = tmgr.get_full_tool_set()
|
||||
for tool in toolset:
|
||||
if not tool.active:
|
||||
toolset.remove_tool(tool.name)
|
||||
else:
|
||||
toolset = ToolSet()
|
||||
if persona["tools"]:
|
||||
for tool_name in persona["tools"]:
|
||||
tool = tmgr.get_func(tool_name)
|
||||
if tool and tool.active:
|
||||
toolset.add_tool(tool)
|
||||
req.func_tool = toolset
|
||||
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
|
||||
|
||||
async def _ensure_img_caption(
|
||||
self, req: ProviderRequest, cfg: dict, img_cap_prov_id: str
|
||||
):
|
||||
try:
|
||||
caption = await self._request_img_caption(
|
||||
img_cap_prov_id, cfg, req.image_urls
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
|
||||
async def _request_img_caption(
|
||||
self, provider_id: str, cfg: dict, image_urls: list[str]
|
||||
) -> str:
|
||||
if prov := self.ctx.get_provider_by_id(provider_id):
|
||||
if isinstance(prov, Provider):
|
||||
img_cap_prompt = cfg.get(
|
||||
"image_caption_prompt", "Please describe the image."
|
||||
)
|
||||
logger.debug(f"Processing image caption with provider: {provider_id}")
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt=img_cap_prompt,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
return llm_resp.completion_text
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not exist."
|
||||
)
|
||||
|
||||
async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_settings"
|
||||
]
|
||||
|
||||
# prompt prefix
|
||||
if prefix := cfg.get("prompt_prefix"):
|
||||
# 支持 {{prompt}} 作为用户输入的占位符
|
||||
if "{{prompt}}" in prefix:
|
||||
req.prompt = prefix.replace("{{prompt}}", req.prompt)
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
current_time = None
|
||||
if self.timezone:
|
||||
# 启用时区
|
||||
try:
|
||||
now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone))
|
||||
current_time = now.strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
except Exception as e:
|
||||
logger.error(f"时区设置错误: {e}, 使用本地时区")
|
||||
if not current_time:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
# inject persona for this request
|
||||
self._ensure_persona(req, cfg)
|
||||
|
||||
# image caption
|
||||
if img_cap_prov_id and req.image_urls:
|
||||
await self._ensure_img_caption(req, cfg, img_cap_prov_id)
|
||||
|
||||
# quote message processing
|
||||
# 解析引用内容
|
||||
quote = None
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Reply):
|
||||
quote = comp
|
||||
break
|
||||
if quote:
|
||||
sender_info = ""
|
||||
if quote.sender_nickname:
|
||||
sender_info = f"(Sent by {quote.sender_nickname})"
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
req.system_prompt += (
|
||||
f"\nUser is quoting a message{sender_info}.\n"
|
||||
f"Here are the information of the quoted message: Text Content: {message_str}.\n"
|
||||
)
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
if image_seg:
|
||||
try:
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = self.ctx.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = self.ctx.get_using_provider(event.unified_msg_origin)
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
req.system_prompt += (
|
||||
f"Image Caption: {llm_resp.completion_text}\n"
|
||||
)
|
||||
else:
|
||||
logger.warning("No provider found for image captioning.")
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
@@ -205,13 +205,14 @@ class Main(star.Star):
|
||||
return
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
if comp.file.startswith("http"):
|
||||
file_path = await comp.get_file()
|
||||
if file_path.startswith("http"):
|
||||
name = comp.name if comp.name else uuid.uuid4().hex[:8]
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(comp.file, path)
|
||||
await download_file(file_path, path)
|
||||
else:
|
||||
path = comp.file
|
||||
path = file_path
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(path)
|
||||
logger.debug(f"User {uid} uploaded file: {path}")
|
||||
yield event.plain_result(f"代码执行器: 文件已经上传: {path}")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from googlesearch import search
|
||||
from googlesearch.asearch import asearch
|
||||
|
||||
from . import SearchEngine, SearchResult
|
||||
|
||||
@@ -14,14 +14,14 @@ class Google(SearchEngine):
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = []
|
||||
try:
|
||||
ls = search(
|
||||
ls = asearch(
|
||||
query,
|
||||
advanced=True,
|
||||
num_results=num_results,
|
||||
timeout=3,
|
||||
proxy=self.proxy,
|
||||
)
|
||||
for i in ls:
|
||||
async for i in ls:
|
||||
results.append(
|
||||
SearchResult(title=i.title, url=i.url, snippet=i.description)
|
||||
)
|
||||
|
||||
@@ -46,7 +46,13 @@ class Main(star.Star):
|
||||
|
||||
self.bing_search = Bing()
|
||||
self.sogo_search = Sogo()
|
||||
self.google = Google()
|
||||
self.google = None
|
||||
try:
|
||||
self.google = Google()
|
||||
except Exception as e:
|
||||
logger.error(f"google search init error: {e}, disable google search")
|
||||
|
||||
self.baidu_initialized = False
|
||||
|
||||
async def _tidy_text(self, text: str) -> str:
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
@@ -89,10 +95,11 @@ class Main(star.Star):
|
||||
self, query, num_results: int = 5
|
||||
) -> list[SearchResult]:
|
||||
results = []
|
||||
try:
|
||||
results = await self.google.search(query, num_results)
|
||||
except Exception as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if self.google:
|
||||
try:
|
||||
results = await self.google.search(query, num_results)
|
||||
except Exception as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search google failed")
|
||||
try:
|
||||
@@ -220,6 +227,30 @@ class Main(star.Star):
|
||||
|
||||
return ret
|
||||
|
||||
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None):
|
||||
if self.baidu_initialized:
|
||||
return
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
key = cfg.get("provider_settings", {}).get(
|
||||
"websearch_baidu_app_builder_key", ""
|
||||
)
|
||||
if not key:
|
||||
raise ValueError(
|
||||
"Error: Baidu AI Search API key is not configured in AstrBot."
|
||||
)
|
||||
func_tool_mgr = self.context.get_llm_tool_manager()
|
||||
await func_tool_mgr.enable_mcp_server(
|
||||
"baidu_ai_search",
|
||||
config={
|
||||
"transport": "sse",
|
||||
"url": f"http://appbuilder.baidu.com/v2/ai_search/mcp/sse?api_key={key}",
|
||||
"headers": {},
|
||||
"timeout": 30,
|
||||
},
|
||||
)
|
||||
self.baidu_initialized = True
|
||||
logger.info("Successfully initialized Baidu AI Search MCP server.")
|
||||
|
||||
@llm_tool(name="fetch_url")
|
||||
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
|
||||
"""fetch the content of a website with the given web url
|
||||
@@ -366,6 +397,7 @@ class Main(star.Star):
|
||||
tool_set.add_tool(fetch_url_t)
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
elif provider == "tavily":
|
||||
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
|
||||
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
|
||||
@@ -375,3 +407,17 @@ class Main(star.Star):
|
||||
tool_set.add_tool(tavily_extract_web_page)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
elif provider == "baidu_ai_search":
|
||||
try:
|
||||
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
|
||||
aisearch_tool = func_tool_mgr.get_func("AIsearch")
|
||||
if not aisearch_tool:
|
||||
raise ValueError("Cannot get Baidu AI Search MCP tool.")
|
||||
tool_set.add_tool(aisearch_tool)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
except Exception as e:
|
||||
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.2.1"
|
||||
version = "4.3.5"
|
||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -24,7 +24,7 @@ dependencies = [
|
||||
"faiss-cpu==1.10.0",
|
||||
"filelock>=3.18.0",
|
||||
"google-genai>=1.14.0",
|
||||
"googlesearch-python>=1.3.0",
|
||||
"mi-googlesearch-python==1.3.0.post1",
|
||||
"lark-oapi>=1.4.15",
|
||||
"lxml-html-clean>=0.4.2",
|
||||
"mcp>=1.8.0",
|
||||
|
||||
+2
-2
@@ -7,7 +7,7 @@ qq-botpy
|
||||
chardet~=5.1.0
|
||||
Pillow
|
||||
beautifulsoup4
|
||||
googlesearch-python
|
||||
mi-googlesearch-python
|
||||
readability-lxml
|
||||
quart
|
||||
lxml_html_clean
|
||||
@@ -43,4 +43,4 @@ pydub
|
||||
sqlmodel
|
||||
deprecated
|
||||
sqlalchemy[asyncio]
|
||||
audioop-lts; python_version>='3.13'
|
||||
audioop-lts; python_version>='3.13'
|
||||
|
||||
Reference in New Issue
Block a user