Merge branch 'master' into gemini
This commit is contained in:
@@ -50,6 +50,7 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"prompt_prefix": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
@@ -248,6 +249,9 @@ CONFIG_METADATA_2 = {
|
||||
"description": "平台设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"plugin_enable": {
|
||||
"invisible": True, # 隐藏插件启用配置
|
||||
},
|
||||
"unique_session": {
|
||||
"description": "会话隔离",
|
||||
"type": "bool",
|
||||
@@ -924,8 +928,8 @@ CONFIG_METADATA_2 = {
|
||||
"dify_api_type": {
|
||||
"description": "Dify 应用类型",
|
||||
"type": "string",
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
||||
"options": ["chat", "agent", "workflow"],
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, chatflow, agent, workflow 三种应用类型。",
|
||||
"options": ["chat", "chatflow", "agent", "workflow"],
|
||||
},
|
||||
"dify_workflow_output_key": {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
@@ -994,6 +998,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话数量(条)",
|
||||
"type": "int",
|
||||
"hint": "超出 最多携带对话数量(条) 时,丢弃多少条记录,用户和AI的一轮聊天记为 1 条。适宜的配置,可以提高超长上下文对话 deepseek 命中缓存效果,理想情况下计费将降低到1/3以下",
|
||||
},
|
||||
"streaming_response": {
|
||||
"description": "启用流式回复",
|
||||
"type": "bool",
|
||||
|
||||
@@ -175,7 +175,15 @@ class ConversationManager:
|
||||
if record["role"] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record["role"] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
if "content" in record and record["content"]:
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
elif "tool_calls" in record:
|
||||
tool_calls_str = json.dumps(
|
||||
record["tool_calls"], ensure_ascii=False
|
||||
)
|
||||
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||
else:
|
||||
temp_contexts.append("Assistant: [未知的内容]")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
@@ -18,6 +19,7 @@ STAGES_ORDER = [
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
@@ -29,6 +31,7 @@ __all__ = [
|
||||
"WhitelistCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class PlatformCompatibilityStage(Stage):
|
||||
"""检查所有处理器的平台兼容性。
|
||||
|
||||
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化平台兼容性检查阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
|
||||
# 获取已激活的处理器
|
||||
activated_handlers = event.get_extra("activated_handlers")
|
||||
if activated_handlers is None:
|
||||
activated_handlers = []
|
||||
|
||||
# 标记不兼容的处理器
|
||||
for handler in activated_handlers:
|
||||
if not isinstance(handler, StarHandlerMetadata):
|
||||
continue
|
||||
# 检查处理器是否在当前平台启用
|
||||
enabled = handler.is_enabled_for_platform(platform_id)
|
||||
if not enabled:
|
||||
if handler.handler_module_path in star_map:
|
||||
plugin_name = star_map[handler.handler_module_path].name
|
||||
logger.debug(
|
||||
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||
)
|
||||
# 设置处理器为平台不兼容状态
|
||||
# TODO: 更好的标记方式
|
||||
handler.platform_compatible = False
|
||||
else:
|
||||
# 确保处理器为平台兼容状态
|
||||
handler.platform_compatible = True
|
||||
|
||||
# 更新已激活的处理器列表
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
@@ -38,6 +38,10 @@ class LLMRequestSubStage(Stage):
|
||||
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||
"max_context_length"
|
||||
] # int
|
||||
self.dequeue_context_length = min(
|
||||
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
) # int
|
||||
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||
"streaming_response"
|
||||
] # bool
|
||||
@@ -62,12 +66,16 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
assert isinstance(
|
||||
req, ProviderRequest
|
||||
), "provider_request 必须是 ProviderRequest 类型。"
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
req.contexts = self._process_tool_message_pairs(
|
||||
all_contexts, remove_tags=True
|
||||
)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.provider_wake_prefix:
|
||||
@@ -108,8 +116,10 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMRequestEvent
|
||||
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
@@ -135,7 +145,9 @@ class LLMRequestSubStage(Stage):
|
||||
and len(req.contexts) // 2 > self.max_context_length
|
||||
):
|
||||
logger.debug("上下文长度超过限制,将截断。")
|
||||
req.contexts = req.contexts[-self.max_context_length * 2 :]
|
||||
req.contexts = req.contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length) * 2 :
|
||||
]
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
@@ -368,6 +380,20 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取处理器,过滤掉平台不兼容的处理器
|
||||
platform_id = event.get_platform_id()
|
||||
star_md = star_map.get(func_tool.handler_module_path)
|
||||
if (
|
||||
star_md and
|
||||
platform_id in star_md.supported_platforms
|
||||
and not star_md.supported_platforms[platform_id]
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||
)
|
||||
# 直接跳过,不添加任何消息到tool_call_result
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
@@ -425,12 +451,22 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts
|
||||
contexts = req.contexts.copy()
|
||||
contexts.append(await req.assemble_context())
|
||||
|
||||
# tool calls result
|
||||
# 记录并标记函数调用结果
|
||||
if req.tool_calls_result:
|
||||
contexts.extend(req.tool_calls_result.to_openai_messages())
|
||||
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||
|
||||
# 添加标记
|
||||
for message in tool_calls_messages:
|
||||
message["_tool_call_history"] = True
|
||||
|
||||
processed_tool_messages = self._process_tool_message_pairs(
|
||||
tool_calls_messages, remove_tags=False
|
||||
)
|
||||
|
||||
contexts.extend(processed_tool_messages)
|
||||
|
||||
contexts.append(
|
||||
{"role": "assistant", "content": llm_response.completion_text}
|
||||
@@ -441,3 +477,59 @@ class LLMRequestSubStage(Stage):
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||
)
|
||||
|
||||
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||
|
||||
Args:
|
||||
messages (list): 消息列表
|
||||
remove_tags (bool): 是否移除_tool_call_history标记
|
||||
|
||||
Returns:
|
||||
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
current_msg = messages[i]
|
||||
|
||||
# 普通消息直接添加
|
||||
if "_tool_call_history" not in current_msg:
|
||||
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 工具调用消息成对处理
|
||||
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||
assistant_msg = current_msg.copy()
|
||||
|
||||
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||
del assistant_msg["_tool_call_history"]
|
||||
|
||||
related_tools = []
|
||||
j = i + 1
|
||||
while (
|
||||
j < len(messages)
|
||||
and messages[j].get("role") == "tool"
|
||||
and "_tool_call_history" in messages[j]
|
||||
):
|
||||
tool_msg = messages[j].copy()
|
||||
|
||||
if remove_tags:
|
||||
del tool_msg["_tool_call_history"]
|
||||
|
||||
related_tools.append(tool_msg)
|
||||
j += 1
|
||||
|
||||
# 成对的时候添加到结果
|
||||
if related_tools:
|
||||
result.append(assistant_msg)
|
||||
result.extend(related_tools)
|
||||
|
||||
i = j # 跳过已处理
|
||||
else:
|
||||
# 单独的tool消息
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
|
||||
)
|
||||
if not handlers_parsed_params:
|
||||
handlers_parsed_params = {}
|
||||
|
||||
for handler in activated_handlers:
|
||||
# 检查处理器是否在当前平台兼容
|
||||
if (
|
||||
hasattr(handler, "platform_compatible")
|
||||
and handler.platform_compatible is False
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||
)
|
||||
continue
|
||||
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
|
||||
@@ -198,7 +198,7 @@ class RespondStage(Stage):
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent
|
||||
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -96,7 +96,7 @@ class ResultDecorateStage(Stage):
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnDecoratingResultEvent
|
||||
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -81,6 +81,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
def get_platform_name(self):
|
||||
return self.platform_meta.name
|
||||
|
||||
def get_platform_id(self):
|
||||
return self.platform_meta.id
|
||||
|
||||
def get_message_str(self) -> str:
|
||||
"""
|
||||
获取消息字符串。
|
||||
|
||||
@@ -7,6 +7,8 @@ class PlatformMetadata:
|
||||
"""平台的名称"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str = None
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
"""平台的默认配置模板"""
|
||||
|
||||
@@ -39,8 +39,9 @@ class AiocqhttpAdapter(Platform):
|
||||
self.port = platform_config["ws_reverse_port"]
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"aiocqhttp",
|
||||
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
@@ -109,7 +110,7 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 请求类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if "group_id" in event and event["group_id"]:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
@@ -129,7 +130,7 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 通知类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if "group_id" in event and event["group_id"]:
|
||||
abm.group_id = str(event.group_id)
|
||||
|
||||
@@ -73,8 +73,9 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"dingtalk",
|
||||
"钉钉机器人官方 API 适配器",
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def convert_msg(
|
||||
|
||||
@@ -60,8 +60,9 @@ class GewechatPlatformAdapter(Platform):
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"gewechat",
|
||||
"基于 gewechat 的 Wechat 适配器",
|
||||
name="gewechat",
|
||||
description="基于 gewechat 的 Wechat 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from astrbot.api.platform import (
|
||||
@@ -66,12 +67,47 @@ class LarkPlatformAdapter(Platform):
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
raise NotImplementedError("Lark 适配器不支持 send_by_session")
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
}
|
||||
}
|
||||
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
id_type = "chat_id"
|
||||
if "%" in session.session_id:
|
||||
session.session_id = session.session_id.split("%")[1]
|
||||
else:
|
||||
id_type = "open_id"
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(session.session_id)
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self.lark_api.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"lark",
|
||||
"飞书机器人官方 API 适配器",
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
@@ -165,7 +201,10 @@ class LarkPlatformAdapter(Platform):
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
logger.debug(abm)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -11,6 +11,7 @@ from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
from botpy.types import message
|
||||
import random
|
||||
|
||||
|
||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
@@ -68,7 +69,6 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
async def _post_send(self, stream: dict = None):
|
||||
"""QQ 官方 API 仅支持回复一次"""
|
||||
if not self.send_buffer:
|
||||
return
|
||||
|
||||
@@ -97,6 +97,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
|
||||
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
|
||||
payload["msg_seq"] = random.randint(1, 10000)
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
if image_base64:
|
||||
|
||||
@@ -126,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"qq_official",
|
||||
"QQ 机器人官方 API 适配器",
|
||||
name="qq_official",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -99,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"qq_official_webhook",
|
||||
"QQ 机器人官方 API 适配器",
|
||||
name="qq_official_webhook",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
@@ -1,26 +1,31 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from telegram import BotCommand, Update
|
||||
from telegram.constants import ChatType
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
|
||||
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, filters
|
||||
from telegram.constants import ChatType
|
||||
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||
from .tg_event import TelegramPlatformEvent
|
||||
from astrbot.api import logger
|
||||
from telegram.ext import ExtBot
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -67,6 +72,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
self.client = self.application.bot
|
||||
logger.debug(f"Telegram base url: {self.client.base_url}")
|
||||
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
@@ -80,18 +87,94 @@ class TelegramPlatformAdapter(Platform):
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"telegram",
|
||||
"telegram 适配器",
|
||||
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
await self.register_commands()
|
||||
|
||||
# TODO 使用更优雅的方式重新注册命令
|
||||
self.scheduler.add_job(
|
||||
self.register_commands,
|
||||
"interval",
|
||||
minutes=5,
|
||||
id="telegram_command_register",
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
await queue
|
||||
|
||||
async def register_commands(self):
|
||||
"""收集所有注册的指令并注册到 Telegram"""
|
||||
try:
|
||||
await self.client.delete_my_commands()
|
||||
commands = self.collect_commands()
|
||||
|
||||
if commands:
|
||||
await self.client.set_my_commands(commands)
|
||||
for cmd in commands:
|
||||
logger.debug(f"已注册指令: /{cmd.command} - {cmd.description}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
|
||||
|
||||
def collect_commands(self) -> list[BotCommand]:
|
||||
"""从注册的处理器中收集所有指令"""
|
||||
command_dict = {}
|
||||
skip_commands = {"start"}
|
||||
|
||||
for handler_md in star_handlers_registry._handlers:
|
||||
handler_metadata = handler_md[1]
|
||||
if not star_map[handler_metadata.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_metadata.event_filters:
|
||||
cmd_info = self._extract_command_info(
|
||||
event_filter, handler_metadata, skip_commands
|
||||
)
|
||||
if cmd_info:
|
||||
cmd_name, description = cmd_info
|
||||
command_dict.setdefault(cmd_name, description)
|
||||
|
||||
commands_a = sorted(command_dict.keys())
|
||||
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
|
||||
|
||||
@staticmethod
|
||||
def _extract_command_info(
|
||||
event_filter, handler_metadata, skip_commands: set
|
||||
) -> tuple[str, str] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
is_group = False
|
||||
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
|
||||
if (
|
||||
event_filter.parent_command_names
|
||||
and event_filter.parent_command_names != [""]
|
||||
):
|
||||
return None
|
||||
cmd_name = event_filter.command_name
|
||||
elif isinstance(event_filter, CommandGroupFilter):
|
||||
if event_filter.parent_group:
|
||||
return None
|
||||
cmd_name = event_filter.group_name
|
||||
is_group = True
|
||||
|
||||
if not cmd_name or cmd_name in skip_commands:
|
||||
return None
|
||||
|
||||
# Build description.
|
||||
description = handler_metadata.desc or (
|
||||
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
|
||||
)
|
||||
if len(description) > 30:
|
||||
description = description[:30] + "..."
|
||||
return cmd_name, description
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
await context.bot.send_message(
|
||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||
@@ -163,6 +246,16 @@ class TelegramPlatformAdapter(Platform):
|
||||
# 处理文本消息
|
||||
plain_text = update.message.text
|
||||
|
||||
# 群聊场景命令特殊处理
|
||||
if plain_text.startswith("/"):
|
||||
command_parts = plain_text.split(" ", 1)
|
||||
if "@" in command_parts[0]:
|
||||
command, bot_name = command_parts[0].split("@")
|
||||
if bot_name == self.client.username:
|
||||
plain_text = command + (
|
||||
f" {command_parts[1]}" if len(command_parts) > 1 else ""
|
||||
)
|
||||
|
||||
if update.message.entities:
|
||||
for entity in update.message.entities:
|
||||
if entity.type == "mention":
|
||||
@@ -242,7 +335,11 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
try:
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown()
|
||||
|
||||
await self.application.stop()
|
||||
await self.client.delete_my_commands()
|
||||
|
||||
# 保险起见先判断是否存在updater对象
|
||||
if self.application.updater is not None:
|
||||
|
||||
@@ -43,8 +43,7 @@ class WebChatAdapter(Platform):
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"webchat",
|
||||
"webchat",
|
||||
name="webchat", description="webchat", id=self.config.get("id")
|
||||
)
|
||||
|
||||
async def send_by_session(
|
||||
|
||||
@@ -212,9 +212,9 @@ class LLMResponse:
|
||||
role: str,
|
||||
completion_text: str = "",
|
||||
result_chain: MessageChain = None,
|
||||
tools_call_args: List[Dict[str, any]] = [],
|
||||
tools_call_name: List[str] = [],
|
||||
tools_call_ids: List[str] = [],
|
||||
tools_call_args: List[Dict[str, any]] = None,
|
||||
tools_call_name: List[str] = None,
|
||||
tools_call_ids: List[str] = None,
|
||||
raw_completion: ChatCompletion = None,
|
||||
_new_record: Dict[str, any] = None,
|
||||
is_chunk: bool = False,
|
||||
@@ -229,6 +229,13 @@ class LLMResponse:
|
||||
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
||||
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||
"""
|
||||
if tools_call_args is None:
|
||||
tools_call_args = []
|
||||
if tools_call_name is None:
|
||||
tools_call_name = []
|
||||
if tools_call_ids is None:
|
||||
tools_call_ids = []
|
||||
|
||||
self.role = role
|
||||
self.completion_text = completion_text
|
||||
self.result_chain = result_chain
|
||||
|
||||
@@ -348,16 +348,17 @@ class FuncCall:
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
)
|
||||
func_ = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
# "parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
if f.parameters.get("properties"):
|
||||
func_["function"]["parameters"] = f.parameters
|
||||
_l.append(func_)
|
||||
return _l
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
|
||||
@@ -102,7 +102,7 @@ class ProviderDify(Provider):
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
|
||||
Regular → Executable
Regular → Executable
@@ -47,5 +47,29 @@ class StarMetadata:
|
||||
star_handler_full_names: List[str] = field(default_factory=list)
|
||||
"""注册的 Handler 的全名列表"""
|
||||
|
||||
supported_platforms: Dict[str, bool] = field(default_factory=dict)
|
||||
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
|
||||
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||
"""更新插件支持的平台列表
|
||||
|
||||
Args:
|
||||
plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项
|
||||
"""
|
||||
if not plugin_enable_config:
|
||||
return
|
||||
|
||||
# 清空之前的配置
|
||||
self.supported_platforms.clear()
|
||||
|
||||
# 遍历所有平台配置
|
||||
for platform_id, plugins in plugin_enable_config.items():
|
||||
# 检查该插件在当前平台的配置
|
||||
if self.name in plugins:
|
||||
self.supported_platforms[platform_id] = plugins[self.name]
|
||||
else:
|
||||
# 如果没有明确配置,默认为启用
|
||||
self.supported_platforms[platform_id] = True
|
||||
|
||||
@@ -30,21 +30,36 @@ class StarHandlerRegistry(Generic[T]):
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过事件类型获取 Handler"""
|
||||
handlers = [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
if handler.event_type == event_type
|
||||
and (
|
||||
not only_activated
|
||||
or (
|
||||
star_map[handler.handler_module_path]
|
||||
and star_map[handler.handler_module_path].activated
|
||||
)
|
||||
)
|
||||
]
|
||||
"""通过事件类型获取 Handler
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
only_activated: 是否只返回已激活的插件的处理器
|
||||
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||
|
||||
Returns:
|
||||
List[StarHandlerMetadata]: 处理器列表
|
||||
"""
|
||||
handlers = []
|
||||
for _, handler in self._handlers:
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
|
||||
# 只激活的插件处理器
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
|
||||
# 平台兼容性过滤
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
continue
|
||||
|
||||
handlers.append(handler)
|
||||
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
@@ -139,3 +154,32 @@ class StarHandlerMetadata:
|
||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||
"priority", 0
|
||||
)
|
||||
|
||||
def is_enabled_for_platform(self, platform_id: str) -> bool:
|
||||
"""检查插件是否在指定平台启用
|
||||
|
||||
Args:
|
||||
platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例
|
||||
|
||||
Returns:
|
||||
bool: 是否启用,True表示启用,False表示禁用
|
||||
"""
|
||||
plugin = star_map.get(self.handler_module_path)
|
||||
|
||||
# 如果插件元数据不存在,默认允许执行
|
||||
if not plugin or not plugin.name:
|
||||
return True
|
||||
|
||||
# 先检查插件是否被激活
|
||||
if not plugin.activated:
|
||||
return False
|
||||
|
||||
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
|
||||
if (
|
||||
hasattr(plugin, "supported_platforms")
|
||||
and platform_id in plugin.supported_platforms
|
||||
):
|
||||
return plugin.supported_platforms[platform_id]
|
||||
|
||||
# 如果没有缓存数据,默认允许执行
|
||||
return True
|
||||
|
||||
@@ -166,8 +166,71 @@ class PluginManager:
|
||||
|
||||
return metadata
|
||||
|
||||
def _get_plugin_related_modules(
|
||||
self, plugin_root_dir: str, is_reserved: bool = False
|
||||
) -> list[str]:
|
||||
"""获取与指定插件相关的所有已加载模块名
|
||||
|
||||
根据插件根目录名和是否为保留插件,从 sys.modules 中筛选出相关的模块名
|
||||
|
||||
Args:
|
||||
plugin_root_dir: 插件根目录名
|
||||
is_reserved: 是否是保留插件,影响模块路径前缀
|
||||
|
||||
Returns:
|
||||
list[str]: 与该插件相关的模块名列表
|
||||
"""
|
||||
prefix = "packages." if is_reserved else "data.plugins."
|
||||
return [
|
||||
key
|
||||
for key in list(sys.modules.keys())
|
||||
if key.startswith(f"{prefix}{plugin_root_dir}")
|
||||
]
|
||||
|
||||
def _purge_modules(
|
||||
self,
|
||||
module_patterns: list[str] = None,
|
||||
root_dir_name: str = None,
|
||||
is_reserved: bool = False,
|
||||
):
|
||||
"""从 sys.modules 中移除指定的模块
|
||||
|
||||
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||
|
||||
Args:
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
||||
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||
"""
|
||||
if module_patterns:
|
||||
for pattern in module_patterns:
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith(pattern):
|
||||
del sys.modules[key]
|
||||
logger.debug(f"删除模块 {key}")
|
||||
|
||||
if root_dir_name:
|
||||
for module_name in self._get_plugin_related_modules(
|
||||
root_dir_name, is_reserved
|
||||
):
|
||||
try:
|
||||
del sys.modules[module_name]
|
||||
logger.debug(f"删除模块 {module_name}")
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {module_name} 未载入")
|
||||
|
||||
async def reload(self, specified_plugin_name=None):
|
||||
"""扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件"""
|
||||
"""重新加载插件
|
||||
|
||||
Args:
|
||||
specified_plugin_name (str, optional): 要重载的特定插件名称。
|
||||
如果为 None,则重载所有插件。
|
||||
|
||||
Returns:
|
||||
tuple: 返回 load() 方法的结果,包含 (success, error_message)
|
||||
- success (bool): 重载是否成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
@@ -192,9 +255,6 @@ class PluginManager:
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
@@ -209,11 +269,44 @@ class PluginManager:
|
||||
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
return await self.load(specified_module_path)
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
# 更新所有插件的平台兼容性
|
||||
await self.update_all_platform_compatibility()
|
||||
|
||||
return result
|
||||
|
||||
async def update_all_platform_compatibility(self):
|
||||
"""更新所有插件的平台兼容性设置"""
|
||||
# 获取最新的平台插件启用配置
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
logger.debug(
|
||||
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
|
||||
)
|
||||
|
||||
# 遍历所有插件,更新平台兼容性
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin.update_platform_compatibility(plugin_enable_config)
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def load(self, specified_module_path=None, specified_dir_name=None):
|
||||
"""载入插件。
|
||||
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
|
||||
|
||||
Args:
|
||||
specified_module_path (str, optional): 指定要加载的插件模块路径。例如: "data.plugins.my_plugin.main"
|
||||
specified_dir_name (str, optional): 指定要加载的插件目录名。例如: "my_plugin"
|
||||
|
||||
Returns:
|
||||
tuple: (success, error_message)
|
||||
- success (bool): 是否全部加载成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
@@ -320,6 +413,12 @@ class PluginManager:
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
# 更新插件的平台兼容性
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = (
|
||||
star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -447,6 +546,20 @@ class PluginManager:
|
||||
return False, fail_rec
|
||||
|
||||
async def install_plugin(self, repo_url: str, proxy=""):
|
||||
"""从仓库 URL 安装插件
|
||||
|
||||
从指定的仓库 URL 下载并安装插件,然后加载该插件到系统中
|
||||
|
||||
Args:
|
||||
repo_url (str): 要安装的插件仓库 URL
|
||||
proxy (str, optional): 用于下载的代理服务器。默认为空字符串。
|
||||
|
||||
Returns:
|
||||
dict | None: 安装成功时返回包含插件信息的字典:
|
||||
- repo: 插件的仓库 URL
|
||||
- readme: README.md 文件的内容(如果存在)
|
||||
如果找不到插件元数据则返回 None。
|
||||
"""
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
@@ -481,6 +594,14 @@ class PluginManager:
|
||||
return plugin_info
|
||||
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
"""卸载指定的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要卸载的插件名称
|
||||
|
||||
Raises:
|
||||
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||
"""
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
@@ -509,9 +630,17 @@ class PluginManager:
|
||||
)
|
||||
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
"""解绑并移除一个插件。
|
||||
|
||||
Args:
|
||||
plugin_name: 要解绑的插件名称
|
||||
plugin_module_path: 插件的完整模块路径
|
||||
"""
|
||||
plugin = None
|
||||
del star_map[plugin_module_path]
|
||||
for i, p in enumerate(star_registry):
|
||||
if p.name == plugin_name:
|
||||
plugin = p
|
||||
del star_registry[i]
|
||||
break
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -521,21 +650,17 @@ class PluginManager:
|
||||
f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})"
|
||||
)
|
||||
star_handlers_registry.remove(handler)
|
||||
keys_to_delete = [
|
||||
k
|
||||
for k, v in star_handlers_registry.star_handlers_map.items()
|
||||
if k.startswith(plugin_module_path)
|
||||
]
|
||||
for k in keys_to_delete:
|
||||
try:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
del sys.modules[plugin_module_path]
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {plugin_module_path} 未载入")
|
||||
for k in [
|
||||
k
|
||||
for k in star_handlers_registry.star_handlers_map
|
||||
if k.startswith(plugin_module_path)
|
||||
]:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
self._purge_modules(
|
||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||
)
|
||||
|
||||
async def update_plugin(self, plugin_name: str, proxy=""):
|
||||
"""升级一个插件"""
|
||||
|
||||
@@ -15,7 +15,7 @@ class SharedPreferences:
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4)
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default=None):
|
||||
|
||||
@@ -60,11 +60,13 @@ def validate_config(
|
||||
data[key] = False
|
||||
continue
|
||||
meta = metadata[key]
|
||||
if "type" not in meta:
|
||||
logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验")
|
||||
continue
|
||||
# null 转换
|
||||
if value is None:
|
||||
data[key] = DEFAULT_VALUE_MAP[meta["type"]]
|
||||
continue
|
||||
# 递归验证
|
||||
if meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(
|
||||
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}"
|
||||
|
||||
@@ -38,6 +38,8 @@ class PluginRoute(Route):
|
||||
"/plugin/on": ("POST", self.on_plugin),
|
||||
"/plugin/reload": ("POST", self.reload_plugins),
|
||||
"/plugin/readme": ("GET", self.get_plugin_readme),
|
||||
"/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable),
|
||||
"/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
@@ -323,38 +325,131 @@ class PluginRoute(Route):
|
||||
async def get_plugin_readme(self):
|
||||
plugin_name = request.args.get("name")
|
||||
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
|
||||
|
||||
|
||||
if not plugin_name:
|
||||
logger.warning("插件名称为空")
|
||||
return Response().error("插件名称不能为空").__dict__
|
||||
|
||||
|
||||
plugin_obj = None
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
if plugin.name == plugin_name:
|
||||
plugin_obj = plugin
|
||||
break
|
||||
|
||||
|
||||
if not plugin_obj:
|
||||
logger.warning(f"插件 {plugin_name} 不存在")
|
||||
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
||||
|
||||
plugin_dir = os.path.join(self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name)
|
||||
|
||||
|
||||
plugin_dir = os.path.join(
|
||||
self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name
|
||||
)
|
||||
|
||||
if not os.path.isdir(plugin_dir):
|
||||
logger.warning(f"无法找到插件目录: {plugin_dir}")
|
||||
return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__
|
||||
|
||||
|
||||
readme_path = os.path.join(plugin_dir, "README.md")
|
||||
|
||||
|
||||
if not os.path.isfile(readme_path):
|
||||
logger.warning(f"插件 {plugin_name} 没有README文件")
|
||||
return Response().error(f"插件 {plugin_name} 没有README文件").__dict__
|
||||
|
||||
|
||||
try:
|
||||
with open(readme_path, 'r', encoding='utf-8') as f:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
|
||||
return Response().ok({"content": readme_content}, "成功获取README内容").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"content": readme_content}, "成功获取README内容")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
|
||||
return Response().error(f"读取README文件失败: {str(e)}").__dict__
|
||||
|
||||
async def get_plugin_platform_enable(self):
|
||||
"""获取插件在各平台的可用性配置"""
|
||||
try:
|
||||
platform_enable = self.core_lifecycle.astrbot_config.get(
|
||||
"platform_settings", {}
|
||||
).get("plugin_enable", {})
|
||||
|
||||
# 获取所有可用平台
|
||||
platforms = []
|
||||
|
||||
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
|
||||
platform_type = platform.get("type", "")
|
||||
platform_id = platform.get("id", "")
|
||||
|
||||
platforms.append(
|
||||
{
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
}
|
||||
)
|
||||
|
||||
adjusted_platform_enable = {}
|
||||
for platform_id, plugins in platform_enable.items():
|
||||
adjusted_platform_enable[platform_id] = plugins
|
||||
|
||||
# 获取所有插件,包括系统内部插件
|
||||
plugins = []
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
plugins.append(
|
||||
{
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def set_plugin_platform_enable(self):
|
||||
"""设置插件在各平台的可用性配置"""
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
data = await request.json
|
||||
platform_enable = data.get("platform_enable", {})
|
||||
|
||||
# 更新配置
|
||||
config = self.core_lifecycle.astrbot_config
|
||||
platform_settings = config.get("platform_settings", {})
|
||||
platform_settings["plugin_enable"] = platform_enable
|
||||
config["platform_settings"] = platform_settings
|
||||
config.save_config()
|
||||
|
||||
# 更新插件的平台兼容性缓存
|
||||
await self.plugin_manager.update_all_platform_compatibility()
|
||||
|
||||
logger.info(f"插件平台可用性配置已更新: {platform_enable}")
|
||||
|
||||
return Response().ok(None, "插件平台可用性配置已更新").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -41,6 +41,14 @@ const readmeDialog = reactive({
|
||||
pluginName: '',
|
||||
repoUrl: null
|
||||
});
|
||||
// 平台插件配置
|
||||
const platformEnableDialog = ref(false);
|
||||
const platformEnableData = reactive({
|
||||
platforms: [],
|
||||
plugins: [],
|
||||
platform_enable: {}
|
||||
});
|
||||
const loadingPlatformData = ref(false);
|
||||
|
||||
const plugin_handler_info_headers = [
|
||||
{ title: '行为类型', key: 'event_type_h' },
|
||||
@@ -238,6 +246,101 @@ const viewReadme = (plugin) => {
|
||||
readmeDialog.show = true;
|
||||
};
|
||||
|
||||
// 获取插件平台可用性配置
|
||||
const getPlatformEnableConfig = async () => {
|
||||
loadingPlatformData.value = true;
|
||||
try {
|
||||
const res = await axios.get('/api/plugin/platform_enable/get');
|
||||
if (res.data.status === "error") {
|
||||
toast(res.data.message, "error");
|
||||
return;
|
||||
}
|
||||
|
||||
platformEnableData.platforms = res.data.data.platforms;
|
||||
platformEnableData.plugins = res.data.data.plugins;
|
||||
platformEnableData.platform_enable = res.data.data.platform_enable;
|
||||
|
||||
// 如果没有平台,给出提示但仍显示对话框
|
||||
if (platformEnableData.platforms.length === 0) {
|
||||
toast("未添加任何平台适配器,请先在平台管理中添加平台", "warning");
|
||||
} else {
|
||||
// 确保每个平台都有一个配置对象
|
||||
platformEnableData.platforms.forEach(platform => {
|
||||
if (!platformEnableData.platform_enable[platform.name]) {
|
||||
platformEnableData.platform_enable[platform.name] = {};
|
||||
}
|
||||
|
||||
// 确保每个插件在每个平台都有一个配置项
|
||||
platformEnableData.plugins.forEach(plugin => {
|
||||
if (platformEnableData.platform_enable[platform.name][plugin.name] === undefined) {
|
||||
platformEnableData.platform_enable[platform.name][plugin.name] = true; // 默认启用
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
platformEnableDialog.value = true;
|
||||
} catch (err) {
|
||||
toast("获取平台插件配置失败: " + err, "error");
|
||||
} finally {
|
||||
loadingPlatformData.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
// 保存插件平台可用性配置
|
||||
const savePlatformEnableConfig = async () => {
|
||||
loadingPlatformData.value = true;
|
||||
try {
|
||||
const res = await axios.post('/api/plugin/platform_enable/set', {
|
||||
platform_enable: platformEnableData.platform_enable
|
||||
});
|
||||
|
||||
if (res.data.status === "error") {
|
||||
toast(res.data.message, "error");
|
||||
return;
|
||||
}
|
||||
|
||||
toast(res.data.message, "success");
|
||||
platformEnableDialog.value = false;
|
||||
} catch (err) {
|
||||
toast("保存平台插件配置失败: " + err, "error");
|
||||
} finally {
|
||||
loadingPlatformData.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
// 全选指定平台的所有插件
|
||||
const selectAllPluginsForPlatform = (platformName, isSelected, onlyReserved = null) => {
|
||||
// 确保平台存在于platform_enable中
|
||||
if (!platformEnableData.platform_enable[platformName]) {
|
||||
platformEnableData.platform_enable[platformName] = {};
|
||||
}
|
||||
|
||||
// 为所有插件设置相同的状态
|
||||
platformEnableData.plugins.forEach(plugin => {
|
||||
// 如果onlyReserved为null,处理所有插件
|
||||
// 如果onlyReserved为true,只处理系统插件
|
||||
// 如果onlyReserved为false,只处理非系统插件
|
||||
if (onlyReserved === null || plugin.reserved === onlyReserved) {
|
||||
platformEnableData.platform_enable[platformName][plugin.name] = isSelected;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// 反选指定平台的所有插件
|
||||
const toggleAllPluginsForPlatform = (platformName) => {
|
||||
// 确保平台存在于platform_enable中
|
||||
if (!platformEnableData.platform_enable[platformName]) {
|
||||
platformEnableData.platform_enable[platformName] = {};
|
||||
}
|
||||
|
||||
// 对每个插件进行反选操作
|
||||
platformEnableData.plugins.forEach(plugin => {
|
||||
const currentState = platformEnableData.platform_enable[platformName][plugin.name];
|
||||
platformEnableData.platform_enable[platformName][plugin.name] = !currentState;
|
||||
});
|
||||
};
|
||||
|
||||
// 生命周期
|
||||
onMounted(async () => {
|
||||
await getExtensions();
|
||||
@@ -261,6 +364,9 @@ onMounted(async () => {
|
||||
<v-btn class="text-none ml-2" size="small" variant="flat" border @click="toggleShowReserved">
|
||||
{{ showReserved ? '隐藏系统保留插件' : '显示系统保留插件' }}
|
||||
</v-btn>
|
||||
<v-btn class="text-none ml-2" size="small" variant="flat" color="primary" border @click="getPlatformEnableConfig">
|
||||
平台命令配置
|
||||
</v-btn>
|
||||
<v-dialog max-width="500px" v-if="extension_data.message">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" icon size="small" color="error" style="margin-left: auto;" variant="plain">
|
||||
@@ -298,6 +404,105 @@ onMounted(async () => {
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<!-- 插件平台配置对话框 -->
|
||||
<v-dialog v-model="platformEnableDialog" max-width="800" persistent>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<span class="headline">平台命令可用性配置</span>
|
||||
</v-card-title>
|
||||
<v-card-subtitle>
|
||||
设置每个插件在不同平台上的可用性,勾选表示启用
|
||||
</v-card-subtitle>
|
||||
<v-card-text>
|
||||
<v-overlay
|
||||
:model-value="loadingPlatformData"
|
||||
class="align-center justify-center"
|
||||
persistent
|
||||
>
|
||||
<v-progress-circular
|
||||
color="primary"
|
||||
indeterminate
|
||||
size="64"
|
||||
></v-progress-circular>
|
||||
</v-overlay>
|
||||
|
||||
<div v-if="platformEnableData.platforms.length === 0" class="text-center pa-4">
|
||||
<v-icon icon="mdi-alert" color="warning" size="64" class="mb-4"></v-icon>
|
||||
<div class="text-h6 mb-2">未找到平台适配器</div>
|
||||
<div class="text-body-1 mb-4">请先在 <strong>平台管理</strong> 中添加并配置平台适配器,然后再设置插件的平台可用性</div>
|
||||
<v-btn color="primary" to="/platforms">前往平台管理</v-btn>
|
||||
</div>
|
||||
|
||||
<v-table v-else>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>插件名称</th>
|
||||
<th v-for="platform in platformEnableData.platforms" :key="platform.name">
|
||||
<div class="d-flex align-center">
|
||||
{{ platform.display_name }}
|
||||
<v-menu>
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn
|
||||
icon
|
||||
density="compact"
|
||||
variant="text"
|
||||
size="small"
|
||||
v-bind="props"
|
||||
class="ms-1"
|
||||
>
|
||||
<v-icon>mdi-dots-vertical</v-icon>
|
||||
</v-btn>
|
||||
</template>
|
||||
<v-list>
|
||||
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true)">
|
||||
<v-list-item-title>全选</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true, false)">
|
||||
<v-list-item-title>全选普通插件</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true, true)">
|
||||
<v-list-item-title>全选系统插件</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click="selectAllPluginsForPlatform(platform.name, false)">
|
||||
<v-list-item-title>全不选</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click="toggleAllPluginsForPlatform(platform.name)">
|
||||
<v-list-item-title>反选</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
</div>
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="plugin in platformEnableData.plugins" :key="plugin.name">
|
||||
<td>
|
||||
<div class="d-flex align-center">
|
||||
{{ plugin.name }}
|
||||
<v-chip v-if="plugin.reserved" color="primary" size="x-small" class="ml-2">系统</v-chip>
|
||||
</div>
|
||||
<div class="text-caption text-grey">{{ plugin.desc }}</div>
|
||||
</td>
|
||||
<td v-for="platform in platformEnableData.platforms" :key="platform.name">
|
||||
<v-checkbox
|
||||
v-model="platformEnableData.platform_enable[platform.name][plugin.name]"
|
||||
hide-details
|
||||
density="compact"
|
||||
></v-checkbox>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</v-table>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" text @click="platformEnableDialog = false">关闭</v-btn>
|
||||
<v-btn v-if="platformEnableData.platforms.length > 0" color="primary" @click="savePlatformEnableConfig">保存</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 配置对话框 -->
|
||||
<v-dialog v-model="configDialog" width="1000">
|
||||
<v-card>
|
||||
@@ -385,4 +590,13 @@ onMounted(async () => {
|
||||
:plugin-name="readmeDialog.pluginName"
|
||||
:repo-url="readmeDialog.repoUrl"
|
||||
/>
|
||||
</template>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.plugin-handler-item {
|
||||
margin-bottom: 10px;
|
||||
padding: 5px;
|
||||
border-radius: 5px;
|
||||
background-color: #f5f5f5;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -880,8 +880,9 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
provider = self.context.get_using_provider()
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
await provider.api_client.delete_chat_conv(message.unified_msg_origin)
|
||||
provider.conversation_ids.pop(message.unified_msg_origin, None)
|
||||
dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None)
|
||||
if dify_cid:
|
||||
await provider.api_client.delete_chat_conv(message.unified_msg_origin, dify_cid)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。"
|
||||
@@ -1232,7 +1233,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
if mood_dialogs := persona["_mood_imitation_dialogs_processed"]:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
if (begin_dialogs := persona["_begin_dialogs_processed"]) and not req.contexts:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
if quote and quote.message_str:
|
||||
|
||||
Reference in New Issue
Block a user