Merge remote-tracking branch 'origin/HEAD' into anka-dev

This commit is contained in:
anka
2025-04-10 11:16:49 +08:00
69 changed files with 2465 additions and 483 deletions
+4
View File
@@ -8,3 +8,7 @@
### Modifications
<!--简单解释你的改动-->
### Check
- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
- [ ] 我新增/修复/优化的功能经过良好的测试
+2
View File
@@ -16,6 +16,8 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=10800&style=for-the-badge&color=3b618e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=7200)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
+1 -1
View File
@@ -1,5 +1,5 @@
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import (
from astrbot.core.provider.entities import (
ProviderRequest,
ProviderType,
ProviderMetaData,
+9
View File
@@ -50,6 +50,7 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"prompt_prefix": "",
"max_context_length": -1,
"streaming_response": False,
},
"provider_stt_settings": {
"enable": False,
@@ -247,6 +248,9 @@ CONFIG_METADATA_2 = {
"description": "平台设置",
"type": "object",
"items": {
"plugin_enable": {
"invisible": True, # 隐藏插件启用配置
},
"unique_session": {
"description": "会话隔离",
"type": "bool",
@@ -993,6 +997,11 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
},
"streaming_response": {
"description": "启用流式回复",
"type": "bool",
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
},
},
},
"persona": {
+7 -5
View File
@@ -141,11 +141,13 @@ class LogQueueHandler(logging.Handler):
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish({
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
})
self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
}
)
class LogManager:
+37 -1
View File
@@ -1,6 +1,6 @@
import enum
from typing import List, Optional, Union
from typing import List, Optional, Union, AsyncGenerator
from dataclasses import dataclass, field
from astrbot.core.message.components import (
BaseMessageComponent,
@@ -111,6 +111,30 @@ class MessageChain:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
def squash_plain(self):
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
if not self.chain:
return
new_chain = []
first_plain = None
plain_texts = []
for comp in self.chain:
if isinstance(comp, Plain):
if first_plain is None:
first_plain = comp
new_chain.append(comp)
plain_texts.append(comp.text)
else:
new_chain.append(comp)
if first_plain is not None:
first_plain.text = "".join(plain_texts)
self.chain = new_chain
return self
class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。
@@ -131,6 +155,10 @@ class ResultContentType(enum.Enum):
"""调用 LLM 产生的结果"""
GENERAL_RESULT = enum.auto()
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
STREAMING_FINISH= enum.auto()
"""流式输出完成"""
@dataclass
@@ -152,6 +180,9 @@ class MessageEventResult(MessageChain):
default_factory=lambda: ResultContentType.GENERAL_RESULT
)
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult":
"""终止事件传播。"""
self.result_type = EventResultType.STOP
@@ -168,6 +199,11 @@ class MessageEventResult(MessageChain):
"""
return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
"""设置异步流。"""
self.async_stream = stream
return self
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
"""设置事件处理的结果类型。
+3
View File
@@ -7,6 +7,7 @@ from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
@@ -18,6 +19,7 @@ STAGES_ORDER = [
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
@@ -29,6 +31,7 @@ __all__ = [
"WhitelistCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PlatformCompatibilityStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
@@ -0,0 +1,56 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
@register_stage
class PlatformCompatibilityStage(Stage):
"""检查所有处理器的平台兼容性。
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化平台兼容性检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
# 获取当前平台ID
platform_id = event.get_platform_id()
# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
activated_handlers = []
# 标记不兼容的处理器
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
handler.platform_compatible = False
else:
# 确保处理器为平台兼容状态
handler.platform_compatible = True
# 更新已激活的处理器列表
event.set_extra("activated_handlers", activated_handlers)
@@ -12,11 +12,12 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
MessageChain,
)
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import (
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
@@ -37,6 +38,9 @@ class LLMRequestSubStage(Stage):
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -146,8 +150,10 @@ class LLMRequestSubStage(Stage):
# 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能
# 获取当前平台ID
platform_id = event.get_platform_id()
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMRequestEvent
EventType.OnLLMRequestEvent, platform_id=platform_id
)
for handler in handlers:
try:
@@ -179,70 +185,127 @@ class LLMRequestSubStage(Stage):
if not req.session_id:
req.session_id = event.unified_msg_origin
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
async def requesting(req: ProviderRequest):
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
final_llm_response = None
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
async for result in self._handle_llm_response(event, req, llm_response):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
if self.streaming_response:
stream = provider.text_chat_stream(**req.__dict__)
async for llm_response in stream:
if llm_response.is_chunk:
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
else:
yield MessageChain().message(
llm_response.completion_text
)
else:
final_llm_response = llm_response
else:
yield
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
if not final_llm_response:
raise Exception("LLM response is None.")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if self.streaming_response:
# 流式输出的处理
async for result in self._handle_llm_stream_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
else:
# 非流式输出的处理
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
)
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
# 保存到历史记录
await self._save_to_history(event, req, final_llm_response)
except BaseException as e:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
if not self.streaming_response:
event.set_extra("tool_call_result", None)
async for _ in requesting(req):
yield
else:
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting(req))
)
return
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
yield
if event.get_extra("tool_call_result"):
event.set_result(event.get_extra("tool_call_result"))
event.set_extra("tool_call_result", None)
yield
async def _handle_llm_response(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
) -> AsyncGenerator[None, None]:
"""处理 LLM 响应。
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理非流式 LLM 响应。
Returns:
bool: 是否需要继续调用 LLM
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
Yields:
Iterator[bool]: 将 event 交付给下一个 stage
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
@@ -265,83 +328,147 @@ class LLMRequestSubStage(Stage):
)
)
elif llm_response.role == "tool":
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_llm_stream_response(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理流式 LLM 响应。
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
Yields:
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.STREAMING_FINISH)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.STREAMING_FINISH)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
elif llm_response.role == "tool":
# 处理函数工具调用
async for result in self._handle_function_tools(event, req, llm_response):
yield result
async def _handle_function_tools(
self,
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse,
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
"""处理函数工具调用。
Returns:
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
"""
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
client = req.func_tool.mcp_client_dict[
func_tool.mcp_server_name
]
res = await client.session.call_tool(
func_tool.name, func_tool_args
else:
# 获取处理器,过滤掉平台不兼容的处理器
platform_id = event.get_platform_id()
if not func_tool.handler.is_enabled_for_platform(platform_id):
logger.debug(
f"处理器 {func_tool_name} 在当前平台不兼容,跳过执行"
)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
# 直接跳过,不添加任何消息到tool_call_result
continue
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
content=resp,
)
)
else:
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
else:
res = event.get_result()
if res and res.chain:
event.set_extra("tool_call_result", res)
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
)
if not handlers_parsed_params:
handlers_parsed_params = {}
for handler in activated_handlers:
# 检查处理器是否在当前平台兼容
if (
hasattr(handler, "platform_compatible")
and handler.platform_compatible is False
):
logger.debug(
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
)
continue
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
+1 -1
View File
@@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core import logger
+21 -6
View File
@@ -7,7 +7,7 @@ from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -18,7 +18,9 @@ from astrbot.core.star.star import star_map
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
Comp.Plain: lambda comp: bool(
comp.text and comp.text.strip()
), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
@@ -31,13 +33,17 @@ class RespondStage(Stage):
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
Comp.Contact: lambda comp: True, # 联系人(未完成)
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
Comp.Music: lambda comp: bool(comp._type)
and bool(comp.url)
and bool(comp.audio), # 音乐
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.RedBag: lambda comp: bool(comp.title), # 红包
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
Comp.Node: lambda comp: bool(comp.name)
and comp.uin != 0
and bool(comp.content), # 一个转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
Comp.Json: lambda comp: bool(comp.data), # JSON
@@ -132,8 +138,17 @@ class RespondStage(Stage):
result = event.get_result()
if result is None:
return
if result.result_content_type == ResultContentType.STREAMING_FINISH:
return
if len(result.chain) > 0:
if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果直接交付平台适配器处理
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream)
await event._post_send()
return
elif len(result.chain) > 0:
await event._pre_send()
# 检查消息链是否为空
@@ -183,7 +198,7 @@ class RespondStage(Stage):
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:
+17 -1
View File
@@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
@@ -72,11 +73,17 @@ class ResultDecorateStage(Stage):
if result is None or not result.chain:
return
if result.result_content_type == ResultContentType.STREAMING_RESULT:
return
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
# 回复时检查内容安全
if (
self.content_safe_check_reply
and self.content_safe_check_stage
and result.is_llm_result()
and not is_stream # 流式输出不检查内容安全
):
text = ""
for comp in result.chain:
@@ -89,13 +96,17 @@ class ResultDecorateStage(Stage):
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnDecoratingResultEvent
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
)
for handler in handlers:
try:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
if is_stream:
logger.warning(
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
logger.debug(
@@ -110,6 +121,11 @@ class ResultDecorateStage(Stage):
)
return
# 流式输出不执行下面的逻辑
if is_stream:
logger.info("流式输出已启用,跳过结果装饰阶段")
return
# 需要再获取一次。插件可能直接对 chain 进行了替换。
result = event.get_result()
if result is None:
@@ -1,5 +1,6 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
@@ -93,6 +94,7 @@ class WakingCheckStage(Stage):
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
permission_filter_raise_error = False
if len(handler.event_filters) == 0:
continue
@@ -101,6 +103,7 @@ class WakingCheckStage(Stage):
if isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
permission_filter_raise_error = filter.raise_error
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -117,6 +120,9 @@ class WakingCheckStage(Stage):
break
if passed:
if permission_not_pass:
if not permission_filter_raise_error:
# 跳过
continue
if self.no_permission_reply:
await event.send(
MessageChain().message(
@@ -124,6 +130,9 @@ class WakingCheckStage(Stage):
)
)
await event._post_send()
logger.info(
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
)
event.stop_event()
return
+14 -2
View File
@@ -1,7 +1,7 @@
import abc
import asyncio
from dataclasses import dataclass
from typing import List, Union, Optional
from typing import List, Union, Optional, AsyncGenerator
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
@@ -16,7 +16,7 @@ from astrbot.core.message.components import (
)
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
@@ -81,6 +81,9 @@ class AstrMessageEvent(abc.ABC):
def get_platform_name(self):
return self.platform_meta.name
def get_platform_id(self):
return self.platform_meta.id
def get_message_str(self) -> str:
"""
获取消息字符串
@@ -202,6 +205,15 @@ class AstrMessageEvent(abc.ABC):
"""
return self.role == "admin"
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
@@ -7,6 +7,8 @@ class PlatformMetadata:
"""平台的名称"""
description: str
"""平台的描述"""
id: str = None
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict = None
"""平台的默认配置模板"""
@@ -82,6 +82,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
group_id = int(group_id)
@@ -39,8 +39,9 @@ class AiocqhttpAdapter(Platform):
self.port = platform_config["ws_reverse_port"]
self.metadata = PlatformMetadata(
"aiocqhttp",
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
)
self.bot = CQHttp(
@@ -73,8 +73,9 @@ class DingtalkPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"dingtalk",
"钉钉机器人官方 API 适配器",
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
)
async def convert_msg(
@@ -24,7 +24,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
None,
client.reply_markdown,
"AstrBot",
segment.text,
self.message_obj.raw_message,
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
@@ -56,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
@@ -216,3 +216,16 @@ class GewechatPlatformEvent(AstrMessageEvent):
group_owner=data.get("chatRoomOwner"),
members=members,
)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
@@ -60,8 +60,9 @@ class GewechatPlatformAdapter(Platform):
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"gewechat",
"基于 gewechat 的 Wechat 适配器",
name="gewechat",
description="基于 gewechat 的 Wechat 适配器",
id=self.config.get("id"),
)
async def terminate(self):
@@ -70,8 +70,9 @@ class LarkPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"lark",
"飞书机器人官方 API 适配器",
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
@@ -91,3 +91,16 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
@@ -2,6 +2,7 @@ import botpy
import botpy.message
import botpy.types
import botpy.types.message
import asyncio
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
@@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
from astrbot.api import logger
from botpy.types import message
class QQOfficialMessageEvent(AstrMessageEvent):
@@ -30,8 +32,46 @@ class QQOfficialMessageEvent(AstrMessageEvent):
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
async def send_streaming(self, generator):
"""流式输出仅支持消息列表私聊"""
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
try:
async for chain in generator:
source = self.message_obj.raw_message
if not self.send_buffer:
self.send_buffer = chain
else:
self.send_buffer.chain.extend(chain.chain)
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
self.send_buffer = None
return await super().send_streaming(generator)
async def _post_send(self, stream: dict = None):
"""QQ 官方 API 仅支持回复一次"""
if not self.send_buffer:
return
source = self.message_obj.raw_message
assert isinstance(
source,
@@ -65,7 +105,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
await self.bot.api.post_group_message(
ret = await self.bot.api.post_group_message(
group_openid=source.group_openid, **payload
)
case botpy.message.C2CMessage:
@@ -75,22 +115,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
payload["media"] = media
payload["msg_type"] = 7
await self.bot.api.post_c2c_message(
openid=source.author.user_openid, **payload
)
if stream:
ret = await self.post_c2c_message(
openid=source.author.user_openid,
**payload,
stream=stream,
)
else:
ret = await self.post_c2c_message(
openid=source.author.user_openid, **payload
)
logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
if image_path:
payload["file_image"] = image_path
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
ret = await self.bot.api.post_message(
channel_id=source.channel_id, **payload
)
case botpy.message.DirectMessage:
if image_path:
payload["file_image"] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(self.send_buffer)
self.send_buffer = None
return ret
async def upload_group_and_c2c_image(
self, image_base64: str, file_type: int, **kwargs
) -> botpy.types.message.Media:
@@ -112,6 +164,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
return await self.bot.api._http.request(route, json=payload)
async def post_c2c_message(
self,
openid: str,
msg_type: int = 0,
content: str = None,
embed: message.Embed = None,
ark: message.Ark = None,
message_reference: message.Reference = None,
media: message.Media = None,
msg_id: str = None,
msg_seq: str = 1,
event_id: str = None,
markdown: message.MarkdownPayload = None,
keyboard: message.Keyboard = None,
stream: dict = None,
) -> message.Message:
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload)
@staticmethod
async def _parse_to_qqofficial(message: MessageChain):
plain_text = ""
@@ -126,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"qq_official",
"QQ 机器人官方 API 适配器",
name="qq_official",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
)
@staticmethod
@@ -99,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"qq_official_webhook",
"QQ 机器人官方 API 适配器",
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
)
async def run(self):
@@ -116,5 +117,8 @@ class QQOfficialWebhookPlatformAdapter(Platform):
async def terminate(self):
self.webhook_helper.shutdown_event.set()
await self.client.close()
await self.webhook_helper.server.shutdown()
try:
await self.webhook_helper.server.shutdown()
except Exception as _:
pass
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
@@ -80,8 +80,7 @@ class TelegramPlatformAdapter(Platform):
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"telegram",
"telegram 适配器",
name="telegram", description="telegram 适配器", id=self.config.get("id")
)
@override
@@ -1,7 +1,15 @@
import asyncio
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
Reply,
At,
File,
Record,
)
from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
@@ -82,3 +90,109 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message)
async def send_streaming(self, generator):
message_thread_id = None
if self.get_message_type() == MessageType.GROUP_MESSAGE:
user_name = self.message_obj.group_id
else:
user_name = self.get_sender_id()
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
payload = {
"chat_id": user_name,
}
if message_thread_id:
payload["reply_to_message_id"] = message_thread_id
delta = ""
current_content = ""
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
async for chain in generator:
if isinstance(chain, MessageChain):
# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
path = "data/temp/" + i.name
await download_file(i.file, path)
i.file = path
await self.client.send_document(
document=i.file, filename=i.name, **payload
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
# Plain
if not message_id:
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
else:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
# 编辑消息
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
try:
if delta and current_content != delta:
try:
markdown_text = telegramify_markdown.markdownify(
delta, max_line_length=None, normalize_whitespace=False
)
await self.client.edit_message_text(
text=markdown_text,
chat_id=payload["chat_id"],
message_id=message_id,
parse_mode="MarkdownV2"
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
return await super().send_streaming(generator)
@@ -43,8 +43,7 @@ class WebChatAdapter(Platform):
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
name="webchat", description="webchat", id=self.config.get("id")
)
async def send_by_session(
@@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True)
@staticmethod
async def _send(message: MessageChain, session_id: str):
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
if not message:
web_chat_back_queue.put_nowait(None)
await web_chat_back_queue.put(
{"type": "end", "data": "", "streaming": False}
)
return
cid = session_id.split("!")[-1]
data = ""
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait((comp.text, cid))
data = comp.text
await web_chat_back_queue.put(
{
"type": "plain",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
@@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
"type": "image",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
@@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
"type": "record",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
else:
logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
return data
async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message)
async def send_streaming(self, generator):
final_data = ""
async for chain in generator:
final_data += await WebChatMessageEvent._send(
chain, session_id=self.session_id, streaming=True
)
await web_chat_back_queue.put(
{
"type": "end",
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
}
)
await super().send_streaming(generator)
@@ -84,3 +84,16 @@ class WecomPlatformEvent(AstrMessageEvent):
)
await super().send(message)
async def send_streaming(self, generator):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
+1 -1
View File
@@ -1,5 +1,5 @@
from .provider import Provider, Personality, STTProvider
from .entites import ProviderMetaData
from .entities import ProviderMetaData
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
@@ -204,6 +204,9 @@ class LLMResponse:
_completion_text: str = ""
is_chunk: bool = False
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
role: str,
@@ -214,6 +217,7 @@ class LLMResponse:
tools_call_ids: List[str] = [],
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -233,6 +237,7 @@ class LLMResponse:
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property
def completion_text(self):
+1 -1
View File
@@ -2,7 +2,7 @@ import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from .entities import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from .register import provider_cls_map, llm_tools
+31 -3
View File
@@ -1,9 +1,9 @@
import abc
from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from dataclasses import dataclass
@@ -108,7 +108,35 @@ class Provider(AbstractProvider):
- 如果传入了 image_urls将会在对话时附上图片如果模型不支持图片输入将会抛出错误
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
raise NotImplementedError()
...
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls将会在对话时附上图片如果模型不支持图片输入将会抛出错误
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
...
async def pop_record(self, context: List):
"""
+1 -1
View File
@@ -1,5 +1,5 @@
from typing import List, Dict
from .entites import ProviderMetaData, ProviderType
from .entities import ProviderMetaData, ProviderType
from astrbot.core import logger
from .func_tool_manager import FuncCall
@@ -10,7 +10,8 @@ from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from .openai_source import ProviderOpenAIOfficial
@@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if content.type == "text":
# text completion
completion_text = str(content.text).strip()
llm_response.completion_text = completion_text
# llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
@@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
messages=context_query, **model_config
)
llm_response = LLMResponse("assistant")
llm_response.completion_text = response.content[0].text
llm_response.result_chain = MessageChain().message(response.content[0].text)
llm_response.raw_completion = response
return llm_response
except Exception as e:
@@ -160,6 +162,33 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""组装上下文,支持文本和图片"""
if not image_urls:
@@ -3,10 +3,11 @@ import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
@@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
return LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
result_chain=MessageChain().message(
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
),
)
output_text = response.output.get("text", "")
@@ -141,11 +144,45 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
ref_title = (
ref.get("title", "")
if ref.get("title")
else ref.get("doc_name", "")
)
ref_str += f"{ref['index_id']}. {ref_title}\n"
output_text += f"\n\n回答来源:\n{ref_str}"
return LLMResponse(role="assistant", completion_text=output_text)
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(output_text)
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def forget(self, session_id):
return True
@@ -3,7 +3,7 @@ import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
@@ -20,7 +20,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model", None))
self.timeout_ms = float(provider_config.get("timeout", 20))*1000
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key
self.synthesizer = SpeechSynthesizer(
+28 -1
View File
@@ -2,7 +2,7 @@ import astrbot.core.message.components as Comp
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
@@ -189,6 +189,33 @@ class ProviderDify(Provider):
return LLMResponse(role="assistant", result_chain=chain)
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
# Chat
@@ -4,7 +4,7 @@ import edge_tts
import subprocess
import asyncio
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -4,7 +4,7 @@ from pydantic import BaseModel, conint
from httpx import AsyncClient
from typing import Annotated, Literal
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
+61 -1
View File
@@ -12,7 +12,7 @@ from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entities import LLMResponse
class SimpleGoogleGenAIClient:
@@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient:
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise Exception("Gemini 返回了非 json 数据: ")
async def stream_generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
"stream": True,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
async for line in resp.content:
if line:
yield line
@register_provider_adapter(
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
@@ -338,6 +371,33 @@ class ProviderGoogleGenAI(Provider):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
def get_current_key(self) -> str:
return self.client.api_key
@@ -2,7 +2,7 @@ import uuid
import aiohttp
import urllib.parse
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
@@ -2,7 +2,7 @@ import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
@@ -95,6 +95,33 @@ class LLMTunerModelLoader(Provider):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def get_current_key(self):
return "none"
+264 -60
View File
@@ -4,19 +4,24 @@ import os
import inspect
import random
import asyncio
import astrbot.core.message.components as Comp
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from typing import List, AsyncGenerator
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entities import LLMResponse
@register_provider_adapter(
@@ -107,16 +112,72 @@ class ProviderOpenAIOfficial(Provider):
logger.debug(f"completion: {completion}")
llm_response = await self.parse_openai_completion(completion, tools)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
tool_list = tools.get_func_desc_openai_style()
if tool_list:
payloads["tools"] = tool_list
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
to_del = []
for key in payloads.keys():
if key not in self.default_params:
extra_body[key] = payloads[key]
to_del.append(key)
for key in to_del:
del payloads[key]
stream = await self.client.chat.completions.create(
**payloads, stream=True, extra_body=extra_body
)
llm_response = LLMResponse("assistant", is_chunk=True)
state = ChatCompletionStreamState()
async for chunk in stream:
try:
state.handle_chunk(chunk)
except Exception as e:
logger.warning("Saving chunk state error: " + str(e))
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
# 处理文本内容
if delta.content:
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)]
)
yield llm_response
final_completion = state.get_final_completion()
llm_response = await self.parse_openai_completion(final_completion, tools)
yield llm_response
async def parse_openai_completion(
self, completion: ChatCompletion, tools: FuncCall
):
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
llm_response = LLMResponse("assistant")
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
if choice.message.tool_calls:
# tools call (function calling)
@@ -148,7 +209,7 @@ class ProviderOpenAIOfficial(Provider):
return llm_response
async def text_chat(
async def _prepare_chat_payload(
self,
prompt: str,
session_id: str = None,
@@ -158,7 +219,8 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
@@ -177,8 +239,117 @@ class ProviderOpenAIOfficial(Provider):
payloads = {"messages": context_query, **model_config}
llm_response = None
return payloads, context_query, func_tool
async def _handle_api_error(
self,
e: Exception,
payloads: dict,
context_query: list,
func_tool: FuncCall,
chosen_key: str,
available_api_keys: List[str],
retry_cnt: int,
max_retries: int,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
payloads["messages"] = context_query
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
return False, chosen_key, available_api_keys, payloads, context_query, None
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
payloads, context_query, func_tool = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
**kwargs,
)
llm_response = None
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
@@ -197,64 +368,97 @@ class ProviderOpenAIOfficial(Provider):
payloads["messages"] = new_contexts
context_query = new_contexts
except Exception as e:
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
continue
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
func_tool = None
else:
logger.error(
f"发生了错误。Provider 配置如下: {self.provider_config}"
)
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error(
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
)
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
return llm_response
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, context_query, func_tool = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
**kwargs,
)
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
e = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
except Exception as e:
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
async def _remove_image_from_context(self, contexts: List):
"""
从上下文中删除所有带有 image 的记录
@@ -1,7 +1,7 @@
import uuid
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..entities import ProviderType
from ..register import register_provider_adapter
@@ -11,7 +11,7 @@ import re
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
from ..provider import STTProvider
from ..entites import ProviderType
from ..entities import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -2,7 +2,7 @@ import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import STTProvider
from ..entites import ProviderType
from ..entities import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -3,7 +3,7 @@ import os
import asyncio
import whisper
from ..provider import STTProvider
from ..entites import ProviderType
from ..entities import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@@ -3,7 +3,7 @@ from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entities import LLMResponse
from .openai_source import ProviderOpenAIOfficial
View File
View File
+24
View File
@@ -47,5 +47,29 @@ class StarMetadata:
star_handler_full_names: List[str] = field(default_factory=list)
"""注册的 Handler 的全名列表"""
supported_platforms: Dict[str, bool] = field(default_factory=dict)
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
"""更新插件支持的平台列表
Args:
plugin_enable_config: 平台插件启用配置即platform_settings.plugin_enable配置项
"""
if not plugin_enable_config:
return
# 清空之前的配置
self.supported_platforms.clear()
# 遍历所有平台配置
for platform_id, plugins in plugin_enable_config.items():
# 检查该插件在当前平台的配置
if self.name in plugins:
self.supported_platforms[platform_id] = plugins[self.name]
else:
# 如果没有明确配置,默认为启用
self.supported_platforms[platform_id] = True
+58 -14
View File
@@ -30,21 +30,36 @@ class StarHandlerRegistry(Generic[T]):
print(handler.handler_full_name)
def get_handlers_by_event_type(
self, event_type: EventType, only_activated=True
self, event_type: EventType, only_activated=True, platform_id=None
) -> List[StarHandlerMetadata]:
"""通过事件类型获取 Handler"""
handlers = [
handler
for _, handler in self._handlers
if handler.event_type == event_type
and (
not only_activated
or (
star_map[handler.handler_module_path]
and star_map[handler.handler_module_path].activated
)
)
]
"""通过事件类型获取 Handler
Args:
event_type: 事件类型
only_activated: 是否只返回已激活的插件的处理器
platform_id: 平台ID如果提供此参数将过滤掉在此平台不兼容的处理器
Returns:
List[StarHandlerMetadata]: 处理器列表
"""
handlers = []
for _, handler in self._handlers:
if handler.event_type != event_type:
continue
# 只激活的插件处理器
if only_activated:
plugin = star_map.get(handler.handler_module_path)
if not (plugin and plugin.activated):
continue
# 平台兼容性过滤
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
if not handler.is_enabled_for_platform(platform_id):
continue
handlers.append(handler)
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
@@ -139,3 +154,32 @@ class StarHandlerMetadata:
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
"priority", 0
)
def is_enabled_for_platform(self, platform_id: str) -> bool:
"""检查插件是否在指定平台启用
Args:
platform_id: 平台ID这是从event.get_platform_id()获取的用于唯一标识平台实例
Returns:
bool: 是否启用True表示启用False表示禁用
"""
plugin = star_map.get(self.handler_module_path)
# 如果插件元数据不存在,默认允许执行
if not plugin or not plugin.name:
return True
# 先检查插件是否被激活
if not plugin.activated:
return False
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
if (
hasattr(plugin, "supported_platforms")
and platform_id in plugin.supported_platforms
):
return plugin.supported_platforms[platform_id]
# 如果没有缓存数据,默认允许执行
return True
+31 -1
View File
@@ -209,7 +209,31 @@ class PluginManager:
await self._unbind_plugin(smd.name, specified_module_path)
return await self.load(specified_module_path)
result = await self.load(specified_module_path)
# 更新所有插件的平台兼容性
await self.update_all_platform_compatibility()
return result
async def update_all_platform_compatibility(self):
"""更新所有插件的平台兼容性设置"""
# 获取最新的平台插件启用配置
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
logger.debug(
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
)
# 遍历所有插件,更新平台兼容性
for plugin in self.context.get_all_stars():
plugin.update_platform_compatibility(plugin_enable_config)
logger.debug(
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
)
return True
async def load(self, specified_module_path=None, specified_dir_name=None):
"""载入插件。
@@ -320,6 +344,12 @@ class PluginManager:
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 更新插件的平台兼容性
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
metadata.update_platform_compatibility(plugin_enable_config)
# 绑定 handler
related_handlers = (
star_handlers_registry.get_handlers_by_module_name(
+48
View File
@@ -1,9 +1,12 @@
import inspect
from typing import Union, Awaitable, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
from astrbot.core.star.star import star_map
from pathlib import Path
class StarTools:
@@ -142,3 +145,48 @@ class StarTools:
name (str): 工具名称
"""
cls._context.unregister_llm_tool(name)
@classmethod
def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path:
"""
返回插件数据目录的绝对路径
此方法会在 data/plugin_data 目录下为插件创建一个专属的数据目录如果未提供插件名称
会自动从调用栈中获取插件信息
Args:
plugin_name: 可选的插件名称如果为None将自动检测调用者的插件名称
Returns:
Path (Path): 插件数据目录的绝对路径位于 data/plugin_data/{plugin_name}
Raises:
RuntimeError: 当出现以下情况时抛出:
- 无法获取调用者模块信息
- 无法获取模块的元数据信息
- 创建目录失败权限不足或其他IO错误
"""
if not plugin_name:
frame = inspect.currentframe().f_back
module = inspect.getmodule(frame)
if not module:
raise RuntimeError("无法获取调用者模块信息")
metadata = star_map.get(module.__name__, None)
if not metadata:
raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息")
plugin_name = metadata.name
data_dir = Path("data/plugin_data") / plugin_name
try:
data_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
if isinstance(e, PermissionError):
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
raise RuntimeError(f"无法创建目录 {data_dir}{e!s}") from e
return data_dir.resolve()
+1 -1
View File
@@ -15,7 +15,7 @@ class SharedPreferences:
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
def get(self, key, default=None):
+27 -16
View File
@@ -161,42 +161,53 @@ class ChatRoute(Route):
username = g.get("username", "guest")
if username in self.curr_chat_sse:
return "[ERROR]\n"
return Response().error("Already connected").__dict__
self.curr_chat_sse[username] = None
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
async def stream():
try:
yield "[HB]\n"
yield f"data: {heartbeat}\n\n" # 心跳包
while True:
try:
result = await asyncio.wait_for(
web_chat_back_queue.get(), timeout=10
) # 设置超时时间为5秒
except asyncio.TimeoutError:
yield "[HB]\n" # 心跳包
yield f"data: {heartbeat}\n\n" # 心跳包
continue
if not result:
continue
result_text, cid = result
result_text = result["data"]
type = result.get("type")
cid = result.get("cid")
streaming = result.get("streaming", False)
if cid != self.curr_user_cid.get(username):
# 丢弃
continue
yield result_text + "\n"
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05)
conversation = self.db.get_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
if streaming and type != "end":
continue
await asyncio.sleep(0.5)
if result_text:
conversation = self.db.get_conversation_by_user_id(
username, cid
)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。")
self.curr_chat_sse.pop(username)
+1 -1
View File
@@ -179,7 +179,7 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
logger.error(e)
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def post_plugin_configs(self):
+1 -1
View File
@@ -20,7 +20,7 @@ class LogRoute(Route):
message = await queue.get()
payload = {
"type": "log",
**message # see astrbot/core/log.py
**message, # see astrbot/core/log.py
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
except asyncio.CancelledError:
+136
View File
@@ -1,5 +1,6 @@
import traceback
import aiohttp
import os
import ssl
import certifi
@@ -36,6 +37,9 @@ class PluginRoute(Route):
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable),
"/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -317,3 +321,135 @@ class PluginRoute(Route):
except Exception as e:
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def get_plugin_readme(self):
plugin_name = request.args.get("name")
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
if not plugin_name:
logger.warning("插件名称为空")
return Response().error("插件名称不能为空").__dict__
plugin_obj = None
for plugin in self.plugin_manager.context.get_all_stars():
if plugin.name == plugin_name:
plugin_obj = plugin
break
if not plugin_obj:
logger.warning(f"插件 {plugin_name} 不存在")
return Response().error(f"插件 {plugin_name} 不存在").__dict__
plugin_dir = os.path.join(
self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name
)
if not os.path.isdir(plugin_dir):
logger.warning(f"无法找到插件目录: {plugin_dir}")
return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__
readme_path = os.path.join(plugin_dir, "README.md")
if not os.path.isfile(readme_path):
logger.warning(f"插件 {plugin_name} 没有README文件")
return Response().error(f"插件 {plugin_name} 没有README文件").__dict__
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
return (
Response()
.ok({"content": readme_content}, "成功获取README内容")
.__dict__
)
except Exception as e:
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
return Response().error(f"读取README文件失败: {str(e)}").__dict__
async def get_plugin_platform_enable(self):
"""获取插件在各平台的可用性配置"""
try:
platform_enable = self.core_lifecycle.astrbot_config.get(
"platform_settings", {}
).get("plugin_enable", {})
# 获取所有可用平台
platforms = []
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
platform_type = platform.get("type", "")
platform_id = platform.get("id", "")
platforms.append(
{
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
"id": platform_id, # 保留id字段以便前端可以显示
"type": platform_type,
"display_name": f"{platform_type}({platform_id})",
}
)
adjusted_platform_enable = {}
for platform_id, plugins in platform_enable.items():
adjusted_platform_enable[platform_id] = plugins
# 获取所有插件,包括系统内部插件
plugins = []
for plugin in self.plugin_manager.context.get_all_stars():
plugins.append(
{
"name": plugin.name,
"desc": plugin.desc,
"reserved": plugin.reserved, # 添加reserved标志
}
)
logger.debug(
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
)
return (
Response()
.ok(
{
"platforms": platforms,
"plugins": plugins,
"platform_enable": adjusted_platform_enable,
}
)
.__dict__
)
except Exception as e:
logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def set_plugin_platform_enable(self):
"""设置插件在各平台的可用性配置"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
data = await request.json
platform_enable = data.get("platform_enable", {})
# 更新配置
config = self.core_lifecycle.astrbot_config
platform_settings = config.get("platform_settings", {})
platform_settings["plugin_enable"] = platform_enable
config["platform_settings"] = platform_settings
config.save_config()
# 更新插件的平台兼容性缓存
await self.plugin_manager.update_all_platform_compatibility()
logger.info(f"插件平台可用性配置已更新: {platform_enable}")
return Response().ok(None, "插件平台可用性配置已更新").__dict__
except Exception as e:
logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
@@ -24,13 +24,10 @@ const emit = defineEmits([
'install',
'uninstall',
'toggle-activation',
'view-handlers'
'view-handlers',
'view-readme'
]);
const open = (link: string | undefined) => {
window.open(link, '_blank');
};
const reveal = ref(false);
//
@@ -70,6 +67,10 @@ const toggleActivation = () => {
const viewHandlers = () => {
emit('view-handlers', props.extension);
};
const viewReadme = () => {
emit('view-readme', props.extension);
};
</script>
<template>
@@ -128,7 +129,7 @@ const viewHandlers = () => {
</v-card-text>
<v-card-actions style="padding: 0px; margin-top: auto;">
<v-btn color="teal-accent-4" text="帮助" variant="text" @click="open(extension.repo)"></v-btn>
<v-btn color="teal-accent-4" text="查看文档" variant="text" @click="viewReadme"></v-btn>
<v-btn v-if="!marketMode" color="teal-accent-4" text="操作" variant="text" @click="reveal = true"></v-btn>
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" text="安装" variant="text"
@click="emit('install', extension)"></v-btn>
@@ -0,0 +1,302 @@
<script setup>
import { ref, watch, onMounted } from 'vue';
import axios from 'axios';
import { marked } from 'marked';
import hljs from 'highlight.js';
import 'highlight.js/styles/github.css';
const props = defineProps({
show: {
type: Boolean,
default: false
},
pluginName: {
type: String,
default: ''
},
repoUrl: {
type: String,
default: null
}
});
const emit = defineEmits(['update:show']);
const content = ref(null);
const error = ref(null);
const loading = ref(false);
// show
watch(() => props.show, (newVal) => {
if (newVal && props.pluginName) {
fetchReadme();
}
});
// pluginName
watch(() => props.pluginName, (newVal) => {
if (props.show && newVal) {
fetchReadme();
}
});
// README
async function fetchReadme() {
if (!props.pluginName) return;
loading.value = true;
content.value = null;
error.value = null;
try {
// README
const res = await axios.get(`/api/plugin/readme?name=${props.pluginName}`);
if (res.data.status === 'ok') {
content.value = res.data.data.content;
} else {
error.value = res.data.message || '获取README失败';
}
} catch (err) {
error.value = err.message || '获取README时发生错误';
} finally {
loading.value = false;
}
}
// GitHub
function openRepoInNewTab() {
if (props.repoUrl) {
window.open(props.repoUrl, '_blank');
}
}
// Markdown
function renderMarkdown(content) {
if (!content) return '';
// marked使highlight.js
marked.setOptions({
highlight: function(code, lang) {
if (lang && hljs.getLanguage(lang)) {
try {
return hljs.highlight(code, { language: lang }).value;
} catch (e) {
console.error(e);
}
}
return hljs.highlightAuto(code).value;
},
gfm: true, // GitHub Flavored Markdown
breaks: true, // Convert \n to <br>
headerIds: true, // Add id attributes to headers
mangle: false // Don't mangle email addresses
});
return marked(content);
}
// README
function refreshReadme() {
fetchReadme();
}
</script>
<template>
<v-dialog v-model="_show" width="800" persistent>
<v-card>
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h5">插件说明文档</span>
<v-btn icon @click="$emit('update:show', false)">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-card-text style="height: 70vh; overflow-y: auto;">
<div class="d-flex justify-space-between mb-4">
<v-btn
v-if="repoUrl"
color="primary"
prepend-icon="mdi-github"
@click="openRepoInNewTab()"
>
在GitHub中查看仓库
</v-btn>
<v-btn
color="secondary"
prepend-icon="mdi-refresh"
@click="refreshReadme()"
>
刷新文档
</v-btn>
</div>
<!-- 加载中 -->
<div v-if="loading" class="d-flex flex-column align-center justify-center" style="height: 100%;">
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
<p class="text-body-1 text-center">正在加载README文档...</p>
</div>
<!-- 内容显示 -->
<div v-else-if="content" class="markdown-body" v-html="renderMarkdown(content)"></div>
<!-- 错误提示 -->
<div v-else-if="error" class="d-flex flex-column align-center justify-center" style="height: 100%;">
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle-outline</v-icon>
<p class="text-body-1 text-center mb-4">{{ error }}</p>
</div>
<!-- 无内容提示 -->
<div v-else class="d-flex flex-column align-center justify-center" style="height: 100%;">
<v-icon size="64" color="warning" class="mb-4">mdi-file-question-outline</v-icon>
<p class="text-body-1 text-center mb-4">该插件未提供文档链接或GitHub仓库地址<br>请查看插件市场或联系插件作者获取更多信息</p>
</div>
</v-card-text>
<v-divider></v-divider>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="primary" variant="tonal" @click="$emit('update:show', false)">
关闭
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<style>
.markdown-body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif;
line-height: 1.6;
padding: 8px 0;
color: #24292e;
}
.markdown-body h1,
.markdown-body h2,
.markdown-body h3,
.markdown-body h4,
.markdown-body h5,
.markdown-body h6 {
margin-top: 24px;
margin-bottom: 16px;
font-weight: 600;
line-height: 1.25;
}
.markdown-body h1 {
font-size: 2em;
border-bottom: 1px solid #eaecef;
padding-bottom: 0.3em;
}
.markdown-body h2 {
font-size: 1.5em;
border-bottom: 1px solid #eaecef;
padding-bottom: 0.3em;
}
.markdown-body p {
margin-top: 0;
margin-bottom: 16px;
}
.markdown-body code {
padding: 0.2em 0.4em;
margin: 0;
background-color: rgba(27, 31, 35, 0.05);
border-radius: 3px;
font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
font-size: 85%;
}
.markdown-body pre {
padding: 16px;
overflow: auto;
font-size: 85%;
line-height: 1.45;
background-color: #f6f8fa;
border-radius: 3px;
margin-bottom: 16px;
}
.markdown-body pre code {
background-color: transparent;
padding: 0;
}
.markdown-body ul,
.markdown-body ol {
padding-left: 2em;
margin-bottom: 16px;
}
.markdown-body img {
max-width: 100%;
margin: 8px 0;
box-sizing: border-box;
background-color: #fff;
border-radius: 3px;
}
.markdown-body blockquote {
padding: 0 1em;
color: #6a737d;
border-left: 0.25em solid #dfe2e5;
margin-bottom: 16px;
}
.markdown-body a {
color: #0366d6;
text-decoration: none;
}
.markdown-body a:hover {
text-decoration: underline;
}
.markdown-body table {
border-spacing: 0;
border-collapse: collapse;
width: 100%;
overflow: auto;
margin-bottom: 16px;
}
.markdown-body table th,
.markdown-body table td {
padding: 6px 13px;
border: 1px solid #dfe2e5;
}
.markdown-body table tr {
background-color: #fff;
border-top: 1px solid #c6cbd1;
}
.markdown-body table tr:nth-child(2n) {
background-color: #f6f8fa;
}
.markdown-body hr {
height: 0.25em;
padding: 0;
margin: 24px 0;
background-color: #e1e4e8;
border: 0;
}
</style>
<script>
export default {
name: 'ReadmeDialog',
computed: {
_show: {
get() {
return this.show;
},
set(value) {
this.$emit('update:show', value);
}
}
}
}
</script>
+272 -177
View File
@@ -1,6 +1,7 @@
<script setup>
import axios from 'axios';
import { marked } from 'marked';
import { ref } from 'vue';
marked.setOptions({
breaks: true
@@ -11,37 +12,73 @@ marked.setOptions({
<v-card class="chat-page-card">
<v-card-text class="chat-page-container">
<div class="chat-layout">
<!-- 左侧对话列表面板 -->
<!-- 左侧对话列表面板 - 优化版 -->
<div class="sidebar-panel">
<v-btn variant="tonal" rounded="xl" class="new-chat-btn" @click="newC"
:disabled="!currCid">
<v-icon class="mr-2">mdi-plus</v-icon>创建对话
</v-btn>
<v-card class="conversation-list-card" v-if="conversations.length > 0">
<v-list density="compact" nav class="conversation-list" @update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl" class="conversation-item">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<div class="status-chips">
<v-chip class="status-chip" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
LLM
</v-chip>
<v-chip class="status-chip" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
语音转文本
</v-chip>
<div class="sidebar-header">
<v-btn variant="elevated" rounded="lg" class="new-chat-btn" @click="newC" :disabled="!currCid"
prepend-icon="mdi-plus">
创建对话
</v-btn>
</div>
<v-btn variant="tonal" rounded="xl" class="delete-chat-btn" v-if="currCid"
@click="deleteConversation(currCid)" color="error">
<v-icon class="mr-2">mdi-delete</v-icon>删除此对话
</v-btn>
<div class="conversations-container">
<div class="sidebar-section-title" v-if="conversations.length > 0">
对话历史
</div>
<v-card class="conversation-list-card" v-if="conversations.length > 0" flat>
<v-list density="compact" nav class="conversation-list"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="lg" class="conversation-item" active-color="primary">
<template v-slot:prepend>
<v-icon size="small" icon="mdi-message-text-outline"></v-icon>
</template>
<v-list-item-title class="conversation-title">新对话</v-list-item-title>
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at)
}}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<v-fade-transition>
<div class="no-conversations" v-if="conversations.length === 0">
<v-icon icon="mdi-message-text-outline" size="large" color="grey-lighten-1"></v-icon>
<div class="no-conversations-text">暂无对话历史</div>
</div>
</v-fade-transition>
</div>
<div class="sidebar-footer">
<div class="sidebar-section-title">
系统状态
</div>
<div class="status-chips">
<v-chip class="status-chip" :color="status?.llm_enabled ? 'primary' : 'grey-lighten-2'"
variant="elevated" size="small">
<template v-slot:prepend>
<v-icon :icon="status?.llm_enabled ? 'mdi-check-circle' : 'mdi-alert-circle'"
size="x-small"></v-icon>
</template>
LLM 服务
</v-chip>
<v-chip class="status-chip" :color="status?.stt_enabled ? 'success' : 'grey-lighten-2'"
variant="elevated" size="small">
<template v-slot:prepend>
<v-icon :icon="status?.stt_enabled ? 'mdi-check-circle' : 'mdi-alert-circle'"
size="x-small"></v-icon>
</template>
语音转文本
</v-chip>
</div>
<v-btn variant="tonal" rounded="lg" class="delete-chat-btn" v-if="currCid"
@click="deleteConversation(currCid)" color="error" density="comfortable" size="small">
<v-icon start size="small">mdi-delete</v-icon>
删除此对话
</v-btn>
</div>
</div>
<!-- 右侧聊天内容区域 -->
@@ -77,14 +114,15 @@ marked.setOptions({
<div v-if="msg.type == 'user'" class="user-message">
<div class="message-bubble user-bubble">
<span>{{ msg.message }}</span>
<!-- 图片附件 -->
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index" class="image-attachment">
<div v-for="(img, index) in msg.image_url" :key="index"
class="image-attachment">
<img :src="img" class="attached-image" />
</div>
</div>
<!-- 音频附件 -->
<div class="audio-attachment" v-if="msg.audio_url && msg.audio_url.length > 0">
<audio controls class="audio-player">
@@ -97,7 +135,7 @@ marked.setOptions({
<v-icon icon="mdi-account" />
</v-avatar>
</div>
<!-- 机器人消息 -->
<div v-else class="bot-message">
<v-avatar class="bot-avatar" color="deep-purple" size="36">
@@ -113,49 +151,30 @@ marked.setOptions({
<!-- 输入区域 -->
<div class="input-area fade-in">
<v-text-field
id="input-field"
variant="outlined"
v-model="prompt"
:label="inputFieldLabel"
placeholder="开始输入..."
:loading="loadingChat"
clear-icon="mdi-close-circle"
clearable
@click:clear="clearMessage"
class="message-input"
@keydown="handleInputKeyDown"
hide-details
>
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
placeholder="开始输入..." :loading="loadingChat" clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" class="message-input" @keydown="handleInputKeyDown"
hide-details>
<template v-slot:loader>
<v-progress-linear :active="loadingChat" height="3" color="deep-purple" indeterminate></v-progress-linear>
<v-progress-linear :active="loadingChat" height="3" color="deep-purple"
indeterminate></v-progress-linear>
</template>
<template v-slot:append>
<v-tooltip text="发送">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
@click="sendMessage"
class="send-btn"
icon="mdi-send"
variant="text"
color="deep-purple"
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl"
/>
<v-btn v-bind="props" @click="sendMessage" class="send-btn" icon="mdi-send"
variant="text" color="deep-purple"
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl" />
</template>
</v-tooltip>
<v-tooltip text="语音输入">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
@click="isRecording ? stopRecording() : startRecording()"
class="record-btn"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'"
variant="text"
:color="isRecording ? 'error' : 'deep-purple'"
/>
<v-btn v-bind="props" @click="isRecording ? stopRecording() : startRecording()"
class="record-btn"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
:color="isRecording ? 'error' : 'deep-purple'" />
</template>
</v-tooltip>
</template>
@@ -165,15 +184,17 @@ marked.setOptions({
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl">
<div v-for="(img, index) in stagedImagesUrl" :key="index" class="image-preview">
<img :src="img" class="preview-image" />
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close"
size="small" color="error" variant="text" />
</div>
<div v-if="stagedAudioUrl" class="audio-preview">
<v-chip color="deep-purple-lighten-4" class="audio-chip">
<v-icon start icon="mdi-microphone" size="small"></v-icon>
新录音
</v-chip>
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small"
color="error" variant="text" />
</div>
</div>
</div>
@@ -206,9 +227,9 @@ export default {
status: {},
statusText: '',
eventSource: null,
// Ctrl
ctrlKeyDown: false,
ctrlKeyTimer: null,
@@ -228,18 +249,17 @@ export default {
this.sendMessage();
}
}.bind(this));
// keyup
document.addEventListener('keyup', this.handleInputKeyUp);
},
beforeUnmount() {
console.log("111")
if (this.eventSource) {
this.eventSource.cancel();
console.log('SSE连接已断开');
}
// keyup
document.removeEventListener('keyup', this.handleInputKeyUp);
},
@@ -265,6 +285,9 @@ export default {
this.eventSource = reader
let in_streaming = false
let message_obj = null
while (true) {
const { done, value } = await reader.read();
if (done) {
@@ -273,40 +296,67 @@ export default {
}
const chunk = decoder.decode(value, { stream: true });
console.log("!!!!", chunk);
if (chunk === '[HB]\n') {
continue; //
}
if (chunk === '[ERROR]\n') {
continue;
}
//
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
let lines = chunk.split('\n\n');
console.log('SSE数据:', lines);
for (let i = 0; i < lines.length; i++) {
let line = lines[i].trim();
if (!line) {
continue;
}
this.messages.push(bot_resp);
} else if (chunk.startsWith('[RECORD]')) {
let audio = chunk.replace('[RECORD]', '');
let bot_resp = {
type: 'bot',
message: `<audio controls class="audio-player">
<source src="/api/chat/get_file?filename=${audio}" type="audio/wav">
您的浏览器不支持音频播放
</audio>`
console.log(line)
// data: {"type": "plain", "data": "helloworld"}
let chunk_json = JSON.parse(line.replace('data: ', ''));
if (chunk_json.type === 'heartbeat') {
continue; //
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
if (chunk_json.type === 'error') {
console.error('Error received:', chunk_json.data);
continue;
}
this.messages.push(bot_resp);
if (chunk_json.type === 'image') {
let img = chunk_json.data.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else if (chunk_json.type === 'record') {
let audio = chunk_json.data.replace('[RECORD]', '');
let bot_resp = {
type: 'bot',
message: `<audio controls class="audio-player">
<source src="/api/chat/get_file?filename=${audio}" type="audio/wav">
您的浏览器不支持音频播放
</audio>`
}
this.messages.push(bot_resp);
} else if (chunk_json.type === 'plain') {
if (!in_streaming) {
message_obj = {
type: 'bot',
message: ref(chunk_json.data),
}
this.messages.push(message_obj);
in_streaming = true;
} else {
message_obj.message.value += chunk_json.data;
}
} else if (chunk_json.type === 'end') {
in_streaming = false;
continue;
}
this.scrollToBottom();
}
this.scrollToBottom();
}
},
@@ -526,42 +576,6 @@ export default {
this.stagedAudioUrl = "";
this.loadingChat = false;
// const reader = response.body.getReader(); // Reader
// const decoder = new TextDecoder();
// const readStream = async () => {
// const { done, value } = await reader.read(); //
// if (done) {
// console.log("Stream finished.");
// return;
// }
// const chunk = decoder.decode(value, { stream: true });
// // bot_resp.message.value += chunk;
// console.log("!!!!", chunk);
// if (chunk.startsWith('[IMAGE]')) {
// let img = chunk.replace('[IMAGE]', '');
// let bot_resp = {
// type: 'bot',
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
// }
// this.messages.push(bot_resp);
// } else {
// let bot_resp = {
// type: 'bot',
// message: chunk
// }
// this.messages.push(bot_resp);
// }
// this.scrollToBottom();
// readStream(); //
// };
// readStream();
})
.catch(err => {
console.error(err);
@@ -578,9 +592,9 @@ export default {
if (e.keyCode === 17) { // Ctrl
//
if (this.ctrlKeyDown) return;
this.ctrlKeyDown = true;
//
this.ctrlKeyTimer = setTimeout(() => {
if (this.ctrlKeyDown && !this.isRecording) {
@@ -589,17 +603,17 @@ export default {
}, this.ctrlKeyLongPressThreshold);
}
},
handleInputKeyUp(e) {
if (e.keyCode === 17) { // Ctrl
this.ctrlKeyDown = false;
//
if (this.ctrlKeyTimer) {
clearTimeout(this.ctrlKeyTimer);
this.ctrlKeyTimer = null;
}
//
if (this.isRecording) {
this.stopRecording();
@@ -613,19 +627,41 @@ export default {
<style>
/* 基础动画 */
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.05); }
100% { transform: scale(1); }
0% {
transform: scale(1);
}
50% {
transform: scale(1.05);
}
100% {
transform: scale(1);
}
}
@keyframes slideIn {
from { transform: translateX(20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
from {
transform: translateX(20px);
opacity: 0;
}
to {
transform: translateX(0);
opacity: 1;
}
}
/* 聊天页面布局 */
@@ -650,84 +686,140 @@ export default {
gap: 24px;
}
/* 侧边栏样式 */
/* 侧边栏样式 - 优化版 */
.sidebar-panel {
max-width: 240px;
min-width: 200px;
max-width: 270px;
min-width: 240px;
display: flex;
flex-direction: column;
padding: 16px 8px;
border-right: 1px solid #f0f0f0;
padding: 0;
border-right: 1px solid rgba(0, 0, 0, 0.05);
background-color: #fcfcfc;
height: 100%;
position: relative;
}
.sidebar-header {
padding: 16px;
border-bottom: 1px solid rgba(0, 0, 0, 0.04);
}
.conversations-container {
flex-grow: 1;
overflow-y: auto;
padding: 16px;
}
.sidebar-footer {
padding: 16px;
border-top: 1px solid rgba(0, 0, 0, 0.04);
}
.sidebar-section-title {
font-size: 12px;
font-weight: 500;
color: #666;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 12px;
padding-left: 4px;
}
.new-chat-btn {
margin-bottom: 16px;
min-width: 200px;
background-color: #f5f0ff !important;
color: #673ab7 !important;
width: 100%;
background-color: #673ab7 !important;
color: white !important;
font-weight: 500;
box-shadow: none !important;
box-shadow: 0 2px 8px rgba(103, 58, 183, 0.25) !important;
transition: all 0.2s ease;
text-transform: none;
letter-spacing: 0.25px;
}
.new-chat-btn:hover {
background-color: #ede7f6 !important;
background-color: #7e57c2 !important;
box-shadow: 0 4px 12px rgba(103, 58, 183, 0.3) !important;
transform: translateY(-1px);
}
.conversation-list-card {
border-radius: 12px;
box-shadow: none !important;
border: 1px solid #f0f0f0;
background-color: #fafafa;
background-color: transparent;
}
.conversation-list {
max-height: 500px;
overflow-y: auto;
padding: 4px;
max-height: none;
overflow-y: visible;
padding: 0;
}
.conversation-item {
margin-bottom: 4px;
border-radius: 8px !important;
transition: all 0.2s ease;
height: auto !important;
min-height: 56px;
padding: 8px 12px !important;
}
.conversation-item:hover {
background-color: #f5f0ff;
background-color: rgba(103, 58, 183, 0.05);
}
.conversation-title {
font-weight: 500;
font-size: 14px;
line-height: 1.3;
margin-bottom: 2px;
}
.timestamp {
font-size: 11px;
color: #999;
margin-top: 4px;
line-height: 1;
}
.status-chips {
margin-top: 16px;
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.status-chip {
font-size: 12px;
height: 24px !important;
}
.delete-chat-btn {
position: fixed;
bottom: 24px;
margin-bottom: 16px;
min-width: 200px;
background-color: #feecec !important;
width: 100%;
color: #d32f2f !important;
font-weight: 500;
box-shadow: none !important;
margin-top: 8px;
text-transform: none;
letter-spacing: 0.25px;
font-size: 12px;
}
.delete-chat-btn:hover {
background-color: #ffebee !important;
background-color: rgba(211, 47, 47, 0.1) !important;
}
.no-conversations {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 150px;
opacity: 0.6;
gap: 12px;
}
.no-conversations-text {
font-size: 14px;
color: #999;
}
/* 聊天内容区域 */
@@ -828,7 +920,8 @@ export default {
border-top-left-radius: 4px;
}
.user-avatar, .bot-avatar {
.user-avatar,
.bot-avatar {
align-self: flex-end;
}
@@ -881,7 +974,8 @@ export default {
margin: 0 auto;
}
.send-btn, .record-btn {
.send-btn,
.record-btn {
margin-left: 4px;
}
@@ -895,7 +989,8 @@ export default {
flex-wrap: wrap;
}
.image-preview, .audio-preview {
.image-preview,
.audio-preview {
position: relative;
display: inline-flex;
}
@@ -1003,7 +1098,7 @@ export default {
margin: 16px 0;
}
.markdown-content th,
.markdown-content th,
.markdown-content td {
border: 1px solid #eee;
padding: 8px 12px;
+32 -11
View File
@@ -22,7 +22,7 @@ import 'highlight.js/styles/github.css';
<v-card-title>
<div class="pl-2 pt-2 d-flex align-center pe-2">
<h2> 插件市场</h2>
<h2> 插件市场</h2>
<v-btn icon size="small" style="margin-left: 8px" variant="plain" @click="jumpToPluginMarket()">
<v-icon size="small">mdi-help</v-icon>
<v-tooltip activator="parent" location="start">
@@ -52,6 +52,7 @@ import 'highlight.js/styles/github.css';
<v-card-text>
<small style="color: #bbb;">每个插件都是作者无偿提供的的劳动成果如果您喜欢某个插件 Star</small>
<div v-if="pinnedPlugins.length > 0" class="mt-4">
<h2>🥳 推荐</h2>
@@ -71,7 +72,7 @@ import 'highlight.js/styles/github.css';
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys">
<template v-slot:item.name="{ item }">
<div class="d-flex align-center">
<div class="d-flex align-center" style="overflow-x: scroll;">
<img v-if="item.logo" :src="item.logo"
style="height: 80px; width: 80px; margin-right: 8px; border-radius: 8px; margin-top: 8px; margin-bottom: 8px;"
alt="logo">
@@ -83,24 +84,43 @@ import 'highlight.js/styles/github.css';
</div>
</template>
<template v-slot:item.desc="{ item }">
<div style="font-size: 13px;">
{{ item.desc }}
</div>
</template>
<template v-slot:item.author="{ item }">
<span v-if="item?.social_link"><a :href="item?.social_link">{{ item.author
<div style="font-size: 12px;">
<span v-if="item?.social_link"><a :href="item?.social_link">{{ item.author
}}</a></span>
<span v-else>{{ item.author }}</span>
</div>
</template>
<template v-slot:item.stars="{ item }">
<a :href="item.repo">
<img v-if="item.repo"
:src="`https://img.shields.io/github/stars/${item.repo.split('/').slice(-2).join('/')}.svg`"
:alt="`Stars for ${item.name}`"
style="height: 20px;"
/>
</a>
</template>
<template v-slot:item.tags="{ item }">
<span v-if="item.tags.length === 0"></span>
<v-chip v-for="tag in item.tags" :key="tag" color="primary" size="small">{{ tag
<v-chip v-for="tag in item.tags" :key="tag" color="primary" size="x-small">{{ tag
}}</v-chip>
</template>
<template v-slot:item.actions="{ item }">
<v-btn v-if="!item.installed" class="text-none mr-2" size="small"
<v-btn v-if="!item.installed" class="text-none mr-2" size="x-small"
variant="flat" border
@click="extension_url = item.repo; newExtension()">安装</v-btn>
<v-btn v-else class="text-none mr-2" size="small" variant="flat" border
<v-btn v-else class="text-none mr-2" size="x-small" variant="flat" border
disabled>已安装</v-btn>
<v-btn class="text-none mr-2" size="small" variant="flat" border
@click="open(item.repo)">查看帮助</v-btn>
<v-btn class="text-none mr-2" size="x-small" variant="flat" border
@click="open(item.repo)">帮助</v-btn>
</template>
</v-data-table>
</v-col>
@@ -259,10 +279,11 @@ export default {
announcement: "",
isListView: true,
pluginMarketHeaders: [
{ title: '名称', key: 'name', maxWidth: '150px' },
{ title: '名称', key: 'name', maxWidth: '200px' },
{ title: '描述', key: 'desc', maxWidth: '250px' },
{ title: '作者', key: 'author', maxWidth: '60px' },
{ title: '标签', key: 'tags', maxWidth: '60px' },
{ title: '作者', key: 'author', maxWidth: '70px' },
{ title: 'Star数', key: 'stars', maxWidth: '100px' },
{ title: '标签', key: 'tags', maxWidth: '100px' },
{ title: '操作', key: 'actions', sortable: false }
],
marketSearch: "",
+236 -2
View File
@@ -3,6 +3,7 @@ import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import ReadmeDialog from '@/components/shared/ReadmeDialog.vue';
import axios from 'axios';
import { useCommonStore } from '@/stores/common';
@@ -35,6 +36,20 @@ const selectedPlugin = ref({});
const curr_namespace = ref("");
const wfr = ref(null);
const readmeDialog = reactive({
show: false,
pluginName: '',
repoUrl: null
});
//
const platformEnableDialog = ref(false);
const platformEnableData = reactive({
platforms: [],
plugins: [],
platform_enable: {}
});
const loadingPlatformData = ref(false);
const plugin_handler_info_headers = [
{ title: '行为类型', key: 'event_type_h' },
{ title: '描述', key: 'desc', maxWidth: '250px' },
@@ -225,6 +240,107 @@ const reloadPlugin = async (plugin_name) => {
}
};
const viewReadme = (plugin) => {
readmeDialog.pluginName = plugin.name;
readmeDialog.repoUrl = plugin.repo;
readmeDialog.show = true;
};
//
const getPlatformEnableConfig = async () => {
loadingPlatformData.value = true;
try {
const res = await axios.get('/api/plugin/platform_enable/get');
if (res.data.status === "error") {
toast(res.data.message, "error");
return;
}
platformEnableData.platforms = res.data.data.platforms;
platformEnableData.plugins = res.data.data.plugins;
platformEnableData.platform_enable = res.data.data.platform_enable;
//
if (platformEnableData.platforms.length === 0) {
toast("未添加任何平台适配器,请先在平台管理中添加平台", "warning");
} else {
//
platformEnableData.platforms.forEach(platform => {
if (!platformEnableData.platform_enable[platform.name]) {
platformEnableData.platform_enable[platform.name] = {};
}
//
platformEnableData.plugins.forEach(plugin => {
if (platformEnableData.platform_enable[platform.name][plugin.name] === undefined) {
platformEnableData.platform_enable[platform.name][plugin.name] = true; //
}
});
});
}
platformEnableDialog.value = true;
} catch (err) {
toast("获取平台插件配置失败: " + err, "error");
} finally {
loadingPlatformData.value = false;
}
};
//
const savePlatformEnableConfig = async () => {
loadingPlatformData.value = true;
try {
const res = await axios.post('/api/plugin/platform_enable/set', {
platform_enable: platformEnableData.platform_enable
});
if (res.data.status === "error") {
toast(res.data.message, "error");
return;
}
toast(res.data.message, "success");
platformEnableDialog.value = false;
} catch (err) {
toast("保存平台插件配置失败: " + err, "error");
} finally {
loadingPlatformData.value = false;
}
};
//
const selectAllPluginsForPlatform = (platformName, isSelected, onlyReserved = null) => {
// platform_enable
if (!platformEnableData.platform_enable[platformName]) {
platformEnableData.platform_enable[platformName] = {};
}
//
platformEnableData.plugins.forEach(plugin => {
// onlyReservednull
// onlyReservedtrue
// onlyReservedfalse
if (onlyReserved === null || plugin.reserved === onlyReserved) {
platformEnableData.platform_enable[platformName][plugin.name] = isSelected;
}
});
};
//
const toggleAllPluginsForPlatform = (platformName) => {
// platform_enable
if (!platformEnableData.platform_enable[platformName]) {
platformEnableData.platform_enable[platformName] = {};
}
//
platformEnableData.plugins.forEach(plugin => {
const currentState = platformEnableData.platform_enable[platformName][plugin.name];
platformEnableData.platform_enable[platformName][plugin.name] = !currentState;
});
};
//
onMounted(async () => {
await getExtensions();
@@ -248,6 +364,9 @@ onMounted(async () => {
<v-btn class="text-none ml-2" size="small" variant="flat" border @click="toggleShowReserved">
{{ showReserved ? '隐藏系统保留插件' : '显示系统保留插件' }}
</v-btn>
<v-btn class="text-none ml-2" size="small" variant="flat" color="primary" border @click="getPlatformEnableConfig">
平台命令配置
</v-btn>
<v-dialog max-width="500px" v-if="extension_data.message">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" icon size="small" color="error" style="margin-left: auto;" variant="plain">
@@ -279,11 +398,111 @@ onMounted(async () => {
@update="updateExtension(extension.name)"
@reload="reloadPlugin(extension.name)"
@toggle-activation="extension.activated ? pluginOff(extension) : pluginOn(extension)"
@view-handlers="showPluginInfo(extension)">
@view-handlers="showPluginInfo(extension)"
@view-readme="viewReadme(extension)">
</ExtensionCard>
</v-col>
</v-row>
<!-- 插件平台配置对话框 -->
<v-dialog v-model="platformEnableDialog" max-width="800" persistent>
<v-card>
<v-card-title>
<span class="headline">平台命令可用性配置</span>
</v-card-title>
<v-card-subtitle>
设置每个插件在不同平台上的可用性勾选表示启用
</v-card-subtitle>
<v-card-text>
<v-overlay
:model-value="loadingPlatformData"
class="align-center justify-center"
persistent
>
<v-progress-circular
color="primary"
indeterminate
size="64"
></v-progress-circular>
</v-overlay>
<div v-if="platformEnableData.platforms.length === 0" class="text-center pa-4">
<v-icon icon="mdi-alert" color="warning" size="64" class="mb-4"></v-icon>
<div class="text-h6 mb-2">未找到平台适配器</div>
<div class="text-body-1 mb-4">请先在 <strong>平台管理</strong> 中添加并配置平台适配器然后再设置插件的平台可用性</div>
<v-btn color="primary" to="/platforms">前往平台管理</v-btn>
</div>
<v-table v-else>
<thead>
<tr>
<th>插件名称</th>
<th v-for="platform in platformEnableData.platforms" :key="platform.name">
<div class="d-flex align-center">
{{ platform.display_name }}
<v-menu>
<template v-slot:activator="{ props }">
<v-btn
icon
density="compact"
variant="text"
size="small"
v-bind="props"
class="ms-1"
>
<v-icon>mdi-dots-vertical</v-icon>
</v-btn>
</template>
<v-list>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true)">
<v-list-item-title>全选</v-list-item-title>
</v-list-item>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true, false)">
<v-list-item-title>全选普通插件</v-list-item-title>
</v-list-item>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, true, true)">
<v-list-item-title>全选系统插件</v-list-item-title>
</v-list-item>
<v-list-item @click="selectAllPluginsForPlatform(platform.name, false)">
<v-list-item-title>全不选</v-list-item-title>
</v-list-item>
<v-list-item @click="toggleAllPluginsForPlatform(platform.name)">
<v-list-item-title>反选</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
</div>
</th>
</tr>
</thead>
<tbody>
<tr v-for="plugin in platformEnableData.plugins" :key="plugin.name">
<td>
<div class="d-flex align-center">
{{ plugin.name }}
<v-chip v-if="plugin.reserved" color="primary" size="x-small" class="ml-2">系统</v-chip>
</div>
<div class="text-caption text-grey">{{ plugin.desc }}</div>
</td>
<td v-for="platform in platformEnableData.platforms" :key="platform.name">
<v-checkbox
v-model="platformEnableData.platform_enable[platform.name][plugin.name]"
hide-details
density="compact"
></v-checkbox>
</td>
</tr>
</tbody>
</v-table>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" text @click="platformEnableDialog = false">关闭</v-btn>
<v-btn v-if="platformEnableData.platforms.length > 0" color="primary" @click="savePlatformEnableConfig">保存</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 配置对话框 -->
<v-dialog v-model="configDialog" width="1000">
<v-card>
@@ -365,4 +584,19 @@ onMounted(async () => {
</v-snackbar>
<WaitingForRestart ref="wfr"></WaitingForRestart>
</template>
<ReadmeDialog
v-model:show="readmeDialog.show"
:plugin-name="readmeDialog.pluginName"
:repo-url="readmeDialog.repoUrl"
/>
</template>
<style scoped>
.plugin-handler-item {
margin-bottom: 10px;
padding: 5px;
border-radius: 5px;
background-color: #f5f5f5;
}
</style>
+1 -1
View File
@@ -22,7 +22,7 @@ class Main(star.Star):
if not self.timezone:
self.timezone = None
try:
self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None
self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None
except Exception as e:
logger.error(f"时区设置错误: {e}, 使用本地时区")
self.timezone = None