Compare commits

...

87 Commits

Author SHA1 Message Date
Soulter d6239822db release: v3.5.3.2 2025-04-12 15:27:33 +08:00
Soulter bced9ffff9 🐛 fix: 修复zhipu工具调用问题 2025-04-12 15:24:37 +08:00
Soulter d7d1c1544a 🐛 fix: 修复重启bot时可能发生报错的问题
在 gewechat, wecom 等消息平台没启动成功的情况下重启bot会报错
2025-04-12 15:01:38 +08:00
Soulter e3b0ca8ef6 🐛 fix: 改进版本号比较逻辑以支持任意长度的版本号 2025-04-12 10:00:25 +08:00
Soulter 9e266eb6d5 release: v3.5.3.1 2025-04-12 09:48:49 +08:00
Soulter 7231403e16 🐛 fix: xai missing field parameters 2025-04-12 09:47:11 +08:00
Soulter 344a486fd7 fix: entites 前向兼容 2025-04-12 09:10:54 +08:00
Soulter 4fd831875d Merge pull request #1237 from AstrBotDevs/release/v3.5.3
📦 release: v3.5.3
2025-04-12 01:04:31 +08:00
Soulter 0988d067ea 📦 release: v3.5.3 2025-04-12 00:58:45 +08:00
Soulter 3b6dd7e15a 🐛 fix: 修复 dify 下删除对话的报错问题
fixes: #1226
2025-04-11 17:27:29 +08:00
Soulter 757d2a3947 🐛 fix: 更新 Dify API 类型提示,增加对 Chatflow 应用类型的说明 2025-04-11 17:23:26 +08:00
Soulter 61b71143f2 Merge pull request #1223 from MR-pofeng/tag-msg-seq
feat:为QQ官方接口需要msg_seq的playload添加随机msg_seq
2025-04-11 16:25:46 +08:00
Soulter 1b343a36c9 Merge pull request #1174 from anka-afk/anka-dev
对关闭的#1167提供完整修复, 修复gemini请求content为空的情况, 增加上下文中验证toolcall逻辑
2025-04-11 16:20:30 +08:00
Soulter 8e94937060 🐛 fix: 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具
fixes: #863 #1150
2025-04-11 15:50:36 +08:00
Soulter a4f212a18f 🐛 fix: 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题
fixes: #1060
2025-04-11 00:20:08 +08:00
Soulter caafb73190 🐛 fix: 修复函数调用的一些bug 2025-04-10 23:28:51 +08:00
kuangfeng 09482799c9 feat:为需要msg_seq的playload添加随机msg_seq 2025-04-10 21:43:12 +08:00
Soulter 37f93d1760 Merge pull request #1175 from Raven95676/telegram
feat: 自动注册指令到Telegram
2025-04-10 20:26:54 +08:00
Soulter 725f2e5204 Merge pull request #1212 from AstrBotDevs/feat-lark-active-message
 feat: 支持飞书平台下主动消息发送
2025-04-10 17:14:37 +08:00
Soulter 967198fae0 feat: 支持飞书平台下主动消息发送
fixes: #1177

WARNING:
这个修复会导致开启对话隔离下飞书群组的对话记录丢失(但没有被删除)。
2025-04-10 17:12:26 +08:00
Soulter 43d57f6dcb 🎈 perf: Add type validation for configuration items in validate_config function 2025-04-10 15:56:14 +08:00
Soulter 6afa4db577 Merge pull request #1208 from Rail1bc/fix_begin_dialogs
fix:使 begin_dialogs ,预设对话,不会多次插入
2025-04-10 15:32:10 +08:00
Soulter 3b8c3fb29a Merge pull request #1207 from zsbai/patch-1
修复了 `event.get_sender_id()` 返回值与函数注释不一致的问题
2025-04-10 15:27:14 +08:00
Soulter 921c3b0627 Merge pull request #1203 from Rail1bc/master
将一项优化插件的简单逻辑,适配到Core中
2025-04-10 15:25:00 +08:00
Raila23 c0fadb45ab 添加更详细的描述 2025-04-10 15:20:56 +08:00
Raven95676 a1481fb179 群聊场景命令特殊处理 2025-04-10 14:54:25 +08:00
Soulter 987cd972d3 Merge pull request #1180 from Raven95676/reload
perf: 确保完整处理插件所有模块。
2025-04-10 14:45:28 +08:00
anka bdf25976a3 fix: 少打一个字 2025-04-10 11:28:47 +08:00
anka 87c3aff4ce perf: 简化llm_request工具调用消息成对验证逻辑, 合并两处验证逻辑到一个函数 2025-04-10 11:25:03 +08:00
anka 99350a957a Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-10 11:16:49 +08:00
Soulter 319068dc7e Merge pull request #1179 from zhx8702/feat-platform-plugin-control
feat: 添加插件能针对不同消息平台开启关闭的功能
2025-04-10 11:02:09 +08:00
Soulter cd18806c39 perf: improve platform compatibility checks 2025-04-10 11:01:04 +08:00
Raila23 95b08b2023 fix:使 begin_dialogs ,预设对话,不会多次插入 2025-04-10 09:18:58 +08:00
baiiylu 0e70f76c86 fix: wrong type of sender_id returned in event.get_sender_id() 2025-04-10 08:03:38 +08:00
Raila23 4d414a2994 增加dequeue_context_length的值的判断,只能在1到max_context_length之间 2025-04-09 22:28:33 +08:00
Raila23 3d22772d4e 新增配置项,允许配置:超出最多携带对话数量 时,一次性丢弃多少条旧消息 2025-04-09 22:12:02 +08:00
Raila23 0b381e2570 新增配置项,允许配置:超出最多携带对话数量 时,一次性丢弃多少条旧消息 2025-04-09 22:10:56 +08:00
Raven95676 f2cc4311c5 fix: optional value 2025-04-09 18:55:20 +08:00
Raven95676 e349671fdf format 2025-04-09 18:45:40 +08:00
Raven95676 01c02d5efa perf: 提取模块清理逻辑到 _purge_modules 方法 2025-04-09 18:11:35 +08:00
zhx b62b1f3870 feat: 添加插件能针对不同消息平台开启关闭的功能
Squashed:

chore: merge master branch

chore: merge from master branch

chore: rename updateAllPlatformCompatibility to update_all_platform_compatibility for consistency

Reviewed by:

@Raven95676 @Soulter
2025-04-09 17:27:44 +08:00
Soulter 8844830859 Merge pull request #1194 from Raven95676/tools
feat: StarTools添加数据目录获取接口
2025-04-09 16:53:22 +08:00
Soulter 0c51ee4b64 chore: 依赖顺序 2025-04-09 16:53:06 +08:00
Soulter 11920d5e31 docs: add a badge to show plugins num 2025-04-09 16:41:32 +08:00
Raven95676 848ea1eb63 提升健壮性 2025-04-09 16:37:19 +08:00
渡鸦95676 a216519486 Merge branch 'AstrBotDevs:master' into tools 2025-04-09 16:16:26 +08:00
Raven95676 b04606c38e 新增获取数据目录的StarTool 2025-04-09 16:13:48 +08:00
Soulter 38072beea7 🎈 perf: 优化插件市场显示 2025-04-09 15:47:44 +08:00
Soulter b843f1fa03 Update PULL_REQUEST_TEMPLATE.md 2025-04-09 15:28:18 +08:00
Soulter 560d40e571 Merge pull request #1184 from kterna/master
feat:查看本地插件readme和市场插件star数
2025-04-09 15:23:50 +08:00
Soulter 5f0b8161b7 perf: 优化 WebUI Chat 的流式传输性能 2025-04-09 15:22:35 +08:00
kterna 062d482917 fix 2025-04-09 08:43:16 +08:00
anka 7cd1eeac30 fix: 直接把空字符串改为" "一条消息的content是空字符串 2025-04-08 15:57:38 +00:00
Soulter bafa473c8e Merge pull request #1157 from AstrBotDevs/feat-streaming
feature: 支持流式输出
2025-04-08 22:53:38 +08:00
Soulter 750cf46b2e 🎈 perf: better ChatPage UI 2025-04-08 17:33:46 +08:00
kterna 68885a4bbc Update astrbot/dashboard/routes/plugin.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-04-08 16:30:36 +08:00
Soulter bcc99a8904 🐛 fix: 修复 permission 过滤算子的 raise_error 参数失效的问题 2025-04-08 14:42:05 +08:00
kterna 59fbd98db3 1 2025-04-08 14:31:35 +08:00
kterna b70ed425f1 Merge branch 'master' of https://github.com/kterna/AstrBot 2025-04-08 14:05:43 +08:00
kterna 45ef5811c8 1 2025-04-08 14:02:59 +08:00
kterna 3b137ac762 插件管理中查看本地插件的readme 2025-04-08 14:01:14 +08:00
kterna 1ddb0caf73 star显示 2025-04-08 10:47:59 +08:00
Raven95676 ae4c6fe2dd 优化,确保完整处理插件所有模块。为核心方法添加文档。 2025-04-08 10:41:47 +08:00
Raven95676 db257af58e 提升代码可读性 2025-04-07 22:29:50 +08:00
Raven95676 735368c71b 保证变量名可读性 2025-04-07 22:16:02 +08:00
Raven95676 9e04e3679b 保证内置插件指令被注册 2025-04-07 22:08:29 +08:00
Raven95676 43b8414727 初步实现指令注册 2025-04-07 21:51:41 +08:00
anka 5a00187147 fix: 对历史记录的toolcall验证是否成对, 参考:
https://github.com/run-llama/llama_index/issues/13715
https://github.com/run-llama/llama_index/pull/16214
2025-04-07 18:14:30 +08:00
Raven95676 cb525c7c84 更新下hint( 2025-04-07 17:56:10 +08:00
anka d88420dd03 fix: 修改获取人类可读的上下文的逻辑, 区分函数调用(无contents)和一般消息 2025-04-07 17:55:12 +08:00
anka b9a983f8e0 fix: 为函数调用历史记录增加标记, 不读取入上下文 2025-04-07 17:45:35 +08:00
Raven95676 42431ea7db 统一text_chat_stream fallback 2025-04-07 17:43:35 +08:00
Raven95676 f9459e4abb 修复无法通过yield发送消息的问题 2025-04-07 17:38:23 +08:00
anka 72f917d611 fix: gemini只在content不为空的时候加入上下文 2025-04-07 17:31:57 +08:00
Raven95676 9fd1d19e93 分离流式与非流式响应处理 2025-04-07 11:52:29 +08:00
Raven95676 41bd76e091 tg适配器最后一次编辑转换markdown 2025-04-07 00:47:52 +08:00
Raven95676 cfd3f4b199 流式输出完成后,将完整的LLM响应设置为事件结果 2025-04-07 00:17:53 +08:00
Soulter b3866559e1 📦release: v3.5.2 2025-04-06 22:35:10 +08:00
anka 8ed3d5f3db fix: 将openai_source的结果消息链的构造方式和其他统一 2025-04-06 09:12:52 +00:00
anka f0c8f39b6d 对tg的通过编辑消息的流式传输完善错误捕获 2025-04-06 08:57:18 +00:00
anka 431db8fc9b 对流式输出做错误捕获 2025-04-06 08:47:17 +00:00
anka ba252c5356 fix: 修正一个偶然发现的命名错误() 2025-04-06 08:12:00 +00:00
Raven95676 a2812c39c0 修正文档注释 2025-04-06 16:05:21 +08:00
Raven95676 0490758820 替换原地修改和删除索引的旧逻辑 2025-04-06 15:36:05 +08:00
Soulter 9b36a5c8a6 feat: 增加全平台对流式输出的处理逻辑 2025-04-06 13:43:23 +08:00
Soulter c1cf2be533 feat: 完善流式处理 2025-04-06 11:56:06 +08:00
Soulter 109650faf3 feat: 支持流式输出 2025-04-06 00:56:33 +08:00
78 changed files with 3299 additions and 824 deletions
+4
View File
@@ -8,3 +8,7 @@
### Modifications
<!--简单解释你的改动-->
### Check
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 我新增/修复/优化的功能经过良好的测试
+2
View File
@@ -16,6 +16,8 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=10800&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=7200)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
+1 -1
View File
@@ -1,5 +1,5 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import (
from astrbot.core.provider.entities import (
ProviderRequest,
ProviderType,
ProviderMetaData,
+18 -3
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.5.2"
VERSION = "3.5.3.2"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -50,6 +50,8 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"prompt_prefix": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
},
"provider_stt_settings": {
"enable": False,
@@ -247,6 +249,9 @@ CONFIG_METADATA_2 = {
"description": "平台设置",
"type": "object",
"items": {
"plugin_enable": {
"invisible": True, # 隐藏插件启用配置
},
"unique_session": {
"description": "会话隔离",
"type": "bool",
@@ -923,8 +928,8 @@ CONFIG_METADATA_2 = {
"dify_api_type": {
"description": "Dify 应用类型",
"type": "string",
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
"options": ["chat", "agent", "workflow"],
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, chatflow, agent, workflow 三种应用类型",
"options": ["chat", "chatflow", "agent", "workflow"],
},
"dify_workflow_output_key": {
"description": "Dify Workflow 输出变量名",
@@ -993,6 +998,16 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
},
"dequeue_context_length": {
"description": "丢弃对话数量(条)",
"type": "int",
"hint": "超出 最多携带对话数量(条) 时,丢弃多少条记录,用户和AI的一轮聊天记为 1 条。适宜的配置,可以提高超长上下文对话 deepseek 命中缓存效果,理想情况下计费将降低到1/3以下",
},
"streaming_response": {
"description": "启用流式回复",
"type": "bool",
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
},
},
},
"persona": {
+9 -1
View File
@@ -175,7 +175,15 @@ class ConversationManager:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
if "content" in record and record["content"]:
temp_contexts.append(f"Assistant: {record['content']}")
elif "tool_calls" in record:
tool_calls_str = json.dumps(
record["tool_calls"], ensure_ascii=False
)
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
else:
temp_contexts.append("Assistant: [未知的内容]")
contexts.insert(0, temp_contexts)
temp_contexts = []
+7 -5
View File
@@ -141,11 +141,13 @@ class LogQueueHandler(logging.Handler):
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish({
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
})
self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
}
)
class LogManager:
+37 -1
View File
@@ -1,6 +1,6 @@
import enum
from typing import List, Optional, Union
from typing import List, Optional, Union, AsyncGenerator
from dataclasses import dataclass, field
from astrbot.core.message.components import (
BaseMessageComponent,
@@ -111,6 +111,30 @@ class MessageChain:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain:
return
new_chain = []
first_plain = None
plain_texts = []
for comp in self.chain:
if isinstance(comp, Plain):
if first_plain is None:
first_plain = comp
new_chain.append(comp)
plain_texts.append(comp.text)
else:
new_chain.append(comp)
if first_plain is not None:
first_plain.text = "".join(plain_texts)
self.chain = new_chain
return self
class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。
@@ -131,6 +155,10 @@ class ResultContentType(enum.Enum):
"""调用 LLM 产生的结果"""
GENERAL_RESULT = enum.auto()
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
STREAMING_FINISH= enum.auto()
"""流式输出完成"""
@dataclass
@@ -152,6 +180,9 @@ class MessageEventResult(MessageChain):
default_factory=lambda: ResultContentType.GENERAL_RESULT
)
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult":
"""终止事件传播。"""
self.result_type = EventResultType.STOP
@@ -168,6 +199,11 @@ class MessageEventResult(MessageChain):
"""
return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
"""设置异步流。"""
self.async_stream = stream
return self
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
"""设置事件处理的结果类型。
+3
View File
@@ -7,6 +7,7 @@ from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
@@ -18,6 +19,7 @@ STAGES_ORDER = [
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
@@ -29,6 +31,7 @@ __all__ = [
"WhitelistCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
@@ -0,0 +1,56 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)
@@ -12,11 +12,12 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
MessageChain,
)
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import (
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
@@ -37,6 +38,13 @@ class LLMRequestSubStage(Stage):
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
self.dequeue_context_length = min(
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
self.max_context_length - 1,
) # int
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -58,12 +66,16 @@ class LLMRequestSubStage(Stage):
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
assert isinstance(
req, ProviderRequest
), "provider_request 必须是 ProviderRequest 类型。"
if req.conversation:
req.contexts = json.loads(req.conversation.history)
all_contexts = json.loads(req.conversation.history)
req.contexts = self._process_tool_message_pairs(
all_contexts, remove_tags=True
)
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.provider_wake_prefix:
@@ -104,8 +116,10 @@ class LLMRequestSubStage(Stage):
# 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能
# 获取当前平台ID
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMRequestEvent
EventType.OnLLMRequestEvent, platform_id=platform_id
)
for handler in handlers:
try:
@@ -131,76 +145,135 @@ class LLMRequestSubStage(Stage):
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[-self.max_context_length * 2 :]
req.contexts = req.contexts[
-(self.max_context_length - self.dequeue_context_length) * 2 :
]
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
async def requesting(req: ProviderRequest):
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
final_llm_response = None
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
async for result in self._handle_llm_response(event, req, llm_response):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
if self.streaming_response:
stream = provider.text_chat_stream(**req.__dict__)
async for llm_response in stream:
if llm_response.is_chunk:
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
else:
yield MessageChain().message(
llm_response.completion_text
)
else:
final_llm_response = llm_response
else:
yield
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
if not final_llm_response:
raise Exception("LLM response is None.")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if self.streaming_response:
# 流式输出的处理
async for result in self._handle_llm_stream_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
else:
# 非流式输出的处理
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
)
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
# 保存到历史记录
await self._save_to_history(event, req, final_llm_response)
except BaseException as e:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
if not self.streaming_response:
event.set_extra("tool_call_result", None)
async for _ in requesting(req):
yield
else:
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting(req))
)
return
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
yield
if event.get_extra("tool_call_result"):
event.set_result(event.get_extra("tool_call_result"))
event.set_extra("tool_call_result", None)
yield
async def _handle_llm_response(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
) -> AsyncGenerator[None, None]:
"""处理 LLM 响应。
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理非流式 LLM 响应。
Returns:
bool: 是否需要继续调用 LLM
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
Yields:
Iterator[bool]: 将 event 交付给下一个 stage
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
@@ -223,83 +296,152 @@ class LLMRequestSubStage(Stage):
)
)
elif llm_response.role == "tool":
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_llm_stream_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理流式 LLM 响应。
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.STREAMING_FINISH)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.STREAMING_FINISH)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_function_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理函数工具调用。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
"""
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
client = req.func_tool.mcp_client_dict[
func_tool.mcp_server_name
]
res = await client.session.call_tool(
func_tool.name, func_tool_args
else:
# 获取处理器,过滤掉平台不兼容的处理器
platform_id = event.get_platform_id()
star_md = star_map.get(func_tool.handler_module_path)
if (
star_md and
platform_id in star_md.supported_platforms
and not star_md.supported_platforms[platform_id]
):
logger.debug(
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
# 直接跳过,不添加任何消息到tool_call_result
continue
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
content=resp,
)
)
else:
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
else:
res = event.get_result()
if res and res.chain:
event.set_extra("tool_call_result", res)
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
@@ -309,12 +451,22 @@ class LLMRequestSubStage(Stage):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
contexts = req.contexts.copy()
contexts.append(await req.assemble_context())
# tool calls result
# 记录并标记函数调用结果
if req.tool_calls_result:
contexts.extend(req.tool_calls_result.to_openai_messages())
tool_calls_messages = req.tool_calls_result.to_openai_messages()
# 添加标记
for message in tool_calls_messages:
message["_tool_call_history"] = True
processed_tool_messages = self._process_tool_message_pairs(
tool_calls_messages, remove_tags=False
)
contexts.extend(processed_tool_messages)
contexts.append(
{"role": "assistant", "content": llm_response.completion_text}
@@ -325,3 +477,59 @@ class LLMRequestSubStage(Stage):
await self.conv_manager.update_conversation(
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
)
def _process_tool_message_pairs(self, messages, remove_tags=True):
"""处理工具调用消息,确保assistant和tool消息成对出现
Args:
messages (list): 消息列表
remove_tags (bool): 是否移除_tool_call_history标记
Returns:
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
"""
result = []
i = 0
while i < len(messages):
current_msg = messages[i]
# 普通消息直接添加
if "_tool_call_history" not in current_msg:
result.append(current_msg.copy() if remove_tags else current_msg)
i += 1
continue
# 工具调用消息成对处理
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
assistant_msg = current_msg.copy()
if remove_tags and "_tool_call_history" in assistant_msg:
del assistant_msg["_tool_call_history"]
related_tools = []
j = i + 1
while (
j < len(messages)
and messages[j].get("role") == "tool"
and "_tool_call_history" in messages[j]
):
tool_msg = messages[j].copy()
if remove_tags:
del tool_msg["_tool_call_history"]
related_tools.append(tool_msg)
j += 1
# 成对的时候添加到结果
if related_tools:
result.append(assistant_msg)
result.extend(related_tools)
i = j # 跳过已处理
else:
# 单独的tool消息
i += 1
return result
@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
)
if not handlers_parsed_params:
handlers_parsed_params = {}
for handler in activated_handlers:
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
+1 -1
View File
@@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core import logger
+21 -6
View File
@@ -7,7 +7,7 @@ from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -18,7 +18,9 @@ from astrbot.core.star.star import star_map
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
Comp.Plain: lambda comp: bool(
comp.text and comp.text.strip()
), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
@@ -31,13 +33,17 @@ class RespondStage(Stage):
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
Comp.Contact: lambda comp: True, # 联系人(未完成)
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
Comp.Music: lambda comp: bool(comp._type)
and bool(comp.url)
and bool(comp.audio), # 音乐
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.RedBag: lambda comp: bool(comp.title), # 红包
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
Comp.Node: lambda comp: bool(comp.name)
and comp.uin != 0
and bool(comp.content), # 一个转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
Comp.Json: lambda comp: bool(comp.data), # JSON
@@ -132,8 +138,17 @@ class RespondStage(Stage):
result = event.get_result()
if result is None:
return
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
if len(result.chain) > 0:
if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果直接交付平台适配器处理
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream)
await event._post_send()
return
elif len(result.chain) > 0:
await event._pre_send()
# 检查消息链是否为空
@@ -183,7 +198,7 @@ class RespondStage(Stage):
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:
+17 -1
View File
@@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
@@ -72,11 +73,17 @@ class ResultDecorateStage(Stage):
if result is None or not result.chain:
return
if result.result_content_type == ResultContentType.STREAMING_RESULT:
return
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
# 回复时检查内容安全
if (
self.content_safe_check_reply
and self.content_safe_check_stage
and result.is_llm_result()
and not is_stream # 流式输出不检查内容安全
):
text = ""
for comp in result.chain:
@@ -89,13 +96,17 @@ class ResultDecorateStage(Stage):
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
if is_stream:
logger.warning(
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
logger.debug(
@@ -110,6 +121,11 @@ class ResultDecorateStage(Stage):
)
return
# 流式输出不执行下面的逻辑
if is_stream:
logger.info("流式输出已启用,跳过结果装饰阶段")
return
# 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result()
if result is None:
@@ -1,5 +1,6 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
@@ -93,6 +94,7 @@ class WakingCheckStage(Stage):
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
permission_filter_raise_error = False
if len(handler.event_filters) == 0:
continue
@@ -101,6 +103,7 @@ class WakingCheckStage(Stage):
if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
permission_filter_raise_error = filter.raise_error
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -117,6 +120,9 @@ class WakingCheckStage(Stage):
break
if passed:
if permission_not_pass:
if not permission_filter_raise_error:
# 跳过
continue
if self.no_permission_reply:
await event.send(
MessageChain().message(
@@ -124,6 +130,9 @@ class WakingCheckStage(Stage):
)
)
await event._post_send()
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
event.stop_event()
return
+14 -2
View File
@@ -1,7 +1,7 @@
import abc
import asyncio
from dataclasses import dataclass
from typing import List, Union, Optional
from typing import List, Union, Optional, AsyncGenerator
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
@@ -16,7 +16,7 @@ from astrbot.core.message.components import (
)
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
@@ -81,6 +81,9 @@ class AstrMessageEvent(abc.ABC):
def get_platform_name(self):
return self.platform_meta.name
def get_platform_id(self):
return self.platform_meta.id
def get_message_str(self) -> str:
"""
获取消息字符串。
@@ -202,6 +205,15 @@ class AstrMessageEvent(abc.ABC):
"""
return self.role == "admin"
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
@@ -7,6 +7,8 @@ class PlatformMetadata:
"""平台的名称"""
description: str
"""平台的描述"""
id: str = None
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict = None
"""平台的默认配置模板"""
@@ -82,6 +82,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
group_id = int(group_id)
@@ -39,8 +39,9 @@ class AiocqhttpAdapter(Platform):
self.port = platform_config["ws_reverse_port"]
self.metadata = PlatformMetadata(
"aiocqhttp",
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
)
self.bot = CQHttp(
@@ -109,7 +110,7 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.type = MessageType.OTHER_MESSAGE
if "group_id" in event and event["group_id"]:
abm.type = MessageType.GROUP_MESSAGE
@@ -129,7 +130,7 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 通知类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.type = MessageType.OTHER_MESSAGE
if "group_id" in event and event["group_id"]:
abm.group_id = str(event.group_id)
@@ -73,8 +73,9 @@ class DingtalkPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"dingtalk",
"钉钉机器人官方 API 适配器",
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
)
async def convert_msg(
@@ -24,7 +24,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
None,
client.reply_markdown,
"AstrBot",
segment.text,
self.message_obj.raw_message,
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
@@ -56,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
@@ -216,3 +216,16 @@ class GewechatPlatformEvent(AstrMessageEvent):
group_owner=data.get("chatRoomOwner"),
members=members,
)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
@@ -60,13 +60,17 @@ class GewechatPlatformAdapter(Platform):
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"gewechat",
"基于 gewechat 的 Wechat 适配器",
name="gewechat",
description="基于 gewechat 的 Wechat 适配器",
id=self.config.get("id"),
)
async def terminate(self):
self.client.shutdown_event.set()
await self.client.server.shutdown()
try:
await self.client.server.shutdown()
except Exception as _:
pass
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self):
@@ -2,6 +2,7 @@ import base64
import asyncio
import json
import re
import uuid
import astrbot.api.message_components as Comp
from astrbot.api.platform import (
@@ -66,12 +67,47 @@ class LarkPlatformAdapter(Platform):
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
raise NotImplementedError("Lark 适配器不支持 send_by_session")
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
"title": "",
"content": res,
}
}
if session.message_type == MessageType.GROUP_MESSAGE:
id_type = "chat_id"
if "%" in session.session_id:
session.session_id = session.session_id.split("%")[1]
else:
id_type = "open_id"
request = (
CreateMessageRequest.builder()
.receive_id_type(id_type)
.request_body(
CreateMessageRequestBody.builder()
.receive_id(session.session_id)
.content(json.dumps(wrapped))
.msg_type("post")
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
response = await self.lark_api.im.v1.message.acreate(request)
if not response.success():
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"lark",
"飞书机器人官方 API 适配器",
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
@@ -165,7 +201,10 @@ class LarkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.sender.user_id
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
else:
abm.session_id = abm.sender.user_id
logger.debug(abm)
await self.handle_msg(abm)
@@ -91,3 +91,16 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
@@ -2,6 +2,7 @@ import botpy
import botpy.message
import botpy.types
import botpy.types.message
import asyncio
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
@@ -9,6 +10,8 @@ from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
from astrbot.api import logger
from botpy.types import message
import random
class QQOfficialMessageEvent(AstrMessageEvent):
@@ -30,8 +33,45 @@ class QQOfficialMessageEvent(AstrMessageEvent):
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
"""QQ 官方 API 仅支持回复一次"""
async def send_streaming(self, generator):
"""流式输出仅支持消息列表私聊"""
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
try:
async for chain in generator:
source = self.message_obj.raw_message
if not self.send_buffer:
self.send_buffer = chain
else:
self.send_buffer.chain.extend(chain.chain)
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
self.send_buffer = None
return await super().send_streaming(generator)
async def _post_send(self, stream: dict = None):
if not self.send_buffer:
return
source = self.message_obj.raw_message
assert isinstance(
source,
@@ -57,6 +97,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"msg_id": self.message_obj.message_id,
}
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
payload["msg_seq"] = random.randint(1, 10000)
match type(source):
case botpy.message.GroupMessage:
if image_base64:
@@ -65,7 +108,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
await self.bot.api.post_group_message(
ret = await self.bot.api.post_group_message(
group_openid=source.group_openid, **payload
)
case botpy.message.C2CMessage:
@@ -75,22 +118,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
await self.bot.api.post_c2c_message(
openid=source.author.user_openid, **payload
)
if stream:
ret = await self.post_c2c_message(
openid=source.author.user_openid,
**payload,
stream=stream,
)
else:
ret = await self.post_c2c_message(
openid=source.author.user_openid, **payload
)
logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
if image_path:
payload["file_image"] = image_path
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
ret = await self.bot.api.post_message(
channel_id=source.channel_id, **payload
)
case botpy.message.DirectMessage:
if image_path:
payload["file_image"] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(self.send_buffer)
self.send_buffer = None
return ret
async def upload_group_and_c2c_image(
self, image_base64: str, file_type: int, **kwargs
) -> botpy.types.message.Media:
@@ -112,6 +167,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
return await self.bot.api._http.request(route, json=payload)
async def post_c2c_message(
self,
openid: str,
msg_type: int = 0,
content: str = None,
embed: message.Embed = None,
ark: message.Ark = None,
message_reference: message.Reference = None,
media: message.Media = None,
msg_id: str = None,
msg_seq: str = 1,
event_id: str = None,
markdown: message.MarkdownPayload = None,
keyboard: message.Keyboard = None,
stream: dict = None,
) -> message.Message:
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload)
@staticmethod
async def _parse_to_qqofficial(message: MessageChain):
plain_text = ""
@@ -126,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"qq_official",
"QQ 机器人官方 API 适配器",
name="qq_official",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
)
@staticmethod
@@ -99,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"qq_official_webhook",
"QQ 机器人官方 API 适配器",
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
)
async def run(self):
@@ -116,5 +117,8 @@ class QQOfficialWebhookPlatformAdapter(Platform):
async def terminate(self):
self.webhook_helper.shutdown_event.set()
await self.client.close()
await self.webhook_helper.server.shutdown()
try:
await self.webhook_helper.server.shutdown()
except Exception as _:
pass
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
@@ -1,26 +1,31 @@
import asyncio
import sys
import uuid
import asyncio
import astrbot.api.message_components as Comp
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from telegram import BotCommand, Update
from telegram.constants import ChatType
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
from telegram.ext import MessageHandler as TelegramMessageHandler
import astrbot.api.message_components as Comp
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
PlatformMetadata,
MessageType,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.api.event import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import star_handlers_registry
from telegram import Update
from telegram.ext import ApplicationBuilder, ContextTypes, filters
from telegram.constants import ChatType
from telegram.ext import MessageHandler as TelegramMessageHandler
from .tg_event import TelegramPlatformEvent
from astrbot.api import logger
from telegram.ext import ExtBot
if sys.version_info >= (3, 12):
from typing import override
@@ -67,6 +72,8 @@ class TelegramPlatformAdapter(Platform):
self.client = self.application.bot
logger.debug(f"Telegram base url: {self.client.base_url}")
self.scheduler = AsyncIOScheduler()
@override
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
@@ -80,18 +87,94 @@ class TelegramPlatformAdapter(Platform):
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"telegram",
"telegram 适配器",
name="telegram", description="telegram 适配器", id=self.config.get("id")
)
@override
async def run(self):
await self.application.initialize()
await self.application.start()
await self.register_commands()
# TODO 使用更优雅的方式重新注册命令
self.scheduler.add_job(
self.register_commands,
"interval",
minutes=5,
id="telegram_command_register",
misfire_grace_time=60,
)
self.scheduler.start()
queue = self.application.updater.start_polling()
logger.info("Telegram Platform Adapter is running.")
await queue
async def register_commands(self):
"""收集所有注册的指令并注册到 Telegram"""
try:
await self.client.delete_my_commands()
commands = self.collect_commands()
if commands:
await self.client.set_my_commands(commands)
for cmd in commands:
logger.debug(f"已注册指令: /{cmd.command} - {cmd.description}")
except Exception as e:
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
def collect_commands(self) -> list[BotCommand]:
"""从注册的处理器中收集所有指令"""
command_dict = {}
skip_commands = {"start"}
for handler_md in star_handlers_registry._handlers:
handler_metadata = handler_md[1]
if not star_map[handler_metadata.handler_module_path].activated:
continue
for event_filter in handler_metadata.event_filters:
cmd_info = self._extract_command_info(
event_filter, handler_metadata, skip_commands
)
if cmd_info:
cmd_name, description = cmd_info
command_dict.setdefault(cmd_name, description)
commands_a = sorted(command_dict.keys())
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
@staticmethod
def _extract_command_info(
event_filter, handler_metadata, skip_commands: set
) -> tuple[str, str] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
is_group = False
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
if (
event_filter.parent_command_names
and event_filter.parent_command_names != [""]
):
return None
cmd_name = event_filter.command_name
elif isinstance(event_filter, CommandGroupFilter):
if event_filter.parent_group:
return None
cmd_name = event_filter.group_name
is_group = True
if not cmd_name or cmd_name in skip_commands:
return None
# Build description.
description = handler_metadata.desc or (
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
)
if len(description) > 30:
description = description[:30] + "..."
return cmd_name, description
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
await context.bot.send_message(
chat_id=update.effective_chat.id, text=self.config["start_message"]
@@ -163,6 +246,16 @@ class TelegramPlatformAdapter(Platform):
# 处理文本消息
plain_text = update.message.text
# 群聊场景命令特殊处理
if plain_text.startswith("/"):
command_parts = plain_text.split(" ", 1)
if "@" in command_parts[0]:
command, bot_name = command_parts[0].split("@")
if bot_name == self.client.username:
plain_text = command + (
f" {command_parts[1]}" if len(command_parts) > 1 else ""
)
if update.message.entities:
for entity in update.message.entities:
if entity.type == "mention":
@@ -242,7 +335,11 @@ class TelegramPlatformAdapter(Platform):
async def terminate(self):
try:
if self.scheduler.running:
self.scheduler.shutdown()
await self.application.stop()
await self.client.delete_my_commands()
# 保险起见先判断是否存在updater对象
if self.application.updater is not None:
@@ -1,7 +1,15 @@
import asyncio
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
Reply,
At,
File,
Record,
)
from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
@@ -82,3 +90,109 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message)
async def send_streaming(self, generator):
message_thread_id = None
if self.get_message_type() == MessageType.GROUP_MESSAGE:
user_name = self.message_obj.group_id
else:
user_name = self.get_sender_id()
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
payload = {
"chat_id": user_name,
}
if message_thread_id:
payload["reply_to_message_id"] = message_thread_id
delta = ""
current_content = ""
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
async for chain in generator:
if isinstance(chain, MessageChain):
# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
path = "data/temp/" + i.name
await download_file(i.file, path)
i.file = path
await self.client.send_document(
document=i.file, filename=i.name, **payload
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
# Plain
if not message_id:
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
else:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
# 编辑消息
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
try:
if delta and current_content != delta:
try:
markdown_text = telegramify_markdown.markdownify(
delta, max_line_length=None, normalize_whitespace=False
)
await self.client.edit_message_text(
text=markdown_text,
chat_id=payload["chat_id"],
message_id=message_id,
parse_mode="MarkdownV2"
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
return await super().send_streaming(generator)
@@ -43,8 +43,7 @@ class WebChatAdapter(Platform):
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
name="webchat", description="webchat", id=self.config.get("id")
)
async def send_by_session(
@@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True)
@staticmethod
async def _send(message: MessageChain, session_id: str):
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
if not message:
web_chat_back_queue.put_nowait(None)
return
await web_chat_back_queue.put(
{"type": "end", "data": "", "streaming": False}
)
return ""
cid = session_id.split("!")[-1]
data = ""
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait((comp.text, cid))
data = comp.text
await web_chat_back_queue.put(
{
"type": "plain",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
@@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
"type": "image",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
@@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
"type": "record",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
else:
logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
return data
async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message)
async def send_streaming(self, generator):
final_data = ""
async for chain in generator:
final_data += await WebChatMessageEvent._send(
chain, session_id=self.session_id, streaming=True
)
await web_chat_back_queue.put(
{
"type": "end",
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
}
)
await super().send_streaming(generator)
@@ -237,5 +237,8 @@ class WecomPlatformAdapter(Platform):
async def terminate(self):
self.server.shutdown_event.set()
await self.server.server.shutdown()
try:
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("企业微信 适配器已被优雅地关闭")
@@ -84,3 +84,16 @@ class WecomPlatformEvent(AstrMessageEvent):
)
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
+1 -1
View File
@@ -1,5 +1,5 @@
from .provider import Provider, Personality, STTProvider
from .entites import ProviderMetaData
from .entities import ProviderMetaData
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
+17 -267
View File
@@ -1,269 +1,19 @@
import enum
import base64
import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
from astrbot.core.provider.entities import (
ProviderRequest,
ProviderType,
ProviderMetaData,
ToolCallsResult,
AssistantMessageSegment,
ToolCallMessageSegment,
LLMResponse,
)
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
@dataclass
class ProviderMetaData:
type: str
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""提供商适配器描述."""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
"""平台的默认配置模板"""
provider_display_name: str = None
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ToolCallMessageSegment:
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
tool_call_id: str
content: str
role: str = "tool"
def to_dict(self):
return {
"tool_call_id": self.tool_call_id,
"content": self.content,
"role": self.role,
}
@dataclass
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
role: str = "assistant"
def to_dict(self):
ret = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@dataclass
class ToolCallsResult:
"""工具调用结果"""
tool_calls_info: AssistantMessageSegment
"""函数调用的信息"""
tool_calls_result: List[ToolCallMessageSegment]
"""函数调用的结果"""
def to_openai_messages(self) -> List[Dict]:
ret = [
self.tool_calls_info.to_dict(),
*[item.to_dict() for item in self.tool_calls_result],
]
return ret
@dataclass
class ProviderRequest:
prompt: str
"""提示词"""
session_id: str = ""
"""会话 ID"""
image_urls: List[str] = None
"""图片 URL 列表"""
func_tool: FuncCall = None
"""可用的函数工具"""
contexts: List = None
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation = None
tool_calls_result: ToolCallsResult = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
result_parts = []
for ctx in self.contexts:
role = ctx.get("role", "unknown")
content = ctx.get("content", "")
if isinstance(content, str):
result_parts.append(f"{role}: {content}")
elif isinstance(content, list):
msg_parts = []
image_count = 0
for item in content:
item_type = item.get("type", "")
if item_type == "text":
msg_parts.append(item.get("text", ""))
elif item_type == "image_url":
image_count += 1
if image_count > 0:
if msg_parts:
msg_parts.append(f"[+{image_count} images]")
else:
msg_parts.append(f"[{image_count} images]")
result_parts.append(f"{role}: {''.join(msg_parts)}")
return result_parts
async def assemble_context(self) -> Dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt}],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
)
return user_content
else:
return {"role": "user", "content": self.prompt}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
@dataclass
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
result_chain: MessageChain = None
"""返回的消息链"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
_completion_text: str = ""
def __init__(
self,
role: str,
completion_text: str = "",
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = [],
tools_call_name: List[str] = [],
tools_call_ids: List[str] = [],
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
):
"""初始化 LLMResponse
Args:
role (str): 角色, assistant, tool, err
completion_text (str, optional): 返回的结果文本已经过时推荐使用 result_chain. Defaults to "".
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
self.role = role
self.completion_text = completion_text
self.result_chain = result_chain
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
@property
def completion_text(self):
if self.result_chain:
return self.result_chain.get_plain_text()
return self._completion_text
@completion_text.setter
def completion_text(self, value):
if self.result_chain:
self.result_chain.chain = [
comp
for comp in self.result_chain.chain
if not isinstance(comp, Comp.Plain)
] # 清空 Plain 组件
self.result_chain.chain.insert(0, Comp.Plain(value))
else:
self._completion_text = value
def to_openai_tool_calls(self) -> List[Dict]:
"""将工具调用信息转换为 OpenAI 格式"""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
{
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
)
return ret
__all__ = [
"ProviderRequest",
"ProviderType",
"ProviderMetaData",
"ToolCallsResult",
"AssistantMessageSegment",
"ToolCallMessageSegment",
"LLMResponse",
]
+281
View File
@@ -0,0 +1,281 @@
import enum
import base64
import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
@dataclass
class ProviderMetaData:
type: str
"""提供商适配器名称,如 openai, ollama"""
desc: str = ""
"""提供商适配器描述."""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
"""平台的默认配置模板"""
provider_display_name: str = None
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ToolCallMessageSegment:
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
tool_call_id: str
content: str
role: str = "tool"
def to_dict(self):
return {
"tool_call_id": self.tool_call_id,
"content": self.content,
"role": self.role,
}
@dataclass
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
role: str = "assistant"
def to_dict(self):
ret = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@dataclass
class ToolCallsResult:
"""工具调用结果"""
tool_calls_info: AssistantMessageSegment
"""函数调用的信息"""
tool_calls_result: List[ToolCallMessageSegment]
"""函数调用的结果"""
def to_openai_messages(self) -> List[Dict]:
ret = [
self.tool_calls_info.to_dict(),
*[item.to_dict() for item in self.tool_calls_result],
]
return ret
@dataclass
class ProviderRequest:
prompt: str
"""提示词"""
session_id: str = ""
"""会话 ID"""
image_urls: List[str] = None
"""图片 URL 列表"""
func_tool: FuncCall = None
"""可用的函数工具"""
contexts: List = None
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation = None
tool_calls_result: ToolCallsResult = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
result_parts = []
for ctx in self.contexts:
role = ctx.get("role", "unknown")
content = ctx.get("content", "")
if isinstance(content, str):
result_parts.append(f"{role}: {content}")
elif isinstance(content, list):
msg_parts = []
image_count = 0
for item in content:
item_type = item.get("type", "")
if item_type == "text":
msg_parts.append(item.get("text", ""))
elif item_type == "image_url":
image_count += 1
if image_count > 0:
if msg_parts:
msg_parts.append(f"[+{image_count} images]")
else:
msg_parts.append(f"[{image_count} images]")
result_parts.append(f"{role}: {''.join(msg_parts)}")
return result_parts
async def assemble_context(self) -> Dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt}],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
)
return user_content
else:
return {"role": "user", "content": self.prompt}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
@dataclass
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
result_chain: MessageChain = None
"""返回的消息链"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
_completion_text: str = ""
is_chunk: bool = False
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
role: str,
completion_text: str = "",
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = None,
tools_call_name: List[str] = None,
tools_call_ids: List[str] = None,
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
Args:
role (str): 角色, assistant, tool, err
completion_text (str, optional): 返回的结果文本已经过时推荐使用 result_chain. Defaults to "".
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
if tools_call_args is None:
tools_call_args = []
if tools_call_name is None:
tools_call_name = []
if tools_call_ids is None:
tools_call_ids = []
self.role = role
self.completion_text = completion_text
self.result_chain = result_chain
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property
def completion_text(self):
if self.result_chain:
return self.result_chain.get_plain_text()
return self._completion_text
@completion_text.setter
def completion_text(self, value):
if self.result_chain:
self.result_chain.chain = [
comp
for comp in self.result_chain.chain
if not isinstance(comp, Comp.Plain)
] # 清空 Plain 组件
self.result_chain.chain.insert(0, Comp.Plain(value))
else:
self._completion_text = value
def to_openai_tool_calls(self) -> List[Dict]:
"""将工具调用信息转换为 OpenAI 格式"""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
{
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
)
return ret
+14 -11
View File
@@ -339,7 +339,7 @@ class FuncCall:
]
logger.info(f"已关闭 MCP 服务 {name}")
def get_func_desc_openai_style(self) -> list:
def get_func_desc_openai_style(self, omit_empty_parameter_field = False) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
@@ -348,16 +348,19 @@ class FuncCall:
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"type": "function",
"function": {
"name": f.name,
"parameters": f.parameters,
"description": f.description,
},
}
)
func_ = {
"type": "function",
"function": {
"name": f.name,
# "parameters": f.parameters,
"description": f.description,
},
}
func_["function"]["parameters"] = f.parameters
if not f.parameters.get("properties") and omit_empty_parameter_field:
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
del func_["function"]["parameters"]
_l.append(func_)
return _l
def get_func_desc_anthropic_style(self) -> list:
+1 -1
View File
@@ -2,7 +2,7 @@ import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from .entities import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from .register import provider_cls_map, llm_tools
+31 -3
View File
@@ -1,9 +1,9 @@
import abc
from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from dataclasses import dataclass
@@ -108,7 +108,35 @@ class Provider(AbstractProvider):
- 如果传入了 image_urls将会在对话时附上图片如果模型不支持图片输入将会抛出错误
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
raise NotImplementedError()
...
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls将会在对话时附上图片如果模型不支持图片输入将会抛出错误
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
...
async def pop_record(self, context: List):
"""
+1 -1
View File
@@ -1,5 +1,5 @@
from typing import List, Dict
from .entites import ProviderMetaData, ProviderType
from .entities import ProviderMetaData, ProviderType
from astrbot.core import logger
from .func_tool_manager import FuncCall
@@ -10,7 +10,8 @@ from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from .openai_source import ProviderOpenAIOfficial
@@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if content.type == "text":
# text completion
completion_text = str(content.text).strip()
llm_response.completion_text = completion_text
# llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
@@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
messages=context_query, **model_config
)
llm_response = LLMResponse("assistant")
llm_response.completion_text = response.content[0].text
llm_response.result_chain = MessageChain().message(response.content[0].text)
llm_response.raw_completion = response
return llm_response
except Exception as e:
@@ -160,6 +162,33 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""组装上下文,支持文本和图片"""
if not image_urls:
@@ -3,10 +3,11 @@ import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
@@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
return LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
result_chain=MessageChain().message(
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
),
)
output_text = response.output.get("text", "")
@@ -141,11 +144,45 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
ref_title = (
ref.get("title", "")
if ref.get("title")
else ref.get("doc_name", "")
)
ref_str += f"{ref['index_id']}. {ref_title}\n"
output_text += f"\n\n回答来源:\n{ref_str}"
return LLMResponse(role="assistant", completion_text=output_text)
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(output_text)
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def forget(self, session_id):
return True
@@ -3,7 +3,7 @@ import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
@@ -20,7 +20,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model", None))
self.timeout_ms = float(provider_config.get("timeout", 20))*1000
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key
self.synthesizer = SpeechSynthesizer(
+29 -2
View File
@@ -2,7 +2,7 @@ import astrbot.core.message.components as Comp
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
@@ -102,7 +102,7 @@ class ProviderDify(Provider):
try:
match self.api_type:
case "chat" | "agent":
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"
@@ -189,6 +189,33 @@ class ProviderDify(Provider):
return LLMResponse(role="assistant", result_chain=chain)
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
# Chat
@@ -4,7 +4,7 @@ import edge_tts
import subprocess
import asyncio
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -4,7 +4,7 @@ from pydantic import BaseModel, conint
from httpx import AsyncClient
from typing import Annotated, Literal
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
+63 -3
View File
@@ -12,7 +12,7 @@ from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entities import LLMResponse
class SimpleGoogleGenAIClient:
@@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient:
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise Exception("Gemini 返回了非 json 数据: ")
async def stream_generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
"stream": True,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
async for line in resp.content:
if line:
yield line
@register_provider_adapter(
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
@@ -147,7 +180,7 @@ class ProviderGoogleGenAI(Provider):
if message["role"] == "user":
if isinstance(message["content"], str):
if not message["content"]:
message["content"] = ""
message["content"] = " "
google_genai_conversation.append(
{"role": "user", "parts": [{"text": message["content"]}]}
@@ -176,7 +209,7 @@ class ProviderGoogleGenAI(Provider):
elif message["role"] == "assistant":
if "content" in message:
if not message["content"]:
message["content"] = ""
message["content"] = " "
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
)
@@ -338,6 +371,33 @@ class ProviderGoogleGenAI(Provider):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
def get_current_key(self) -> str:
return self.client.api_key
@@ -2,7 +2,7 @@ import uuid
import aiohttp
import urllib.parse
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
@@ -2,7 +2,7 @@ import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
@@ -95,6 +95,33 @@ class LLMTunerModelLoader(Provider):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def get_current_key(self):
return "none"
+273 -61
View File
@@ -4,19 +4,24 @@ import os
import inspect
import random
import asyncio
import astrbot.core.message.components as Comp
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from typing import List, AsyncGenerator
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entities import LLMResponse
@register_provider_adapter(
@@ -82,7 +87,11 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_openai_style()
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
omit_empty_parameter_field=omit_empty_param_field
)
if tool_list:
payloads["tools"] = tool_list
@@ -107,16 +116,76 @@ class ProviderOpenAIOfficial(Provider):
logger.debug(f"completion: {completion}")
llm_response = await self.parse_openai_completion(completion, tools)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
omit_empty_parameter_field=omit_empty_param_field
)
if tool_list:
payloads["tools"] = tool_list
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
to_del = []
for key in payloads.keys():
if key not in self.default_params:
extra_body[key] = payloads[key]
to_del.append(key)
for key in to_del:
del payloads[key]
stream = await self.client.chat.completions.create(
**payloads, stream=True, extra_body=extra_body
)
llm_response = LLMResponse("assistant", is_chunk=True)
state = ChatCompletionStreamState()
async for chunk in stream:
try:
state.handle_chunk(chunk)
except Exception as e:
logger.warning("Saving chunk state error: " + str(e))
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
# 处理文本内容
if delta.content:
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)]
)
yield llm_response
final_completion = state.get_final_completion()
llm_response = await self.parse_openai_completion(final_completion, tools)
yield llm_response
async def parse_openai_completion(
self, completion: ChatCompletion, tools: FuncCall
):
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
llm_response = LLMResponse("assistant")
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
if choice.message.tool_calls:
# tools call (function calling)
@@ -148,7 +217,7 @@ class ProviderOpenAIOfficial(Provider):
return llm_response
async def text_chat(
async def _prepare_chat_payload(
self,
prompt: str,
session_id: str = None,
@@ -158,7 +227,8 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
@@ -177,8 +247,117 @@ class ProviderOpenAIOfficial(Provider):
payloads = {"messages": context_query, **model_config}
llm_response = None
return payloads, context_query, func_tool
async def _handle_api_error(
self,
e: Exception,
payloads: dict,
context_query: list,
func_tool: FuncCall,
chosen_key: str,
available_api_keys: List[str],
retry_cnt: int,
max_retries: int,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
payloads["messages"] = context_query
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
return False, chosen_key, available_api_keys, payloads, context_query, None
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
payloads, context_query, func_tool = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
**kwargs,
)
llm_response = None
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
@@ -197,64 +376,97 @@ class ProviderOpenAIOfficial(Provider):
payloads["messages"] = new_contexts
context_query = new_contexts
except Exception as e:
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
continue
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
func_tool = None
else:
logger.error(
f"发生了错误。Provider 配置如下: {self.provider_config}"
)
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error(
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
)
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
return llm_response
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, context_query, func_tool = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
**kwargs,
)
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
e = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
except Exception as e:
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
async def _remove_image_from_context(self, contexts: List):
"""
从上下文中删除所有带有 image 的记录
@@ -1,7 +1,7 @@
import uuid
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
@@ -11,7 +11,7 @@ import re
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
from ..provider import STTProvider
from ..entites import ProviderType
from ..entities import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -2,7 +2,7 @@ import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import STTProvider
from ..entites import ProviderType
from ..entities import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -3,7 +3,7 @@ import os
import asyncio
import whisper
from ..provider import STTProvider
from ..entites import ProviderType
from ..entities import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -3,7 +3,7 @@ from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entities import LLMResponse
from .openai_source import ProviderOpenAIOfficial
View File
View File
+24
View File
@@ -47,5 +47,29 @@ class StarMetadata:
star_handler_full_names: List[str] = field(default_factory=list)
"""注册的 Handler 的全名列表"""
supported_platforms: Dict[str, bool] = field(default_factory=dict)
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
"""更新插件支持的平台列表
Args:
plugin_enable_config: 平台插件启用配置即platform_settings.plugin_enable配置项
"""
if not plugin_enable_config:
return
# 清空之前的配置
self.supported_platforms.clear()
# 遍历所有平台配置
for platform_id, plugins in plugin_enable_config.items():
# 检查该插件在当前平台的配置
if self.name in plugins:
self.supported_platforms[platform_id] = plugins[self.name]
else:
# 如果没有明确配置,默认为启用
self.supported_platforms[platform_id] = True
+58 -14
View File
@@ -30,21 +30,36 @@ class StarHandlerRegistry(Generic[T]):
print(handler.handler_full_name)
def get_handlers_by_event_type(
self, event_type: EventType, only_activated=True
self, event_type: EventType, only_activated=True, platform_id=None
) -> List[StarHandlerMetadata]:
"""通过事件类型获取 Handler"""
handlers = [
handler
for _, handler in self._handlers
if handler.event_type == event_type
and (
not only_activated
or (
star_map[handler.handler_module_path]
and star_map[handler.handler_module_path].activated
)
)
]
"""通过事件类型获取 Handler
Args:
event_type: 事件类型
only_activated: 是否只返回已激活的插件的处理器
platform_id: 平台ID如果提供此参数将过滤掉在此平台不兼容的处理器
Returns:
List[StarHandlerMetadata]: 处理器列表
"""
handlers = []
for _, handler in self._handlers:
if handler.event_type != event_type:
continue
# 只激活的插件处理器
if only_activated:
plugin = star_map.get(handler.handler_module_path)
if not (plugin and plugin.activated):
continue
# 平台兼容性过滤
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
if not handler.is_enabled_for_platform(platform_id):
continue
handlers.append(handler)
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
@@ -139,3 +154,32 @@ class StarHandlerMetadata:
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
"priority", 0
)
def is_enabled_for_platform(self, platform_id: str) -> bool:
"""检查插件是否在指定平台启用
Args:
platform_id: 平台ID这是从event.get_platform_id()获取的用于唯一标识平台实例
Returns:
bool: 是否启用True表示启用False表示禁用
"""
plugin = star_map.get(self.handler_module_path)
# 如果插件元数据不存在,默认允许执行
if not plugin or not plugin.name:
return True
# 先检查插件是否被激活
if not plugin.activated:
return False
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
if (
hasattr(plugin, "supported_platforms")
and platform_id in plugin.supported_platforms
):
return plugin.supported_platforms[platform_id]
# 如果没有缓存数据,默认允许执行
return True
+144 -19
View File
@@ -166,8 +166,71 @@ class PluginManager:
return metadata
def _get_plugin_related_modules(
self, plugin_root_dir: str, is_reserved: bool = False
) -> list[str]:
"""获取与指定插件相关的所有已加载模块名
根据插件根目录名和是否为保留插件 sys.modules 中筛选出相关的模块名
Args:
plugin_root_dir: 插件根目录名
is_reserved: 是否是保留插件影响模块路径前缀
Returns:
list[str]: 与该插件相关的模块名列表
"""
prefix = "packages." if is_reserved else "data.plugins."
return [
key
for key in list(sys.modules.keys())
if key.startswith(f"{prefix}{plugin_root_dir}")
]
def _purge_modules(
self,
module_patterns: list[str] = None,
root_dir_name: str = None,
is_reserved: bool = False,
):
"""从 sys.modules 中移除指定的模块
可以基于模块名模式或插件目录名移除模块用于清理插件相关的模块缓存
Args:
module_patterns: 要移除的模块名模式列表例如 ["data.plugins", "packages"]
root_dir_name: 插件根目录名用于移除与该插件相关的所有模块
is_reserved: 插件是否为保留插件影响模块路径前缀
"""
if module_patterns:
for pattern in module_patterns:
for key in list(sys.modules.keys()):
if key.startswith(pattern):
del sys.modules[key]
logger.debug(f"删除模块 {key}")
if root_dir_name:
for module_name in self._get_plugin_related_modules(
root_dir_name, is_reserved
):
try:
del sys.modules[module_name]
logger.debug(f"删除模块 {module_name}")
except KeyError:
logger.warning(f"模块 {module_name} 未载入")
async def reload(self, specified_plugin_name=None):
"""扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件"""
"""重新加载插件
Args:
specified_plugin_name (str, optional): 要重载的特定插件名称
如果为 None则重载所有插件
Returns:
tuple: 返回 load() 方法的结果包含 (success, error_message)
- success (bool): 重载是否成功
- error_message (str|None): 错误信息成功时为 None
"""
specified_module_path = None
if specified_plugin_name:
for smd in star_registry:
@@ -192,9 +255,6 @@ class PluginManager:
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
@@ -209,11 +269,44 @@ class PluginManager:
await self._unbind_plugin(smd.name, specified_module_path)
return await self.load(specified_module_path)
result = await self.load(specified_module_path)
# 更新所有插件的平台兼容性
await self.update_all_platform_compatibility()
return result
async def update_all_platform_compatibility(self):
"""更新所有插件的平台兼容性设置"""
# 获取最新的平台插件启用配置
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
logger.debug(
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
)
# 遍历所有插件,更新平台兼容性
for plugin in self.context.get_all_stars():
plugin.update_platform_compatibility(plugin_enable_config)
logger.debug(
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
)
return True
async def load(self, specified_module_path=None, specified_dir_name=None):
"""载入插件。
specified_module_path 或者 specified_dir_name 不为 None 只载入指定的插件
Args:
specified_module_path (str, optional): 指定要加载的插件模块路径例如: "data.plugins.my_plugin.main"
specified_dir_name (str, optional): 指定要加载的插件目录名例如: "my_plugin"
Returns:
tuple: (success, error_message)
- success (bool): 是否全部加载成功
- error_message (str|None): 错误信息成功时为 None
"""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
@@ -320,6 +413,12 @@ class PluginManager:
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 更新插件的平台兼容性
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
metadata.update_platform_compatibility(plugin_enable_config)
# 绑定 handler
related_handlers = (
star_handlers_registry.get_handlers_by_module_name(
@@ -447,6 +546,20 @@ class PluginManager:
return False, fail_rec
async def install_plugin(self, repo_url: str, proxy=""):
"""从仓库 URL 安装插件
从指定的仓库 URL 下载并安装插件然后加载该插件到系统中
Args:
repo_url (str): 要安装的插件仓库 URL
proxy (str, optional): 用于下载的代理服务器默认为空字符串
Returns:
dict | None: 安装成功时返回包含插件信息的字典:
- repo: 插件的仓库 URL
- readme: README.md 文件的内容(如果存在)
如果找不到插件元数据则返回 None
"""
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
dir_name = os.path.basename(plugin_path)
@@ -481,6 +594,14 @@ class PluginManager:
return plugin_info
async def uninstall_plugin(self, plugin_name: str):
"""卸载指定的插件。
Args:
plugin_name (str): 要卸载的插件名称
Raises:
Exception: 当插件不存在是保留插件时或删除插件文件夹失败时抛出异常
"""
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -509,9 +630,17 @@ class PluginManager:
)
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
"""解绑并移除一个插件。
Args:
plugin_name: 要解绑的插件名称
plugin_module_path: 插件的完整模块路径
"""
plugin = None
del star_map[plugin_module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
plugin = p
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(
@@ -521,21 +650,17 @@ class PluginManager:
f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})"
)
star_handlers_registry.remove(handler)
keys_to_delete = [
k
for k, v in star_handlers_registry.star_handlers_map.items()
if k.startswith(plugin_module_path)
]
for k in keys_to_delete:
try:
del star_handlers_registry.star_handlers_map[k]
except KeyError:
pass
try:
del sys.modules[plugin_module_path]
except KeyError:
logger.warning(f"模块 {plugin_module_path} 未载入")
for k in [
k
for k in star_handlers_registry.star_handlers_map
if k.startswith(plugin_module_path)
]:
del star_handlers_registry.star_handlers_map[k]
self._purge_modules(
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
)
async def update_plugin(self, plugin_name: str, proxy=""):
"""升级一个插件"""
+48
View File
@@ -1,9 +1,12 @@
import inspect
from typing import Union, Awaitable, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
from astrbot.core.star.star import star_map
from pathlib import Path
class StarTools:
@@ -142,3 +145,48 @@ class StarTools:
name (str): 工具名称
"""
cls._context.unregister_llm_tool(name)
@classmethod
def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path:
"""
返回插件数据目录的绝对路径
此方法会在 data/plugin_data 目录下为插件创建一个专属的数据目录如果未提供插件名称
会自动从调用栈中获取插件信息
Args:
plugin_name: 可选的插件名称如果为None将自动检测调用者的插件名称
Returns:
Path (Path): 插件数据目录的绝对路径位于 data/plugin_data/{plugin_name}
Raises:
RuntimeError: 当出现以下情况时抛出:
- 无法获取调用者模块信息
- 无法获取模块的元数据信息
- 创建目录失败权限不足或其他IO错误
"""
if not plugin_name:
frame = inspect.currentframe().f_back
module = inspect.getmodule(frame)
if not module:
raise RuntimeError("无法获取调用者模块信息")
metadata = star_map.get(module.__name__, None)
if not metadata:
raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息")
plugin_name = metadata.name
data_dir = Path("data/plugin_data") / plugin_name
try:
data_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
if isinstance(e, PermissionError):
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
raise RuntimeError(f"无法创建目录 {data_dir}{e!s}") from e
return data_dir.resolve()
+1 -1
View File
@@ -15,7 +15,7 @@ class SharedPreferences:
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default=None):
+13 -5
View File
@@ -105,16 +105,24 @@ class RepoZipUpdator:
"""
比较两个版本号的大小
返回 1 表示 v1 > v2返回 -1 表示 v1 < v2返回 0 表示 v1 = v2
支持任意长度的版本号如v1.2.3或v3.5.3.1
"""
v1 = v1.replace("v", "")
v2 = v2.replace("v", "")
v1 = v1.split(".")
v2 = v2.split(".")
v1_parts = v1.split(".")
v2_parts = v2.split(".")
for i in range(3):
if int(v1[i]) > int(v2[i]):
# 获取最长的版本号长度
length = max(len(v1_parts), len(v2_parts))
# 将短版本号补0以便比较
v1_parts.extend(["0"] * (length - len(v1_parts)))
v2_parts.extend(["0"] * (length - len(v2_parts)))
for i in range(length):
if int(v1_parts[i]) > int(v2_parts[i]):
return 1
elif int(v1[i]) < int(v2[i]):
elif int(v1_parts[i]) < int(v2_parts[i]):
return -1
return 0
+27 -16
View File
@@ -161,42 +161,53 @@ class ChatRoute(Route):
username = g.get("username", "guest")
if username in self.curr_chat_sse:
return "[ERROR]\n"
return Response().error("Already connected").__dict__
self.curr_chat_sse[username] = None
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
async def stream():
try:
yield "[HB]\n"
yield f"data: {heartbeat}\n\n" # 心跳包
while True:
try:
result = await asyncio.wait_for(
web_chat_back_queue.get(), timeout=10
) # 设置超时时间为5秒
except asyncio.TimeoutError:
yield "[HB]\n" # 心跳包
yield f"data: {heartbeat}\n\n" # 心跳包
continue
if not result:
continue
result_text, cid = result
result_text = result["data"]
type = result.get("type")
cid = result.get("cid")
streaming = result.get("streaming", False)
if cid != self.curr_user_cid.get(username):
# 丢弃
continue
yield result_text + "\n"
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05)
conversation = self.db.get_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
if streaming and type != "end":
continue
await asyncio.sleep(0.5)
if result_text:
conversation = self.db.get_conversation_by_user_id(
username, cid
)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。")
self.curr_chat_sse.pop(username)
+4 -2
View File
@@ -60,11 +60,13 @@ def validate_config(
data[key] = False
continue
meta = metadata[key]
if "type" not in meta:
logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验")
continue
# null 转换
if value is None:
data[key] = DEFAULT_VALUE_MAP[meta["type"]]
continue
# 递归验证
if meta["type"] == "list" and not isinstance(value, list):
errors.append(
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}"
@@ -179,7 +181,7 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def post_plugin_configs(self):
+1 -1
View File
@@ -20,7 +20,7 @@ class LogRoute(Route):
message = await queue.get()
payload = {
"type": "log",
**message # see astrbot/core/log.py
**message, # see astrbot/core/log.py
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
except asyncio.CancelledError:
+136
View File
@@ -1,5 +1,6 @@
import traceback
import aiohttp
import os
import ssl
import certifi
@@ -36,6 +37,9 @@ class PluginRoute(Route):
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable),
"/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -317,3 +321,135 @@ class PluginRoute(Route):
except Exception as e:
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def get_plugin_readme(self):
plugin_name = request.args.get("name")
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
if not plugin_name:
logger.warning("插件名称为空")
return Response().error("插件名称不能为空").__dict__
plugin_obj = None
for plugin in self.plugin_manager.context.get_all_stars():
if plugin.name == plugin_name:
plugin_obj = plugin
break
if not plugin_obj:
logger.warning(f"插件 {plugin_name} 不存在")
return Response().error(f"插件 {plugin_name} 不存在").__dict__
plugin_dir = os.path.join(
self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name
)
if not os.path.isdir(plugin_dir):
logger.warning(f"无法找到插件目录: {plugin_dir}")
return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__
readme_path = os.path.join(plugin_dir, "README.md")
if not os.path.isfile(readme_path):
logger.warning(f"插件 {plugin_name} 没有README文件")
return Response().error(f"插件 {plugin_name} 没有README文件").__dict__
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
return (
Response()
.ok({"content": readme_content}, "成功获取README内容")
.__dict__
)
except Exception as e:
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
return Response().error(f"读取README文件失败: {str(e)}").__dict__
async def get_plugin_platform_enable(self):
"""获取插件在各平台的可用性配置"""
try:
platform_enable = self.core_lifecycle.astrbot_config.get(
"platform_settings", {}
).get("plugin_enable", {})
# 获取所有可用平台
platforms = []
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
platform_type = platform.get("type", "")
platform_id = platform.get("id", "")
platforms.append(
{
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
"id": platform_id, # 保留id字段以便前端可以显示
"type": platform_type,
"display_name": f"{platform_type}({platform_id})",
}
)
adjusted_platform_enable = {}
for platform_id, plugins in platform_enable.items():
adjusted_platform_enable[platform_id] = plugins
# 获取所有插件,包括系统内部插件
plugins = []
for plugin in self.plugin_manager.context.get_all_stars():
plugins.append(
{
"name": plugin.name,
"desc": plugin.desc,
"reserved": plugin.reserved, # 添加reserved标志
}
)
logger.debug(
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
)
return (
Response()
.ok(
{
"platforms": platforms,
"plugins": plugins,
"platform_enable": adjusted_platform_enable,
}
)
.__dict__
)
except Exception as e:
logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def set_plugin_platform_enable(self):
"""设置插件在各平台的可用性配置"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
data = await request.json
platform_enable = data.get("platform_enable", {})
# 更新配置
config = self.core_lifecycle.astrbot_config
platform_settings = config.get("platform_settings", {})
platform_settings["plugin_enable"] = platform_enable
config["platform_settings"] = platform_settings
config.save_config()
# 更新插件的平台兼容性缓存
await self.plugin_manager.update_all_platform_compatibility()
logger.info(f"插件平台可用性配置已更新: {platform_enable}")
return Response().ok(None, "插件平台可用性配置已更新").__dict__
except Exception as e:
logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+35
View File
@@ -0,0 +1,35 @@
# What's Changed
> 📢 在升级前,请完整阅读本次更新日志。
> 此版本为针对 `v3.5.3` 的紧急修复版本
## ✨ 新增的功能
1. Telegram、Webchat、QQ官方机器人平台(私聊)支持流式输出(实验性)。@Soulter @Raven95676 @anka-afk
2. 支持针对不同消息平台开启/关闭插件 @zhx8702 @Raven95676 @Soulter
3. 插件市场支持显示 Star 个数、插件管理支持插件帮助对话框 @kterna
4. 飞书平台支持主动消息发送 @Soulter
5. Telegram 平台适配显示指令列表,支持自动补全 @Raven95676
6. 新增配置项允许配置当超出最多携带对话数量时,一次性丢弃多少条旧消息 @Rail1bc
7. StarTool 新增获取插件数据目录接口 @Raven95676
## 🎈 功能性优化
1. 优化 /his 指令对函数调用的显示 @anka-afk
2. QQ 官方机器人支持对同一条消息多次回复 @kuangfeng
## 🐛 修复的 Bug
1. ‼️ 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具 @Soulter
2. 修复使用 Gemini 模型时出现 <empty_content> 的问题 @anka-afk
4. 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题 @Soulter
5. 修复 permission 过滤算子的 raise_error 参数失效的问题 @Soulter
6. 修复函数调用时可能出现 `messages with role 'tool' must be a response to a preceeding message with 'tool_calls'` 报错的问题 @anka-afk
7. 修复 dify 下删除对话的报错问题 @Soulter
8. 修复人格预设对话多次插入上下文的问题 @Rail1bc
9. 修复了 event.get_sender_id() 返回值与函数注释不一致的问题 @zsbai
## 🧩 新增的插件
待补充
+41
View File
@@ -0,0 +1,41 @@
# What's Changed
> 📢 在升级前,请完整阅读本次更新日志。
> 此版本为针对 `v3.5.3` 的紧急修复版本
> 修复以下 BUG
> 1. 智谱 GLM 在函数工具有空参数时报错的问题。
---
v3.5.3
## ✨ 新增的功能
1. Telegram、Webchat、QQ官方机器人平台(私聊)支持流式输出(实验性)。@Soulter @Raven95676 @anka-afk
2. 支持针对不同消息平台开启/关闭插件 @zhx8702 @Raven95676 @Soulter
3. 插件市场支持显示 Star 个数、插件管理支持插件帮助对话框 @kterna
4. 飞书平台支持主动消息发送 @Soulter
5. Telegram 平台适配显示指令列表,支持自动补全 @Raven95676
6. 新增配置项允许配置当超出最多携带对话数量时,一次性丢弃多少条旧消息 @Rail1bc
7. StarTool 新增获取插件数据目录接口 @Raven95676
## 🎈 功能性优化
1. 优化 /his 指令对函数调用的显示 @anka-afk
2. QQ 官方机器人支持对同一条消息多次回复 @kuangfeng
## 🐛 修复的 Bug
1. ‼️ 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具 @Soulter
2. 修复使用 Gemini 模型时出现 <empty_content> 的问题 @anka-afk
4. 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题 @Soulter
5. 修复 permission 过滤算子的 raise_error 参数失效的问题 @Soulter
6. 修复函数调用时可能出现 `messages with role 'tool' must be a response to a preceeding message with 'tool_calls'` 报错的问题 @anka-afk
7. 修复 dify 下删除对话的报错问题 @Soulter
8. 修复人格预设对话多次插入上下文的问题 @Rail1bc
9. 修复了 event.get_sender_id() 返回值与函数注释不一致的问题 @zsbai
## 🧩 新增的插件
待补充
+34
View File
@@ -0,0 +1,34 @@
# What's Changed
> 📢 在升级前,请完整阅读本次更新日志。
## ✨ 新增的功能
1. Telegram、Webchat、QQ官方机器人平台(私聊)支持流式输出(实验性)。@Soulter @Raven95676 @anka-afk
2. 支持针对不同消息平台开启/关闭插件 @zhx8702 @Raven95676 @Soulter
3. 插件市场支持显示 Star 个数、插件管理支持插件帮助对话框 @kterna
4. 飞书平台支持主动消息发送 @Soulter
5. Telegram 平台适配显示指令列表,支持自动补全 @Raven95676
6. 新增配置项允许配置当超出最多携带对话数量时,一次性丢弃多少条旧消息 @Rail1bc
7. StarTool 新增获取插件数据目录接口 @Raven95676
## 🎈 功能性优化
1. 优化 /his 指令对函数调用的显示 @anka-afk
2. QQ 官方机器人支持对同一条消息多次回复 @kuangfeng
## 🐛 修复的 Bug
1. ‼️ 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具 @Soulter
2. 修复使用 Gemini 模型时出现 <empty_content> 的问题 @anka-afk
4. 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题 @Soulter
5. 修复 permission 过滤算子的 raise_error 参数失效的问题 @Soulter
6. 修复函数调用时可能出现 `messages with role 'tool' must be a response to a preceeding message with 'tool_calls'` 报错的问题 @anka-afk
7. 修复 dify 下删除对话的报错问题 @Soulter
8. 修复人格预设对话多次插入上下文的问题 @Rail1bc
9. 修复了 event.get_sender_id() 返回值与函数注释不一致的问题 @zsbai
## 🧩 新增的插件
待补充
@@ -24,13 +24,10 @@ const emit = defineEmits([
'install',
'uninstall',
'toggle-activation',
'view-handlers'
'view-handlers',
'view-readme'
]);
const open = (link: string | undefined) => {
window.open(link, '_blank');
};
const reveal = ref(false);
//
@@ -70,6 +67,10 @@ const toggleActivation = () => {
const viewHandlers = () => {
emit('view-handlers', props.extension);
};
const viewReadme = () => {
emit('view-readme', props.extension);
};
</script>
<template>
@@ -128,7 +129,7 @@ const viewHandlers = () => {
</v-card-text>
<v-card-actions style="padding: 0px; margin-top: auto;">
<v-btn color="teal-accent-4" text="帮助" variant="text" @click="open(extension.repo)"></v-btn>
<v-btn color="teal-accent-4" text="查看文档" variant="text" @click="viewReadme"></v-btn>
<v-btn v-if="!marketMode" color="teal-accent-4" text="操作" variant="text" @click="reveal = true"></v-btn>
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" text="安装" variant="text"
@click="emit('install', extension)"></v-btn>
@@ -0,0 +1,302 @@
<script setup>
import { ref, watch, onMounted } from 'vue';
import axios from 'axios';
import { marked } from 'marked';
import hljs from 'highlight.js';
import 'highlight.js/styles/github.css';
const props = defineProps({
show: {
type: Boolean,
default: false
},
pluginName: {
type: String,
default: ''
},
repoUrl: {
type: String,
default: null
}
});
const emit = defineEmits(['update:show']);
const content = ref(null);
const error = ref(null);
const loading = ref(false);
// show
watch(() => props.show, (newVal) => {
if (newVal && props.pluginName) {
fetchReadme();
}
});
// pluginName
watch(() => props.pluginName, (newVal) => {
if (props.show && newVal) {
fetchReadme();
}
});
// README
async function fetchReadme() {
if (!props.pluginName) return;
loading.value = true;
content.value = null;
error.value = null;
try {
// README
const res = await axios.get(`/api/plugin/readme?name=${props.pluginName}`);
if (res.data.status === 'ok') {
content.value = res.data.data.content;
} else {
error.value = res.data.message || '获取README失败';
}
} catch (err) {
error.value = err.message || '获取README时发生错误';
} finally {
loading.value = false;
}
}
// GitHub
function openRepoInNewTab() {
if (props.repoUrl) {
window.open(props.repoUrl, '_blank');
}
}
// Markdown
function renderMarkdown(content) {
if (!content) return '';
// marked使highlight.js
marked.setOptions({
highlight: function(code, lang) {
if (lang && hljs.getLanguage(lang)) {
try {
return hljs.highlight(code, { language: lang }).value;
} catch (e) {
console.error(e);
}
}
return hljs.highlightAuto(code).value;
},
gfm: true, // GitHub Flavored Markdown
breaks: true, // Convert \n to <br>
headerIds: true, // Add id attributes to headers
mangle: false // Don't mangle email addresses
});
return marked(content);
}
// README
function refreshReadme() {
fetchReadme();
}
</script>
<template>
<v-dialog v-model="_show" width="800" persistent>
<v-card>
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h5">插件说明文档</span>
<v-btn icon @click="$emit('update:show', false)">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-card-text style="height: 70vh; overflow-y: auto;">
<div class="d-flex justify-space-between mb-4">
<v-btn
v-if="repoUrl"
color="primary"
prepend-icon="mdi-github"
@click="openRepoInNewTab()"
>
在GitHub中查看仓库
</v-btn>
<v-btn
color="secondary"
prepend-icon="mdi-refresh"
@click="refreshReadme()"
>
刷新文档
</v-btn>
</div>
<!-- 加载中 -->
<div v-if="loading" class="d-flex flex-column align-center justify-center" style="height: 100%;">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<p class="text-body-1 text-center">正在加载README文档...</p>
</div>
<!-- 内容显示 -->
<div v-else-if="content" class="markdown-body" v-html="renderMarkdown(content)"></div>
<!-- 错误提示 -->
<div v-else-if="error" class="d-flex flex-column align-center justify-center" style="height: 100%;">
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle-outline</v-icon>
<p class="text-body-1 text-center mb-4">{{ error }}</p>
</div>
<!-- 无内容提示 -->
<div v-else class="d-flex flex-column align-center justify-center" style="height: 100%;">
<v-icon size="64" color="warning" class="mb-4">mdi-file-question-outline</v-icon>
<p class="text-body-1 text-center mb-4">该插件未提供文档链接或GitHub仓库地址<br>请查看插件市场或联系插件作者获取更多信息</p>
</div>
</v-card-text>
<v-divider></v-divider>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="primary" variant="tonal" @click="$emit('update:show', false)">
关闭
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<style>
.markdown-body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif;
line-height: 1.6;
padding: 8px 0;
color: #24292e;
}
.markdown-body h1,
.markdown-body h2,
.markdown-body h3,
.markdown-body h4,
.markdown-body h5,
.markdown-body h6 {
margin-top: 24px;
margin-bottom: 16px;
font-weight: 600;
line-height: 1.25;
}
.markdown-body h1 {
font-size: 2em;
border-bottom: 1px solid #eaecef;
padding-bottom: 0.3em;
}
.markdown-body h2 {
font-size: 1.5em;
border-bottom: 1px solid #eaecef;
padding-bottom: 0.3em;
}
.markdown-body p {
margin-top: 0;
margin-bottom: 16px;
}
.markdown-body code {
padding: 0.2em 0.4em;
margin: 0;
background-color: rgba(27, 31, 35, 0.05);
border-radius: 3px;
font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
font-size: 85%;
}
.markdown-body pre {
padding: 16px;
overflow: auto;
font-size: 85%;
line-height: 1.45;
background-color: #f6f8fa;
border-radius: 3px;
margin-bottom: 16px;
}
.markdown-body pre code {
background-color: transparent;
padding: 0;
}
.markdown-body ul,
.markdown-body ol {
padding-left: 2em;
margin-bottom: 16px;
}
.markdown-body img {
max-width: 100%;
margin: 8px 0;
box-sizing: border-box;
background-color: #fff;
border-radius: 3px;
}
.markdown-body blockquote {
padding: 0 1em;
color: #6a737d;
border-left: 0.25em solid #dfe2e5;
margin-bottom: 16px;
}
.markdown-body a {
color: #0366d6;
text-decoration: none;
}
.markdown-body a:hover {
text-decoration: underline;
}
.markdown-body table {
border-spacing: 0;
border-collapse: collapse;
width: 100%;
overflow: auto;
margin-bottom: 16px;
}
.markdown-body table th,
.markdown-body table td {
padding: 6px 13px;
border: 1px solid #dfe2e5;
}
.markdown-body table tr {
background-color: #fff;
border-top: 1px solid #c6cbd1;
}
.markdown-body table tr:nth-child(2n) {
background-color: #f6f8fa;
}
.markdown-body hr {
height: 0.25em;
padding: 0;
margin: 24px 0;
background-color: #e1e4e8;
border: 0;
}
</style>
<script>
export default {
name: 'ReadmeDialog',
computed: {
_show: {
get() {
return this.show;
},
set(value) {
this.$emit('update:show', value);
}
}
}
}
</script>
+272 -177
View File
@@ -1,6 +1,7 @@
<script setup>
import axios from 'axios';
import { marked } from 'marked';
import { ref } from 'vue';
marked.setOptions({
breaks: true
@@ -11,37 +12,73 @@ marked.setOptions({
<v-card class="chat-page-card">
<v-card-text class="chat-page-container">
<div class="chat-layout">
<!-- 左侧对话列表面板 -->
<!-- 左侧对话列表面板 - 优化版 -->
<div class="sidebar-panel">
<v-btn variant="tonal" rounded="xl" class="new-chat-btn" @click="newC"
:disabled="!currCid">
<v-icon class="mr-2">mdi-plus</v-icon>创建对话
</v-btn>
<v-card class="conversation-list-card" v-if="conversations.length > 0">
<v-list density="compact" nav class="conversation-list" @update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl" class="conversation-item">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<div class="status-chips">
<v-chip class="status-chip" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
LLM
</v-chip>
<v-chip class="status-chip" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
语音转文本
</v-chip>
<div class="sidebar-header">
<v-btn variant="elevated" rounded="lg" class="new-chat-btn" @click="newC" :disabled="!currCid"
prepend-icon="mdi-plus">
创建对话
</v-btn>
</div>
<v-btn variant="tonal" rounded="xl" class="delete-chat-btn" v-if="currCid"
@click="deleteConversation(currCid)" color="error">
<v-icon class="mr-2">mdi-delete</v-icon>删除此对话
</v-btn>
<div class="conversations-container">
<div class="sidebar-section-title" v-if="conversations.length > 0">
对话历史
</div>
<v-card class="conversation-list-card" v-if="conversations.length > 0" flat>
<v-list density="compact" nav class="conversation-list"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="lg" class="conversation-item" active-color="primary">
<template v-slot:prepend>
<v-icon size="small" icon="mdi-message-text-outline"></v-icon>
</template>
<v-list-item-title class="conversation-title">新对话</v-list-item-title>
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at)
}}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<v-fade-transition>
<div class="no-conversations" v-if="conversations.length === 0">
<v-icon icon="mdi-message-text-outline" size="large" color="grey-lighten-1"></v-icon>
<div class="no-conversations-text">暂无对话历史</div>
</div>
</v-fade-transition>
</div>
<div class="sidebar-footer">
<div class="sidebar-section-title">
系统状态
</div>
<div class="status-chips">
<v-chip class="status-chip" :color="status?.llm_enabled ? 'primary' : 'grey-lighten-2'"
variant="elevated" size="small">
<template v-slot:prepend>
<v-icon :icon="status?.llm_enabled ? 'mdi-check-circle' : 'mdi-alert-circle'"
size="x-small"></v-icon>
</template>
LLM 服务
</v-chip>
<v-chip class="status-chip" :color="status?.stt_enabled ? 'success' : 'grey-lighten-2'"
variant="elevated" size="small">
<template v-slot:prepend>
<v-icon :icon="status?.stt_enabled ? 'mdi-check-circle' : 'mdi-alert-circle'"
size="x-small"></v-icon>
</template>
语音转文本
</v-chip>
</div>
<v-btn variant="tonal" rounded="lg" class="delete-chat-btn" v-if="currCid"
@click="deleteConversation(currCid)" color="error" density="comfortable" size="small">
<v-icon start size="small">mdi-delete</v-icon>
删除此对话
</v-btn>
</div>
</div>
<!-- 右侧聊天内容区域 -->
@@ -77,14 +114,15 @@ marked.setOptions({
<div v-if="msg.type == 'user'" class="user-message">
<div class="message-bubble user-bubble">
<span>{{ msg.message }}</span>
<!-- 图片附件 -->
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index" class="image-attachment">
<div v-for="(img, index) in msg.image_url" :key="index"
class="image-attachment">
<img :src="img" class="attached-image" />
</div>
</div>
<!-- 音频附件 -->
<div class="audio-attachment" v-if="msg.audio_url && msg.audio_url.length > 0">
<audio controls class="audio-player">
@@ -97,7 +135,7 @@ marked.setOptions({
<v-icon icon="mdi-account" />
</v-avatar>
</div>
<!-- 机器人消息 -->
<div v-else class="bot-message">
<v-avatar class="bot-avatar" color="deep-purple" size="36">
@@ -113,49 +151,30 @@ marked.setOptions({
<!-- 输入区域 -->
<div class="input-area fade-in">
<v-text-field
id="input-field"
variant="outlined"
v-model="prompt"
:label="inputFieldLabel"
placeholder="开始输入..."
:loading="loadingChat"
clear-icon="mdi-close-circle"
clearable
@click:clear="clearMessage"
class="message-input"
@keydown="handleInputKeyDown"
hide-details
>
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
placeholder="开始输入..." :loading="loadingChat" clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" class="message-input" @keydown="handleInputKeyDown"
hide-details>
<template v-slot:loader>
<v-progress-linear :active="loadingChat" height="3" color="deep-purple" indeterminate></v-progress-linear>
<v-progress-linear :active="loadingChat" height="3" color="deep-purple"
indeterminate></v-progress-linear>
</template>
<template v-slot:append>
<v-tooltip text="发送">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
@click="sendMessage"
class="send-btn"
icon="mdi-send"
variant="text"
color="deep-purple"
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl"
/>
<v-btn v-bind="props" @click="sendMessage" class="send-btn" icon="mdi-send"
variant="text" color="deep-purple"
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl" />
</template>
</v-tooltip>
<v-tooltip text="语音输入">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
@click="isRecording ? stopRecording() : startRecording()"
class="record-btn"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'"
variant="text"
:color="isRecording ? 'error' : 'deep-purple'"
/>
<v-btn v-bind="props" @click="isRecording ? stopRecording() : startRecording()"
class="record-btn"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
:color="isRecording ? 'error' : 'deep-purple'" />
</template>
</v-tooltip>
</template>
@@ -165,15 +184,17 @@ marked.setOptions({
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl">
<div v-for="(img, index) in stagedImagesUrl" :key="index" class="image-preview">
<img :src="img" class="preview-image" />
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close"
size="small" color="error" variant="text" />
</div>
<div v-if="stagedAudioUrl" class="audio-preview">
<v-chip color="deep-purple-lighten-4" class="audio-chip">
<v-icon start icon="mdi-microphone" size="small"></v-icon>
新录音
</v-chip>
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small"
color="error" variant="text" />
</div>
</div>
</div>
@@ -206,9 +227,9 @@ export default {
status: {},
statusText: '',
eventSource: null,
// Ctrl
ctrlKeyDown: false,
ctrlKeyTimer: null,
@@ -228,18 +249,17 @@ export default {
this.sendMessage();
}
}.bind(this));
// keyup
document.addEventListener('keyup', this.handleInputKeyUp);
},
beforeUnmount() {
console.log("111")
if (this.eventSource) {
this.eventSource.cancel();
console.log('SSE连接已断开');
}
// keyup
document.removeEventListener('keyup', this.handleInputKeyUp);
},
@@ -265,6 +285,9 @@ export default {
this.eventSource = reader
let in_streaming = false
let message_obj = null
while (true) {
const { done, value } = await reader.read();
if (done) {
@@ -273,40 +296,67 @@ export default {
}
const chunk = decoder.decode(value, { stream: true });
console.log("!!!!", chunk);
if (chunk === '[HB]\n') {
continue; //
}
if (chunk === '[ERROR]\n') {
continue;
}
//
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
let lines = chunk.split('\n\n');
console.log('SSE数据:', lines);
for (let i = 0; i < lines.length; i++) {
let line = lines[i].trim();
if (!line) {
continue;
}
this.messages.push(bot_resp);
} else if (chunk.startsWith('[RECORD]')) {
let audio = chunk.replace('[RECORD]', '');
let bot_resp = {
type: 'bot',
message: `<audio controls class="audio-player">
<source src="/api/chat/get_file?filename=${audio}" type="audio/wav">
您的浏览器不支持音频播放
</audio>`
console.log(line)
// data: {"type": "plain", "data": "helloworld"}
let chunk_json = JSON.parse(line.replace('data: ', ''));
if (chunk_json.type === 'heartbeat') {
continue; //
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
if (chunk_json.type === 'error') {
console.error('Error received:', chunk_json.data);
continue;
}
this.messages.push(bot_resp);
if (chunk_json.type === 'image') {
let img = chunk_json.data.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else if (chunk_json.type === 'record') {
let audio = chunk_json.data.replace('[RECORD]', '');
let bot_resp = {
type: 'bot',
message: `<audio controls class="audio-player">
<source src="/api/chat/get_file?filename=${audio}" type="audio/wav">
您的浏览器不支持音频播放
</audio>`
}
this.messages.push(bot_resp);
} else if (chunk_json.type === 'plain') {
if (!in_streaming) {
message_obj = {
type: 'bot',
message: ref(chunk_json.data),
}
this.messages.push(message_obj);
in_streaming = true;
} else {
message_obj.message.value += chunk_json.data;
}
} else if (chunk_json.type === 'end') {
in_streaming = false;
continue;
}
this.scrollToBottom();
}
this.scrollToBottom();
}
},
@@ -526,42 +576,6 @@ export default {
this.stagedAudioUrl = "";
this.loadingChat = false;
// const reader = response.body.getReader(); // Reader
// const decoder = new TextDecoder();
// const readStream = async () => {
// const { done, value } = await reader.read(); //
// if (done) {
// console.log("Stream finished.");
// return;
// }
// const chunk = decoder.decode(value, { stream: true });
// // bot_resp.message.value += chunk;
// console.log("!!!!", chunk);
// if (chunk.startsWith('[IMAGE]')) {
// let img = chunk.replace('[IMAGE]', '');
// let bot_resp = {
// type: 'bot',
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
// }
// this.messages.push(bot_resp);
// } else {
// let bot_resp = {
// type: 'bot',
// message: chunk
// }
// this.messages.push(bot_resp);
// }
// this.scrollToBottom();
// readStream(); //
// };
// readStream();
})
.catch(err => {
console.error(err);
@@ -578,9 +592,9 @@ export default {
if (e.keyCode === 17) { // Ctrl
//
if (this.ctrlKeyDown) return;
this.ctrlKeyDown = true;
//
this.ctrlKeyTimer = setTimeout(() => {
if (this.ctrlKeyDown && !this.isRecording) {
@@ -589,17 +603,17 @@ export default {
}, this.ctrlKeyLongPressThreshold);
}
},
handleInputKeyUp(e) {
if (e.keyCode === 17) { // Ctrl
this.ctrlKeyDown = false;
//
if (this.ctrlKeyTimer) {
clearTimeout(this.ctrlKeyTimer);
this.ctrlKeyTimer = null;
}
//
if (this.isRecording) {
this.stopRecording();
@@ -613,19 +627,41 @@ export default {
<style>
/* 基础动画 */
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.05); }
100% { transform: scale(1); }
0% {
transform: scale(1);
}
50% {
transform: scale(1.05);
}
100% {
transform: scale(1);
}
}
@keyframes slideIn {
from { transform: translateX(20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
from {
transform: translateX(20px);
opacity: 0;
}
to {
transform: translateX(0);
opacity: 1;
}
}
/* 聊天页面布局 */
@@ -650,84 +686,140 @@ export default {
gap: 24px;
}
/* 侧边栏样式 */
/* 侧边栏样式 - 优化版 */
.sidebar-panel {
max-width: 240px;
min-width: 200px;
max-width: 270px;
min-width: 240px;
display: flex;
flex-direction: column;
padding: 16px 8px;
border-right: 1px solid #f0f0f0;
padding: 0;
border-right: 1px solid rgba(0, 0, 0, 0.05);
background-color: #fcfcfc;
height: 100%;
position: relative;
}
.sidebar-header {
padding: 16px;
border-bottom: 1px solid rgba(0, 0, 0, 0.04);
}
.conversations-container {
flex-grow: 1;
overflow-y: auto;
padding: 16px;
}
.sidebar-footer {
padding: 16px;
border-top: 1px solid rgba(0, 0, 0, 0.04);
}
.sidebar-section-title {
font-size: 12px;
font-weight: 500;
color: #666;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 12px;
padding-left: 4px;
}
.new-chat-btn {
margin-bottom: 16px;
min-width: 200px;
background-color: #f5f0ff !important;
color: #673ab7 !important;
width: 100%;
background-color: #673ab7 !important;
color: white !important;
font-weight: 500;
box-shadow: none !important;
box-shadow: 0 2px 8px rgba(103, 58, 183, 0.25) !important;
transition: all 0.2s ease;
text-transform: none;
letter-spacing: 0.25px;
}
.new-chat-btn:hover {
background-color: #ede7f6 !important;
background-color: #7e57c2 !important;
box-shadow: 0 4px 12px rgba(103, 58, 183, 0.3) !important;
transform: translateY(-1px);
}
.conversation-list-card {
border-radius: 12px;
box-shadow: none !important;
border: 1px solid #f0f0f0;
background-color: #fafafa;
background-color: transparent;
}
.conversation-list {
max-height: 500px;
overflow-y: auto;
padding: 4px;
max-height: none;
overflow-y: visible;
padding: 0;
}
.conversation-item {
margin-bottom: 4px;
border-radius: 8px !important;
transition: all 0.2s ease;
height: auto !important;
min-height: 56px;
padding: 8px 12px !important;
}
.conversation-item:hover {
background-color: #f5f0ff;
background-color: rgba(103, 58, 183, 0.05);
}
.conversation-title {
font-weight: 500;
font-size: 14px;
line-height: 1.3;
margin-bottom: 2px;
}
.timestamp {
font-size: 11px;
color: #999;
margin-top: 4px;
line-height: 1;
}
.status-chips {
margin-top: 16px;
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.status-chip {
font-size: 12px;
height: 24px !important;
}
.delete-chat-btn {
position: fixed;
bottom: 24px;
margin-bottom: 16px;
min-width: 200px;
background-color: #feecec !important;
width: 100%;
color: #d32f2f !important;
font-weight: 500;
box-shadow: none !important;
margin-top: 8px;
text-transform: none;
letter-spacing: 0.25px;
font-size: 12px;
}
.delete-chat-btn:hover {
background-color: #ffebee !important;
background-color: rgba(211, 47, 47, 0.1) !important;
}
.no-conversations {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 150px;
opacity: 0.6;
gap: 12px;
}
.no-conversations-text {
font-size: 14px;
color: #999;
}
/* 聊天内容区域 */
@@ -828,7 +920,8 @@ export default {
border-top-left-radius: 4px;
}
.user-avatar, .bot-avatar {
.user-avatar,
.bot-avatar {
align-self: flex-end;
}
@@ -881,7 +974,8 @@ export default {
margin: 0 auto;
}
.send-btn, .record-btn {
.send-btn,
.record-btn {
margin-left: 4px;
}
@@ -895,7 +989,8 @@ export default {
flex-wrap: wrap;
}
.image-preview, .audio-preview {
.image-preview,
.audio-preview {
position: relative;
display: inline-flex;
}
@@ -1003,7 +1098,7 @@ export default {
margin: 16px 0;
}
.markdown-content th,
.markdown-content th,
.markdown-content td {
border: 1px solid #eee;
padding: 8px 12px;
+32 -11
View File
@@ -22,7 +22,7 @@ import 'highlight.js/styles/github.css';
<v-card-title>
<div class="pl-2 pt-2 d-flex align-center pe-2">
<h2> 插件市场</h2>
<h2> 插件市场</h2>
<v-btn icon size="small" style="margin-left: 8px" variant="plain" @click="jumpToPluginMarket()">
<v-icon size="small">mdi-help</v-icon>
<v-tooltip activator="parent" location="start">
@@ -52,6 +52,7 @@ import 'highlight.js/styles/github.css';
<v-card-text>
<small style="color: #bbb;">每个插件都是作者无偿提供的的劳动成果如果您喜欢某个插件 Star</small>
<div v-if="pinnedPlugins.length > 0" class="mt-4">
<h2>🥳 推荐</h2>
@@ -71,7 +72,7 @@ import 'highlight.js/styles/github.css';
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys">
<template v-slot:item.name="{ item }">
<div class="d-flex align-center">
<div class="d-flex align-center" style="overflow-x: scroll;">
<img v-if="item.logo" :src="item.logo"
style="height: 80px; width: 80px; margin-right: 8px; border-radius: 8px; margin-top: 8px; margin-bottom: 8px;"
alt="logo">
@@ -83,24 +84,43 @@ import 'highlight.js/styles/github.css';
</div>
</template>
<template v-slot:item.desc="{ item }">
<div style="font-size: 13px;">
{{ item.desc }}
</div>
</template>
<template v-slot:item.author="{ item }">
<span v-if="item?.social_link"><a :href="item?.social_link">{{ item.author
<div style="font-size: 12px;">
<span v-if="item?.social_link"><a :href="item?.social_link">{{ item.author
}}</a></span>
<span v-else>{{ item.author }}</span>
</div>
</template>
<template v-slot:item.stars="{ item }">
<a :href="item.repo">
<img v-if="item.repo"
:src="`https://img.shields.io/github/stars/${item.repo.split('/').slice(-2).join('/')}.svg`"
:alt="`Stars for ${item.name}`"
style="height: 20px;"
/>
</a>
</template>
<template v-slot:item.tags="{ item }">
<span v-if="item.tags.length === 0"></span>
<v-chip v-for="tag in item.tags" :key="tag" color="primary" size="small">{{ tag
<v-chip v-for="tag in item.tags" :key="tag" color="primary" size="x-small">{{ tag
}}</v-chip>
</template>
<template v-slot:item.actions="{ item }">
<v-btn v-if="!item.installed" class="text-none mr-2" size="small"
<v-btn v-if="!item.installed" class="text-none mr-2" size="x-small"
variant="flat" border
@click="extension_url = item.repo; newExtension()">安装</v-btn>
<v-btn v-else class="text-none mr-2" size="small" variant="flat" border
<v-btn v-else class="text-none mr-2" size="x-small" variant="flat" border
disabled>已安装</v-btn>
<v-btn class="text-none mr-2" size="small" variant="flat" border
@click="open(item.repo)">查看帮助</v-btn>
<v-btn class="text-none mr-2" size="x-small" variant="flat" border
@click="open(item.repo)">帮助</v-btn>
</template>
</v-data-table>
</v-col>
@@ -259,10 +279,11 @@ export default {
announcement: "",
isListView: true,
pluginMarketHeaders: [
{ title: '名称', key: 'name', maxWidth: '150px' },
{ title: '名称', key: 'name', maxWidth: '200px' },
{ title: '描述', key: 'desc', maxWidth: '250px' },
{ title: '作者', key: 'author', maxWidth: '60px' },
{ title: '标签', key: 'tags', maxWidth: '60px' },
{ title: '作者', key: 'author', maxWidth: '70px' },
{ title: 'Star数', key: 'stars', maxWidth: '100px' },
{ title: '标签', key: 'tags', maxWidth: '100px' },
{ title: '操作', key: 'actions', sortable: false }
],
marketSearch: "",
+236 -2
View File
@@ -3,6 +3,7 @@ import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
import axios from 'axios';
import { useCommonStore } from '@/stores/common';
@@ -35,6 +36,20 @@ const selectedPlugin = ref({});
const curr_namespace = ref("");
const wfr = ref(null);
const readmeDialog = reactive({
show: false,
pluginName: '',
repoUrl: null
});
//
const platformEnableDialog = ref(false);
const platformEnableData = reactive({
platforms: [],
plugins: [],
platform_enable: {}
});
const loadingPlatformData = ref(false);
const plugin_handler_info_headers = [
{ title: '行为类型', key: 'event_type_h' },
{ title: '描述', key: 'desc', maxWidth: '250px' },
@@ -225,6 +240,107 @@ const reloadPlugin = async (plugin_name) => {
}
};
const viewReadme = (plugin) => {
readmeDialog.pluginName = plugin.name;
readmeDialog.repoUrl = plugin.repo;
readmeDialog.show = true;
};
//
const getPlatformEnableConfig = async () => {
loadingPlatformData.value = true;
try {
const res = await axios.get('/api/plugin/platform_enable/get');
if (res.data.status === "error") {
toast(res.data.message, "error");
return;
}
platformEnableData.platforms = res.data.data.platforms;
platformEnableData.plugins = res.data.data.plugins;
platformEnableData.platform_enable = res.data.data.platform_enable;
//
if (platformEnableData.platforms.length === 0) {
toast("未添加任何平台适配器,请先在平台管理中添加平台", "warning");
} else {
//
platformEnableData.platforms.forEach(platform => {
if (!platformEnableData.platform_enable[platform.name]) {
platformEnableData.platform_enable[platform.name] = {};
}
//
platformEnableData.plugins.forEach(plugin => {
if (platformEnableData.platform_enable[platform.name][plugin.name] === undefined) {
platformEnableData.platform_enable[platform.name][plugin.name] = true; //
}
});
});
}
platformEnableDialog.value = true;
} catch (err) {
toast("获取平台插件配置失败: " + err, "error");
} finally {
loadingPlatformData.value = false;
}
};
//
const savePlatformEnableConfig = async () => {
loadingPlatformData.value = true;
try {
const res = await axios.post('/api/plugin/platform_enable/set', {
platform_enable: platformEnableData.platform_enable
});
if (res.data.status === "error") {
toast(res.data.message, "error");
return;
}
toast(res.data.message, "success");
platformEnableDialog.value = false;
} catch (err) {
toast("保存平台插件配置失败: " + err, "error");
} finally {
loadingPlatformData.value = false;
}
};
//
const selectAllPluginsForPlatform = (platformName, isSelected, onlyReserved = null) => {
// platform_enable
if (!platformEnableData.platform_enable[platformName]) {
platformEnableData.platform_enable[platformName] = {};
}
//
platformEnableData.plugins.forEach(plugin => {
// onlyReservednull
// onlyReservedtrue
// onlyReservedfalse
if (onlyReserved === null || plugin.reserved === onlyReserved) {
platformEnableData.platform_enable[platformName][plugin.name] = isSelected;
}
});
};
//
const toggleAllPluginsForPlatform = (platformName) => {
// platform_enable
if (!platformEnableData.platform_enable[platformName]) {
platformEnableData.platform_enable[platformName] = {};
}
//
platformEnableData.plugins.forEach(plugin => {
const currentState = platformEnableData.platform_enable[platformName][plugin.name];
platformEnableData.platform_enable[platformName][plugin.name] = !currentState;
});
};
//
onMounted(async () => {
await getExtensions();
@@ -248,6 +364,9 @@ onMounted(async () => {
<v-btn class="text-none ml-2" size="small" variant="flat" border @click="toggleShowReserved">
{{ showReserved ? '隐藏系统保留插件' : '显示系统保留插件' }}
</v-btn>
<v-btn class="text-none ml-2" size="small" variant="flat" color="primary" border @click="getPlatformEnableConfig">
平台命令配置
</v-btn>
<v-dialog max-width="500px" v-if="extension_data.message">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" icon size="small" color="error" style="margin-left: auto;" variant="plain">
@@ -279,11 +398,111 @@ onMounted(async () => {
@update="updateExtension(extension.name)"
@reload="reloadPlugin(extension.name)"
@toggle-activation="extension.activated ? pluginOff(extension) : pluginOn(extension)"
@view-handlers="showPluginInfo(extension)">
@view-handlers="showPluginInfo(extension)"
@view-readme="viewReadme(extension)">
</ExtensionCard>
</v-col>
</v-row>
<!-- 插件平台配置对话框 -->
<v-dialog v-model="platformEnableDialog" max-width="800" persistent>
<v-card>
<v-card-title>
<span class="headline">平台命令可用性配置</span>
</v-card-title>
<v-card-subtitle>
设置每个插件在不同平台上的可用性勾选表示启用
</v-card-subtitle>
<v-card-text>
<v-overlay
:model-value="loadingPlatformData"
class="align-center justify-center"
persistent
>
<v-progress-circular
color="primary"
indeterminate
size="64"
></v-progress-circular>
</v-overlay>
<div v-if="platformEnableData.platforms.length === 0" class="text-center pa-4">
<v-icon icon="mdi-alert" color="warning" size="64" class="mb-4"></v-icon>
<div class="text-h6 mb-2">未找到平台适配器</div>
<div class="text-body-1 mb-4">请先在 <strong>平台管理</strong> 中添加并配置平台适配器然后再设置插件的平台可用性</div>
<v-btn color="primary" to="/platforms">前往平台管理</v-btn>
</div>
<v-table v-else>
<thead>
<tr>
<th>插件名称</th>
<th v-for="platform in platformEnableData.platforms" :key="platform.name">
<div class="d-flex align-center">
{{ platform.display_name }}
<v-menu>
<template v-slot:activator="{ props }">
<v-btn
icon
density="compact"
variant="text"
size="small"
v-bind="props"
class="ms-1"
>
<v-icon>mdi-dots-vertical</v-icon>
</v-btn>
</template>
<v-list>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true)">
<v-list-item-title>全选</v-list-item-title>
</v-list-item>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true, false)">
<v-list-item-title>全选普通插件</v-list-item-title>
</v-list-item>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true, true)">
<v-list-item-title>全选系统插件</v-list-item-title>
</v-list-item>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, false)">
<v-list-item-title>全不选</v-list-item-title>
</v-list-item>
<v-list-item @click="toggleAllPluginsForPlatform(platform.name)">
<v-list-item-title>反选</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
</div>
</th>
</tr>
</thead>
<tbody>
<tr v-for="plugin in platformEnableData.plugins" :key="plugin.name">
<td>
<div class="d-flex align-center">
{{ plugin.name }}
<v-chip v-if="plugin.reserved" color="primary" size="x-small" class="ml-2">系统</v-chip>
</div>
<div class="text-caption text-grey">{{ plugin.desc }}</div>
</td>
<td v-for="platform in platformEnableData.platforms" :key="platform.name">
<v-checkbox
v-model="platformEnableData.platform_enable[platform.name][plugin.name]"
hide-details
density="compact"
></v-checkbox>
</td>
</tr>
</tbody>
</v-table>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" text @click="platformEnableDialog = false">关闭</v-btn>
<v-btn v-if="platformEnableData.platforms.length > 0" color="primary" @click="savePlatformEnableConfig">保存</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 配置对话框 -->
<v-dialog v-model="configDialog" width="1000">
<v-card>
@@ -365,4 +584,19 @@ onMounted(async () => {
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<ReadmeDialog
v-model:show="readmeDialog.show"
:plugin-name="readmeDialog.pluginName"
:repo-url="readmeDialog.repoUrl"
/>
</template>
<style scoped>
.plugin-handler-item {
margin-bottom: 10px;
padding: 5px;
border-radius: 5px;
background-color: #f5f5f5;
}
</style>
+4 -3
View File
@@ -880,8 +880,9 @@ UID: {user_id} 此 ID 可用于设置管理员。
provider = self.context.get_using_provider()
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
await provider.api_client.delete_chat_conv(message.unified_msg_origin)
provider.conversation_ids.pop(message.unified_msg_origin, None)
dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
if dify_cid:
await provider.api_client.delete_chat_conv(message.unified_msg_origin, dify_cid)
message.set_result(
MessageEventResult().message(
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
@@ -1232,7 +1233,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
if mood_dialogs := persona["_mood_imitation_dialogs_processed"]:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
if (begin_dialogs := persona["_begin_dialogs_processed"]) and not req.contexts:
req.contexts[:0] = begin_dialogs
if quote and quote.message_str:
+1 -1
View File
@@ -22,7 +22,7 @@ class Main(star.Star):
if not self.timezone:
self.timezone = None
try:
self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None
self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None
except Exception as e:
logger.error(f"时区设置错误: {e}, 使用本地时区")
self.timezone = None