diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index e8a9d23a9..540171f1d 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -3,5 +3,18 @@ from astrbot import logger from astrbot.core import html_renderer from astrbot.core import sp from astrbot.core.star.register import register_llm_tool as llm_tool +from astrbot.core.star.register import register_agent as agent +from astrbot.core.agent.tool import ToolSet, FunctionTool +from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor -__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"] +__all__ = [ + "AstrBotConfig", + "logger", + "html_renderer", + "llm_tool", + "agent", + "sp", + "ToolSet", + "FunctionTool", + "BaseFunctionToolExecutor", +] diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 16f108ece..235a8284b 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,5 +1,4 @@ import os -import asyncio from .log import LogManager, LogBroker # noqa from astrbot.core.utils.t2i.renderer import HtmlRenderer from astrbot.core.utils.shared_preferences import SharedPreferences @@ -21,7 +20,7 @@ html_renderer = HtmlRenderer(t2i_base_url) logger = LogManager.GetLogger(log_name="astrbot") db_helper = SQLiteDatabase(DB_PATH) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 -sp = SharedPreferences() +sp = SharedPreferences(db_helper=db_helper) # 文件令牌服务 file_token_service = FileTokenService() pip_installer = PipInstaller( diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py new file mode 100644 index 000000000..70536ca88 --- /dev/null +++ b/astrbot/core/agent/agent.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from .tool import FunctionTool +from typing import Generic +from .run_context import TContext +from .hooks import BaseAgentRunHooks + + +@dataclass +class Agent(Generic[TContext]): + name: str + instructions: str | None = None + tools: list[str, FunctionTool] | None = None + run_hooks: BaseAgentRunHooks[TContext] | None = None diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py new file mode 100644 index 000000000..d26463147 --- /dev/null +++ b/astrbot/core/agent/handoff.py @@ -0,0 +1,34 @@ +from typing import Generic +from .tool import FunctionTool +from .agent import Agent +from .run_context import TContext + + +class HandoffTool(FunctionTool, Generic[TContext]): + """Handoff tool for delegating tasks to another agent.""" + + def __init__( + self, agent: Agent[TContext], parameters: dict | None = None, **kwargs + ): + self.agent = agent + super().__init__( + name=f"transfer_to_{agent.name}", + parameters=parameters or self.default_parameters(), + description=agent.instructions or self.default_description(agent.name), + **kwargs, + ) + + def default_parameters(self) -> dict: + return { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to be handed off to another agent. This should be a clear and concise request or task.", + }, + }, + } + + def default_description(self, agent_name: str | None) -> str: + agent_name = agent_name or "another" + return f"Delegate tasks to {self.name} agent to handle the request." diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py new file mode 100644 index 000000000..884fe6bd4 --- /dev/null +++ b/astrbot/core/agent/hooks.py @@ -0,0 +1,27 @@ +import mcp +from dataclasses import dataclass +from .run_context import ContextWrapper, TContext +from typing import Generic +from astrbot.core.provider.entities import LLMResponse +from astrbot.core.agent.tool import FunctionTool + + +@dataclass +class BaseAgentRunHooks(Generic[TContext]): + async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_tool_start( + self, + run_context: ContextWrapper[TContext], + tool: FunctionTool, + tool_args: dict | None, + ): ... + async def on_tool_end( + self, + run_context: ContextWrapper[TContext], + tool: FunctionTool, + tool_args: dict | None, + tool_result: mcp.types.CallToolResult | None, + ): ... + async def on_agent_done( + self, run_context: ContextWrapper[TContext], llm_response: LLMResponse + ): ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py new file mode 100644 index 000000000..f22a222a0 --- /dev/null +++ b/astrbot/core/agent/mcp_client.py @@ -0,0 +1,208 @@ +import asyncio +import logging +from datetime import timedelta +from typing import Optional +from contextlib import AsyncExitStack +from astrbot import logger +from astrbot.core.utils.log_pipe import LogPipe + +try: + import mcp + from mcp.client.sse import sse_client +except (ModuleNotFoundError, ImportError): + logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") + +try: + from mcp.client.streamable_http import streamablehttp_client +except (ModuleNotFoundError, ImportError): + logger.warning( + "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。" + ) + + +def _prepare_config(config: dict) -> dict: + """准备配置,处理嵌套格式""" + if "mcpServers" in config and config["mcpServers"]: + first_key = next(iter(config["mcpServers"])) + config = config["mcpServers"][first_key] + config.pop("active", None) + return config + + +async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: + """快速测试 MCP 服务器可达性""" + import aiohttp + + cfg = _prepare_config(config.copy()) + + url = cfg["url"] + headers = cfg.get("headers", {}) + timeout = cfg.get("timeout", 10) + + try: + async with aiohttp.ClientSession() as session: + if cfg.get("transport") == "streamable_http": + test_payload = { + "jsonrpc": "2.0", + "method": "initialize", + "id": 0, + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.2.3"}, + }, + } + async with session.post( + url, + headers={ + **headers, + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + json=test_payload, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 200: + return True, "" + else: + return False, f"HTTP {response.status}: {response.reason}" + else: + async with session.get( + url, + headers={ + **headers, + "Accept": "application/json, text/event-stream", + }, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 200: + return True, "" + else: + return False, f"HTTP {response.status}: {response.reason}" + + except asyncio.TimeoutError: + return False, f"连接超时: {timeout}秒" + except Exception as e: + return False, f"{e!s}" + + +class MCPClient: + def __init__(self): + # Initialize session and client objects + self.session: Optional[mcp.ClientSession] = None + self.exit_stack = AsyncExitStack() + + self.name = None + self.active: bool = True + self.tools: list[mcp.Tool] = [] + self.server_errlogs: list[str] = [] + self.running_event = asyncio.Event() + + async def connect_to_server(self, mcp_server_config: dict, name: str): + """连接到 MCP 服务器 + + 如果 `url` 参数存在: + 1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。 + 1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。 + 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。 + + Args: + mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server + """ + cfg = _prepare_config(mcp_server_config.copy()) + + def logging_callback(msg: str): + # 处理 MCP 服务的错误日志 + print(f"MCP Server {name} Error: {msg}") + self.server_errlogs.append(msg) + + if "url" in cfg: + success, error_msg = await _quick_test_mcp_connection(cfg) + if not success: + raise Exception(error_msg) + + if cfg.get("transport") != "streamable_http": + # SSE transport method + self._streams_context = sse_client( + url=cfg["url"], + headers=cfg.get("headers", {}), + timeout=cfg.get("timeout", 5), + sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), + ) + streams = await self.exit_stack.enter_async_context( + self._streams_context + ) + + # Create a new client session + read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) + self.session = await self.exit_stack.enter_async_context( + mcp.ClientSession( + *streams, + read_timeout_seconds=read_timeout, + logging_callback=logging_callback, # type: ignore + ) + ) + else: + timeout = timedelta(seconds=cfg.get("timeout", 30)) + sse_read_timeout = timedelta( + seconds=cfg.get("sse_read_timeout", 60 * 5) + ) + self._streams_context = streamablehttp_client( + url=cfg["url"], + headers=cfg.get("headers", {}), + timeout=timeout, + sse_read_timeout=sse_read_timeout, + terminate_on_close=cfg.get("terminate_on_close", True), + ) + read_s, write_s, _ = await self.exit_stack.enter_async_context( + self._streams_context + ) + + # Create a new client session + read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) + self.session = await self.exit_stack.enter_async_context( + mcp.ClientSession( + read_stream=read_s, + write_stream=write_s, + read_timeout_seconds=read_timeout, + logging_callback=logging_callback, # type: ignore + ) + ) + + else: + server_params = mcp.StdioServerParameters( + **cfg, + ) + + def callback(msg: str): + # 处理 MCP 服务的错误日志 + self.server_errlogs.append(msg) + + stdio_transport = await self.exit_stack.enter_async_context( + mcp.stdio_client( + server_params, + errlog=LogPipe( + level=logging.ERROR, + logger=logger, + identifier=f"MCPServer-{name}", + callback=callback, + ), # type: ignore + ), + ) + + # Create a new client session + self.session = await self.exit_stack.enter_async_context( + mcp.ClientSession(*stdio_transport) + ) + await self.session.initialize() + + async def list_tools_and_save(self) -> mcp.ListToolsResult: + """List all tools from the server and save them to self.tools""" + response = await self.session.list_tools() + self.tools = response.tools + return response + + async def cleanup(self): + """Clean up resources""" + await self.exit_stack.aclose() + self.running_event.set() # Set the running event to indicate cleanup is done diff --git a/astrbot/core/agent/response.py b/astrbot/core/agent/response.py new file mode 100644 index 000000000..3f683a233 --- /dev/null +++ b/astrbot/core/agent/response.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +import typing as T +from astrbot.core.message.message_event_result import MessageChain + +class AgentResponseData(T.TypedDict): + chain: MessageChain + + +@dataclass +class AgentResponse: + type: str + data: AgentResponseData diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py new file mode 100644 index 000000000..58ea2ca43 --- /dev/null +++ b/astrbot/core/agent/run_context.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from typing import Any, Generic +from typing_extensions import TypeVar + +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +TContext = TypeVar("TContext", default=Any) + + +@dataclass +class ContextWrapper(Generic[TContext]): + """A context for running an agent, which can be used to pass additional data or state.""" + + context: TContext + event: AstrMessageEvent + +NoContext = ContextWrapper[None] diff --git a/astrbot/core/agent/runners/__init__.py b/astrbot/core/agent/runners/__init__.py new file mode 100644 index 000000000..c13589f51 --- /dev/null +++ b/astrbot/core/agent/runners/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseAgentRunner + +__all__ = ["BaseAgentRunner"] diff --git a/astrbot/core/pipeline/process_stage/agent_runner/base.py b/astrbot/core/agent/runners/base.py similarity index 56% rename from astrbot/core/pipeline/process_stage/agent_runner/base.py rename to astrbot/core/agent/runners/base.py index 431a95ca6..83821ae29 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/base.py +++ b/astrbot/core/agent/runners/base.py @@ -1,32 +1,33 @@ import abc import typing as T -from dataclasses import dataclass -from astrbot.core.provider.entities import LLMResponse -from ....message.message_event_result import MessageChain from enum import Enum, auto +from ..run_context import ContextWrapper, TContext +from ..response import AgentResponse +from ..hooks import BaseAgentRunHooks +from ..tool_executor import BaseFunctionToolExecutor +from astrbot.core.provider import Provider +from astrbot.core.provider.entities import LLMResponse class AgentState(Enum): - """Agent 状态枚举""" - IDLE = auto() # 初始状态 - RUNNING = auto() # 运行中 - DONE = auto() # 完成 - ERROR = auto() # 错误状态 + """Defines the state of the agent.""" + + IDLE = auto() # Initial state + RUNNING = auto() # Currently processing + DONE = auto() # Completed + ERROR = auto() # Error state -class AgentResponseData(T.TypedDict): - chain: MessageChain - - -@dataclass -class AgentResponse: - type: str - data: AgentResponseData - - -class BaseAgentRunner: +class BaseAgentRunner(T.Generic[TContext]): @abc.abstractmethod - async def reset(self) -> None: + async def reset( + self, + provider: Provider, + run_context: ContextWrapper[TContext], + tool_executor: BaseFunctionToolExecutor[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + **kwargs: T.Any, + ) -> None: """ Reset the agent to its initial state. This method should be called before starting a new run. diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py similarity index 55% rename from astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py rename to astrbot/core/agent/runners/tool_loop_agent_runner.py index c2961ded5..c38285f55 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,10 +1,12 @@ import sys import traceback import typing as T -from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState -from ...context import PipelineContext +from .base import BaseAgentRunner, AgentResponse, AgentState +from ..hooks import BaseAgentRunHooks +from ..tool_executor import BaseFunctionToolExecutor +from ..run_context import ContextWrapper, TContext +from ..response import AgentResponseData from astrbot.core.provider.provider import Provider -from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import ( MessageChain, ) @@ -21,8 +23,8 @@ from mcp.types import ( EmbeddedResource, TextResourceContents, BlobResourceContents, + CallToolResult, ) -from astrbot.core.star.star_handler import EventType from astrbot import logger if sys.version_info >= (3, 12): @@ -31,28 +33,25 @@ else: from typing_extensions import override -# TODO: -# 1. 处理平台不兼容的处理器 - - -class ToolLoopAgent(BaseAgentRunner): - def __init__( - self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext - ) -> None: - self.provider = provider - self.req = None - self.event = event - self.pipeline_ctx = pipeline_ctx - self._state = AgentState.IDLE - self.final_llm_resp = None - self.streaming = False - +class ToolLoopAgentRunner(BaseAgentRunner[TContext]): @override - async def reset(self, req: ProviderRequest, streaming: bool) -> None: - self.req = req - self.streaming = streaming + async def reset( + self, + provider: Provider, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + tool_executor: BaseFunctionToolExecutor[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.provider = provider self.final_llm_resp = None self._state = AgentState.IDLE + self.tool_executor = tool_executor + self.agent_hooks = agent_hooks + self.run_context = run_context def _transition_state(self, new_state: AgentState) -> None: """转换 Agent 状态""" @@ -78,6 +77,12 @@ class ToolLoopAgent(BaseAgentRunner): if not self.req: raise ValueError("Request is not set. Please call reset() first.") + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + # 开始处理,转换到运行状态 self._transition_state(AgentState.RUNNING) llm_resp_result = None @@ -124,12 +129,10 @@ class ToolLoopAgent(BaseAgentRunner): # 如果没有工具调用,转换到完成状态 self.final_llm_resp = llm_resp self._transition_state(AgentState.DONE) - - # 执行事件钩子 - if await self.pipeline_ctx.call_event_hook( - self.event, EventType.OnLLMResponseEvent, llm_resp - ): - return + try: + await self.agent_hooks.on_agent_done(self.run_context, llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) # 返回 LLM 结果 if llm_resp.result_chain: @@ -193,50 +196,33 @@ class ToolLoopAgent(BaseAgentRunner): if not req.func_tool: return func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") + + try: + await self.agent_hooks.on_tool_start( + self.run_context, func_tool, func_tool_args ) - client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] - res = await client.session.call_tool(func_tool.name, func_tool_args) - if not res: - continue - if isinstance(res.content[0], TextContent): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, - ) - ) - yield MessageChain().message(res.content[0].text) - elif isinstance(res.content[0], ImageContent): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回了图片(已直接发送给用户)", - ) - ) - yield MessageChain(type="tool_direct_result").base64_image( - res.content[0].data - ) - elif isinstance(res.content[0], EmbeddedResource): - resource = res.content[0].resource - if isinstance(resource, TextResourceContents): + except Exception as e: + logger.error(f"Error in on_tool_start hook: {e}", exc_info=True) + + executor = self.tool_executor.execute( + tool=func_tool, + run_context=self.run_context, + **func_tool_args, + ) + async for resp in executor: + if isinstance(resp, CallToolResult): + res = resp + if isinstance(res.content[0], TextContent): tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content=resource.text, + content=res.content[0].text, ) ) - yield MessageChain().message(resource.text) - elif ( - isinstance(resource, BlobResourceContents) - and resource.mimeType - and resource.mimeType.startswith("image/") - ): + yield MessageChain().message(res.content[0].text) + elif isinstance(res.content[0], ImageContent): tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", @@ -247,43 +233,85 @@ class ToolLoopAgent(BaseAgentRunner): yield MessageChain(type="tool_direct_result").base64_image( res.content[0].data ) - else: - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回的数据类型不受支持", - ) - ) - yield MessageChain().message("返回的数据类型不受支持。") - else: - logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") - # 尝试调用工具函数 - wrapper = self.pipeline_ctx.call_handler( - self.event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: - # Tool 返回结果 - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resp, - ) - ) - yield MessageChain().message(resp) - else: - # Tool 直接请求发送消息给用户 - # 这里我们将直接结束 Agent Loop。 - self._transition_state(AgentState.DONE) - if res := self.event.get_result(): - if res.chain: - yield MessageChain( - chain=res.chain, type="tool_direct_result" + elif isinstance(res.content[0], EmbeddedResource): + resource = res.content[0].resource + if isinstance(resource, TextResourceContents): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resource.text, ) + ) + yield MessageChain().message(resource.text) + elif ( + isinstance(resource, BlobResourceContents) + and resource.mimeType + and resource.mimeType.startswith("image/") + ): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + yield MessageChain( + type="tool_direct_result" + ).base64_image(res.content[0].data) + else: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回的数据类型不受支持", + ) + ) + yield MessageChain().message("返回的数据类型不受支持。") - self.event.clear_result() + try: + await self.agent_hooks.on_tool_end( + self.run_context, + func_tool_name, + func_tool_args, + resp, + ) + except Exception as e: + logger.error( + f"Error in on_tool_end hook: {e}", exc_info=True + ) + elif resp is None: + # Tool 直接请求发送消息给用户 + # 这里我们将直接结束 Agent Loop。 + self._transition_state(AgentState.DONE) + if res := self.run_context.event.get_result(): + if res.chain: + yield MessageChain( + chain=res.chain, type="tool_direct_result" + ) + try: + await self.agent_hooks.on_tool_end( + self.run_context, func_tool_name, func_tool_args, None + ) + except Exception as e: + logger.error( + f"Error in on_tool_end hook: {e}", exc_info=True + ) + else: + logger.warning( + f"Tool 返回了不支持的类型: {type(resp)},将忽略。" + ) + + try: + await self.agent_hooks.on_tool_end( + self.run_context, func_tool_name, func_tool_args, None + ) + except Exception as e: + logger.error( + f"Error in on_tool_end hook: {e}", exc_info=True + ) + + self.run_context.event.clear_result() except Exception as e: logger.warning(traceback.format_exc()) tool_call_result_blocks.append( diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py new file mode 100644 index 000000000..743deae1f --- /dev/null +++ b/astrbot/core/agent/tool.py @@ -0,0 +1,256 @@ +from dataclasses import dataclass +from deprecated import deprecated +from typing import Awaitable, Literal, Any, Optional +from .mcp_client import MCPClient + + +@dataclass +class FunctionTool: + """A class representing a function tool that can be used in function calling.""" + + name: str | None = None + parameters: dict | None = None + description: str | None = None + handler: Awaitable | None = None + """处理函数, 当 origin 为 mcp 时,这个为空""" + handler_module_path: str | None = None + """处理函数的模块路径,当 origin 为 mcp 时,这个为空 + + 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + """ + active: bool = True + """是否激活""" + + origin: Literal["local", "mcp"] = "local" + """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" + + # MCP 相关字段 + mcp_server_name: str | None = None + """MCP 服务名称,当 origin 为 mcp 时有效""" + mcp_client: MCPClient | None = None + """MCP 客户端,当 origin 为 mcp 时有效""" + + def __repr__(self): + return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" + + def __dict__(self) -> dict[str, Any]: + """将 FunctionTool 转换为字典格式""" + return { + "name": self.name, + "parameters": self.parameters, + "description": self.description, + "active": self.active, + "origin": self.origin, + "mcp_server_name": self.mcp_server_name, + } + + +class ToolSet: + """A set of function tools that can be used in function calling. + + This class provides methods to add, remove, and retrieve tools, as well as + convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).""" + + def __init__(self, tools: list[FunctionTool] = None): + self.tools: list[FunctionTool] = tools or [] + + def empty(self) -> bool: + """Check if the tool set is empty.""" + return len(self.tools) == 0 + + def add_tool(self, tool: FunctionTool): + """Add a tool to the set.""" + # 检查是否已存在同名工具 + for i, existing_tool in enumerate(self.tools): + if existing_tool.name == tool.name: + self.tools[i] = tool + return + self.tools.append(tool) + + def remove_tool(self, name: str): + """Remove a tool by its name.""" + self.tools = [tool for tool in self.tools if tool.name != name] + + def get_tool(self, name: str) -> Optional[FunctionTool]: + """Get a tool by its name.""" + for tool in self.tools: + if tool.name == name: + return tool + return None + + @deprecated(reason="Use add_tool() instead", version="4.0.0") + def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable): + """Add a function tool to the set.""" + params = { + "type": "object", # hard-coded here + "properties": {}, + } + for param in func_args: + params["properties"][param["name"]] = { + "type": param["type"], + "description": param["description"], + } + _func = FunctionTool( + name=name, + parameters=params, + description=desc, + handler=handler, + ) + self.add_tool(_func) + + @deprecated(reason="Use remove_tool() instead", version="4.0.0") + def remove_func(self, name: str): + """Remove a function tool by its name.""" + self.remove_tool(name) + + @deprecated(reason="Use get_tool() instead", version="4.0.0") + def get_func(self, name: str) -> list[FunctionTool]: + """Get all function tools.""" + return self.get_tool(name) + + @property + def func_list(self) -> list[FunctionTool]: + """Get the list of function tools.""" + return self.tools + + def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]: + """Convert tools to OpenAI API function calling schema format.""" + result = [] + for tool in self.tools: + func_def = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + }, + } + + if tool.parameters.get("properties") or not omit_empty_parameter_field: + func_def["function"]["parameters"] = tool.parameters + + result.append(func_def) + return result + + def anthropic_schema(self) -> list[dict]: + """Convert tools to Anthropic API format.""" + result = [] + for tool in self.tools: + tool_def = { + "name": tool.name, + "description": tool.description, + "input_schema": { + "type": "object", + "properties": tool.parameters.get("properties", {}), + "required": tool.parameters.get("required", []), + }, + } + result.append(tool_def) + return result + + def google_schema(self) -> dict: + """Convert tools to Google GenAI API format.""" + + def convert_schema(schema: dict) -> dict: + """Convert schema to Gemini API format.""" + supported_types = { + "string", + "number", + "integer", + "boolean", + "array", + "object", + "null", + } + supported_formats = { + "string": {"enum", "date-time"}, + "integer": {"int32", "int64"}, + "number": {"float", "double"}, + } + + if "anyOf" in schema: + return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]} + + result = {} + + if "type" in schema and schema["type"] in supported_types: + result["type"] = schema["type"] + if "format" in schema and schema["format"] in supported_formats.get( + result["type"], set() + ): + result["format"] = schema["format"] + else: + result["type"] = "null" + + support_fields = { + "title", + "description", + "enum", + "minimum", + "maximum", + "maxItems", + "minItems", + "nullable", + "required", + } + result.update({k: schema[k] for k in support_fields if k in schema}) + + if "properties" in schema: + properties = {} + for key, value in schema["properties"].items(): + prop_value = convert_schema(value) + if "default" in prop_value: + del prop_value["default"] + properties[key] = prop_value + + if properties: + result["properties"] = properties + + if "items" in schema: + result["items"] = convert_schema(schema["items"]) + + return result + + tools = [ + { + "name": tool.name, + "description": tool.description, + "parameters": convert_schema(tool.parameters), + } + for tool in self.tools + ] + + declarations = {} + if tools: + declarations["function_declarations"] = tools + return declarations + + @deprecated(reason="Use openai_schema() instead", version="4.0.0") + def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False): + return self.openai_schema(omit_empty_parameter_field) + + @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") + def get_func_desc_anthropic_style(self): + return self.anthropic_schema() + + @deprecated(reason="Use google_schema() instead", version="4.0.0") + def get_func_desc_google_genai_style(self): + return self.google_schema() + + def names(self) -> list[str]: + """获取所有工具的名称列表""" + return [tool.name for tool in self.tools] + + def __len__(self): + return len(self.tools) + + def __bool__(self): + return len(self.tools) > 0 + + def __iter__(self): + return iter(self.tools) + + def __repr__(self): + return f"ToolSet(tools={self.tools})" + + def __str__(self): + return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py new file mode 100644 index 000000000..34a2f5e77 --- /dev/null +++ b/astrbot/core/agent/tool_executor.py @@ -0,0 +1,11 @@ +import mcp +from typing import Any, Generic, AsyncGenerator +from .run_context import TContext, ContextWrapper +from .tool import FunctionTool + + +class BaseFunctionToolExecutor(Generic[TContext]): + @classmethod + async def execute( + cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args + ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py new file mode 100644 index 000000000..b09d03b3c --- /dev/null +++ b/astrbot/core/astr_agent_context.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from astrbot.core.provider import Provider +from astrbot.core.provider.entities import ProviderRequest + + +@dataclass +class AstrAgentContext: + provider: Provider + first_provider_request: ProviderRequest + curr_provider_request: ProviderRequest + streaming: bool diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py new file mode 100644 index 000000000..51ea8fcd4 --- /dev/null +++ b/astrbot/core/astrbot_config_mgr.py @@ -0,0 +1,276 @@ +import os +import uuid +from astrbot.core import AstrBotConfig, logger +from astrbot.core.utils.shared_preferences import SharedPreferences +from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH +from astrbot.core.config.default import DEFAULT_CONFIG +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.utils.astrbot_path import get_astrbot_config_path +from typing import TypeVar, TypedDict + +_VT = TypeVar("_VT") + + +class ConfInfo(TypedDict): + """Configuration information for a specific session or platform.""" + + id: str # UUID of the configuration or "default" + umop: list[str] # Unified Message Origin Pattern + name: str + path: str # File name to the configuration file + + +DEFAULT_CONFIG_CONF_INFO = ConfInfo( + id="default", + umop=["::"], + name="default", + path=ASTRBOT_CONFIG_PATH, +) + + +class AstrBotConfigManager: + """A class to manage the system configuration of AstrBot, aka ACM""" + + def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences): + self.sp = sp + self.confs: dict[str, AstrBotConfig] = {} + """uuid / "default" -> AstrBotConfig""" + self.confs["default"] = default_config + self._load_all_configs() + + def _load_all_configs(self): + """Load all configurations from the shared preferences.""" + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + for uuid_, meta in abconf_data.items(): + filename = meta["path"] + conf_path = os.path.join(get_astrbot_config_path(), filename) + if os.path.exists(conf_path): + conf = AstrBotConfig(config_path=conf_path) + self.confs[uuid_] = conf + else: + logger.warning( + f"Config file {conf_path} for UUID {uuid_} does not exist, skipping." + ) + continue + + def _is_umo_match(self, p1: str, p2: str) -> bool: + """判断 p2 umo 是否逻辑包含于 p1 umo""" + p1_ls = p1.split(":") + p2_ls = p2.split(":") + + if len(p1_ls) != 3 or len(p2_ls) != 3: + return False # 非法格式 + + return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls)) + + def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo: + """获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default") + + Returns: + ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型 + """ + # uuid -> { "umop": list, "path": str, "name": str } + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + if isinstance(umo, MessageSession): + umo = str(umo) + else: + try: + umo = str(MessageSession.from_str(umo)) # validate + except Exception: + return DEFAULT_CONFIG_CONF_INFO + + for uuid_, meta in abconf_data.items(): + for pattern in meta["umop"]: + if self._is_umo_match(pattern, umo): + return ConfInfo(**meta, id=uuid_) + + return DEFAULT_CONFIG_CONF_INFO + + def _save_conf_mapping( + self, + abconf_path: str, + abconf_id: str, + umo_parts: list[str] | list[MessageSession], + abconf_name: str | None = None, + ) -> None: + """保存配置文件的映射关系""" + for part in umo_parts: + if isinstance(part, MessageSession): + part = str(part) + elif not isinstance(part, str): + raise ValueError( + "umo_parts must be a list of strings or MessageSession instances" + ) + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + random_word = abconf_name or uuid.uuid4().hex[:8] + abconf_data[abconf_id] = { + "umop": umo_parts, + "path": abconf_path, + "name": random_word, + } + self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") + + def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig: + """获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。""" + if not umo: + return self.confs["default"] + if isinstance(umo, MessageSession): + umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}" + + uuid_ = self._load_conf_mapping(umo)["id"] + + conf = self.confs.get(uuid_) + if not conf: + conf = self.confs["default"] # default MUST exists + + return conf + + @property + def default_conf(self) -> AstrBotConfig: + """获取默认配置文件""" + return self.confs["default"] + + def get_conf_info(self, umo: str | MessageSession) -> ConfInfo: + """获取指定 umo 的配置文件元数据""" + if isinstance(umo, MessageSession): + umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}" + + return self._load_conf_mapping(umo) + + def get_conf_list(self) -> list[ConfInfo]: + """获取所有配置文件的元数据列表""" + conf_list = [] + conf_list.append(DEFAULT_CONFIG_CONF_INFO) + abconf_mapping = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + for uuid_, meta in abconf_mapping.items(): + conf_list.append(ConfInfo(**meta, id=uuid_)) + return conf_list + + def create_conf( + self, + umo_parts: list[str] | list[MessageSession], + config: dict = DEFAULT_CONFIG, + name: str | None = None, + ) -> str: + """ + umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。 + + umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。 + """ + conf_uuid = str(uuid.uuid4()) + conf_file_name = f"abconf_{conf_uuid}.json" + conf_path = os.path.join(get_astrbot_config_path(), conf_file_name) + conf = AstrBotConfig(config_path=conf_path, default_config=config) + conf.save_config() + self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name) + self.confs[conf_uuid] = conf + return conf_uuid + + def delete_conf(self, conf_id: str) -> bool: + """删除指定配置文件 + + Args: + conf_id: 配置文件的 UUID + + Returns: + bool: 删除是否成功 + + Raises: + ValueError: 如果试图删除默认配置文件 + """ + if conf_id == "default": + raise ValueError("不能删除默认配置文件") + + # 从映射中移除 + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + if conf_id not in abconf_data: + logger.warning(f"配置文件 {conf_id} 不存在于映射中") + return False + + # 获取配置文件路径 + conf_path = os.path.join( + get_astrbot_config_path(), abconf_data[conf_id]["path"] + ) + + # 删除配置文件 + try: + if os.path.exists(conf_path): + os.remove(conf_path) + logger.info(f"已删除配置文件: {conf_path}") + except Exception as e: + logger.error(f"删除配置文件 {conf_path} 失败: {e}") + return False + + # 从内存中移除 + if conf_id in self.confs: + del self.confs[conf_id] + + # 从映射中移除 + del abconf_data[conf_id] + self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") + + logger.info(f"成功删除配置文件 {conf_id}") + return True + + def update_conf_info( + self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None + ) -> bool: + """更新配置文件信息 + + Args: + conf_id: 配置文件的 UUID + name: 新的配置文件名称 (可选) + umo_parts: 新的 UMO 部分列表 (可选) + + Returns: + bool: 更新是否成功 + """ + if conf_id == "default": + raise ValueError("不能更新默认配置文件的信息") + + abconf_data = self.sp.get( + "abconf_mapping", {}, scope="global", scope_id="global" + ) + if conf_id not in abconf_data: + logger.warning(f"配置文件 {conf_id} 不存在于映射中") + return False + + # 更新名称 + if name is not None: + abconf_data[conf_id]["name"] = name + + # 更新 UMO 部分 + if umo_parts is not None: + # 验证 UMO 部分格式 + for part in umo_parts: + if isinstance(part, MessageSession): + part = str(part) + elif not isinstance(part, str): + raise ValueError( + "umo_parts must be a list of strings or MessageSession instances" + ) + abconf_data[conf_id]["umop"] = umo_parts + + # 保存更新 + self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global") + logger.info(f"成功更新配置文件 {conf_id} 的信息") + return True + + def g( + self, umo: str | None = None, key: str | None = None, default: _VT = None + ) -> _VT: + """获取配置项。umo 为 None 时使用默认配置""" + if umo is None: + return self.confs["default"].get(key, default) + conf = self.get_conf(umo) + return conf.get(key, default) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 637d30bcb..f9663b7a0 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -6,14 +6,13 @@ import os from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "3.5.26" -DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db") +VERSION = "4.0.0-beta.1" +DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") # 默认配置 DEFAULT_CONFIG = { "config_version": 2, "platform_settings": { - "plugin_enable": {}, "unique_session": False, "rate_limit": { "time": 60, @@ -51,21 +50,26 @@ DEFAULT_CONFIG = { "provider_settings": { "enable": True, "default_provider_id": "", + "default_image_caption_provider_id": "", + "image_caption_prompt": "Please describe the image using Chinese.", + "provider_pool": ["*"], # "*" 表示使用所有可用的提供者 "wake_prefix": "", "web_search": False, + "websearch_provider": "default", + "websearch_tavily_key": "", "web_search_link": False, "display_reasoning_text": False, "identifier": False, "datetime_system_prompt": True, "default_personality": "default", + "persona_pool": ["*"], "prompt_prefix": "", "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, "show_tool_use_status": False, "streaming_segmented": False, - "separate_provider": True, - "max_agent_step": 30 + "max_agent_step": 30, }, "provider_stt_settings": { "enable": False, @@ -81,13 +85,10 @@ DEFAULT_CONFIG = { "group_icl_enable": False, "group_message_max_cnt": 300, "image_caption": False, - "image_caption_provider_id": "", - "image_caption_prompt": "Please describe the image using Chinese.", "active_reply": { "enable": False, "method": "possibility_reply", "possibility_reply": 0.1, - "prompt": "", "whitelist": [], }, }, @@ -117,17 +118,17 @@ DEFAULT_CONFIG = { "log_level": "INFO", "pip_install_arg": "", "pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/", - "knowledge_db": {}, - "persona": [], - "timezone": "", + "persona": [], # deprecated + "timezone": "Asia/Shanghai", "callback_api_base": "", + "default_kb_collection": "", # 默认知识库名称 + "plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件 } # 配置项的中文描述、值类型 CONFIG_METADATA_2 = { "platform_group": { - "name": "消息平台", "metadata": { "platform": { "description": "消息平台适配器", @@ -385,152 +386,117 @@ CONFIG_METADATA_2 = { }, }, "platform_settings": { - "description": "平台设置", "type": "object", "items": { - "plugin_enable": { - "invisible": True, # 隐藏插件启用配置 - }, "unique_session": { - "description": "会话隔离", "type": "bool", - "hint": "启用后,在群组或者频道中,每个人的消息上下文都是独立的。", }, "rate_limit": { - "description": "速率限制", - "hint": "每个会话在 `time` 秒内最多只能发送 `count` 条消息。", "type": "object", "items": { - "time": {"description": "消息速率限制时间", "type": "int"}, - "count": {"description": "消息速率限制计数", "type": "int"}, + "time": {"type": "int"}, + "count": {"type": "int"}, "strategy": { - "description": "速率限制策略", "type": "string", "options": ["stall", "discard"], - "hint": "当消息速率超过限制时的处理策略。stall 为等待,discard 为丢弃。", }, }, }, "no_permission_reply": { - "description": "无权限回复", "type": "bool", "hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。", }, "empty_mention_waiting": { - "description": "只 @ 机器人是否触发等待", "type": "bool", "hint": "启用后,当消息内容只有 @ 机器人时,会触发等待,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。", }, "empty_mention_waiting_need_reply": { - "description": "只 @ 机器人触发等待时是否需要回复提醒", "type": "bool", "hint": "在上面一个配置项中,如果启用了触发等待,启用此项后,机器人会使用 LLM 生成一条回复。否则,将不回复而只是等待。", }, "friend_message_needs_wake_prefix": { - "description": "私聊消息是否需要唤醒前缀", "type": "bool", "hint": "启用后,私聊消息需要唤醒前缀才会被处理,同群聊一样。", }, "ignore_bot_self_message": { - "description": "是否忽略机器人自身的消息", "type": "bool", "hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人", }, "ignore_at_all": { - "description": "是否忽略 @ 全体成员", "type": "bool", "hint": "启用后,机器人会忽略 @ 全体成员 的消息事件。", }, "segmented_reply": { - "description": "分段回复", "type": "object", "items": { "enable": { - "description": "启用分段回复", "type": "bool", }, "only_llm_result": { - "description": "仅对 LLM 结果分段", "type": "bool", }, "interval_method": { - "description": "间隔时间计算方法", "type": "string", "options": ["random", "log"], "hint": "分段回复的间隔时间计算方法。random 为随机时间,log 为根据消息长度计算,$y=log_(x)$,x为字数,y的单位为秒。", }, "interval": { - "description": "随机间隔时间(秒)", "type": "string", "hint": "`random` 方法用。每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`", }, "log_base": { - "description": "对数函数底数", "type": "float", "hint": "`log` 方法用。对数函数的底数。默认为 2.6", }, "words_count_threshold": { - "description": "字数阈值", "type": "int", "hint": "超过这个字数的消息不会被分段回复。默认为 150", }, "regex": { - "description": "正则表达式", "type": "string", "hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'', text)", }, "content_cleanup_rule": { - "description": "过滤分段后的内容", "type": "string", "hint": "移除分段后的内容中的指定的内容。支持正则表达式。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。re.sub(r'', '', text)", }, }, }, "reply_prefix": { - "description": "回复前缀", "type": "string", "hint": "机器人回复消息时带有的前缀。", }, "forward_threshold": { - "description": "转发消息的字数阈值", "type": "int", "hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。", }, "enable_id_white_list": { - "description": "启用 ID 白名单", "type": "bool", }, "id_whitelist": { - "description": "ID 白名单", "type": "list", "items": {"type": "string"}, "hint": "只处理填写的 ID 发来的消息事件,为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单", }, "id_whitelist_log": { - "description": "打印白名单日志", "type": "bool", "hint": "启用后,当一条消息没通过白名单时,会输出 INFO 级别的日志。", }, "wl_ignore_admin_on_group": { - "description": "管理员群组消息无视 ID 白名单", "type": "bool", }, "wl_ignore_admin_on_friend": { - "description": "管理员私聊消息无视 ID 白名单", "type": "bool", }, "reply_with_mention": { - "description": "回复时 @ 发送者", "type": "bool", "hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。", }, "reply_with_quote": { - "description": "回复时引用消息", "type": "bool", "hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。", }, "path_mapping": { - "description": "路径映射", "type": "list", "items": {"type": "string"}, "hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。", @@ -538,41 +504,33 @@ CONFIG_METADATA_2 = { }, }, "content_safety": { - "description": "内容安全", "type": "object", "items": { "also_use_in_response": { - "description": "对大模型响应安全审核", "type": "bool", "hint": "启用后,大模型的响应也会通过内容安全审核。", }, "baidu_aip": { - "description": "百度内容审核配置", "type": "object", "items": { "enable": { - "description": "启用百度内容审核", "type": "bool", "hint": "启用此功能前,您需要手动在设备中安装 baidu-aip 库。一般来说,安装指令如下: `pip3 install baidu-aip`", }, "app_id": {"description": "APP ID", "type": "string"}, "api_key": {"description": "API Key", "type": "string"}, "secret_key": { - "description": "Secret Key", "type": "string", }, }, }, "internal_keywords": { - "description": "内部关键词过滤", "type": "object", "items": { "enable": { - "description": "启用内部关键词过滤", "type": "bool", }, "extra_keywords": { - "description": "额外关键词", "type": "list", "items": {"type": "string"}, "hint": "额外的屏蔽关键词列表,支持正则表达式。", @@ -587,7 +545,6 @@ CONFIG_METADATA_2 = { "name": "服务提供商", "metadata": { "provider": { - "description": "服务提供商配置", "type": "list", "config_template": { "OpenAI": { @@ -599,11 +556,9 @@ CONFIG_METADATA_2 = { "key": [], "api_base": "https://api.openai.com/v1", "timeout": 120, - "model_config": { - "model": "gpt-4o-mini", - "temperature": 0.4 - }, - "hint": "也兼容所有与OpenAI API兼容的服务。" + "model_config": {"model": "gpt-4o-mini", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], + "hint": "也兼容所有与 OpenAI API 兼容的服务。", }, "Azure OpenAI": { "id": "azure", @@ -615,10 +570,8 @@ CONFIG_METADATA_2 = { "key": [], "api_base": "", "timeout": 120, - "model_config": { - "model": "gpt-4o-mini", - "temperature": 0.4 - }, + "model_config": {"model": "gpt-4o-mini", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "xAI": { "id": "xai", @@ -629,10 +582,8 @@ CONFIG_METADATA_2 = { "key": [], "api_base": "https://api.x.ai/v1", "timeout": 120, - "model_config": { - "model": "grok-2-latest", - "temperature": 0.4 - }, + "model_config": {"model": "grok-2-latest", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "Anthropic": { "hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错", @@ -647,11 +598,12 @@ CONFIG_METADATA_2 = { "model_config": { "model": "claude-3-5-sonnet-latest", "max_tokens": 4096, - "temperature": 0.2 + "temperature": 0.2, }, + "modalities": ["text", "image", "tool_use"], }, "Ollama": { - "hint":"启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key", + "hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key", "id": "ollama_default", "provider": "ollama", "type": "openai_chat_completion", @@ -659,10 +611,8 @@ CONFIG_METADATA_2 = { "enable": True, "key": ["ollama"], # ollama 的 key 默认是 ollama "api_base": "http://localhost:11434/v1", - "model_config": { - "model": "llama3.1-8b", - "temperature": 0.4 - }, + "model_config": {"model": "llama3.1-8b", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "LM Studio": { "id": "lm_studio", @@ -675,6 +625,7 @@ CONFIG_METADATA_2 = { "model_config": { "model": "llama-3.1-8b", }, + "modalities": ["text", "image", "tool_use"], }, "Gemini(OpenAI兼容)": { "id": "gemini_default", @@ -687,8 +638,9 @@ CONFIG_METADATA_2 = { "timeout": 120, "model_config": { "model": "gemini-1.5-flash", - "temperature": 0.4 + "temperature": 0.4, }, + "modalities": ["text", "image", "tool_use"], }, "Gemini": { "id": "gemini_default", @@ -701,7 +653,7 @@ CONFIG_METADATA_2 = { "timeout": 120, "model_config": { "model": "gemini-2.0-flash-exp", - "temperature": 0.4 + "temperature": 0.4, }, "gm_resp_image_modal": False, "gm_native_search": False, @@ -716,6 +668,7 @@ CONFIG_METADATA_2 = { "gm_thinking_config": { "budget": 0, }, + "modalities": ["text", "image", "tool_use"], }, "DeepSeek": { "id": "deepseek_default", @@ -726,10 +679,8 @@ CONFIG_METADATA_2 = { "key": [], "api_base": "https://api.deepseek.com/v1", "timeout": 120, - "model_config": { - "model": "deepseek-chat", - "temperature": 0.4 - }, + "model_config": {"model": "deepseek-chat", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "302.AI": { "id": "302ai", @@ -740,10 +691,8 @@ CONFIG_METADATA_2 = { "key": [], "api_base": "https://api.302.ai/v1", "timeout": 120, - "model_config": { - "model": "gpt-4.1-mini", - "temperature": 0.4 - }, + "model_config": {"model": "gpt-4.1-mini", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "硅基流动": { "id": "siliconflow", @@ -756,8 +705,9 @@ CONFIG_METADATA_2 = { "api_base": "https://api.siliconflow.cn/v1", "model_config": { "model": "deepseek-ai/DeepSeek-V3", - "temperature": 0.4 + "temperature": 0.4, }, + "modalities": ["text", "image", "tool_use"], }, "PPIO派欧云": { "id": "ppio", @@ -770,7 +720,7 @@ CONFIG_METADATA_2 = { "timeout": 120, "model_config": { "model": "deepseek/deepseek-r1", - "temperature": 0.4 + "temperature": 0.4, }, }, "优云智算": { @@ -785,6 +735,7 @@ CONFIG_METADATA_2 = { "model_config": { "model": "moonshotai/Kimi-K2-Instruct", }, + "modalities": ["text", "image", "tool_use"], }, "Kimi": { "id": "moonshot", @@ -795,10 +746,8 @@ CONFIG_METADATA_2 = { "key": [], "timeout": 120, "api_base": "https://api.moonshot.cn/v1", - "model_config": { - "model": "moonshot-v1-8k", - "temperature": 0.4 - }, + "model_config": {"model": "moonshot-v1-8k", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "智谱 AI": { "id": "zhipu_default", @@ -812,6 +761,7 @@ CONFIG_METADATA_2 = { "model_config": { "model": "glm-4-flash", }, + "modalities": ["text", "image", "tool_use"], }, "Dify": { "id": "dify_app_default", @@ -826,7 +776,7 @@ CONFIG_METADATA_2 = { "dify_query_input_key": "astrbot_text_query", "variables": {}, "timeout": 60, - "hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!" + "hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!", }, "阿里云百炼应用": { "id": "dashscope", @@ -854,10 +804,8 @@ CONFIG_METADATA_2 = { "key": [], "timeout": 120, "api_base": "https://api-inference.modelscope.cn/v1", - "model_config": { - "model": "Qwen/Qwen3-32B", - "temperature": 0.4 - }, + "model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4}, + "modalities": ["text", "image", "tool_use"], }, "FastGPT": { "id": "fastgpt", @@ -1073,8 +1021,42 @@ CONFIG_METADATA_2 = { "embedding_dimensions": 768, "timeout": 20, }, + "vLLM Rerank": { + "id": "vllm_rerank", + "type": "vllm_rerank", + "provider": "vllm", + "provider_type": "rerank", + "enable": True, + "rerank_api_key": "", + "rerank_api_base": "http://127.0.0.1:8000", + "rerank_model": "BAAI/bge-reranker-base", + "timeout": 20, + }, }, "items": { + "rerank_api_base": { + "description": "重排序模型 API Base URL", + "type": "string", + "hint": "AstrBot 会在请求时在末尾加上 /v1/rerank。", + }, + "rerank_api_key": { + "description": "API Key", + "type": "string", + "hint": "如果不需要 API Key, 请留空。", + }, + "rerank_model": { + "description": "重排序模型名称", + "type": "string", + }, + "modalities": { + "description": "模型能力", + "type": "list", + "items": {"type": "string"}, + "options": ["text", "image", "tool_use"], + "labels": ["文本", "图像", "工具使用"], + "render_type": "checkbox", + "hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。", + }, "provider": { "type": "string", "invisible": True, @@ -1654,88 +1636,52 @@ CONFIG_METADATA_2 = { }, }, "provider_settings": { - "description": "大语言模型设置", "type": "object", "items": { "enable": { - "description": "启用大语言模型聊天", "type": "bool", - "hint": "如需切换大语言模型提供商,请使用 /provider 命令。", - }, - "separate_provider": { - "description": "提供商会话隔离", - "type": "bool", - "hint": "启用后,每个会话支持独立选择文本生成、STT、TTS 等提供商。如果会话在使用 /provider 指令时提示无权限,可以将会话加入管理员名单或者使用 /alter_cmd provider member 将指令设为非管理员指令。", }, "default_provider_id": { - "description": "默认模型提供商 ID", "type": "string", - "hint": "可选。每个聊天会话的默认提供商 ID。", }, "wake_prefix": { - "description": "LLM 聊天额外唤醒前缀", "type": "string", - "hint": "使用 LLM 聊天额外的触发条件。如填写 `chat`,则需要消息前缀加上 `/chat` 才能触发 LLM 聊天,是一个防止滥用的手段。", }, "web_search": { - "description": "启用网页搜索", "type": "bool", - "hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。", }, "web_search_link": { - "description": "网页搜索引用链接", "type": "bool", - "hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。", }, "display_reasoning_text": { - "description": "显示思考内容", "type": "bool", - "hint": "开启后,将在回复中显示模型的思考过程。", }, "identifier": { - "description": "启动识别群员", "type": "bool", - "hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。", }, "datetime_system_prompt": { - "description": "启用日期时间系统提示", "type": "bool", - "hint": "启用后,会在系统提示词中加上当前机器的日期时间。", }, "default_personality": { - "description": "默认采用的人格情景的名称", "type": "string", - "hint": "", }, "prompt_prefix": { - "description": "Prompt 前缀文本", "type": "string", - "hint": "添加之后,会在每次对话的 Prompt 前加上此文本。", }, "max_context_length": { - "description": "最多携带对话数量(条)", "type": "int", - "hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。", }, "dequeue_context_length": { - "description": "丢弃对话数量(条)", "type": "int", - "hint": "超出 最多携带对话数量(条) 时,丢弃多少条记录,用户和AI的一轮聊天记为 1 条。适宜的配置,可以提高超长上下文对话 deepseek 命中缓存效果,理想情况下计费将降低到1/3以下", }, "streaming_response": { - "description": "启用流式回复", "type": "bool", - "hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台", }, "show_tool_use_status": { - "description": "函数调用状态输出", "type": "bool", - "hint": "在触发函数调用时输出其函数名和内容。", }, "streaming_segmented": { - "description": "不支持流式回复的平台分段输出", "type": "bool", - "hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项", }, "max_agent_step": { "description": "工具调用轮数上限", @@ -1743,143 +1689,65 @@ CONFIG_METADATA_2 = { }, }, }, - "persona": { - "description": "人格情景设置", - "type": "list", - "config_template": { - "新人格情景": { - "name": "", - "prompt": "", - "begin_dialogs": [], - "mood_imitation_dialogs": [], - } - }, - "tmpl_display_title": "name", - "items": { - "name": { - "description": "人格名称", - "type": "string", - "hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。", - }, - "prompt": { - "description": "设定(系统提示词)", - "type": "text", - "hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。", - }, - "begin_dialogs": { - "description": "预设对话", - "type": "list", - "items": {"type": "string"}, - "hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话", - }, - "mood_imitation_dialogs": { - "description": "对话风格模仿", - "type": "list", - "items": {"type": "string"}, - "hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话", - }, - }, - }, "provider_stt_settings": { - "description": "语音转文本(STT)", "type": "object", "items": { "enable": { - "description": "启用语音转文本(STT)", "type": "bool", - "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。", }, "provider_id": { - "description": "提供商 ID", "type": "string", - "hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。", }, }, }, "provider_tts_settings": { - "description": "文本转语音(TTS)", "type": "object", "items": { "enable": { - "description": "启用文本转语音(TTS)", "type": "bool", - "hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。", }, "provider_id": { - "description": "提供商 ID", "type": "string", - "hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。", }, "dual_output": { - "description": "启用语音和文字双输出", "type": "bool", - "hint": "启用后,Bot 将同时输出语音和文字消息。", }, "use_file_service": { - "description": "使用文件服务提供 TTS 语音文件", "type": "bool", - "hint": "启用后,如已配置 callback_api_base ,将会使用文件服务提供TTS语音文件", }, }, }, "provider_ltm_settings": { - "description": "聊天记忆增强(Beta)", "type": "object", "items": { "group_icl_enable": { - "description": "群聊内记录各群员对话", "type": "bool", - "hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。", }, "group_message_max_cnt": { - "description": "群聊消息最大数量", "type": "int", - "hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。", }, "image_caption": { - "description": "群聊图像转述(需模型支持)", "type": "bool", - "hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入。", - }, - "image_caption_provider_id": { - "description": "图像转述提供商 ID", - "type": "string", - "hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。", }, "image_caption_prompt": { - "description": "图像转述提示词", "type": "string", }, "active_reply": { - "description": "主动回复", "type": "object", "items": { "enable": { - "description": "启用主动回复", "type": "bool", - "hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用", }, "whitelist": { - "description": "主动回复白名单", "type": "list", "items": {"type": "string"}, - "hint": "启用后,只有在白名单内的群聊会被主动回复。为空时不启用白名单过滤。需要通过 /sid 获取 SID 添加到这里。", }, "method": { - "description": "回复方法", "type": "string", "options": ["possibility_reply"], - "hint": "回复方法。possibility_reply 为根据概率回复", }, "possibility_reply": { - "description": "回复概率", "type": "float", - "hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。", - }, - "prompt": { - "description": "提示词", - "type": "string", - "hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。", }, }, }, @@ -1888,34 +1756,23 @@ CONFIG_METADATA_2 = { }, }, "misc_config_group": { - "name": "其他配置", "metadata": { "wake_prefix": { - "description": "机器人唤醒前缀", "type": "list", "items": {"type": "string"}, - "hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。", }, "t2i": { - "description": "文本转图像", "type": "bool", - "hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。", }, "t2i_word_threshold": { - "description": "文本转图像字数阈值", "type": "int", - "hint": "超出此字符长度的文本将会被转换成图片。字数不能低于 50。", }, "admins_id": { - "description": "管理员 ID", "type": "list", "items": {"type": "string"}, - "hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。", }, "http_proxy": { - "description": "HTTP 代理", "type": "string", - "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`", }, "no_proxy": { "description": "直连地址列表", @@ -1924,51 +1781,553 @@ CONFIG_METADATA_2 = { "hint": "在此处添加不希望通过代理访问的地址,例如内部服务地址。回车添加,可添加多个,如未设置代理请忽略此配置", }, "timezone": { - "description": "时区", "type": "string", - "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab", }, "callback_api_base": { - "description": "对外可达的回调接口地址", "type": "string", - "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。", }, "log_level": { - "description": "控制台日志级别", "type": "string", - "hint": "控制台输出日志的级别。", "options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], }, "t2i_strategy": { - "description": "文本转图像渲染源", "type": "string", - "hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。", "options": ["remote", "local"], }, "t2i_endpoint": { - "description": "文本转图像服务接口", "type": "string", - "hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务", }, "t2i_use_file_service": { - "description": "本地文本转图像使用文件服务提供文件", "type": "bool", - "hint": "当 t2i_strategy 为 local 并且配置 callback_api_base 时生效。是否使用文件服务提供文件。", }, "pip_install_arg": { - "description": "pip 安装参数", "type": "string", - "hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。", }, "pypi_index_url": { - "description": "PyPI 软件仓库地址", "type": "string", - "hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/", + }, + "default_kb_collection": { + "type": "string", }, }, }, } + +CONFIG_METADATA_3 = { + "ai_group": { + "name": "AI 配置", + "metadata": { + "ai": { + "description": "模型", + "type": "object", + "items": { + "provider_settings.enable": { + "description": "启用大语言模型聊天", + "type": "bool", + }, + "provider_settings.default_provider_id": { + "description": "默认聊天模型", + "type": "string", + "_special": "select_provider", + "hint": "留空时使用第一个模型。", + }, + "provider_settings.default_image_caption_provider_id": { + "description": "默认图片转述模型", + "type": "string", + "_special": "select_provider", + "hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。", + }, + "provider_stt_settings.provider_id": { + "description": "语音转文本模型", + "type": "string", + "hint": "留空代表不使用。", + "_special": "select_provider_stt", + }, + "provider_tts_settings.provider_id": { + "description": "文本转语音模型", + "type": "string", + "hint": "留空代表不使用。", + "_special": "select_provider_tts", + }, + "provider_settings.image_caption_prompt": { + "description": "图片转述提示词", + "type": "text", + }, + }, + }, + "persona": { + "description": "人格", + "type": "object", + "items": { + "provider_settings.default_personality": { + "description": "默认采用的人格", + "type": "string", + "_special": "select_persona", + }, + }, + }, + "knowledgebase": { + "description": "知识库", + "type": "object", + "items": { + "default_kb_collection": { + "description": "默认使用的知识库", + "type": "string", + "_special": "select_knowledgebase", + }, + }, + }, + "websearch": { + "description": "网页搜索", + "type": "object", + "items": { + "provider_settings.web_search": { + "description": "启用网页搜索", + "type": "bool", + }, + "provider_settings.websearch_provider": { + "description": "网页搜索提供商", + "type": "string", + "options": ["default", "tavily"], + }, + "provider_settings.websearch_tavily_key": { + "description": "Tavily API Key", + "type": "string", + "condition": { + "provider_settings.websearch_provider": "tavily", + }, + }, + "provider_settings.web_search_link": { + "description": "显示来源引用", + "type": "bool", + }, + }, + }, + "others": { + "description": "其他配置", + "type": "object", + "items": { + "provider_settings.display_reasoning_text": { + "description": "显示思考内容", + "type": "bool", + }, + "provider_settings.identifier": { + "description": "用户感知", + "type": "bool", + }, + "provider_settings.datetime_system_prompt": { + "description": "现实世界时间感知", + "type": "bool", + }, + "provider_settings.show_tool_use_status": { + "description": "输出函数调用状态", + "type": "bool", + }, + "provider_settings.max_agent_step": { + "description": "工具调用轮数上限", + "type": "bool", + }, + "provider_settings.streaming_response": { + "description": "流式回复", + "type": "bool", + }, + "provider_settings.streaming_segmented": { + "description": "不支持流式回复的平台采取分段输出", + "type": "bool", + }, + "provider_settings.max_context_length": { + "description": "最多携带对话轮数", + "type": "int", + "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。", + }, + "provider_settings.dequeue_context_length": { + "description": "丢弃对话轮数", + "type": "int", + "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数。", + }, + "provider_settings.wake_prefix": { + "description": "LLM 聊天额外唤醒前缀 ", + "type": "string", + }, + "provider_settings.prompt_prefix": { + "description": "额外前缀提示词", + "type": "string", + }, + "provider_settings.dual_output": { + "description": "开启 TTS 时同时输出语音和文字内容", + "type": "bool", + }, + }, + }, + }, + }, + "platform_group": { + "name": "平台配置", + "metadata": { + "general": { + "description": "基本", + "type": "object", + "items": { + "admins_id": { + "description": "管理员 ID", + "type": "list", + "items": {"type": "string"}, + }, + "platform_settings.unique_session": { + "description": "隔离会话", + "type": "bool", + "hint": "启用后,群成员的上下文独立。", + }, + "wake_prefix": { + "description": "唤醒词", + "type": "list", + "items": {"type": "string"}, + }, + "platform_settings.friend_message_needs_wake_prefix": { + "description": "私聊消息需要唤醒词", + "type": "bool", + }, + "platform_settings.reply_prefix": { + "description": "回复时的文本前缀", + "type": "string", + }, + "platform_settings.reply_with_mention": { + "description": "回复时 @ 发送人", + "type": "bool", + }, + "platform_settings.reply_with_quote": { + "description": "回复时引用发送人消息", + "type": "bool", + }, + "platform_settings.forward_threshold": { + "description": "转发消息的字数阈值", + "type": "int", + }, + "platform_settings.empty_mention_waiting": { + "description": "只 @ 机器人是否触发等待", + "type": "bool", + }, + }, + }, + "whitelist": { + "description": "白名单", + "type": "object", + "items": { + "platform_settings.enable_id_white_list": { + "description": "启用白名单", + "type": "bool", + "hint": "启用后,只有在白名单内的会话会被响应。", + }, + "platform_settings.id_whitelist": { + "description": "白名单 ID 列表", + "type": "list", + "items": {"type": "string"}, + "hint": "使用 /sid 获取 ID。", + }, + "platform_settings.id_whitelist_log": { + "description": "输出日志", + "type": "bool", + "hint": "启用后,当一条消息没通过白名单时,会输出 INFO 级别的日志。", + }, + "platform_settings.wl_ignore_admin_on_group": { + "description": "管理员群组消息无视 ID 白名单", + "type": "bool", + }, + "platform_settings.wl_ignore_admin_on_friend": { + "description": "管理员私聊消息无视 ID 白名单", + "type": "bool", + }, + }, + }, + "rate_limit": { + "description": "速率限制", + "type": "object", + "items": { + "platform_settings.rate_limit.time": { + "description": "消息速率限制时间(秒)", + "type": "int", + }, + "platform_settings.rate_limit.count": { + "description": "消息速率限制计数", + "type": "int", + }, + "platform_settings.rate_limit.strategy": { + "description": "速率限制策略", + "type": "string", + "options": ["stall", "discard"], + }, + }, + }, + "content_safety": { + "description": "内容安全", + "type": "object", + "items": { + "platform_settings.content_safety.also_use_in_response": { + "description": "同时检查模型的响应内容", + "type": "bool", + }, + "platform_settings.content_safety.baidu_aip.enable": { + "description": "使用百度内容安全审核", + "type": "bool", + "hint": "您需要手动安装 baidu-aip 库。", + }, + "platform_settings.content_safety.baidu_aip.app_id": { + "description": "App ID", + "type": "string", + "condition": { + "platform_settings.content_safety.baidu_aip.enable": True, + }, + }, + "platform_settings.content_safety.baidu_aip.api_key": { + "description": "API Key", + "type": "string", + "condition": { + "platform_settings.content_safety.baidu_aip.enable": True, + }, + }, + "platform_settings.content_safety.baidu_aip.secret_key": { + "description": "Secret Key", + "type": "string", + "condition": { + "platform_settings.content_safety.baidu_aip.enable": True, + }, + }, + "platform_settings.content_safety.internal_keywords.enable": { + "description": "关键词检查", + "type": "bool", + }, + "platform_settings.content_safety.internal_keywords.extra_keywords": { + "description": "额外关键词", + "type": "list", + "items": {"type": "string"}, + "hint": "额外的屏蔽关键词列表,支持正则表达式。", + }, + }, + }, + "t2i": { + "description": "文本转图像", + "type": "object", + "items": { + "t2i": { + "description": "文本转图像输出", + "type": "bool", + }, + "t2i_word_threshold": { + "description": "文本转图像字数阈值", + "type": "int", + }, + }, + }, + "others": { + "description": "其他配置", + "type": "object", + "items": { + "platform_settings.ignore_bot_self_message": { + "description": "是否忽略机器人自身的消息", + "type": "bool", + }, + "platform_settings.ignore_at_all": { + "description": "是否忽略 @ 全体成员事件", + "type": "bool", + }, + "platform_settings.no_permission_reply": { + "description": "用户权限不足时是否回复", + "type": "bool", + }, + }, + }, + }, + }, + "plugin_group": { + "name": "插件配置", + "metadata": { + "plugin": { + "description": "插件", + "type": "object", + "items": { + "plugin_set": { + "description": "可用插件", + "type": "bool", + "hint": "默认启用全部未被禁用的插件。若插件在插件页面被禁用,则此处的选择不会生效。", + "_special": "select_plugin_set", + }, + }, + }, + }, + }, + "ext_group": { + "name": "扩展功能", + "metadata": { + "segmented_reply": { + "description": "分段回复", + "type": "object", + "items": { + "platform_settings.segmented_reply.enable": { + "description": "启用分段回复", + "type": "bool", + }, + "platform_settings.segmented_reply.only_llm_result": { + "description": "仅对 LLM 结果分段", + "type": "bool", + }, + "platform_settings.segmented_reply.interval_method": { + "description": "间隔方法", + "type": "string", + "options": ["random", "log"], + }, + "platform_settings.segmented_reply.interval": { + "description": "随机间隔时间", + "type": "string", + "hint": "格式:最小值,最大值(如:1.5,3.5)", + "condition": { + "platform_settings.segmented_reply.interval_method": "random", + }, + }, + "platform_settings.segmented_reply.log_base": { + "description": "对数底数", + "type": "float", + "hint": "对数间隔的底数,默认为 2.0。取值范围为 1.0-10.0。", + "condition": { + "platform_settings.segmented_reply.interval_method": "log", + }, + }, + "platform_settings.segmented_reply.words_count_threshold": { + "description": "分段回复字数阈值", + "type": "int", + }, + "platform_settings.segmented_reply.regex": { + "description": "分段正则表达式", + "type": "string", + }, + "platform_settings.segmented_reply.content_cleanup_rule": { + "description": "内容过滤正则表达式", + "type": "string", + "hint": "移除分段后内容中的指定内容。如填写 `[。?!]` 将移除所有的句号、问号、感叹号。", + }, + }, + }, + "ltm": { + "description": "群聊上下文感知(原聊天记忆增强)", + "type": "object", + "items": { + "provider_ltm_settings.group_icl_enable": { + "description": "启用群聊上下文感知", + "type": "bool", + }, + "provider_ltm_settings.group_message_max_cnt": { + "description": "最大消息数量", + "type": "int", + }, + "provider_ltm_settings.image_caption": { + "description": "自动理解图片", + "type": "bool", + "hint": "需要设置默认图片转述模型。", + }, + "provider_ltm_settings.active_reply.enable": { + "description": "主动回复", + "type": "bool", + }, + "provider_ltm_settings.active_reply.method": { + "description": "主动回复方法", + "type": "string", + "options": ["possibility_reply"], + "condition": { + "provider_ltm_settings.active_reply.enable": True, + }, + }, + "provider_ltm_settings.active_reply.possibility_reply": { + "description": "回复概率", + "type": "float", + "hint": "0.0-1.0 之间的数值", + "condition": { + "provider_ltm_settings.active_reply.enable": True, + }, + }, + "provider_ltm_settings.active_reply.whitelist": { + "description": "主动回复白名单", + "type": "list", + "items": {"type": "string"}, + "hint": "为空时不启用白名单过滤。使用 /sid 获取 ID。", + "condition": { + "provider_ltm_settings.active_reply.enable": True, + }, + }, + }, + }, + }, + }, +} + +CONFIG_METADATA_3_SYSTEM = { + "system_group": { + "name": "系统配置", + "metadata": { + "system": { + "description": "系统配置", + "type": "object", + "items": { + "t2i_strategy": { + "description": "文本转图像策略", + "type": "string", + "hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。", + "options": ["remote", "local"], + }, + "t2i_endpoint": { + "description": "文本转图像服务 API 地址", + "type": "string", + "hint": "为空时使用 AstrBot API 服务", + "condition": { + "t2i_strategy": "remote", + }, + }, + "t2i_template": { + "description": "文本转图像自定义模版", + "type": "bool", + "hint": "启用后可自定义 HTML 模板用于文转图渲染。", + "condition": { + "t2i_strategy": "remote", + }, + "_special": "t2i_template" + }, + "log_level": { + "description": "控制台日志级别", + "type": "string", + "hint": "控制台输出日志的级别。", + "options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + }, + "pip_install_arg": { + "description": "pip 安装额外参数", + "type": "string", + "hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。", + }, + "pypi_index_url": { + "description": "PyPI 软件仓库地址", + "type": "string", + "hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/", + }, + "callback_api_base": { + "description": "对外可达的回调接口地址", + "type": "string", + "hint": "外部服务可能会通过 AstrBot 生成的回调链接(如文件下载链接)访问 AstrBot 后端。由于 AstrBot 无法自动判断部署环境中对外可达的主机地址(host),因此需要通过此配置项显式指定 “外部服务如何访问 AstrBot” 的地址。如 http://localhost:6185,https://example.com 等。", + }, + "timezone": { + "description": "时区", + "type": "string", + "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab", + }, + "http_proxy": { + "description": "HTTP 代理", + "type": "string", + "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`", + }, + }, + } + }, + } +} + + DEFAULT_VALUE_MAP = { "int": 0, "float": 0.0, diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index b665488e4..76112fa60 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -5,40 +5,44 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 在一个会话中可以建立多个对话, 并且支持对话的切换和删除 """ -import uuid import json -import asyncio from astrbot.core import sp from typing import Dict, List from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import Conversation +from astrbot.core.db.po import Conversation, ConversationV2 class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): - # session_conversations 字典记录会话ID-对话ID 映射关系 - self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) + self.session_conversations: Dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 - self._start_periodic_save() - def _start_periodic_save(self): - """启动定时保存任务""" - asyncio.create_task(self._periodic_save()) + def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: + """将 ConversationV2 对象转换为 Conversation 对象""" + created_at = int(conv_v2.created_at.timestamp()) + updated_at = int(conv_v2.updated_at.timestamp()) + return Conversation( + platform_id=conv_v2.platform_id, + user_id=conv_v2.user_id, + cid=conv_v2.conversation_id, + history=json.dumps(conv_v2.content or []), + title=conv_v2.title, + persona_id=conv_v2.persona_id, + created_at=created_at, + updated_at=updated_at, + ) - async def _periodic_save(self): - """定时保存会话对话映射关系到存储中""" - while True: - await asyncio.sleep(self.save_interval) - self._save_to_storage() - - def _save_to_storage(self): - """保存会话对话映射关系到存储中""" - sp.put("session_conversation", self.session_conversations) - - async def new_conversation(self, unified_msg_origin: str) -> str: + async def new_conversation( + self, + unified_msg_origin: str, + platform_id: str | None = None, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + ) -> str: """新建对话,并将当前会话的对话转移到新对话 Args: @@ -46,11 +50,23 @@ class ConversationManager: Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ - conversation_id = str(uuid.uuid4()) - self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id) - self.session_conversations[unified_msg_origin] = conversation_id - sp.put("session_conversation", self.session_conversations) - return conversation_id + if not platform_id: + # 如果没有提供 platform_id,则从 unified_msg_origin 中解析 + parts = unified_msg_origin.split(":") + if len(parts) >= 3: + platform_id = parts[0] + if not platform_id: + platform_id = "unknown" + conv = await self.db.create_conversation( + user_id=unified_msg_origin, + platform_id=platform_id, + content=content, + title=title, + persona_id=persona_id, + ) + self.session_conversations[unified_msg_origin] = conv.conversation_id + await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) + return conv.conversation_id async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): """切换会话的对话 @@ -60,10 +76,10 @@ class ConversationManager: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ self.session_conversations[unified_msg_origin] = conversation_id - sp.put("session_conversation", self.session_conversations) + await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id) async def delete_conversation( - self, unified_msg_origin: str, conversation_id: str = None + self, unified_msg_origin: str, conversation_id: str | None = None ): """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 @@ -71,13 +87,18 @@ class ConversationManager: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ - conversation_id = self.session_conversations.get(unified_msg_origin) + f = False + if not conversation_id: + conversation_id = self.session_conversations.get(unified_msg_origin) + if conversation_id: + f = True if conversation_id: - self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id) - del self.session_conversations[unified_msg_origin] - sp.put("session_conversation", self.session_conversations) + await self.db.delete_conversation(cid=conversation_id) + if f: + self.session_conversations.pop(unified_msg_origin, None) + await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: + async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None: """获取会话当前的对话 ID Args: @@ -85,14 +106,19 @@ class ConversationManager: Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ - return self.session_conversations.get(unified_msg_origin, None) + ret = self.session_conversations.get(unified_msg_origin, None) + if not ret: + ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None) + if ret: + self.session_conversations[unified_msg_origin] = ret + return ret async def get_conversation( self, unified_msg_origin: str, conversation_id: str, create_if_not_exists: bool = False, - ) -> Conversation: + ) -> Conversation | None: """获取会话的对话 Args: @@ -101,27 +127,74 @@ class ConversationManager: Returns: conversation (Conversation): 对话对象 """ - conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) + conv = await self.db.get_conversation_by_id(cid=conversation_id) if not conv and create_if_not_exists: # 如果对话不存在且需要创建,则新建一个对话 conversation_id = await self.new_conversation(unified_msg_origin) - return self.db.get_conversation_by_user_id( - unified_msg_origin, conversation_id - ) - return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) + conv = await self.db.get_conversation_by_id(cid=conversation_id) + conv_res = None + if conv: + conv_res = self._convert_conv_from_v2_to_v1(conv) + return conv_res - async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: - """获取会话的所有对话 + async def get_conversations( + self, unified_msg_origin: str | None = None, platform_id: str | None = None + ) -> List[Conversation]: + """获取对话列表 Args: - unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选 + platform_id (str): 平台 ID, 可选参数, 用于过滤对话 Returns: conversations (List[Conversation]): 对话对象列表 """ - return self.db.get_conversations(unified_msg_origin) + convs = await self.db.get_conversations( + user_id=unified_msg_origin, platform_id=platform_id + ) + convs_res = [] + for conv in convs: + conv_res = self._convert_conv_from_v2_to_v1(conv) + convs_res.append(conv_res) + return convs_res + + async def get_filtered_conversations( + self, + page: int = 1, + page_size: int = 20, + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs, + ) -> tuple[list[Conversation], int]: + """获取过滤后的对话列表 + + Args: + page (int): 页码, 默认为 1 + page_size (int): 每页大小, 默认为 20 + platform_ids (list[str]): 平台 ID 列表, 可选 + search_query (str): 搜索查询字符串, 可选 + Returns: + conversations (list[Conversation]): 对话对象列表 + """ + convs, cnt = await self.db.get_filtered_conversations( + page=page, + page_size=page_size, + platform_ids=platform_ids, + search_query=search_query, + **kwargs, + ) + convs_res = [] + for conv in convs: + conv_res = self._convert_conv_from_v2_to_v1(conv) + convs_res.append(conv_res) + return convs_res, cnt async def update_conversation( - self, unified_msg_origin: str, conversation_id: str, history: List[Dict] + self, + unified_msg_origin: str, + conversation_id: str | None = None, + history: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, ): """更新会话的对话 @@ -130,40 +203,55 @@ class ConversationManager: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 """ + if not conversation_id: + # 如果没有提供 conversation_id,则获取当前的 + conversation_id = await self.get_curr_conversation_id(unified_msg_origin) if conversation_id: - self.db.update_conversation( - user_id=unified_msg_origin, + await self.db.update_conversation( cid=conversation_id, - history=json.dumps(history), + title=title, + persona_id=persona_id, + content=history, ) - async def update_conversation_title(self, unified_msg_origin: str, title: str): + async def update_conversation_title( + self, unified_msg_origin: str, title: str, conversation_id: str | None = None + ): """更新会话的对话标题 Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id title (str): 对话标题 + + Deprecated: + Use `update_conversation` with `title` parameter instead. """ - conversation_id = self.session_conversations.get(unified_msg_origin) - if conversation_id: - self.db.update_conversation_title( - user_id=unified_msg_origin, cid=conversation_id, title=title - ) + await self.update_conversation( + unified_msg_origin=unified_msg_origin, + conversation_id=conversation_id, + title=title, + ) async def update_conversation_persona_id( - self, unified_msg_origin: str, persona_id: str + self, + unified_msg_origin: str, + persona_id: str, + conversation_id: str | None = None, ): """更新会话的对话 Persona ID Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id persona_id (str): 对话 Persona ID + + Deprecated: + Use `update_conversation` with `persona_id` parameter instead. """ - conversation_id = self.session_conversations.get(unified_msg_origin) - if conversation_id: - self.db.update_conversation_persona_id( - user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id - ) + await self.update_conversation( + unified_msg_origin=unified_msg_origin, + conversation_id=conversation_id, + persona_id=persona_id, + ) async def get_human_readable_context( self, unified_msg_origin, conversation_id, page=1, page_size=10 diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 12226d9e1..972a5f4f1 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -15,20 +15,23 @@ import time import threading import os from .event_bus import EventBus -from . import astrbot_config +from . import astrbot_config, html_renderer from asyncio import Queue from typing import List from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext from astrbot.core.star import PluginManager from astrbot.core.platform.manager import PlatformManager from astrbot.core.star.context import Context +from astrbot.core.persona_mgr import PersonaManager from astrbot.core.provider.manager import ProviderManager from astrbot.core import LogBroker from astrbot.core.db import BaseDatabase from astrbot.core.updator import AstrBotUpdator -from astrbot.core import logger +from astrbot.core import logger, sp from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star_handler import star_map @@ -77,11 +80,26 @@ class AstrBotCoreLifecycle: else: logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别 + await self.db.initialize() + + await html_renderer.initialize() + + # 初始化 AstrBot 配置管理器 + self.astrbot_config_mgr = AstrBotConfigManager( + default_config=self.astrbot_config, sp=sp + ) + # 初始化事件队列 self.event_queue = Queue() + # 初始化人格管理器 + self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr) + await self.persona_mgr.initialize() + # 初始化供应商管理器 - self.provider_manager = ProviderManager(self.astrbot_config, self.db) + self.provider_manager = ProviderManager( + self.astrbot_config_mgr, self.db, self.persona_mgr + ) # 初始化平台管理器 self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) @@ -89,6 +107,9 @@ class AstrBotCoreLifecycle: # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) + # 初始化平台消息历史管理器 + self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) + # 初始化提供给插件的上下文 self.star_context = Context( self.event_queue, @@ -97,6 +118,9 @@ class AstrBotCoreLifecycle: self.provider_manager, self.platform_manager, self.conversation_manager, + self.platform_message_history_manager, + self.persona_mgr, + self.astrbot_config_mgr, ) # 初始化插件管理器 @@ -109,16 +133,16 @@ class AstrBotCoreLifecycle: await self.provider_manager.initialize() # 初始化消息事件流水线调度器 - self.pipeline_scheduler = PipelineScheduler( - PipelineContext(self.astrbot_config, self.plugin_manager) - ) - await self.pipeline_scheduler.initialize() + + self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler() # 初始化更新器 self.astrbot_updator = AstrBotUpdator() # 初始化事件总线 - self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) + self.event_bus = EventBus( + self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr + ) # 记录启动时间 self.start_time = int(time.time()) @@ -235,6 +259,39 @@ class AstrBotCoreLifecycle: platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: tasks.append( - asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name) + asyncio.create_task( + platform_inst.run(), + name=f"{platform_inst.meta().id}({platform_inst.meta().name})", + ) ) return tasks + + async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]: + """加载消息事件流水线调度器 + + Returns: + dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 + """ + mapping = {} + for conf_id, ab_config in self.astrbot_config_mgr.confs.items(): + scheduler = PipelineScheduler( + PipelineContext(ab_config, self.plugin_manager, conf_id) + ) + await scheduler.initialize() + mapping[conf_id] = scheduler + return mapping + + async def reload_pipeline_scheduler(self, conf_id: str): + """重新加载消息事件流水线调度器 + + Returns: + dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 + """ + ab_config = self.astrbot_config_mgr.confs.get(conf_id) + if not ab_config: + raise ValueError(f"配置文件 {conf_id} 不存在") + scheduler = PipelineScheduler( + PipelineContext(ab_config, self.plugin_manager, conf_id) + ) + await scheduler.initialize() + self.pipeline_scheduler_mapping[conf_id] = scheduler diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 6688dcced..2de109b7d 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,7 +1,20 @@ import abc +import datetime +import typing as T +from deprecated import deprecated from dataclasses import dataclass -from typing import List, Dict, Any, Tuple -from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation +from astrbot.core.db.po import ( + Stats, + PlatformStat, + ConversationV2, + PlatformMessageHistory, + Attachment, + Persona, + Preference, +) +from contextlib import asynccontextmanager +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker @dataclass @@ -10,152 +23,262 @@ class BaseDatabase(abc.ABC): 数据库基类 """ + DATABASE_URL = "" + def __init__(self) -> None: + self.engine = create_async_engine( + self.DATABASE_URL, + echo=False, + future=True, + ) + self.AsyncSessionLocal = sessionmaker( + self.engine, class_=AsyncSession, expire_on_commit=False + ) + + async def initialize(self): + """初始化数据库连接""" pass - def insert_base_metrics(self, metrics: dict): - """插入基础指标数据""" - self.insert_platform_metrics(metrics["platform_stats"]) - self.insert_plugin_metrics(metrics["plugin_stats"]) - self.insert_command_metrics(metrics["command_stats"]) - self.insert_llm_metrics(metrics["llm_stats"]) - - @abc.abstractmethod - def insert_platform_metrics(self, metrics: dict): - """插入平台指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def insert_plugin_metrics(self, metrics: dict): - """插入插件指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def insert_command_metrics(self, metrics: dict): - """插入指令指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def insert_llm_metrics(self, metrics: dict): - """插入 LLM 指标数据""" - raise NotImplementedError - - @abc.abstractmethod - def update_llm_history(self, session_id: str, content: str, provider_type: str): - """更新 LLM 历史记录。当不存在 session_id 时插入""" - raise NotImplementedError - - @abc.abstractmethod - def get_llm_history( - self, session_id: str = None, provider_type: str = None - ) -> List[LLMHistory]: - """获取 LLM 历史记录, 如果 session_id 为 None, 返回所有""" - raise NotImplementedError + @asynccontextmanager + async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]: + """Get a database session.""" + if not self.inited: + await self.initialize() + self.inited = True + async with self.AsyncSessionLocal() as session: + yield session + @deprecated(version="4.0.0", reason="Use get_platform_stats instead") @abc.abstractmethod def get_base_stats(self, offset_sec: int = 86400) -> Stats: """获取基础统计数据""" raise NotImplementedError + @deprecated(version="4.0.0", reason="Use get_platform_stats instead") @abc.abstractmethod def get_total_message_count(self) -> int: """获取总消息数""" raise NotImplementedError + @deprecated(version="4.0.0", reason="Use get_platform_stats instead") @abc.abstractmethod def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: """获取基础统计数据(合并)""" raise NotImplementedError - @abc.abstractmethod - def insert_atri_vision_data(self, vision_data: ATRIVision): - """插入 ATRI 视觉数据""" - raise NotImplementedError + # New methods in v4.0.0 @abc.abstractmethod - def get_atri_vision_data(self) -> List[ATRIVision]: - """获取 ATRI 视觉数据""" - raise NotImplementedError + async def insert_platform_stats( + self, + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime.datetime | None = None, + ) -> None: + """Insert a new platform statistic record.""" + ... @abc.abstractmethod - def get_atri_vision_data_by_path_or_id( - self, url_or_path: str, id: str - ) -> ATRIVision: - """通过 url 或 path 获取 ATRI 视觉数据""" - raise NotImplementedError + async def count_platform_stats(self) -> int: + """Count the number of platform statistics records.""" + ... @abc.abstractmethod - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: - """通过 user_id 和 cid 获取 Conversation""" - raise NotImplementedError + async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: + """Get platform statistics within the specified offset in seconds and group by platform_id.""" + ... @abc.abstractmethod - def new_conversation(self, user_id: str, cid: str): - """新建 Conversation""" - raise NotImplementedError + async def get_conversations( + self, user_id: str | None = None, platform_id: str | None = None + ) -> list[ConversationV2]: + """Get all conversations for a specific user and platform_id(optional). - @abc.abstractmethod - def get_conversations(self, user_id: str) -> List[Conversation]: - raise NotImplementedError - - @abc.abstractmethod - def update_conversation(self, user_id: str, cid: str, history: str): - """更新 Conversation""" - raise NotImplementedError - - @abc.abstractmethod - def delete_conversation(self, user_id: str, cid: str): - """删除 Conversation""" - raise NotImplementedError - - @abc.abstractmethod - def update_conversation_title(self, user_id: str, cid: str, title: str): - """更新 Conversation 标题""" - raise NotImplementedError - - @abc.abstractmethod - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): - """更新 Conversation Persona ID""" - raise NotImplementedError - - @abc.abstractmethod - def get_all_conversations( - self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: - """获取所有对话,支持分页 - - Args: - page: 页码,从1开始 - page_size: 每页数量 - - Returns: - Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数 + content is not included in the result. """ - raise NotImplementedError + ... @abc.abstractmethod - def get_filtered_conversations( + async def get_conversation_by_id(self, cid: str) -> ConversationV2: + """Get a specific conversation by its ID.""" + ... + + @abc.abstractmethod + async def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> list[ConversationV2]: + """Get all conversations with pagination.""" + ... + + @abc.abstractmethod + async def get_filtered_conversations( self, page: int = 1, page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, - search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """获取筛选后的对话列表 + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs, + ) -> tuple[list[ConversationV2], int]: + """Get conversations filtered by platform IDs and search query.""" + ... - Args: - page: 页码 - page_size: 每页数量 - platforms: 平台筛选列表 - message_types: 消息类型筛选列表 - search_query: 搜索关键词 - exclude_ids: 排除的用户ID列表 - exclude_platforms: 排除的平台列表 + @abc.abstractmethod + async def create_conversation( + self, + user_id: str, + platform_id: str, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + cid: str | None = None, + created_at: datetime.datetime | None = None, + updated_at: datetime.datetime | None = None, + ) -> ConversationV2: + """Create a new conversation.""" + ... - Returns: - Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数 - """ - raise NotImplementedError + @abc.abstractmethod + async def update_conversation( + self, + cid: str, + title: str | None = None, + persona_id: str | None = None, + content: list[dict] | None = None, + ) -> None: + """Update a conversation's history.""" + ... + + @abc.abstractmethod + async def delete_conversation(self, cid: str) -> None: + """Delete a conversation by its ID.""" + ... + + @abc.abstractmethod + async def insert_platform_message_history( + self, + platform_id: str, + user_id: str, + content: list[dict], + sender_id: str | None = None, + sender_name: str | None = None, + ) -> None: + """Insert a new platform message history record.""" + ... + + @abc.abstractmethod + async def delete_platform_message_offset( + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: + """Delete platform message history records older than the specified offset.""" + ... + + @abc.abstractmethod + async def get_platform_message_history( + self, + platform_id: str, + user_id: str, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformMessageHistory]: + """Get platform message history for a specific user.""" + ... + + @abc.abstractmethod + async def insert_attachment( + self, + path: str, + type: str, + mime_type: str, + ): + """Insert a new attachment record.""" + ... + + @abc.abstractmethod + async def get_attachment_by_id(self, attachment_id: str) -> Attachment: + """Get an attachment by its ID.""" + ... + + @abc.abstractmethod + async def insert_persona( + self, + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona: + """Insert a new persona record.""" + ... + + @abc.abstractmethod + async def get_persona_by_id(self, persona_id: str) -> Persona: + """Get a persona by its ID.""" + ... + + @abc.abstractmethod + async def get_personas(self) -> list[Persona]: + """Get all personas for a specific bot.""" + ... + + @abc.abstractmethod + async def update_persona( + self, + persona_id: str, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona | None: + """Update a persona's system prompt or begin dialogs.""" + ... + + @abc.abstractmethod + async def delete_persona(self, persona_id: str) -> None: + """Delete a persona by its ID.""" + ... + + @abc.abstractmethod + async def insert_preference_or_update( + self, scope: str, scope_id: str, key: str, value: dict + ) -> Preference: + """Insert a new preference record.""" + ... + + @abc.abstractmethod + async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: + """Get a preference by scope ID and key.""" + ... + + @abc.abstractmethod + async def get_preferences( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: + """Get all preferences for a specific scope ID or key.""" + ... + + @abc.abstractmethod + async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: + """Remove a preference by scope ID and key.""" + ... + + @abc.abstractmethod + async def clear_preferences(self, scope: str, scope_id: str) -> None: + """Clear all preferences for a specific scope ID.""" + ... + + # @abc.abstractmethod + # async def insert_llm_message( + # self, + # cid: str, + # role: str, + # content: list, + # tool_calls: list = None, + # tool_call_id: str = None, + # parent_id: str = None, + # ) -> LLMMessage: + # """Insert a new LLM message into the conversation.""" + # ... + + # @abc.abstractmethod + # async def get_llm_messages(self, cid: str) -> list[LLMMessage]: + # """Get all LLM messages for a specific conversation.""" + # ... diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py new file mode 100644 index 000000000..796a7b336 --- /dev/null +++ b/astrbot/core/db/migration/helper.py @@ -0,0 +1,64 @@ +import os +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.db import BaseDatabase +from astrbot.core.config import AstrBotConfig +from astrbot.api import logger, sp +from .migra_3_to_4 import ( + migration_conversation_table, + migration_platform_table, + migration_webchat_data, + migration_persona_data, + migration_preferences, +) + + +async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: + """ + 检查是否需要进行数据库迁移 + 如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。 + """ + data_v3_exists = os.path.exists(get_astrbot_data_path()) + if not data_v3_exists: + return False + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_v4" + ) + if migration_done: + return False + return True + + +async def do_migration_v4( + db_helper: BaseDatabase, + platform_id_map: dict[str, dict[str, str]], + astrbot_config: AstrBotConfig, +): + """ + 执行数据库迁移 + 迁移旧的 webchat_conversation 表到新的 conversation 表。 + 迁移旧的 platform 到新的 platform_stats 表。 + """ + if not await check_migration_needed_v4(db_helper): + return + + logger.info("开始执行数据库迁移...") + + # 执行会话表迁移 + await migration_conversation_table(db_helper, platform_id_map) + + # 执行人格数据迁移 + await migration_persona_data(db_helper, astrbot_config) + + # 执行 WebChat 数据迁移 + await migration_webchat_data(db_helper, platform_id_map) + + # 执行偏好设置迁移 + await migration_preferences(db_helper,platform_id_map) + + # 执行平台统计表迁移 + await migration_platform_table(db_helper, platform_id_map) + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_v4", True) + + logger.info("数据库迁移完成。") diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py new file mode 100644 index 000000000..4aa5082db --- /dev/null +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -0,0 +1,338 @@ +import json +import datetime +from .. import BaseDatabase +from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3 +from .shared_preferences_v3 import sp as sp_v3 +from astrbot.core.config.default import DB_PATH +from astrbot.api import logger, sp +from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import MessageSesion +from sqlalchemy.ext.asyncio import AsyncSession +from astrbot.core.db.po import ConversationV2, PlatformMessageHistory +from sqlalchemy import text + +""" +1. 迁移旧的 webchat_conversation 表到新的 conversation 表。 +2. 迁移旧的 platform 到新的 platform_stats 表。 +""" + + +def get_platform_id( + platform_id_map: dict[str, dict[str, str]], old_platform_name: str +) -> str: + return platform_id_map.get( + old_platform_name, + {"platform_id": old_platform_name, "platform_type": old_platform_name}, + ).get("platform_id", old_platform_name) + + +def get_platform_type( + platform_id_map: dict[str, dict[str, str]], old_platform_name: str +) -> str: + return platform_id_map.get( + old_platform_name, + {"platform_id": old_platform_name, "platform_type": old_platform_name}, + ).get("platform_type", old_platform_name) + + +async def migration_conversation_table( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + db_helper_v3 = SQLiteV3DatabaseV3( + db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + ) + conversations, total_cnt = db_helper_v3.get_all_conversations( + page=1, page_size=10000000 + ) + logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...") + + async with db_helper.get_db() as dbsession: + dbsession: AsyncSession + async with dbsession.begin(): + for idx, conversation in enumerate(conversations): + if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0: + progress = int((idx + 1) / total_cnt * 100) + if progress % 10 == 0: + logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})") + try: + conv = db_helper_v3.get_conversation_by_user_id( + user_id=conversation.get("user_id", "unknown"), + cid=conversation.get("cid", "unknown"), + ) + if not conv: + logger.info( + f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" + ) + if ":" not in conv.user_id: + continue + session = MessageSesion.from_str(session_str=conv.user_id) + platform_id = get_platform_id( + platform_id_map, session.platform_name + ) + session.platform_id = platform_id # 更新平台名称为新的 ID + conv_v2 = ConversationV2( + user_id=str(session), + content=json.loads(conv.history) if conv.history else [], + platform_id=platform_id, + title=conv.title, + persona_id=conv.persona_id, + conversation_id=conv.cid, + created_at=datetime.datetime.fromtimestamp(conv.created_at), + updated_at=datetime.datetime.fromtimestamp(conv.updated_at), + ) + dbsession.add(conv_v2) + except Exception as e: + logger.error( + f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}", + exc_info=True, + ) + logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。") + + +async def migration_platform_table( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + db_helper_v3 = SQLiteV3DatabaseV3( + db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + ) + secs_from_2023_4_10_to_now = ( + datetime.datetime.now(datetime.timezone.utc) + - datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc) + ).total_seconds() + offset_sec = int(secs_from_2023_4_10_to_now) + logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。") + stats = db_helper_v3.get_base_stats(offset_sec=offset_sec) + logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...") + platform_stats_v3 = stats.platform + + if not platform_stats_v3: + logger.info("没有找到旧平台数据,跳过迁移。") + return + + first_time_stamp = platform_stats_v3[0].timestamp + end_time_stamp = platform_stats_v3[-1].timestamp + start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时 + end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时 + + idx = 0 + + async with db_helper.get_db() as dbsession: + dbsession: AsyncSession + async with dbsession.begin(): + total_buckets = (end_time - start_time) // 3600 + for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)): + if bucket_idx % 500 == 0: + progress = int((bucket_idx + 1) / total_buckets * 100) + logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})") + cnt = 0 + while ( + idx < len(platform_stats_v3) + and platform_stats_v3[idx].timestamp < bucket_end + ): + cnt += platform_stats_v3[idx].count + idx += 1 + if cnt == 0: + continue + platform_id = get_platform_id( + platform_id_map, platform_stats_v3[idx].name + ) + platform_type = get_platform_type( + platform_id_map, platform_stats_v3[idx].name + ) + try: + await dbsession.execute( + text(""" + INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) + VALUES (:timestamp, :platform_id, :platform_type, :count) + ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET + count = platform_stats.count + EXCLUDED.count + """), + { + "timestamp": datetime.datetime.fromtimestamp( + bucket_end, tz=datetime.timezone.utc + ), + "platform_id": platform_id, + "platform_type": platform_type, + "count": cnt, + }, + ) + except Exception: + logger.error( + f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}", + exc_info=True, + ) + logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。") + + +async def migration_webchat_data( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" + db_helper_v3 = SQLiteV3DatabaseV3( + db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + ) + conversations, total_cnt = db_helper_v3.get_all_conversations( + page=1, page_size=10000000 + ) + logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...") + + async with db_helper.get_db() as dbsession: + dbsession: AsyncSession + async with dbsession.begin(): + for idx, conversation in enumerate(conversations): + if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0: + progress = int((idx + 1) / total_cnt * 100) + if progress % 10 == 0: + logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})") + try: + conv = db_helper_v3.get_conversation_by_user_id( + user_id=conversation.get("user_id", "unknown"), + cid=conversation.get("cid", "unknown"), + ) + if not conv: + logger.info( + f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" + ) + if ":" in conv.user_id: + continue + platform_id = "webchat" + history = json.loads(conv.history) if conv.history else [] + for msg in history: + type_ = msg.get("type") # user type, "bot" or "user" + new_history = PlatformMessageHistory( + platform_id=platform_id, + user_id=conv.cid, # we use conv.cid as user_id for webchat + content=msg, + sender_id=type_, + sender_name=type_, + ) + dbsession.add(new_history) + + except Exception: + logger.error( + f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败", + exc_info=True, + ) + + logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。") + + +async def migration_persona_data( + db_helper: BaseDatabase, astrbot_config: AstrBotConfig +): + """ + 迁移 Persona 数据到新的表中。 + 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 + """ + v3_persona_config: list[dict] = astrbot_config.get("persona", []) + total_personas = len(v3_persona_config) + logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...") + + for idx, persona in enumerate(v3_persona_config): + if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0: + progress = int((idx + 1) / total_personas * 100) + if progress % 10 == 0: + logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})") + try: + begin_dialogs = persona.get("begin_dialogs", []) + mood_imitation_dialogs = persona.get("mood_imitation_dialogs", []) + mood_prompt = "" + user_turn = True + for mood_dialog in mood_imitation_dialogs: + if user_turn: + mood_prompt += f"A: {mood_dialog}\n" + else: + mood_prompt += f"B: {mood_dialog}\n" + user_turn = not user_turn + system_prompt = persona.get("prompt", "") + if mood_prompt: + system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}" + persona_new = await db_helper.insert_persona( + persona_id=persona["name"], + system_prompt=system_prompt, + begin_dialogs=begin_dialogs, + ) + logger.info( + f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。" + ) + except Exception as e: + logger.error(f"解析 Persona 配置失败:{e}") + + +async def migration_preferences( + db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] +): + # 1. global scope migration + keys = [ + "inactivated_llm_tools", + "inactivated_plugins", + "curr_provider", + "curr_provider_tts", + "curr_provider_stt", + "alter_cmd", + ] + for key in keys: + value = sp_v3.get(key) + if value is not None: + await sp.put_async("global", "global", key, value) + logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}") + + # 2. umo scope migration + session_conversation = sp_v3.get("session_conversation", default={}) + for umo, conversation_id in session_conversation.items(): + if not umo or not conversation_id: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + await sp.put_async("umo", str(session), "sel_conv_id", conversation_id) + logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}") + except Exception as e: + logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True) + + session_service_config = sp_v3.get("session_service_config", default={}) + for umo, config in session_service_config.items(): + if not umo or not config: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + + await sp.put_async("umo", str(session), "session_service_config", config) + + logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}") + except Exception as e: + logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True) + + session_variables = sp_v3.get("session_variables", default={}) + for umo, variables in session_variables.items(): + if not umo or not variables: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + await sp.put_async("umo", str(session), "session_variables", variables) + except Exception as e: + logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True) + + session_provider_perf = sp_v3.get("session_provider_perf", default={}) + for umo, perf in session_provider_perf.items(): + if not umo or not perf: + continue + try: + session = MessageSesion.from_str(session_str=umo) + platform_id = get_platform_id(platform_id_map, session.platform_name) + session.platform_id = platform_id + + for provider_type, provider_id in perf.items(): + await sp.put_async( + "umo", str(session), f"provider_perf_{provider_type}", provider_id + ) + logger.info( + f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}" + ) + except Exception as e: + logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True) diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py new file mode 100644 index 000000000..dda2cbcaf --- /dev/null +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -0,0 +1,45 @@ +import json +import os +from typing import TypeVar +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +_VT = TypeVar("_VT") + +class SharedPreferences: + def __init__(self, path=None): + if path is None: + path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") + self.path = path + self._data = self._load_preferences() + + def _load_preferences(self): + if os.path.exists(self.path): + try: + with open(self.path, "r") as f: + return json.load(f) + except json.JSONDecodeError: + os.remove(self.path) + return {} + + def _save_preferences(self): + with open(self.path, "w") as f: + json.dump(self._data, f, indent=4, ensure_ascii=False) + f.flush() + + def get(self, key, default: _VT = None) -> _VT: + return self._data.get(key, default) + + def put(self, key, value): + self._data[key] = value + self._save_preferences() + + def remove(self, key): + if key in self._data: + del self._data[key] + self._save_preferences() + + def clear(self): + self._data.clear() + self._save_preferences() + +sp = SharedPreferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py new file mode 100644 index 000000000..e7e734abd --- /dev/null +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -0,0 +1,493 @@ +import sqlite3 +import time +from astrbot.core.db.po import Platform, Stats +from typing import Tuple, List, Dict, Any +from dataclasses import dataclass + +@dataclass +class Conversation: + """LLM 对话存储 + + 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 + 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 + """ + + user_id: str + cid: str + history: str = "" + """字符串格式的列表。""" + created_at: int = 0 + updated_at: int = 0 + title: str = "" + persona_id: str = "" + + +INIT_SQL = """ +CREATE TABLE IF NOT EXISTS platform( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS llm( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS plugin( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS command( + name VARCHAR(32), + count INTEGER, + timestamp INTEGER +); +CREATE TABLE IF NOT EXISTS llm_history( + provider_type VARCHAR(32), + session_id VARCHAR(32), + content TEXT +); + +-- ATRI +CREATE TABLE IF NOT EXISTS atri_vision( + id TEXT, + url_or_path TEXT, + caption TEXT, + is_meme BOOLEAN, + keywords TEXT, + platform_name VARCHAR(32), + session_id VARCHAR(32), + sender_nickname VARCHAR(32), + timestamp INTEGER +); + +CREATE TABLE IF NOT EXISTS webchat_conversation( + user_id TEXT, -- 会话 id + cid TEXT, -- 对话 id + history TEXT, + created_at INTEGER, + updated_at INTEGER, + title TEXT, + persona_id TEXT +); + +PRAGMA encoding = 'UTF-8'; +""" + + +class SQLiteDatabase(): + def __init__(self, db_path: str) -> None: + super().__init__() + self.db_path = db_path + + sql = INIT_SQL + + # 初始化数据库 + self.conn = self._get_conn(self.db_path) + c = self.conn.cursor() + c.executescript(sql) + self.conn.commit() + + # 检查 webchat_conversation 的 title 字段是否存在 + c.execute( + """ + PRAGMA table_info(webchat_conversation) + """ + ) + res = c.fetchall() + has_title = False + has_persona_id = False + for row in res: + if row[1] == "title": + has_title = True + if row[1] == "persona_id": + has_persona_id = True + if not has_title: + c.execute( + """ + ALTER TABLE webchat_conversation ADD COLUMN title TEXT; + """ + ) + self.conn.commit() + if not has_persona_id: + c.execute( + """ + ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; + """ + ) + self.conn.commit() + + c.close() + + def _get_conn(self, db_path: str) -> sqlite3.Connection: + conn = sqlite3.connect(self.db_path) + conn.text_factory = str + return conn + + def _exec_sql(self, sql: str, params: Tuple = None): + conn = self.conn + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + conn = self._get_conn(self.db_path) + c = conn.cursor() + + if params: + c.execute(sql, params) + c.close() + else: + c.execute(sql) + c.close() + + conn.commit() + + def insert_platform_metrics(self, metrics: dict): + for k, v in metrics.items(): + self._exec_sql( + """ + INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) + """, + (k, v, int(time.time())), + ) + + def insert_llm_metrics(self, metrics: dict): + for k, v in metrics.items(): + self._exec_sql( + """ + INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) + """, + (k, v, int(time.time())), + ) + + def get_base_stats(self, offset_sec: int = 86400) -> Stats: + """获取 offset_sec 秒前到现在的基础统计数据""" + where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" + + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT * FROM platform + """ + + where_clause + ) + + platform = [] + for row in c.fetchall(): + platform.append(Platform(*row)) + + c.close() + + return Stats(platform=platform) + + def get_total_message_count(self) -> int: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT SUM(count) FROM platform + """ + ) + res = c.fetchone() + c.close() + return res[0] + + def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: + """获取 offset_sec 秒前到现在的基础统计数据(合并)""" + where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" + + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT name, SUM(count), timestamp FROM platform + """ + + where_clause + + " GROUP BY name" + ) + + platform = [] + for row in c.fetchall(): + platform.append(Platform(*row)) + + c.close() + + return Stats(platform, [], []) + + def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? + """, + (user_id, cid), + ) + + res = c.fetchone() + c.close() + + if not res: + return + + return Conversation(*res) + + def new_conversation(self, user_id: str, cid: str): + history = "[]" + updated_at = int(time.time()) + created_at = updated_at + self._exec_sql( + """ + INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) + """, + (user_id, cid, history, updated_at, created_at), + ) + + def get_conversations(self, user_id: str) -> Tuple: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + """ + SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC + """, + (user_id,), + ) + + res = c.fetchall() + c.close() + conversations = [] + for row in res: + cid = row[0] + created_at = row[1] + updated_at = row[2] + title = row[3] + persona_id = row[4] + conversations.append( + Conversation("", cid, "[]", created_at, updated_at, title, persona_id) + ) + return conversations + + def update_conversation(self, user_id: str, cid: str, history: str): + """更新对话,并且同时更新时间""" + updated_at = int(time.time()) + self._exec_sql( + """ + UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? + """, + (history, updated_at, user_id, cid), + ) + + def update_conversation_title(self, user_id: str, cid: str, title: str): + self._exec_sql( + """ + UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? + """, + (title, user_id, cid), + ) + + def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + self._exec_sql( + """ + UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? + """, + (persona_id, user_id, cid), + ) + + def delete_conversation(self, user_id: str, cid: str): + self._exec_sql( + """ + DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? + """, + (user_id, cid), + ) + + def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> Tuple[List[Dict[str, Any]], int]: + """获取所有对话,支持分页,按更新时间降序排序""" + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + try: + # 获取总记录数 + c.execute(""" + SELECT COUNT(*) FROM webchat_conversation + """) + total_count = c.fetchone()[0] + + # 计算偏移量 + offset = (page - 1) * page_size + + # 获取分页数据,按更新时间降序排序 + c.execute( + """ + SELECT user_id, cid, created_at, updated_at, title, persona_id + FROM webchat_conversation + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (page_size, offset), + ) + + rows = c.fetchall() + + conversations = [] + + for row in rows: + user_id, cid, created_at, updated_at, title, persona_id = row + # 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值 + safe_cid = str(cid) if cid else "unknown" + display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid + + conversations.append( + { + "user_id": user_id or "", + "cid": safe_cid, + "title": title or f"对话 {display_cid}", + "persona_id": persona_id or "", + "created_at": created_at or 0, + "updated_at": updated_at or 0, + } + ) + + return conversations, total_count + + except Exception as _: + # 返回空列表和0,确保即使出错也有有效的返回值 + return [], 0 + finally: + c.close() + + def get_filtered_conversations( + self, + page: int = 1, + page_size: int = 20, + platforms: List[str] = None, + message_types: List[str] = None, + search_query: str = None, + exclude_ids: List[str] = None, + exclude_platforms: List[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """获取筛选后的对话列表""" + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + try: + # 构建查询条件 + where_clauses = [] + params = [] + + # 平台筛选 + if platforms and len(platforms) > 0: + platform_conditions = [] + for platform in platforms: + platform_conditions.append("user_id LIKE ?") + params.append(f"{platform}:%") + + if platform_conditions: + where_clauses.append(f"({' OR '.join(platform_conditions)})") + + # 消息类型筛选 + if message_types and len(message_types) > 0: + message_type_conditions = [] + for msg_type in message_types: + message_type_conditions.append("user_id LIKE ?") + params.append(f"%:{msg_type}:%") + + if message_type_conditions: + where_clauses.append(f"({' OR '.join(message_type_conditions)})") + + # 搜索关键词 + if search_query: + search_query = search_query.encode("unicode_escape").decode("utf-8") + where_clauses.append( + "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)" + ) + search_param = f"%{search_query}%" + params.extend([search_param, search_param, search_param, search_param]) + + # 排除特定用户ID + if exclude_ids and len(exclude_ids) > 0: + for exclude_id in exclude_ids: + where_clauses.append("user_id NOT LIKE ?") + params.append(f"{exclude_id}%") + + # 排除特定平台 + if exclude_platforms and len(exclude_platforms) > 0: + for exclude_platform in exclude_platforms: + where_clauses.append("user_id NOT LIKE ?") + params.append(f"{exclude_platform}:%") + + # 构建完整的 WHERE 子句 + where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" + + # 构建计数查询 + count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" + + # 获取总记录数 + c.execute(count_sql, params) + total_count = c.fetchone()[0] + + # 计算偏移量 + offset = (page - 1) * page_size + + # 构建分页数据查询 + data_sql = f""" + SELECT user_id, cid, created_at, updated_at, title, persona_id + FROM webchat_conversation + {where_sql} + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """ + query_params = params + [page_size, offset] + + # 获取分页数据 + c.execute(data_sql, query_params) + rows = c.fetchall() + + conversations = [] + + for row in rows: + user_id, cid, created_at, updated_at, title, persona_id = row + # 确保 cid 是字符串类型,否则使用一个默认值 + safe_cid = str(cid) if cid else "unknown" + display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid + + conversations.append( + { + "user_id": user_id or "", + "cid": safe_cid, + "title": title or f"对话 {display_cid}", + "persona_id": persona_id or "", + "created_at": created_at or 0, + "updated_at": updated_at or 0, + } + ) + + return conversations, total_count + + except Exception as _: + # 返回空列表和0,确保即使出错也有有效的返回值 + return [], 0 + finally: + c.close() diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 49adb2781..88113d130 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,7 +1,233 @@ -"""指标数据""" +import uuid +from datetime import datetime, timezone from dataclasses import dataclass, field -from typing import List +from sqlmodel import ( + SQLModel, + Text, + JSON, + UniqueConstraint, + Field, +) +from typing import Optional, TypedDict + + +class PlatformStat(SQLModel, table=True): + """This class represents the statistics of bot usage across different platforms. + + Note: In astrbot v4, we moved `platform` table to here. + """ + + __tablename__ = "platform_stats" + + id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + timestamp: datetime = Field(nullable=False) + platform_id: str = Field(nullable=False) + platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc. + count: int = Field(default=0, nullable=False) + + __table_args__ = ( + UniqueConstraint( + "timestamp", + "platform_id", + "platform_type", + name="uix_platform_stats", + ), + ) + + +class ConversationV2(SQLModel, table=True): + __tablename__ = "conversations" + + inner_conversation_id: int = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True} + ) + conversation_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + ) + platform_id: str = Field(nullable=False) + user_id: str = Field(nullable=False) + content: Optional[list] = Field(default=None, sa_type=JSON) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + title: Optional[str] = Field(default=None, max_length=255) + persona_id: Optional[str] = Field(default=None) + + __table_args__ = ( + UniqueConstraint( + "conversation_id", + name="uix_conversation_id", + ), + ) + + +class Persona(SQLModel, table=True): + """Persona is a set of instructions for LLMs to follow. + + It can be used to customize the behavior of LLMs. + """ + + __tablename__ = "personas" + + id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + persona_id: str = Field(max_length=255, nullable=False) + system_prompt: str = Field(sa_type=Text, nullable=False) + begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) + """a list of strings, each representing a dialog to start with""" + tools: Optional[list] = Field(default=None, sa_type=JSON) + """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names.""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "persona_id", + name="uix_persona_id", + ), + ) + + +class Preference(SQLModel, table=True): + """This class represents preferences for bots.""" + + __tablename__ = "preferences" + + id: int | None = Field( + default=None, primary_key=True, sa_column_kwargs={"autoincrement": True} + ) + scope: str = Field(nullable=False) + """Scope of the preference, such as 'global', 'umo', 'plugin'.""" + scope_id: str = Field(nullable=False) + """ID of the scope, such as 'global', 'umo', 'plugin_name'.""" + key: str = Field(nullable=False) + value: dict = Field(sa_type=JSON, nullable=False) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "scope", + "scope_id", + "key", + name="uix_preference_scope_scope_id_key", + ), + ) + + +class PlatformMessageHistory(SQLModel, table=True): + """This class represents the message history for a specific platform. + + It is used to store messages that are not LLM-generated, such as user messages + or platform-specific messages. + """ + + __tablename__ = "platform_message_history" + + id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + platform_id: str = Field(nullable=False) + user_id: str = Field(nullable=False) # An id of group, user in platform + sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform + sender_name: Optional[str] = Field( + default=None + ) # Name of the sender in the platform + content: dict = Field(sa_type=JSON, nullable=False) # a message chain list + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + +class Attachment(SQLModel, table=True): + """This class represents attachments for messages in AstrBot. + + Attachments can be images, files, or other media types. + """ + + __tablename__ = "attachments" + + inner_attachment_id: int = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True} + ) + attachment_id: str = Field( + max_length=36, + nullable=False, + unique=True, + default_factory=lambda: str(uuid.uuid4()), + ) + path: str = Field(nullable=False) # Path to the file on disk + type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file') + mime_type: str = Field(nullable=False) # MIME type of the file + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + ) + + __table_args__ = ( + UniqueConstraint( + "attachment_id", + name="uix_attachment_id", + ), + ) + + +@dataclass +class Conversation: + """LLM 对话类 + + 对于 WebChat,history 存储了包括指令、回复、图片等在内的所有消息。 + 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 + + 在 v4.0.0 版本及之后,WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中, + """ + + platform_id: str + user_id: str + cid: str + """对话 ID, 是 uuid 格式的字符串""" + history: str = "" + """字符串格式的对话列表。""" + title: str | None = "" + persona_id: str | None = "" + created_at: int = 0 + updated_at: int = 0 + + +class Personality(TypedDict): + """LLM 人格类。 + + 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。 + """ + + prompt: str = "" + name: str = "" + begin_dialogs: list[str] = [] + mood_imitation_dialogs: list[str] = [] + """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" + tools: list[str] | None = None + """工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" + + # cache + _begin_dialogs_processed: list[dict] = [] + _mood_imitation_dialogs_processed: str = "" + + +# ==== +# Deprecated, and will be removed in future versions. +# ==== @dataclass @@ -13,77 +239,6 @@ class Platform: timestamp: int -@dataclass -class Provider: - """供应商使用统计数据""" - - name: str - count: int - timestamp: int - - -@dataclass -class Plugin: - """插件使用统计数据""" - - name: str - count: int - timestamp: int - - -@dataclass -class Command: - """命令使用统计数据""" - - name: str - count: int - timestamp: int - - @dataclass class Stats: - platform: List[Platform] = field(default_factory=list) - command: List[Command] = field(default_factory=list) - llm: List[Provider] = field(default_factory=list) - - -@dataclass -class LLMHistory: - """LLM 聊天时持久化的信息""" - - provider_type: str - session_id: str - content: str - - -@dataclass -class ATRIVision: - """Deprecated""" - - id: str - url_or_path: str - caption: str - is_meme: bool - keywords: List[str] - platform_name: str - session_id: str - sender_nickname: str - timestamp: int = -1 - - -@dataclass -class Conversation: - """LLM 对话存储 - - 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 - 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 - """ - - user_id: str - cid: str - history: str = "" - """字符串格式的列表。""" - created_at: int = 0 - updated_at: int = 0 - title: str = "" - persona_id: str = "" + platform: list[Platform] = field(default_factory=list) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 2abba1de9..418b35761 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,567 +1,542 @@ -import sqlite3 -import os -import time -from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation -from . import BaseDatabase -from typing import Tuple, List, Dict, Any +import asyncio +import typing as T +import threading +from datetime import datetime, timedelta +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ( + ConversationV2, + PlatformStat, + PlatformMessageHistory, + Attachment, + Persona, + Preference, + Stats as DeprecatedStats, + Platform as DeprecatedPlatformStat, + SQLModel, +) + +from sqlalchemy import select, update, delete, text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import func + +NOT_GIVEN = T.TypeVar("NOT_GIVEN") class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: - super().__init__() self.db_path = db_path + self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" + self.inited = False + super().__init__() - with open( - os.path.dirname(__file__) + "/sqlite_init.sql", "r", encoding="utf-8" - ) as f: - sql = f.read() + async def initialize(self) -> None: + """Initialize the database by creating tables if they do not exist.""" + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await conn.commit() - # 初始化数据库 - self.conn = self._get_conn(self.db_path) - c = self.conn.cursor() - c.executescript(sql) - self.conn.commit() + # ==== + # Platform Statistics + # ==== - # 检查 webchat_conversation 的 title 字段是否存在 - c.execute( - """ - PRAGMA table_info(webchat_conversation) - """ - ) - res = c.fetchall() - has_title = False - has_persona_id = False - for row in res: - if row[1] == "title": - has_title = True - if row[1] == "persona_id": - has_persona_id = True - if not has_title: - c.execute( - """ - ALTER TABLE webchat_conversation ADD COLUMN title TEXT; - """ - ) - self.conn.commit() - if not has_persona_id: - c.execute( - """ - ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; - """ - ) - self.conn.commit() - - c.close() - - def _get_conn(self, db_path: str) -> sqlite3.Connection: - conn = sqlite3.connect(self.db_path) - conn.text_factory = str - return conn - - def _exec_sql(self, sql: str, params: Tuple = None): - conn = self.conn - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - conn = self._get_conn(self.db_path) - c = conn.cursor() - - if params: - c.execute(sql, params) - c.close() - else: - c.execute(sql) - c.close() - - conn.commit() - - def insert_platform_metrics(self, metrics: dict): - for k, v in metrics.items(): - self._exec_sql( - """ - INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) - """, - (k, v, int(time.time())), - ) - - def insert_plugin_metrics(self, metrics: dict): - pass - - def insert_command_metrics(self, metrics: dict): - for k, v in metrics.items(): - self._exec_sql( - """ - INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?) - """, - (k, v, int(time.time())), - ) - - def insert_llm_metrics(self, metrics: dict): - for k, v in metrics.items(): - self._exec_sql( - """ - INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) - """, - (k, v, int(time.time())), - ) - - def update_llm_history(self, session_id: str, content: str, provider_type: str): - res = self.get_llm_history(session_id, provider_type) - if res: - self._exec_sql( - """ - UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ? - """, - (content, session_id, provider_type), - ) - else: - self._exec_sql( - """ - INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?) - """, - (provider_type, session_id, content), - ) - - def get_llm_history( - self, session_id: str = None, provider_type: str = None - ) -> Tuple: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - conditions = [] - params = [] - - if session_id: - conditions.append("session_id = ?") - params.append(session_id) - - if provider_type: - conditions.append("provider_type = ?") - params.append(provider_type) - - sql = "SELECT * FROM llm_history" - if conditions: - sql += " WHERE " + " AND ".join(conditions) - - c.execute(sql, params) - - res = c.fetchall() - histories = [] - for row in res: - histories.append(LLMHistory(*row)) - c.close() - return histories - - def get_base_stats(self, offset_sec: int = 86400) -> Stats: - """获取 offset_sec 秒前到现在的基础统计数据""" - where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" - - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM platform - """ - + where_clause - ) - - platform = [] - for row in c.fetchall(): - platform.append(Platform(*row)) - - # c.execute( - # ''' - # SELECT * FROM command - # ''' + where_clause - # ) - - # command = [] - # for row in c.fetchall(): - # command.append(Command(*row)) - - # c.execute( - # ''' - # SELECT * FROM llm - # ''' + where_clause - # ) - - # llm = [] - # for row in c.fetchall(): - # llm.append(Provider(*row)) - - c.close() - - return Stats(platform, [], []) - - def get_total_message_count(self) -> int: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT SUM(count) FROM platform - """ - ) - res = c.fetchone() - c.close() - return res[0] - - def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: - """获取 offset_sec 秒前到现在的基础统计数据(合并)""" - where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" - - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT name, SUM(count), timestamp FROM platform - """ - + where_clause - + " GROUP BY name" - ) - - platform = [] - for row in c.fetchall(): - platform.append(Platform(*row)) - - c.close() - - return Stats(platform, [], []) - - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? - """, - (user_id, cid), - ) - - res = c.fetchone() - c.close() - - if not res: - return - - return Conversation(*res) - - def new_conversation(self, user_id: str, cid: str): - history = "[]" - updated_at = int(time.time()) - created_at = updated_at - self._exec_sql( - """ - INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) - """, - (user_id, cid, history, updated_at, created_at), - ) - - def get_conversations(self, user_id: str) -> Tuple: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC - """, - (user_id,), - ) - - res = c.fetchall() - c.close() - conversations = [] - for row in res: - cid = row[0] - created_at = row[1] - updated_at = row[2] - title = row[3] - persona_id = row[4] - conversations.append( - Conversation("", cid, "[]", created_at, updated_at, title, persona_id) - ) - return conversations - - def update_conversation(self, user_id: str, cid: str, history: str): - """更新对话,并且同时更新时间""" - updated_at = int(time.time()) - self._exec_sql( - """ - UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? - """, - (history, updated_at, user_id, cid), - ) - - def update_conversation_title(self, user_id: str, cid: str, title: str): - self._exec_sql( - """ - UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? - """, - (title, user_id, cid), - ) - - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): - self._exec_sql( - """ - UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? - """, - (persona_id, user_id, cid), - ) - - def delete_conversation(self, user_id: str, cid: str): - self._exec_sql( - """ - DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? - """, - (user_id, cid), - ) - - def insert_atri_vision_data(self, vision: ATRIVision): - ts = int(time.time()) - keywords = ",".join(vision.keywords) - self._exec_sql( - """ - INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - vision.id, - vision.url_or_path, - vision.caption, - vision.is_meme, - keywords, - vision.platform_name, - vision.session_id, - vision.sender_nickname, - ts, - ), - ) - - def get_atri_vision_data(self) -> Tuple: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM atri_vision - """ - ) - - res = c.fetchall() - visions = [] - for row in res: - visions.append(ATRIVision(*row)) - c.close() - return visions - - def get_atri_vision_data_by_path_or_id( - self, url_or_path: str, id: str - ) -> ATRIVision: - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - c.execute( - """ - SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ? - """, - (url_or_path, id), - ) - - res = c.fetchone() - c.close() - if res: - return ATRIVision(*res) - return None - - def get_all_conversations( - self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: - """获取所有对话,支持分页,按更新时间降序排序""" - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - try: - # 获取总记录数 - c.execute(""" - SELECT COUNT(*) FROM webchat_conversation - """) - total_count = c.fetchone()[0] - - # 计算偏移量 - offset = (page - 1) * page_size - - # 获取分页数据,按更新时间降序排序 - c.execute( - """ - SELECT user_id, cid, created_at, updated_at, title, persona_id - FROM webchat_conversation - ORDER BY updated_at DESC - LIMIT ? OFFSET ? - """, - (page_size, offset), - ) - - rows = c.fetchall() - - conversations = [] - - for row in rows: - user_id, cid, created_at, updated_at, title, persona_id = row - # 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值 - safe_cid = str(cid) if cid else "unknown" - display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid - - conversations.append( - { - "user_id": user_id or "", - "cid": safe_cid, - "title": title or f"对话 {display_cid}", - "persona_id": persona_id or "", - "created_at": created_at or 0, - "updated_at": updated_at or 0, - } - ) - - return conversations, total_count - - except Exception as _: - # 返回空列表和0,确保即使出错也有有效的返回值 - return [], 0 - finally: - c.close() - - def get_filtered_conversations( + async def insert_platform_stats( self, - page: int = 1, - page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, - search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """获取筛选后的对话列表""" - try: - c = self.conn.cursor() - except sqlite3.ProgrammingError: - c = self._get_conn(self.db_path).cursor() - - try: - # 构建查询条件 - where_clauses = [] - params = [] - - # 平台筛选 - if platforms and len(platforms) > 0: - platform_conditions = [] - for platform in platforms: - platform_conditions.append("user_id LIKE ?") - params.append(f"{platform}:%") - - if platform_conditions: - where_clauses.append(f"({' OR '.join(platform_conditions)})") - - # 消息类型筛选 - if message_types and len(message_types) > 0: - message_type_conditions = [] - for msg_type in message_types: - message_type_conditions.append("user_id LIKE ?") - params.append(f"%:{msg_type}:%") - - if message_type_conditions: - where_clauses.append(f"({' OR '.join(message_type_conditions)})") - - # 搜索关键词 - if search_query: - search_query = search_query.encode("unicode_escape").decode("utf-8") - where_clauses.append( - "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)" - ) - search_param = f"%{search_query}%" - params.extend([search_param, search_param, search_param, search_param]) - - # 排除特定用户ID - if exclude_ids and len(exclude_ids) > 0: - for exclude_id in exclude_ids: - where_clauses.append("user_id NOT LIKE ?") - params.append(f"{exclude_id}%") - - # 排除特定平台 - if exclude_platforms and len(exclude_platforms) > 0: - for exclude_platform in exclude_platforms: - where_clauses.append("user_id NOT LIKE ?") - params.append(f"{exclude_platform}:%") - - # 构建完整的 WHERE 子句 - where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" - - # 构建计数查询 - count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" - - # 获取总记录数 - c.execute(count_sql, params) - total_count = c.fetchone()[0] - - # 计算偏移量 - offset = (page - 1) * page_size - - # 构建分页数据查询 - data_sql = f""" - SELECT user_id, cid, created_at, updated_at, title, persona_id - FROM webchat_conversation - {where_sql} - ORDER BY updated_at DESC - LIMIT ? OFFSET ? - """ - query_params = params + [page_size, offset] - - # 获取分页数据 - c.execute(data_sql, query_params) - rows = c.fetchall() - - conversations = [] - - for row in rows: - user_id, cid, created_at, updated_at, title, persona_id = row - # 确保 cid 是字符串类型,否则使用一个默认值 - safe_cid = str(cid) if cid else "unknown" - display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid - - conversations.append( + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime = None, + ) -> None: + """Insert a new platform statistic record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + if timestamp is None: + timestamp = datetime.now().replace( + minute=0, second=0, microsecond=0 + ) + current_hour = timestamp + await session.execute( + text(""" + INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) + VALUES (:timestamp, :platform_id, :platform_type, :count) + ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET + count = platform_stats.count + EXCLUDED.count + """), { - "user_id": user_id or "", - "cid": safe_cid, - "title": title or f"对话 {display_cid}", - "persona_id": persona_id or "", - "created_at": created_at or 0, - "updated_at": updated_at or 0, - } + "timestamp": current_hour, + "platform_id": platform_id, + "platform_type": platform_type, + "count": count, + }, ) - return conversations, total_count + async def count_platform_stats(self) -> int: + """Count the number of platform statistics records.""" + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(func.count(PlatformStat.platform_id)).select_from(PlatformStat) + ) + count = result.scalar_one_or_none() + return count if count is not None else 0 - except Exception as _: - # 返回空列表和0,确保即使出错也有有效的返回值 - return [], 0 - finally: - c.close() + async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]: + """Get platform statistics within the specified offset in seconds and group by platform_id.""" + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + text(""" + SELECT * FROM platform_stats + WHERE timestamp >= :start_time + ORDER BY timestamp DESC + GROUP BY platform_id + """), + {"start_time": start_time}, + ) + return result.scalars().all() + + # ==== + # Conversation Management + # ==== + + async def get_conversations(self, user_id=None, platform_id=None): + async with self.get_db() as session: + session: AsyncSession + query = select(ConversationV2) + + if user_id: + query = query.where(ConversationV2.user_id == user_id) + if platform_id: + query = query.where(ConversationV2.platform_id == platform_id) + # order by + query = query.order_by(ConversationV2.created_at.desc()) + result = await session.execute(query) + + return result.scalars().all() + + async def get_conversation_by_id(self, cid): + async with self.get_db() as session: + session: AsyncSession + query = select(ConversationV2).where(ConversationV2.conversation_id == cid) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_all_conversations(self, page=1, page_size=20): + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + result = await session.execute( + select(ConversationV2) + .order_by(ConversationV2.created_at.desc()) + .offset(offset) + .limit(page_size) + ) + return result.scalars().all() + + async def get_filtered_conversations( + self, + page=1, + page_size=20, + platform_ids=None, + search_query="", + **kwargs, + ): + async with self.get_db() as session: + session: AsyncSession + # Build the base query with filters + base_query = select(ConversationV2) + + if platform_ids: + base_query = base_query.where( + ConversationV2.platform_id.in_(platform_ids) + ) + if search_query: + base_query = base_query.where( + ConversationV2.title.ilike(f"%{search_query}%") + ) + + # Get total count matching the filters + count_query = select(func.count()).select_from(base_query.subquery()) + total_count = await session.execute(count_query) + total = total_count.scalar_one() + + # Get paginated results + offset = (page - 1) * page_size + result_query = ( + base_query.order_by(ConversationV2.created_at.desc()) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(result_query) + conversations = result.scalars().all() + + return conversations, total + + async def create_conversation( + self, + user_id, + platform_id, + content=None, + title=None, + persona_id=None, + cid=None, + created_at=None, + updated_at=None, + ): + kwargs = {} + if cid: + kwargs["conversation_id"] = cid + if created_at: + kwargs["created_at"] = created_at + if updated_at: + kwargs["updated_at"] = updated_at + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_conversation = ConversationV2( + user_id=user_id, + content=content or [], + platform_id=platform_id, + title=title, + persona_id=persona_id, + **kwargs, + ) + session.add(new_conversation) + return new_conversation + + async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = update(ConversationV2).where( + ConversationV2.conversation_id == cid + ) + values = {} + if title is not None: + values["title"] = title + if persona_id is not None: + values["persona_id"] = persona_id + if content is not None: + values["content"] = content + if not values: + return + query = query.values(**values) + await session.execute(query) + return await self.get_conversation_by_id(cid) + + async def delete_conversation(self, cid): + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(ConversationV2).where(ConversationV2.conversation_id == cid) + ) + + async def insert_platform_message_history( + self, + platform_id, + user_id, + content, + sender_id=None, + sender_name=None, + ): + """Insert a new platform message history record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_history = PlatformMessageHistory( + platform_id=platform_id, + user_id=user_id, + content=content, + sender_id=sender_id, + sender_name=sender_name, + ) + session.add(new_history) + return new_history + + async def delete_platform_message_offset( + self, platform_id, user_id, offset_sec=86400 + ): + """Delete platform message history records older than the specified offset.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now() + cutoff_time = now - timedelta(seconds=offset_sec) + await session.execute( + delete(PlatformMessageHistory).where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + PlatformMessageHistory.created_at < cutoff_time, + ) + ) + + async def get_platform_message_history( + self, platform_id, user_id, page=1, page_size=20 + ): + """Get platform message history records.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + .order_by(PlatformMessageHistory.created_at.desc()) + ) + result = await session.execute(query.offset(offset).limit(page_size)) + return result.scalars().all() + + async def insert_attachment(self, path, type, mime_type): + """Insert a new attachment record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_attachment = Attachment( + path=path, + type=type, + mime_type=mime_type, + ) + session.add(new_attachment) + return new_attachment + + async def get_attachment_by_id(self, attachment_id): + """Get an attachment by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Attachment).where(Attachment.id == attachment_id) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def insert_persona( + self, persona_id, system_prompt, begin_dialogs=None, tools=None + ): + """Insert a new persona record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_persona = Persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs or [], + tools=tools, + ) + session.add(new_persona) + return new_persona + + async def get_persona_by_id(self, persona_id): + """Get a persona by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Persona).where(Persona.persona_id == persona_id) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_personas(self): + """Get all personas for a specific bot.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Persona) + result = await session.execute(query) + return result.scalars().all() + + async def update_persona( + self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN + ): + """Update a persona's system prompt or begin dialogs.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = update(Persona).where(Persona.persona_id == persona_id) + values = {} + if system_prompt is not None: + values["system_prompt"] = system_prompt + if begin_dialogs is not None: + values["begin_dialogs"] = begin_dialogs + if tools is not NOT_GIVEN: + values["tools"] = tools + if not values: + return + query = query.values(**values) + await session.execute(query) + return await self.get_persona_by_id(persona_id) + + async def delete_persona(self, persona_id): + """Delete a persona by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Persona).where(Persona.persona_id == persona_id) + ) + + async def insert_preference_or_update(self, scope, scope_id, key, value): + """Insert a new preference record or update if it exists.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = select(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) + result = await session.execute(query) + existing_preference = result.scalar_one_or_none() + if existing_preference: + existing_preference.value = value + else: + new_preference = Preference( + scope=scope, scope_id=scope_id, key=key, value=value + ) + session.add(new_preference) + return existing_preference or new_preference + + async def get_preference(self, scope, scope_id, key): + """Get a preference by key.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_preferences(self, scope, scope_id=None, key=None): + """Get all preferences for a specific scope ID or key.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Preference).where(Preference.scope == scope) + if scope_id is not None: + query = query.where(Preference.scope_id == scope_id) + if key is not None: + query = query.where(Preference.key == key) + result = await session.execute(query) + return result.scalars().all() + + async def remove_preference(self, scope, scope_id, key): + """Remove a preference by scope ID and key.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) + ) + await session.commit() + + async def clear_preferences(self, scope, scope_id): + """Clear all preferences for a specific scope ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Preference).where( + Preference.scope == scope, Preference.scope_id == scope_id + ) + ) + await session.commit() + + # ==== + # Deprecated Methods + # ==== + + def get_base_stats(self, offset_sec=86400): + """Get base statistics within the specified offset in seconds.""" + + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + select(PlatformStat).where(PlatformStat.timestamp >= start_time) + ) + all_datas = result.scalars().all() + deprecated_stats = DeprecatedStats() + for data in all_datas: + deprecated_stats.platform.append( + DeprecatedPlatformStat( + name=data.platform_id, + count=data.count, + timestamp=data.timestamp.timestamp(), + ) + ) + return deprecated_stats + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result + + def get_total_message_count(self): + """Get the total message count from platform statistics.""" + + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(func.sum(PlatformStat.count)).select_from(PlatformStat) + ) + total_count = result.scalar_one_or_none() + return total_count if total_count is not None else 0 + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result + + def get_grouped_base_stats(self, offset_sec=86400): + # group by platform_id + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + select(PlatformStat.platform_id, func.sum(PlatformStat.count)) + .where(PlatformStat.timestamp >= start_time) + .group_by(PlatformStat.platform_id) + ) + grouped_stats = result.all() + deprecated_stats = DeprecatedStats() + for platform_id, count in grouped_stats: + deprecated_stats.platform.append( + DeprecatedPlatformStat( + name=platform_id, + count=count, + timestamp=start_time.timestamp(), + ) + ) + return deprecated_stats + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result diff --git a/astrbot/core/db/sqlite_init.sql b/astrbot/core/db/sqlite_init.sql deleted file mode 100644 index a1ebc54b5..000000000 --- a/astrbot/core/db/sqlite_init.sql +++ /dev/null @@ -1,50 +0,0 @@ -CREATE TABLE IF NOT EXISTS platform( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS llm( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS plugin( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS command( - name VARCHAR(32), - count INTEGER, - timestamp INTEGER -); -CREATE TABLE IF NOT EXISTS llm_history( - provider_type VARCHAR(32), - session_id VARCHAR(32), - content TEXT -); - --- ATRI -CREATE TABLE IF NOT EXISTS atri_vision( - id TEXT, - url_or_path TEXT, - caption TEXT, - is_meme BOOLEAN, - keywords TEXT, - platform_name VARCHAR(32), - session_id VARCHAR(32), - sender_nickname VARCHAR(32), - timestamp INTEGER -); - -CREATE TABLE IF NOT EXISTS webchat_conversation( - user_id TEXT, -- 会话 id - cid TEXT, -- 对话 id - history TEXT, - created_at INTEGER, - updated_at INTEGER, - title TEXT, - persona_id TEXT -); - -PRAGMA encoding = 'UTF-8'; \ No newline at end of file diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 8d95c2501..bc23922ef 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -5,6 +5,7 @@ from .document_storage import DocumentStorage from .embedding_storage import EmbeddingStorage from ..base import Result, BaseVecDB from astrbot.core.provider.provider import EmbeddingProvider +from astrbot.core.provider.provider import RerankProvider class FaissVecDB(BaseVecDB): @@ -17,6 +18,7 @@ class FaissVecDB(BaseVecDB): doc_store_path: str, index_store_path: str, embedding_provider: EmbeddingProvider, + rerank_provider: RerankProvider | None = None, ): self.doc_store_path = doc_store_path self.index_store_path = index_store_path @@ -26,11 +28,14 @@ class FaissVecDB(BaseVecDB): embedding_provider.get_dim(), index_store_path ) self.embedding_provider = embedding_provider + self.rerank_provider = rerank_provider async def initialize(self): await self.document_storage.initialize() - async def insert(self, content: str, metadata: dict = None, id: str = None) -> int: + async def insert( + self, content: str, metadata: dict | None = None, id: str | None = None + ) -> int: """ 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 """ @@ -53,7 +58,12 @@ class FaissVecDB(BaseVecDB): return int_id async def retrieve( - self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None + self, + query: str, + k: int = 5, + fetch_k: int = 20, + rerank: bool = False, + metadata_filters: dict | None = None, ) -> list[Result]: """ 搜索最相似的文档。 @@ -62,6 +72,7 @@ class FaissVecDB(BaseVecDB): query (str): 查询文本 k (int): 返回的最相似文档的数量 fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量 + rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。 metadata_filters (dict): 元数据过滤器 Returns: @@ -72,7 +83,6 @@ class FaissVecDB(BaseVecDB): vector=np.array([embedding]).astype("float32"), k=fetch_k if metadata_filters else k, ) - # TODO: rerank if len(indices[0]) == 0 or indices[0][0] == -1: return [] # normalize scores @@ -83,7 +93,7 @@ class FaissVecDB(BaseVecDB): ) if not fetched_docs: return [] - result_docs = [] + result_docs: list[Result] = [] idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)} for i, indice_idx in enumerate(indices[0]): @@ -93,7 +103,20 @@ class FaissVecDB(BaseVecDB): fetch_doc = fetched_docs[pos] score = scores[0][i] result_docs.append(Result(similarity=float(score), data=fetch_doc)) - return result_docs[:k] + + top_k_results = result_docs[:k] + + if rerank and self.rerank_provider: + documents = [doc.data["text"] for doc in top_k_results] + reranked_results = await self.rerank_provider.rerank(query, documents) + reranked_results = sorted( + reranked_results, key=lambda x: x.relevance_score, reverse=True + ) + top_k_results = [ + top_k_results[reranked_result.index] for reranked_result in reranked_results + ] + + return top_k_results async def delete(self, doc_id: int): """ diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index d4caa2910..2ae709396 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -16,30 +16,32 @@ from asyncio import Queue from astrbot.core.pipeline.scheduler import PipelineScheduler from astrbot.core import logger from .platform import AstrMessageEvent +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager class EventBus: - """事件总线: 用于处理事件的分发和处理 + """用于处理事件的分发和处理""" - 维护一个异步队列, 来接受各种消息事件 - """ - - def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler): + def __init__( + self, + event_queue: Queue, + pipeline_scheduler_mapping: dict[str, PipelineScheduler], + astrbot_config_mgr: AstrBotConfigManager = None, + ): self.event_queue = event_queue # 事件队列 - self.pipeline_scheduler = pipeline_scheduler # 管道调度器 + # abconf uuid -> scheduler + self.pipeline_scheduler_mapping = pipeline_scheduler_mapping + self.astrbot_config_mgr = astrbot_config_mgr async def dispatch(self): - """无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑""" while True: - event: AstrMessageEvent = ( - await self.event_queue.get() - ) # 从事件队列中获取新的事件 - self._print_event(event) # 打印日志 - asyncio.create_task( - self.pipeline_scheduler.execute(event) - ) # 创建新的异步任务来执行管道调度器的处理逻辑 + event: AstrMessageEvent = await self.event_queue.get() + conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) + self._print_event(event, conf_info["name"]) + scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) + asyncio.create_task(scheduler.execute(event)) - def _print_event(self, event: AstrMessageEvent): + def _print_event(self, event: AstrMessageEvent, conf_name: str): """用于记录事件信息 Args: @@ -48,10 +50,10 @@ class EventBus: # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 if event.get_sender_name(): logger.info( - f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" + f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" ) # 没有发送者名称: [平台名] 发送者ID: 消息概要 else: logger.info( - f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}" + f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}" ) diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py new file mode 100644 index 000000000..add3c74bc --- /dev/null +++ b/astrbot/core/persona_mgr.py @@ -0,0 +1,183 @@ +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Persona, Personality +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.platform.message_session import MessageSession +from astrbot import logger + +DEFAULT_PERSONALITY = Personality( + prompt="You are a helpful and friendly assistant.", + name="default", + begin_dialogs=[], + mood_imitation_dialogs=[], + tools=None, + _begin_dialogs_processed=[], + _mood_imitation_dialogs_processed="", +) + + +class PersonaManager: + def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): + self.db = db_helper + self.acm = acm + default_ps = acm.default_conf.get("provider_settings", {}) + self.default_persona: str = default_ps.get("default_personality", "default") + self.personas: list[Persona] = [] + self.selected_default_persona: Persona | None = None + + self.personas_v3: list[Personality] = [] + self.selected_default_persona_v3: Personality | None = None + self.persona_v3_config: list[dict] = [] + + async def initialize(self): + self.personas = await self.get_all_personas() + self.get_v3_persona_data() + logger.info(f"已加载 {len(self.personas)} 个人格。") + + async def get_persona(self, persona_id: str): + """获取指定 persona 的信息""" + persona = await self.db.get_persona_by_id(persona_id) + if not persona: + raise ValueError(f"Persona with ID {persona_id} does not exist.") + return persona + + async def get_default_persona_v3( + self, umo: str | MessageSession | None = None + ) -> Personality: + """获取默认 persona""" + cfg = self.acm.get_conf(umo) + default_persona_id = cfg.get("provider_settings", {}).get( + "default_personality", "default" + ) + if not default_persona_id or default_persona_id == "default": + return DEFAULT_PERSONALITY + try: + return next(p for p in self.personas_v3 if p["name"] == default_persona_id) + except Exception: + return DEFAULT_PERSONALITY + + async def delete_persona(self, persona_id: str): + """删除指定 persona""" + if not await self.db.get_persona_by_id(persona_id): + raise ValueError(f"Persona with ID {persona_id} does not exist.") + await self.db.delete_persona(persona_id) + self.personas = [p for p in self.personas if p.persona_id != persona_id] + self.get_v3_persona_data() + + async def update_persona( + self, + persona_id: str, + system_prompt: str = None, + begin_dialogs: list[str] = None, + tools: list[str] = None, + ): + """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" + existing_persona = await self.db.get_persona_by_id(persona_id) + if not existing_persona: + raise ValueError(f"Persona with ID {persona_id} does not exist.") + persona = await self.db.update_persona( + persona_id, system_prompt, begin_dialogs, tools=tools + ) + if persona: + for i, p in enumerate(self.personas): + if p.persona_id == persona_id: + self.personas[i] = persona + break + self.get_v3_persona_data() + return persona + + async def get_all_personas(self) -> list[Persona]: + """获取所有 personas""" + return await self.db.get_personas() + + async def create_persona( + self, + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] = None, + tools: list[str] = None, + ) -> Persona: + """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" + if await self.db.get_persona_by_id(persona_id): + raise ValueError(f"Persona with ID {persona_id} already exists.") + new_persona = await self.db.insert_persona( + persona_id, system_prompt, begin_dialogs, tools=tools + ) + self.personas.append(new_persona) + self.get_v3_persona_data() + return new_persona + + def get_v3_persona_data( + self, + ) -> tuple[list[dict], list[Personality], Personality]: + """获取 AstrBot <4.0.0 版本的 persona 数据。 + + Returns: + - list[dict]: 包含 persona 配置的字典列表。 + - list[Personality]: 包含 Personality 对象的列表。 + - Personality: 默认选择的 Personality 对象。 + """ + v3_persona_config = [ + { + "prompt": persona.system_prompt, + "name": persona.persona_id, + "begin_dialogs": persona.begin_dialogs or [], + "mood_imitation_dialogs": [], # deprecated + "tools": persona.tools, + } + for persona in self.personas + ] + + personas_v3: list[Personality] = [] + selected_default_persona: Personality | None = None + + for persona_cfg in v3_persona_config: + begin_dialogs = persona_cfg.get("begin_dialogs", []) + bd_processed = [] + if begin_dialogs: + if len(begin_dialogs) % 2 != 0: + logger.error( + f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。" + ) + begin_dialogs = [] + user_turn = True + for dialog in begin_dialogs: + bd_processed.append( + { + "role": "user" if user_turn else "assistant", + "content": dialog, + "_no_save": None, # 不持久化到 db + } + ) + user_turn = not user_turn + + try: + persona = Personality( + **persona_cfg, + _begin_dialogs_processed=bd_processed, + _mood_imitation_dialogs_processed="", # deprecated + ) + if persona["name"] == self.default_persona: + selected_default_persona = persona + personas_v3.append(persona) + except Exception as e: + logger.error(f"解析 Persona 配置失败:{e}") + + if not selected_default_persona and len(personas_v3) > 0: + # 默认选择第一个 + selected_default_persona = personas_v3[0] + + if not selected_default_persona: + selected_default_persona = DEFAULT_PERSONALITY + personas_v3.append(selected_default_persona) + + self.personas_v3 = personas_v3 + self.selected_default_persona_v3 = selected_default_persona + self.persona_v3_config = v3_persona_config + self.selected_default_persona = Persona( + persona_id=selected_default_persona["name"], + system_prompt=selected_default_persona["prompt"], + begin_dialogs=selected_default_persona["begin_dialogs"], + tools=selected_default_persona["tools"] or None, + ) + + return v3_persona_config, personas_v3, selected_default_persona diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 3501a5271..29a324a1d 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -4,7 +4,6 @@ from astrbot.core.message.message_event_result import ( ) from .content_safety_check.stage import ContentSafetyCheckStage -from .platform_compatibility.stage import PlatformCompatibilityStage from .preprocess_stage.stage import PreProcessStage from .process_stage.stage import ProcessStage from .rate_limit_check.stage import RateLimitStage @@ -21,7 +20,6 @@ STAGES_ORDER = [ "SessionStatusCheckStage", # 检查会话是否整体启用 "RateLimitStage", # 检查会话是否超过频率限制 "ContentSafetyCheckStage", # 检查内容安全 - "PlatformCompatibilityStage", # 检查所有处理器的平台兼容性 "PreProcessStage", # 预处理 "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 @@ -34,7 +32,6 @@ __all__ = [ "SessionStatusCheckStage", "RateLimitStage", "ContentSafetyCheckStage", - "PlatformCompatibilityStage", "PreProcessStage", "ProcessStage", "ResultDecorateStage", diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 0b9d9e533..803626aaa 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,14 +1,7 @@ -import inspect -import traceback -import typing as T from dataclasses import dataclass -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig from astrbot.core.star import PluginManager -from astrbot.api import logger -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star import star_map -from astrbot.core.message.message_event_result import MessageEventResult, CommandResult +from .context_utils import call_handler, call_event_hook @dataclass @@ -17,97 +10,6 @@ class PipelineContext: astrbot_config: AstrBotConfig # AstrBot 配置对象 plugin_manager: PluginManager # 插件管理器对象 - - async def call_event_hook( - self, - event: AstrMessageEvent, - hook_type: EventType, - *args, - ) -> bool: - """调用事件钩子函数 - - Returns: - bool: 如果事件被终止,返回 True - """ - platform_id = event.get_platform_id() - handlers = star_handlers_registry.get_handlers_by_event_type( - hook_type, platform_id=platform_id - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, *args) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - - return event.is_stopped() - - async def call_handler( - self, - event: AstrMessageEvent, - handler: T.Awaitable, - *args, - **kwargs, - ) -> T.AsyncGenerator[None, None]: - """执行事件处理函数并处理其返回结果 - - 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: - 1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层 - 2. 协程: 执行一次并处理返回值 - - Args: - ctx (PipelineContext): 消息管道上下文对象 - event (AstrMessageEvent): 事件对象 - handler (Awaitable): 事件处理函数 - - Returns: - AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 - """ - ready_to_call = None # 一个协程或者异步生成器 - - trace_ = None - - try: - ready_to_call = handler(event, *args, **kwargs) - except TypeError as _: - # 向下兼容 - trace_ = traceback.format_exc() - # 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份 - ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs) - - if inspect.isasyncgen(ready_to_call): - _has_yielded = False - try: - async for ret in ready_to_call: - # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 - # 返回值只能是 MessageEventResult 或者 None(无返回值) - _has_yielded = True - if isinstance(ret, (MessageEventResult, CommandResult)): - # 如果返回值是 MessageEventResult, 设置结果并继续 - event.set_result(ret) - yield - else: - # 如果返回值是 None, 则不设置结果并继续 - # 继续执行后续阶段 - yield ret - if not _has_yielded: - # 如果这个异步生成器没有执行到 yield 分支 - yield - except Exception as e: - logger.error(f"Previous Error: {trace_}") - raise e - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个协程, 直接执行 - ret = await ready_to_call - if isinstance(ret, (MessageEventResult, CommandResult)): - event.set_result(ret) - yield - else: - yield ret + astrbot_config_id: str + call_handler = call_handler + call_event_hook = call_event_hook diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py new file mode 100644 index 000000000..02e87e6d0 --- /dev/null +++ b/astrbot/core/pipeline/context_utils.py @@ -0,0 +1,98 @@ +import inspect +import traceback +import typing as T +from astrbot import logger +from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star import star_map +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent + + +async def call_handler( + event: AstrMessageEvent, + handler: T.Awaitable, + *args, + **kwargs, +) -> T.AsyncGenerator[T.Any, None]: + """执行事件处理函数并处理其返回结果 + + 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: + 1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层 + 2. 协程: 执行一次并处理返回值 + + Args: + event (AstrMessageEvent): 事件对象 + handler (Awaitable): 事件处理函数 + + Returns: + AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 + """ + ready_to_call = None # 一个协程或者异步生成器 + + trace_ = None + + try: + ready_to_call = handler(event, *args, **kwargs) + except TypeError: + logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) + + if inspect.isasyncgen(ready_to_call): + _has_yielded = False + try: + async for ret in ready_to_call: + # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) + _has_yielded = True + if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 + event.set_result(ret) + yield + else: + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret + if not _has_yielded: + # 如果这个异步生成器没有执行到 yield 分支 + yield + except Exception as e: + logger.error(f"Previous Error: {trace_}") + raise e + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个协程, 直接执行 + ret = await ready_to_call + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) + yield + else: + yield ret + + +async def call_event_hook( + event: AstrMessageEvent, + hook_type: EventType, + *args, + **kwargs, +) -> bool: + """调用事件钩子函数 + + Returns: + bool: 如果事件被终止,返回 True + # """ + handlers = star_handlers_registry.get_handlers_by_event_type( + hook_type, plugins_name=event.plugins_name + ) + for handler in handlers: + try: + logger.debug( + f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) + await handler.handler(event, *args, **kwargs) + except BaseException: + logger.error(traceback.format_exc()) + + if event.is_stopped(): + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) + + return event.is_stopped() diff --git a/astrbot/core/pipeline/platform_compatibility/stage.py b/astrbot/core/pipeline/platform_compatibility/stage.py deleted file mode 100644 index 644912c26..000000000 --- a/astrbot/core/pipeline/platform_compatibility/stage.py +++ /dev/null @@ -1,56 +0,0 @@ -from ..stage import Stage, register_stage -from ..context import PipelineContext -from typing import Union, AsyncGenerator -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.star import star_map -from astrbot.core.star.star_handler import StarHandlerMetadata -from astrbot.core import logger - - -@register_stage -class PlatformCompatibilityStage(Stage): - """检查所有处理器的平台兼容性。 - - 这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。 - """ - - async def initialize(self, ctx: PipelineContext) -> None: - """初始化平台兼容性检查阶段 - - Args: - ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 - """ - self.ctx = ctx - - async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: - # 获取当前平台ID - platform_id = event.get_platform_id() - - # 获取已激活的处理器 - activated_handlers = event.get_extra("activated_handlers") - if activated_handlers is None: - activated_handlers = [] - - # 标记不兼容的处理器 - for handler in activated_handlers: - if not isinstance(handler, StarHandlerMetadata): - continue - # 检查处理器是否在当前平台启用 - enabled = handler.is_enabled_for_platform(platform_id) - if not enabled: - if handler.handler_module_path in star_map: - plugin_name = star_map[handler.handler_module_path].name - logger.debug( - f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容" - ) - # 设置处理器为平台不兼容状态 - # TODO: 更好的标记方式 - handler.platform_compatible = False - else: - # 确保处理器为平台兼容状态 - handler.platform_compatible = True - - # 更新已激活的处理器列表 - event.set_extra("activated_handlers", activated_handlers) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index c81a5df51..c07ba0d70 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -20,12 +20,270 @@ from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolSet, FunctionTool +from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor +from astrbot.core.agent.handoff import HandoffTool from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric -from ...context import PipelineContext -from ..agent_runner.tool_loop_agent import ToolLoopAgent +from ...context import PipelineContext, call_event_hook, call_handler from ..stage import Stage +from astrbot.core.provider.register import llm_tools +from astrbot.core.star.star_handler import star_map +from astrbot.core.astr_agent_context import AstrAgentContext + +try: + import mcp +except (ModuleNotFoundError, ImportError): + logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") + + +AgentContextWrapper = ContextWrapper[AstrAgentContext] +AgentRunner = ToolLoopAgentRunner[AgentContextWrapper] + + +class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + @classmethod + async def execute(cls, tool, run_context, **tool_args): + """执行函数调用。 + + Args: + event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。 + **kwargs: 函数调用的参数。 + + Returns: + AsyncGenerator[None | mcp.types.CallToolResult, None] + """ + if isinstance(tool, HandoffTool): + async for r in cls._execute_handoff(tool, run_context, **tool_args): + yield r + return + + if tool.origin == "local": + async for r in cls._execute_local(tool, run_context, **tool_args): + yield r + return + + elif tool.origin == "mcp": + async for r in cls._execute_mcp(tool, run_context, **tool_args): + yield r + return + + raise Exception(f"Unknown function origin: {tool.origin}") + + @classmethod + async def _execute_handoff( + cls, + tool: HandoffTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args, + ): + input_ = tool_args.get("input", "agent") + agent_runner = AgentRunner() + + # make toolset for the agent + tools = tool.agent.tools + if tools: + toolset = ToolSet() + for t in tools: + if isinstance(t, str): + _t = llm_tools.get_func(t) + if _t: + toolset.add_tool(_t) + elif isinstance(t, FunctionTool): + toolset.add_tool(t) + else: + toolset = None + + request = ProviderRequest( + prompt=input_, + system_prompt=tool.description, + image_urls=[], # 暂时不传递原始 agent 的上下文 + contexts=[], # 暂时不传递原始 agent 的上下文 + func_tool=toolset, + ) + astr_agent_ctx = AstrAgentContext( + provider=run_context.context.provider, + first_provider_request=run_context.context.first_provider_request, + curr_provider_request=request, + streaming=run_context.context.streaming, + ) + + logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") + await run_context.event.send( + MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name) + ) + + await agent_runner.reset( + provider=run_context.context.provider, + request=request, + run_context=AgentContextWrapper( + context=astr_agent_ctx, event=run_context.event + ), + tool_executor=FunctionToolExecutor(), + agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), + streaming=run_context.context.streaming, + ) + + async for _ in run_agent(agent_runner, 15, True): + pass + + if agent_runner.done(): + llm_response = agent_runner.get_final_llm_resp() + logger.debug( + f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}" + ) + + result = ( + f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n" + "Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)." + ) + + text_content = mcp.types.TextContent( + type="text", + text=result, + ) + yield mcp.types.CallToolResult(content=[text_content]) + else: + yield mcp.types.TextContent( + type="text", + text=f"error when deligate task to {tool.agent.name}", + ) + yield mcp.types.CallToolResult(content=[text_content]) + return + + @classmethod + async def _execute_local( + cls, + tool: FunctionTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args, + ): + if not run_context.event: + raise ValueError("Event must be provided for local function tools.") + + # 检查 tool 下有没有 run 方法 + if not tool.handler and not hasattr(tool, "run"): + raise ValueError("Tool must have a valid handler or 'run' method.") + awaitable = tool.handler or getattr(tool, "run") + + wrapper = call_handler( + event=run_context.event, + handler=awaitable, + **tool_args, + ) + async for resp in wrapper: + if resp is not None: + if isinstance(resp, mcp.types.CallToolResult): + yield resp + else: + text_content = mcp.types.TextContent( + type="text", + text=str(resp), + ) + yield mcp.types.CallToolResult(content=[text_content]) + else: + # NOTE: Tool 在这里直接请求发送消息给用户 + # TODO: 是否需要判断 event.get_result() 是否为空? + # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" + yield None + + @classmethod + async def _execute_mcp( + cls, + tool: FunctionTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args, + ): + if not tool.mcp_client: + raise ValueError("MCP client is not available for MCP function tools.") + res = await tool.mcp_client.session.call_tool( + name=tool.name, + arguments=tool_args, + ) + if not res: + return + yield res + + +class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]): + async def on_agent_done(self, run_context, llm_response): + # 执行事件钩子 + await call_event_hook( + run_context.event, EventType.OnLLMResponseEvent, llm_response + ) + + +MAIN_AGENT_HOOKS = MainAgentHooks() + + +async def run_agent( + agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True +) -> AsyncGenerator[MessageChain, None]: + step_idx = 0 + astr_event = agent_runner.run_context.event + while step_idx < max_step: + step_idx += 1 + try: + async for resp in agent_runner.step(): + if astr_event.is_stopped(): + return + if resp.type == "tool_call_result": + msg_chain = resp.data["chain"] + if msg_chain.type == "tool_direct_result": + # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 + resp.data["chain"].type = "tool_call_result" + await astr_event.send(resp.data["chain"]) + continue + # 对于其他情况,暂时先不处理 + continue + elif resp.type == "tool_call": + if agent_runner.streaming: + # 用来标记流式响应需要分节 + yield MessageChain(chain=[], type="break") + if show_tool_use or astr_event.get_platform_name() == "webchat": + resp.data["chain"].type = "tool_call" + await astr_event.send(resp.data["chain"]) + continue + + if not agent_runner.streaming: + content_typ = ( + ResultContentType.LLM_RESULT + if resp.type == "llm_result" + else ResultContentType.GENERAL_RESULT + ) + astr_event.set_result( + MessageEventResult( + chain=resp.data["chain"].chain, + result_content_type=content_typ, + ) + ) + yield + astr_event.clear_result() + else: + if resp.type == "streaming_delta": + yield resp.data["chain"] # MessageChain + if agent_runner.done(): + break + + except Exception as e: + logger.error(traceback.format_exc()) + astr_event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n" + ) + ) + return + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, + ) + ) class LLMRequestSubStage(Stage): @@ -65,6 +323,20 @@ class LLMRequestSubStage(Stage): return _ctx.get_using_provider(umo=event.unified_msg_origin) + async def _get_session_conv(self, event: AstrMessageEvent): + umo = event.unified_msg_origin + conv_mgr = self.conv_manager + + # 获取对话上下文 + cid = await conv_mgr.get_curr_conversation_id(umo) + if not cid: + cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) + conversation = await conv_mgr.get_conversation(umo, cid) + if not conversation: + cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) + conversation = await conv_mgr.get_conversation(umo, cid) + return conversation + async def process( self, event: AstrMessageEvent, _nested: bool = False ) -> Union[None, AsyncGenerator[None, None]]: @@ -100,30 +372,14 @@ class LLMRequestSubStage(Stage): if not event.message_str.startswith(self.provider_wake_prefix): return req.prompt = event.message_str[len(self.provider_wake_prefix) :] - req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 + # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): image_path = await comp.convert_to_file_path() req.image_urls.append(image_path) - # 获取对话上下文 - conversation_id = await self.conv_manager.get_curr_conversation_id( - event.unified_msg_origin - ) - if not conversation_id: - conversation_id = await self.conv_manager.new_conversation( - event.unified_msg_origin - ) - conversation = await self.conv_manager.get_conversation( - event.unified_msg_origin, conversation_id - ) - if not conversation: - conversation_id = await self.conv_manager.new_conversation( - event.unified_msg_origin - ) - conversation = await self.conv_manager.get_conversation( - event.unified_msg_origin, conversation_id - ) + conversation = await self._get_session_conv(event) req.conversation = conversation req.contexts = json.loads(conversation.history) @@ -133,7 +389,7 @@ class LLMRequestSubStage(Stage): return # 执行请求 LLM 前事件钩子。 - if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req): + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): return if isinstance(req.contexts, str): @@ -167,92 +423,62 @@ class LLMRequestSubStage(Stage): # fix messages req.contexts = self.fix_messages(req.contexts) - # Call Agent - tool_loop_agent = ToolLoopAgent( - provider=provider, - event=event, - pipeline_ctx=self.ctx, - ) + # check provider modalities + # 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。 + if req.image_urls: + provider_cfg = provider.provider_config.get("modalities", ["image"]) + if "image" not in provider_cfg: + logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。") + req.image_urls = [] + if req.func_tool: + provider_cfg = provider.provider_config.get("modalities", ["tool_use"]) + # 如果模型不支持工具使用,但请求中包含工具列表,则清空。 + if "tool_use" not in provider_cfg: + logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。") + req.func_tool = None + # 插件可用性设置 + if event.plugins_name is not None and req.func_tool: + new_tool_set = ToolSet() + for tool in req.func_tool.tools: + plugin = star_map.get(tool.handler_module_path) + if not plugin: + continue + if plugin.name in event.plugins_name or plugin.reserved: + new_tool_set.add_tool(tool) + req.func_tool = new_tool_set + + # run agent + agent_runner = AgentRunner() logger.debug( f"handle provider[id: {provider.provider_config['id']}] request: {req}" ) - await tool_loop_agent.reset(req=req, streaming=self.streaming_response) - - async def requesting(): - step_idx = 0 - while step_idx < self.max_step: - step_idx += 1 - try: - async for resp in tool_loop_agent.step(): - if event.is_stopped(): - return - if resp.type == "tool_call_result": - msg_chain = resp.data["chain"] - if msg_chain.type == "tool_direct_result": - # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 - resp.data["chain"].type = "tool_call_result" - await event.send(resp.data["chain"]) - continue - # 对于其他情况,暂时先不处理 - continue - elif resp.type == "tool_call": - if self.streaming_response: - # 用来标记流式响应需要分节 - yield MessageChain(chain=[], type="break") - if ( - self.show_tool_use - or event.get_platform_name() == "webchat" - ): - resp.data["chain"].type = "tool_call" - await event.send(resp.data["chain"]) - continue - - if not self.streaming_response: - content_typ = ( - ResultContentType.LLM_RESULT - if resp.type == "llm_result" - else ResultContentType.GENERAL_RESULT - ) - event.set_result( - MessageEventResult( - chain=resp.data["chain"].chain, - result_content_type=content_typ, - ) - ) - yield - event.clear_result() - else: - if resp.type == "streaming_delta": - yield resp.data["chain"] # MessageChain - if tool_loop_agent.done(): - break - - except Exception as e: - logger.error(traceback.format_exc()) - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n" - ) - ) - return - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=provider.get_model(), - provider_type=provider.meta().type, - ) - ) + astr_agent_ctx = AstrAgentContext( + provider=provider, + first_provider_request=req, + curr_provider_request=req, + streaming=self.streaming_response, + ) + await agent_runner.reset( + provider=provider, + request=req, + run_context=AgentContextWrapper(context=astr_agent_ctx, event=event), + tool_executor=FunctionToolExecutor(), + agent_hooks=MAIN_AGENT_HOOKS, + streaming=self.streaming_response, + ) if self.streaming_response: # 流式响应 event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream(requesting()) + .set_async_stream( + run_agent(agent_runner, self.max_step, self.show_tool_use) + ) ) yield - if tool_loop_agent.done(): - if final_llm_resp := tool_loop_agent.get_final_llm_resp(): + if agent_runner.done(): + if final_llm_resp := agent_runner.get_final_llm_resp(): if final_llm_resp.completion_text: chain = ( MessageChain().message(final_llm_resp.completion_text).chain @@ -266,15 +492,15 @@ class LLMRequestSubStage(Stage): ) ) else: - async for _ in requesting(): + async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use): yield + await self._save_to_history(event, req, agent_runner.get_final_llm_resp()) + # 异步处理 WebChat 特殊情况 if event.get_platform_name() == "webchat": asyncio.create_task(self._handle_webchat(event, req, provider)) - await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp()) - async def _handle_webchat( self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider ): @@ -307,19 +533,10 @@ class LLMRequestSubStage(Stage): if not title or "" in title: return await self.conv_manager.update_conversation_title( - event.unified_msg_origin, title=title + unified_msg_origin=event.unified_msg_origin, + title=title, + conversation_id=req.conversation.cid, ) - # 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题 - # webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}" - # TODO: 优化 WebChat 适配器的对话管理 - if event.session_id: - username, cid = event.session_id.split("!")[1:3] - db_helper = self.ctx.plugin_manager.context._db - db_helper.update_conversation_title( - user_id=username, - cid=cid, - title=title, - ) async def _save_to_history( self, diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 00f58d55b..c5c0f5738 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -2,7 +2,7 @@ 本地 Agent 模式的 AstrBot 插件调用 Stage """ -from ...context import PipelineContext +from ...context import PipelineContext, call_handler from ..stage import Stage from typing import Dict, Any, List, AsyncGenerator, Union from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -33,16 +33,6 @@ class StarRequestSubStage(Stage): handlers_parsed_params = {} for handler in activated_handlers: - # 检查处理器是否在当前平台兼容 - if ( - hasattr(handler, "platform_compatible") - and handler.platform_compatible is False - ): - logger.debug( - f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行" - ) - continue - params = handlers_parsed_params.get(handler.handler_full_name, {}) try: if handler.handler_module_path not in star_map: @@ -50,7 +40,7 @@ class StarRequestSubStage(Stage): logger.debug( f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}" ) - wrapper = self.ctx.call_handler(event, handler.handler, **params) + wrapper = call_handler(event, handler.handler, **params) async for ret in wrapper: yield ret event.clear_result() # 清除上一个 handler 的结果 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 77e62ec7c..ebbba7ed3 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -128,7 +128,7 @@ class RespondStage(Stage): use_fallback = self.config.get("provider_settings", {}).get( "streaming_segmented", False ) - logger.info(f"应用流式输出({event.get_platform_name()})") + logger.info(f"应用流式输出({event.get_platform_id()})") await event.send_streaming(result.async_stream, use_fallback) return elif len(result.chain) > 0: @@ -214,7 +214,7 @@ class RespondStage(Stage): ) handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id() + EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name ) for handler in handlers: try: diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index c9b8b4b8a..f87f7bbc0 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -64,9 +64,10 @@ class ResultDecorateStage(Stage): ] self.content_safe_check_stage = None if self.content_safe_check_reply: - for stage in registered_stages: - if stage.__class__.__name__ == "ContentSafetyCheckStage": - self.content_safe_check_stage = stage + for stage_cls in registered_stages: + if stage_cls.__name__ == "ContentSafetyCheckStage": + self.content_safe_check_stage = stage_cls() + await self.content_safe_check_stage.initialize(ctx) async def process( self, event: AstrMessageEvent @@ -98,7 +99,7 @@ class ResultDecorateStage(Stage): # 发送消息前事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id() + EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name ) for handler in handlers: try: diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index a014aae6f..f1c3988a6 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -11,16 +11,17 @@ class PipelineScheduler: def __init__(self, context: PipelineContext): registered_stages.sort( - key=lambda x: STAGES_ORDER.index(x.__class__.__name__) + key=lambda x: STAGES_ORDER.index(x.__name__) ) # 按照顺序排序 self.ctx = context # 上下文对象 + self.stages = [] # 存储阶段实例 async def initialize(self): """初始化管道调度器时, 初始化所有阶段""" - for stage in registered_stages: - # logger.debug(f"初始化阶段 {stage.__class__ .__name__}") - - await stage.initialize(self.ctx) + for stage_cls in registered_stages: + stage_instance = stage_cls() # 创建实例 + await stage_instance.initialize(self.ctx) + self.stages.append(stage_instance) async def _process_stages(self, event: AstrMessageEvent, from_stage=0): """依次执行各个阶段 @@ -29,9 +30,9 @@ class PipelineScheduler: event (AstrMessageEvent): 事件对象 from_stage (int): 从第几个阶段开始执行, 默认从0开始 """ - for i in range(from_stage, len(registered_stages)): - stage = registered_stages[i] # 获取当前要执行的阶段 - # logger.debug(f"执行阶段 {stage.__class__ .__name__}") + for i in range(from_stage, len(self.stages)): + stage = self.stages[i] # 获取当前要执行的阶段 + # logger.debug(f"执行阶段 {stage.__class__.__name__}") coroutine = stage.process( event ) # 调用阶段的process方法, 返回协程或者异步生成器 diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index b41794733..c4550495a 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,15 +1,15 @@ from __future__ import annotations import abc -from typing import List, AsyncGenerator, Union +from typing import List, AsyncGenerator, Union, Type from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext -registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类 +registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 def register_stage(cls): """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" - registered_stages.append(cls()) + registered_stages.append(cls) return cls diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 2345b6466..63bc8b52d 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -112,8 +112,17 @@ class WakingCheckStage(Stage): activated_handlers = [] handlers_parsed_params = {} # 注册了指令的 handler + # 将 plugins_name 设置到 event 中 + enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"]) + if enabled_plugins_name == ["*"]: + # 如果是 *,则表示所有插件都启用 + event.plugins_name = None + else: + event.plugins_name = enabled_plugins_name + logger.debug(f"enabled_plugins_name: {enabled_plugins_name}") + for handler in star_handlers_registry.get_handlers_by_event_type( - EventType.AdapterMessageEvent + EventType.AdapterMessageEvent, plugins_name=event.plugins_name ): # filter 需满足 AND 逻辑关系 passed = True diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 9867a51b3..75ea317ad 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -3,9 +3,10 @@ import asyncio import re import hashlib import uuid -from dataclasses import dataclass + from typing import List, Union, Optional, AsyncGenerator +from astrbot import logger from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( Plain, @@ -23,21 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric from .astrbot_message import AstrBotMessage, Group from .platform_metadata import PlatformMetadata - - -@dataclass -class MessageSesion: - platform_name: str - message_type: MessageType - session_id: str - - def __str__(self): - return f"{self.platform_name}:{self.message_type.value}:{self.session_id}" - - @staticmethod - def from_str(session_str: str): - platform_name, message_type, session_id = session_str.split(":") - return MessageSesion(platform_name, MessageType(message_type), session_id) +from .message_session import MessageSession, MessageSesion # noqa class AstrMessageEvent(abc.ABC): @@ -64,7 +51,7 @@ class AstrMessageEvent(abc.ABC): """是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)""" self._extras = {} self.session = MessageSesion( - platform_name=platform_meta.name, + platform_name=platform_meta.id, message_type=message_obj.type, session_id=session_id, ) @@ -78,13 +65,23 @@ class AstrMessageEvent(abc.ABC): self.call_llm = False """是否在此消息事件中禁止默认的 LLM 请求""" + self.plugins_name: list[str] | None = None + """该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。""" + # back_compability self.platform = platform_meta def get_platform_name(self): + """获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。 + + NOTE: 用户可能会同时运行多个相同类型的平台适配器。""" return self.platform_meta.name def get_platform_id(self): + """获取这个事件所属的平台的 ID。 + + NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。 + """ return self.platform_meta.id def get_message_str(self) -> str: @@ -188,6 +185,7 @@ class AstrMessageEvent(abc.ABC): """ 清除额外的信息。 """ + logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}") self._extras.clear() def is_private_chat(self) -> bool: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 23109ca53..62328e881 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -18,6 +18,9 @@ class PlatformManager: self.platforms_config = config["platform"] self.settings = config["platform_settings"] + """NOTE: 这里是 default 的配置文件,以保证最大的兼容性; + 这个配置中的 unique_session 需要特殊处理, + 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue async def initialize(self): diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py new file mode 100644 index 000000000..bf5a72a9a --- /dev/null +++ b/astrbot/core/platform/message_session.py @@ -0,0 +1,28 @@ +from astrbot.core.platform.message_type import MessageType +from dataclasses import dataclass + + +@dataclass +class MessageSession: + """描述一条消息在 AstrBot 中对应的会话的唯一标识。 + 如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。""" + + platform_name: str + """平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。""" + message_type: MessageType + session_id: str + platform_id: str = None + + def __str__(self): + return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" + + def __post_init__(self): + self.platform_id = self.platform_name + + @staticmethod + def from_str(session_str: str): + platform_id, message_type, session_id = session_str.split(":") + return MessageSession(platform_id, MessageType(message_type), session_id) + + +MessageSesion = MessageSession # back compatibility diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 6ed53fe0e..c109f29b4 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -5,7 +5,7 @@ from asyncio import Queue from .platform_metadata import PlatformMetadata from .astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageChain -from .astr_message_event import MessageSesion +from .message_session import MessageSesion from astrbot.core.utils.metrics import Metric diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index dd0e93fec..7fb7f9d3e 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -4,7 +4,7 @@ from dataclasses import dataclass @dataclass class PlatformMetadata: name: str - """平台的名称""" + """平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" description: str """平台的描述""" id: str = None diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index aaac8e289..43da100f4 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -77,7 +77,7 @@ class WebChatAdapter(Platform): os.makedirs(self.imgs_dir, exist_ok=True) self.metadata = PlatformMetadata( - name="webchat", description="webchat", id=self.config.get("id", "") + name="webchat", description="webchat", id="webchat" ) async def send_by_session( diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py new file mode 100644 index 000000000..16e59a5cc --- /dev/null +++ b/astrbot/core/platform_message_history_mgr.py @@ -0,0 +1,47 @@ +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import PlatformMessageHistory + + +class PlatformMessageHistoryManager: + def __init__(self, db_helper: BaseDatabase): + self.db = db_helper + + async def insert( + self, + platform_id: str, + user_id: str, + content: list[dict], # TODO: parse from message chain + sender_id: str = None, + sender_name: str = None, + ): + """Insert a new platform message history record.""" + await self.db.insert_platform_message_history( + platform_id=platform_id, + user_id=user_id, + content=content, + sender_id=sender_id, + sender_name=sender_name, + ) + + async def get( + self, + platform_id: str, + user_id: str, + page: int = 1, + page_size: int = 200, + ) -> list[PlatformMessageHistory]: + """Get platform message history for a specific user.""" + history = await self.db.get_platform_message_history( + platform_id=platform_id, + user_id=user_id, + page=page, + page_size=page_size, + ) + history.reverse() + return history + + async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): + """Delete platform message history records older than the specified offset.""" + await self.db.delete_platform_message_offset( + platform_id=platform_id, user_id=user_id, offset_sec=offset_sec + ) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 2d120d7f6..0a31093ae 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field from typing import List, Dict, Type -from .func_tool_manager import FuncCall +from astrbot.core.agent.tool import ToolSet from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, @@ -20,6 +20,7 @@ class ProviderType(enum.Enum): SPEECH_TO_TEXT = "speech_to_text" TEXT_TO_SPEECH = "text_to_speech" EMBEDDING = "embedding" + RERANK = "rerank" @dataclass @@ -97,7 +98,7 @@ class ProviderRequest: """会话 ID""" image_urls: list[str] = field(default_factory=list) """图片 URL 列表""" - func_tool: FuncCall | None = None + func_tool: ToolSet | None = None """可用的函数工具""" contexts: list[dict] = field(default_factory=list) """上下文。格式与 openai 的上下文格式一致: @@ -293,3 +294,10 @@ class LLMResponse: } ) return ret + +@dataclass +class RerankResult: + index: int + """在候选列表中的索引位置""" + relevance_score: float + """相关性分数""" diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 07a0fbd8f..509975556 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,32 +1,17 @@ from __future__ import annotations import json -import textwrap import os import asyncio -import logging -from datetime import timedelta +import aiohttp -from typing import Dict, List, Awaitable, Literal, Any -from dataclasses import dataclass -from typing import Optional -from contextlib import AsyncExitStack +from typing import Dict, List, Awaitable from astrbot import logger -from astrbot.core.utils.log_pipe import LogPipe +from astrbot.core import sp from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.agent.mcp_client import MCPClient +from astrbot.core.agent.tool import ToolSet, FunctionTool -try: - import mcp - from mcp.client.sse import sse_client -except (ModuleNotFoundError, ImportError): - logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。") - -try: - from mcp.client.streamable_http import streamablehttp_client -except (ModuleNotFoundError, ImportError): - logger.warning( - "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。" - ) DEFAULT_MCP_CONFIG = {"mcpServers": {}} @@ -39,6 +24,10 @@ SUPPORTED_TYPES = [ ] # json schema 支持的数据类型 +# alias +FuncTool = FunctionTool + + def _prepare_config(config: dict) -> dict: """准备配置,处理嵌套格式""" if "mcpServers" in config and config["mcpServers"]: @@ -105,181 +94,9 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" -@dataclass -class FuncTool: - """ - 用于描述一个函数调用工具。 - """ - - name: str - parameters: Dict - description: str - handler: Awaitable = None - """处理函数, 当 origin 为 mcp 时,这个为空""" - handler_module_path: str = None - """处理函数的模块路径,当 origin 为 mcp 时,这个为空 - - 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools - """ - active: bool = True - """是否激活""" - - origin: Literal["local", "mcp"] = "local" - """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" - - # MCP 相关字段 - mcp_server_name: str = None - """MCP 服务名称,当 origin 为 mcp 时有效""" - mcp_client: MCPClient = None - """MCP 客户端,当 origin 为 mcp 时有效""" - - def __repr__(self): - return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" - - async def execute(self, **args) -> Any: - """执行函数调用""" - if self.origin == "local": - if not self.handler: - raise Exception(f"Local function {self.name} has no handler") - return await self.handler(**args) - elif self.origin == "mcp": - if not self.mcp_client or not self.mcp_client.session: - raise Exception(f"MCP client for {self.name} is not available") - # 使用name属性而不是额外的mcp_tool_name - actual_tool_name = ( - self.name.split(":")[-1] if ":" in self.name else self.name - ) - return await self.mcp_client.session.call_tool(actual_tool_name, args) - else: - raise Exception(f"Unknown function origin: {self.origin}") - - -class MCPClient: - def __init__(self): - # Initialize session and client objects - self.session: Optional[mcp.ClientSession] = None - self.exit_stack = AsyncExitStack() - - self.name = None - self.active: bool = True - self.tools: List[mcp.Tool] = [] - self.server_errlogs: List[str] = [] - self.running_event = asyncio.Event() - - async def connect_to_server(self, mcp_server_config: dict, name: str): - """连接到 MCP 服务器 - - 如果 `url` 参数存在: - 1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。 - 1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。 - 2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。 - - Args: - mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server - """ - cfg = _prepare_config(mcp_server_config.copy()) - - def logging_callback(msg: str): - # 处理 MCP 服务的错误日志 - print(f"MCP Server {name} Error: {msg}") - self.server_errlogs.append(msg) - - if "url" in cfg: - success, error_msg = await _quick_test_mcp_connection(cfg) - if not success: - raise Exception(error_msg) - - if cfg.get("transport") != "streamable_http": - # SSE transport method - self._streams_context = sse_client( - url=cfg["url"], - headers=cfg.get("headers", {}), - timeout=cfg.get("timeout", 5), - sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), - ) - streams = await self.exit_stack.enter_async_context( - self._streams_context - ) - - # Create a new client session - read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) - self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession( - *streams, - read_timeout_seconds=read_timeout, - logging_callback=logging_callback, # type: ignore - ) - ) - else: - timeout = timedelta(seconds=cfg.get("timeout", 30)) - sse_read_timeout = timedelta( - seconds=cfg.get("sse_read_timeout", 60 * 5) - ) - self._streams_context = streamablehttp_client( - url=cfg["url"], - headers=cfg.get("headers", {}), - timeout=timeout, - sse_read_timeout=sse_read_timeout, - terminate_on_close=cfg.get("terminate_on_close", True), - ) - read_s, write_s, _ = await self.exit_stack.enter_async_context( - self._streams_context - ) - - # Create a new client session - read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) - self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession( - read_stream=read_s, - write_stream=write_s, - read_timeout_seconds=read_timeout, - logging_callback=logging_callback, # type: ignore - ) - ) - - else: - server_params = mcp.StdioServerParameters( - **cfg, - ) - - def callback(msg: str): - # 处理 MCP 服务的错误日志 - self.server_errlogs.append(msg) - - stdio_transport = await self.exit_stack.enter_async_context( - mcp.stdio_client( - server_params, - errlog=LogPipe( - level=logging.ERROR, - logger=logger, - identifier=f"MCPServer-{name}", - callback=callback, - ), # type: ignore - ), - ) - - # Create a new client session - self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(*stdio_transport) - ) - await self.session.initialize() - - async def list_tools_and_save(self) -> mcp.ListToolsResult: - """List all tools from the server and save them to self.tools""" - response = await self.session.list_tools() - self.tools = response.tools - return response - - async def cleanup(self): - """Clean up resources""" - await self.exit_stack.aclose() - self.running_event.set() # Set the running event to indicate cleanup is done - - -class FuncCall: +class FunctionToolManager: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - """内部加载的 func tools""" self.mcp_client_dict: Dict[str, MCPClient] = {} """MCP 服务列表""" self.mcp_client_event: Dict[str, asyncio.Event] = {} @@ -287,6 +104,29 @@ class FuncCall: def empty(self) -> bool: return len(self.func_list) == 0 + def spec_to_func( + self, + name: str, + func_args: list, + desc: str, + handler: Awaitable, + ) -> FuncTool: + params = { + "type": "object", # hard-coded here + "properties": {}, + } + for param in func_args: + params["properties"][param["name"]] = { + "type": param["type"], + "description": param["description"], + } + return FuncTool( + name=name, + parameters=params, + description=desc, + handler=handler, + ) + def add_func( self, name: str, @@ -304,22 +144,14 @@ class FuncCall: # check if the tool has been added before self.remove_func(name) - params = { - "type": "object", # hard-coded here - "properties": {}, - } - for param in func_args: - params["properties"][param["name"]] = { - "type": param["type"], - "description": param["description"], - } - _func = FuncTool( - name=name, - parameters=params, - description=desc, - handler=handler, + self.func_list.append( + self.spec_to_func( + name=name, + func_args=func_args, + desc=desc, + handler=handler, + ) ) - self.func_list.append(_func) logger.info(f"添加函数调用工具: {name}") def remove_func(self, name: str) -> None: @@ -331,11 +163,15 @@ class FuncCall: self.func_list.pop(i) break - def get_func(self, name) -> FuncTool: + def get_func(self, name) -> FuncTool | None: for f in self.func_list: if f.name == name: return f - return None + + def get_full_tool_set(self) -> ToolSet: + """获取完整工具集""" + tool_set = ToolSet(self.func_list.copy()) + return tool_set async def init_mcp_clients(self) -> None: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: @@ -556,203 +392,179 @@ class FuncCall: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ - _l = [] - # 处理所有工具(包括本地和MCP工具) - for f in self.func_list: - if not f.active: - continue - func_ = { - "type": "function", - "function": { - "name": f.name, - # "parameters": f.parameters, - "description": f.description, - }, - } - func_["function"]["parameters"] = f.parameters - if not f.parameters.get("properties") and omit_empty_parameter_field: - # 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段 - del func_["function"]["parameters"] - _l.append(func_) - return _l + tools = [f for f in self.func_list if f.active] + toolset = ToolSet(tools) + return toolset.openai_schema( + omit_empty_parameter_field=omit_empty_parameter_field + ) def get_func_desc_anthropic_style(self) -> list: """ 获得 Anthropic API 风格的**已经激活**的工具描述 """ - tools = [] - for f in self.func_list: - if not f.active: - continue - - # Convert internal format to Anthropic style - tool = { - "name": f.name, - "description": f.description, - "input_schema": { - "type": "object", - "properties": f.parameters.get("properties", {}), - # Keep the required field from the original parameters if it exists - "required": f.parameters.get("required", []), - }, - } - tools.append(tool) - return tools + tools = [f for f in self.func_list if f.active] + toolset = ToolSet(tools) + return toolset.anthropic_schema() def get_func_desc_google_genai_style(self) -> dict: """ 获得 Google GenAI API 风格的**已经激活**的工具描述 """ + tools = [f for f in self.func_list if f.active] + toolset = ToolSet(tools) + return toolset.google_schema() - # Gemini API 支持的数据类型和格式 - supported_types = { - "string", - "number", - "integer", - "boolean", - "array", - "object", - "null", - } - supported_formats = { - "string": {"enum", "date-time"}, - "integer": {"int32", "int64"}, - "number": {"float", "double"}, - } + def deactivate_llm_tool(self, name: str) -> bool: + """停用一个已经注册的函数调用工具。 - def convert_schema(schema: dict) -> dict: - """转换 schema 为 Gemini API 格式""" + Returns: + 如果没找到,会返回 False""" + func_tool = self.get_func(name) + if func_tool is not None: + func_tool.active = False - # 如果 schema 包含 anyOf,则只返回 anyOf 字段 - if "anyOf" in schema: - return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]} - - result = {} - - if "type" in schema and schema["type"] in supported_types: - result["type"] = schema["type"] - if "format" in schema and schema["format"] in supported_formats.get( - result["type"], set() - ): - result["format"] = schema["format"] - else: - # 暂时指定默认为null - result["type"] = "null" - - support_fields = { - "title", - "description", - "enum", - "minimum", - "maximum", - "maxItems", - "minItems", - "nullable", - "required", - } - result.update({k: schema[k] for k in support_fields if k in schema}) - - if "properties" in schema: - properties = {} - for key, value in schema["properties"].items(): - prop_value = convert_schema(value) - if "default" in prop_value: - del prop_value["default"] - properties[key] = prop_value - - if properties: # 只在有非空属性时添加 - result["properties"] = properties - - if "items" in schema: - result["items"] = convert_schema(schema["items"]) - - return result - - tools = [ - { - "name": f.name, - "description": f.description, - **({"parameters": convert_schema(f.parameters)}), - } - for f in self.func_list - if f.active - ] - - declarations = {} - if tools: - declarations["function_declarations"] = tools - return declarations - - async def func_call(self, question: str, session_id: str, provider) -> tuple: - _l = [] - for f in self.func_list: - if not f.active: - continue - _l.append( - { - "name": f.name, - "parameters": f.parameters, - "description": f.description, - } + inactivated_llm_tools: list = sp.get( + "inactivated_llm_tools", [], scope="global", scope_id="global" ) - func_definition = json.dumps(_l, ensure_ascii=False) + if name not in inactivated_llm_tools: + inactivated_llm_tools.append(name) + sp.put( + "inactivated_llm_tools", + inactivated_llm_tools, + scope="global", + scope_id="global", + ) - prompt = textwrap.dedent(f""" - ROLE: - 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。 + return True + return False - TOOLS: - 可用的函数列表: + # 因为不想解决循环引用,所以这里直接传入 star_map 先了... + def activate_llm_tool(self, name: str, star_map: dict) -> bool: + func_tool = self.get_func(name) + if func_tool is not None: + if func_tool.handler_module_path in star_map: + if not star_map[func_tool.handler_module_path].activated: + raise ValueError( + f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。" + ) - {func_definition} + func_tool.active = True - LIMIT: - 1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。 - 2. 你的 Json 返回的格式如下:`[{{"name": "", "args": }}, ...]`。参数根据上面提供的函数列表中的参数来填写。 - 3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。 - 4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。 + inactivated_llm_tools: list = sp.get( + "inactivated_llm_tools", [], scope="global", scope_id="global" + ) + if name in inactivated_llm_tools: + inactivated_llm_tools.remove(name) + sp.put( + "inactivated_llm_tools", + inactivated_llm_tools, + scope="global", + scope_id="global", + ) - EXAMPLE: - 1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}] + return True + return False - 用户的提问是:{question} - """) + @property + def mcp_config_path(self): + data_dir = get_astrbot_data_path() + return os.path.join(data_dir, "mcp_server.json") - _c = 0 - while _c < 3: - try: - res = await provider.text_chat(prompt, session_id) - if res.find("```") != -1: - res = res[res.find("```json") + 7 : res.rfind("```")] - res = json.loads(res) - break - except Exception as e: - _c += 1 - if _c == 3: - raise e - if "The message you submitted was too long" in str(e): - raise e + def load_mcp_config(self): + if not os.path.exists(self.mcp_config_path): + # 配置文件不存在,创建默认配置 + os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) + with open(self.mcp_config_path, "w", encoding="utf-8") as f: + json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + return DEFAULT_MCP_CONFIG - if "res" in res and not res["res"]: - return "", False + try: + with open(self.mcp_config_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.error(f"加载 MCP 配置失败: {e}") + return DEFAULT_MCP_CONFIG - tool_call_result = [] - for tool in res: - # 说明有函数调用 - func_name = tool["name"] - args = tool["args"] - # 调用函数 - func_tool = self.get_func(func_name) - if not func_tool: - raise Exception(f"Request function {func_name} not found.") + def save_mcp_config(self, config: dict): + try: + with open(self.mcp_config_path, "w", encoding="utf-8") as f: + json.dump(config, f, ensure_ascii=False, indent=4) + return True + except Exception as e: + logger.error(f"保存 MCP 配置失败: {e}") + return False - ret = await func_tool.execute(**args) - if ret: - tool_call_result.append(str(ret)) - return tool_call_result, True + async def sync_modelscope_mcp_servers(self, access_token: str) -> None: + """从 ModelScope 平台同步 MCP 服务器配置""" + base_url = "https://www.modelscope.cn/openapi/v1" + url = f"{base_url}/mcp/servers/operational" + headers = { + "Authorization": f"Bearer {access_token.strip()}", + "Content-Type": "application/json", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + mcp_server_list = data.get("data", {}).get( + "mcp_server_list", [] + ) + local_mcp_config = self.load_mcp_config() + + synced_count = 0 + for server in mcp_server_list: + server_name = server["name"] + operational_urls = server.get("operational_urls", []) + if not operational_urls: + continue + url_info = operational_urls[0] + server_url = url_info.get("url") + if not server_url: + continue + # 添加到配置中(同名会覆盖) + local_mcp_config["mcpServers"][server_name] = { + "url": server_url, + "transport": "sse", + "active": True, + "provider": "modelscope", + } + synced_count += 1 + + if synced_count > 0: + self.save_mcp_config(local_mcp_config) + tasks = [] + for server in mcp_server_list: + name = server["name"] + tasks.append( + self.enable_mcp_server( + name=name, + config=local_mcp_config["mcpServers"][name], + ) + ) + await asyncio.gather(*tasks) + logger.info( + f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器" + ) + else: + logger.warning("没有找到可用的 ModelScope MCP 服务器") + else: + raise Exception( + f"ModelScope API 请求失败: HTTP {response.status}" + ) + + except aiohttp.ClientError as e: + raise Exception(f"网络连接错误: {str(e)}") + except Exception as e: + raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {str(e)}") def __str__(self): return str(self.func_list) def __repr__(self): return str(self.func_list) + + +# alias +FuncCall = FunctionToolManager diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 370c5322b..19f62edfa 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -3,89 +3,32 @@ import traceback from typing import List from astrbot.core import logger, sp -from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase from .entities import ProviderType -from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider +from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider from .register import llm_tools, provider_cls_map +from ..persona_mgr import PersonaManager class ProviderManager: - def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase): + def __init__( + self, + acm: AstrBotConfigManager, + db_helper: BaseDatabase, + persona_mgr: PersonaManager, + ): + self.persona_mgr = persona_mgr + self.acm = acm + config = acm.confs["default"] self.providers_config: List = config["provider"] self.provider_settings: dict = config["provider_settings"] self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) - self.persona_configs: list = config.get("persona", []) - self.astrbot_config = config - # 人格情景管理 - # 目前没有拆成独立的模块 - self.default_persona_name = self.provider_settings.get( - "default_personality", "default" - ) - self.personas: List[Personality] = [] - self.selected_default_persona = None - for persona in self.persona_configs: - begin_dialogs = persona.get("begin_dialogs", []) - mood_imitation_dialogs = persona.get("mood_imitation_dialogs", []) - bd_processed = [] - mid_processed = "" - if begin_dialogs: - if len(begin_dialogs) % 2 != 0: - logger.error( - f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。" - ) - begin_dialogs = [] - user_turn = True - for dialog in begin_dialogs: - bd_processed.append( - { - "role": "user" if user_turn else "assistant", - "content": dialog, - "_no_save": None, # 不持久化到 db - } - ) - user_turn = not user_turn - if mood_imitation_dialogs: - if len(mood_imitation_dialogs) % 2 != 0: - logger.error( - f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。" - ) - mood_imitation_dialogs = [] - user_turn = True - for dialog in mood_imitation_dialogs: - role = "A" if user_turn else "B" - mid_processed += f"{role}: {dialog}\n" - if not user_turn: - mid_processed += "\n" - user_turn = not user_turn - - try: - persona = Personality( - **persona, - _begin_dialogs_processed=bd_processed, - _mood_imitation_dialogs_processed=mid_processed, - ) - if persona["name"] == self.default_persona_name: - self.selected_default_persona = persona - self.personas.append(persona) - except Exception as e: - logger.error(f"解析 Persona 配置失败:{e}") - - if not self.selected_default_persona and len(self.personas) > 0: - # 默认选择第一个 - self.selected_default_persona = self.personas[0] - - if not self.selected_default_persona: - self.selected_default_persona = Personality( - prompt="You are a helpful and friendly assistant.", - name="default", - _begin_dialogs_processed=[], - _mood_imitation_dialogs_processed="", - ) - self.personas.append(self.selected_default_persona) + # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager + self.default_persona_name = persona_mgr.default_persona self.provider_insts: List[Provider] = [] """加载的 Provider 的实例""" @@ -100,46 +43,111 @@ class ProviderManager: self.llm_tools = llm_tools self.curr_provider_inst: Provider | None = None - """默认的 Provider 实例""" + """默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.curr_stt_provider_inst: STTProvider | None = None - """默认的 Speech To Text Provider 实例""" + """默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.curr_tts_provider_inst: TTSProvider | None = None - """默认的 Text To Speech Provider 实例""" + """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper - # kdb(experimental) - self.curr_kdb_name = "" - kdb_cfg = config.get("knowledge_db", {}) - if kdb_cfg and len(kdb_cfg): - self.curr_kdb_name = list(kdb_cfg.keys())[0] + @property + def persona_configs(self) -> list: + """动态获取最新的 persona 配置""" + return self.persona_mgr.persona_v3_config + + @property + def personas(self) -> list: + """动态获取最新的 personas 列表""" + return self.persona_mgr.personas_v3 + + @property + def selected_default_persona(self): + """动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()""" + return self.persona_mgr.selected_default_persona_v3 async def set_provider( - self, provider_id: str, provider_type: ProviderType, umo: str = None + self, provider_id: str, provider_type: ProviderType, umo: str | None = None ): """设置提供商。 Args: provider_id (str): 提供商 ID。 provider_type (ProviderType): 提供商类型。 - umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。 + umo (str, optional): 用户会话 ID,用于提供商会话隔离。 + + Version 4.0.0: 这个版本下已经默认隔离提供商 """ if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") - if umo and self.provider_settings["separate_provider"]: - perf = sp.get("session_provider_perf", {}) - session_perf = perf.get(umo, {}) - session_perf[provider_type.value] = provider_id - perf[umo] = session_perf - sp.put("session_provider_perf", perf) + if umo: + await sp.session_put( + umo, + f"provider_perf_{provider_type.value}", + provider_id, + ) return # 不启用提供商会话隔离模式的情况 self.curr_provider_inst = self.inst_map[provider_id] if provider_type == ProviderType.TEXT_TO_SPEECH: - sp.put("curr_provider_tts", provider_id) + sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.SPEECH_TO_TEXT: - sp.put("curr_provider_stt", provider_id) + sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.CHAT_COMPLETION: - sp.put("curr_provider", provider_id) + sp.put("curr_provider", provider_id, scope="global", scope_id="global") + + async def get_provider_by_id(self, provider_id: str) -> Provider | None: + """根据提供商 ID 获取提供商实例""" + return self.inst_map.get(provider_id) + + def get_using_provider(self, provider_type: ProviderType, umo=None): + """获取正在使用的提供商实例。 + + Args: + provider_type (ProviderType): 提供商类型。 + umo (str, optional): 用户会话 ID,用于提供商会话隔离。 + + Returns: + Provider: 正在使用的提供商实例。 + """ + provider = None + if umo: + provider_id = sp.get( + f"provider_perf_{provider_type.value}", + None, + scope="umo", + scope_id=umo, + ) + if provider_id: + provider = self.inst_map.get(provider_id) + if not provider: + # default setting + config = self.acm.get_conf(umo) + if provider_type == ProviderType.CHAT_COMPLETION: + provider_id = config["provider_settings"].get("default_provider_id") + provider = self.inst_map.get(provider_id) + if not provider: + provider = self.provider_insts[0] if self.provider_insts else None + elif provider_type == ProviderType.SPEECH_TO_TEXT: + provider_id = config["provider_stt_settings"].get("provider_id") + if not provider_id: + return None + provider = self.inst_map.get(provider_id) + if not provider: + provider = ( + self.stt_provider_insts[0] if self.stt_provider_insts else None + ) + elif provider_type == ProviderType.TEXT_TO_SPEECH: + provider_id = config["provider_tts_settings"].get("provider_id") + if not provider_id: + return None + provider = self.inst_map.get(provider_id) + if not provider: + provider = ( + self.tts_provider_insts[0] if self.tts_provider_insts else None + ) + else: + raise ValueError(f"Unknown provider type: {provider_type}") + return provider async def initialize(self): # 逐个初始化提供商 @@ -148,13 +156,22 @@ class ProviderManager: # 设置默认提供商 selected_provider_id = sp.get( - "curr_provider", self.provider_settings.get("default_provider_id") + "curr_provider", + self.provider_settings.get("default_provider_id"), + scope="global", + scope_id="global", ) selected_stt_provider_id = sp.get( - "curr_provider_stt", self.provider_stt_settings.get("provider_id") + "curr_provider_stt", + self.provider_stt_settings.get("provider_id"), + scope="global", + scope_id="global", ) selected_tts_provider_id = sp.get( - "curr_provider_tts", self.provider_tts_settings.get("provider_id") + "curr_provider_tts", + self.provider_tts_settings.get("provider_id"), + scope="global", + scope_id="global", ) self.curr_provider_inst = self.inst_map.get(selected_provider_id) if not self.curr_provider_inst and self.provider_insts: @@ -262,6 +279,10 @@ class ProviderManager: from .sources.gemini_embedding_source import ( GeminiEmbeddingProvider as GeminiEmbeddingProvider, ) + case "vllm_rerank": + from .sources.vllm_rerank_source import ( + VLLMRerankProvider as VLLMRerankProvider, + ) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" @@ -345,7 +366,7 @@ class ProviderManager: if not self.curr_provider_inst: self.curr_provider_inst = inst - elif provider_metadata.provider_type == ProviderType.EMBEDDING: + elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]: inst = provider_metadata.cls_type( provider_config, self.provider_settings ) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 36401b089..01618767c 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,23 +1,18 @@ import abc from typing import List -from typing import TypedDict, AsyncGenerator -from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType +from typing import AsyncGenerator +from astrbot.core.agent.tool import ToolSet +from astrbot.core.provider.entities import ( + LLMResponse, + ToolCallsResult, + ProviderType, + RerankResult, +) from astrbot.core.provider.register import provider_cls_map +from astrbot.core.db.po import Personality from dataclasses import dataclass -class Personality(TypedDict): - prompt: str = "" - name: str = "" - begin_dialogs: List[str] = [] - mood_imitation_dialogs: List[str] = [] - - # cache - _begin_dialogs_processed: List[dict] = [] - _mood_imitation_dialogs_processed: str = "" - - @dataclass class ProviderMeta: id: str @@ -90,7 +85,7 @@ class Provider(AbstractProvider): prompt: str, session_id: str = None, image_urls: list[str] = None, - func_tool: FuncCall = None, + func_tool: ToolSet = None, contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, @@ -119,7 +114,7 @@ class Provider(AbstractProvider): prompt: str, session_id: str = None, image_urls: list[str] = None, - func_tool: FuncCall = None, + func_tool: ToolSet = None, contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, @@ -206,3 +201,17 @@ class EmbeddingProvider(AbstractProvider): def get_dim(self) -> int: """获取向量的维度""" ... + + +class RerankProvider(AbstractProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config) + self.provider_config = provider_config + self.provider_settings = provider_settings + + @abc.abstractmethod + async def rerank( + self, query: str, documents: list[str], top_n: int | None = None + ) -> list[RerankResult]: + """获取查询和文档的重排序分数""" + ... diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 46b12726b..4e14d20da 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -75,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 - session_vars = sp.get("session_variables", {}) - session_var = session_vars.get(session_id, {}) + session_var = await sp.session_get(session_id, "session_variables", default={}) payload_vars.update(session_var) if ( diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 9539227fe..e19e912ac 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -97,8 +97,7 @@ class ProviderDify(Provider): # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 - session_vars = sp.get("session_variables", {}) - session_var = session_vars.get(session_id, {}) + session_var = await sp.session_get(session_id, "session_variables", default={}) payload_vars.update(session_var) payload_vars["system_prompt"] = system_prompt diff --git a/astrbot/core/provider/sources/vllm_rerank_source.py b/astrbot/core/provider/sources/vllm_rerank_source.py new file mode 100644 index 000000000..af48e69af --- /dev/null +++ b/astrbot/core/provider/sources/vllm_rerank_source.py @@ -0,0 +1,59 @@ +import aiohttp +from ..provider import RerankProvider +from ..register import register_provider_adapter +from ..entities import ProviderType, RerankResult + + +@register_provider_adapter( + "vllm_rerank", + "VLLM Rerank 适配器", + provider_type=ProviderType.RERANK, +) +class VLLMRerankProvider(RerankProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self.auth_key = provider_config.get("rerank_api_key", "") + self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000") + self.base_url = self.base_url.rstrip("/") + self.timeout = provider_config.get("timeout", 20) + self.model = provider_config.get("rerank_model", "BAAI/bge-reranker-base") + + h = {} + if self.auth_key: + h["Authorization"] = f"Bearer {self.auth_key}" + self.client = aiohttp.ClientSession( + headers=h, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) + + async def rerank( + self, query: str, documents: list[str], top_n: int | None = None + ) -> list[RerankResult]: + payload = { + "query": query, + "documents": documents, + "model": self.model, + } + if top_n is not None: + payload["top_n"] = top_n + async with self.client.post( + f"{self.base_url}/v1/rerank", json=payload + ) as response: + response_data = await response.json() + results = response_data.get("results", []) + + return [ + RerankResult( + index=result["index"], + relevance_score=result["relevance_score"], + ) + for result in results + ] + + async def terminate(self) -> None: + """关闭客户端会话""" + if self.client: + await self.client.close() + self.client = None diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 86318f8b7..fab39294b 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -34,7 +34,7 @@ class Star(CommandParserMixin): @staticmethod async def html_render( - tmpl: str, data: dict, return_url=True, options: dict = None + tmpl: str, data: dict, return_url=True, options: dict | None = None ) -> str: """渲染 HTML""" return await html_renderer.render_custom_template( diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0b14525d3..76db898aa 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,17 +1,24 @@ from asyncio import Queue from typing import List, Union -from astrbot.core import sp -from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider +from astrbot.core.provider.provider import ( + Provider, + TTSProvider, + STTProvider, + EmbeddingProvider, +) from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.provider.func_tool_manager import FunctionToolManager from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.manager import ProviderManager from astrbot.core.platform import Platform from astrbot.core.platform.manager import PlatformManager +from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.persona_mgr import PersonaManager from .star import star_registry, StarMetadata, star_map from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter @@ -22,6 +29,7 @@ from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, ADAPTER_NAME_2_TYPE, ) +from deprecated import deprecated class Context: @@ -29,19 +37,6 @@ class Context: 暴露给插件的接口上下文。 """ - _event_queue: Queue = None - """事件队列。消息平台通过事件队列传递消息事件。""" - - _config: AstrBotConfig = None - """AstrBot 配置信息""" - - _db: BaseDatabase = None - """AstrBot 数据库""" - - provider_manager: ProviderManager = None - - platform_manager: PlatformManager = None - registered_web_apis: list = [] # back compatibility @@ -53,18 +48,27 @@ class Context: event_queue: Queue, config: AstrBotConfig, db: BaseDatabase, - provider_manager: ProviderManager = None, - platform_manager: PlatformManager = None, - conversation_manager: ConversationManager = None, + provider_manager: ProviderManager, + platform_manager: PlatformManager, + conversation_manager: ConversationManager, + message_history_manager: PlatformMessageHistoryManager, + persona_manager: PersonaManager, + astrbot_config_mgr: AstrBotConfigManager, ): self._event_queue = event_queue + """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config + """AstrBot 默认配置""" self._db = db + """AstrBot 数据库""" self.provider_manager = provider_manager self.platform_manager = platform_manager self.conversation_manager = conversation_manager + self.message_history_manager = message_history_manager + self.persona_manager = persona_manager + self.astrbot_config_mgr = astrbot_config_mgr - def get_registered_star(self, star_name: str) -> StarMetadata: + def get_registered_star(self, star_name: str) -> StarMetadata | None: """根据插件名获取插件的 Metadata""" for star in star_registry: if star.name == star_name: @@ -74,7 +78,7 @@ class Context: """获取当前载入的所有插件 Metadata 的列表""" return star_registry - def get_llm_tool_manager(self) -> FuncCall: + def get_llm_tool_manager(self) -> FunctionToolManager: """获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools""" return self.provider_manager.llm_tools @@ -84,40 +88,14 @@ class Context: Returns: 如果没找到,会返回 False """ - func_tool = self.provider_manager.llm_tools.get_func(name) - if func_tool is not None: - if func_tool.handler_module_path in star_map: - if not star_map[func_tool.handler_module_path].activated: - raise ValueError( - f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。" - ) - - func_tool.active = True - - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - if name in inactivated_llm_tools: - inactivated_llm_tools.remove(name) - sp.put("inactivated_llm_tools", inactivated_llm_tools) - - return True - return False + return self.provider_manager.llm_tools.activate_llm_tool(name, star_map) def deactivate_llm_tool(self, name: str) -> bool: """停用一个已经注册的函数调用工具。 Returns: 如果没找到,会返回 False""" - func_tool = self.provider_manager.llm_tools.get_func(name) - if func_tool is not None: - func_tool.active = False - - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - if name not in inactivated_llm_tools: - inactivated_llm_tools.append(name) - sp.put("inactivated_llm_tools", inactivated_llm_tools) - - return True - return False + return self.provider_manager.llm_tools.deactivate_llm_tool(name) def register_provider(self, provider: Provider): """ @@ -125,7 +103,7 @@ class Context: """ self.provider_manager.provider_insts.append(provider) - def get_provider_by_id(self, provider_id: str) -> Provider: + def get_provider_by_id(self, provider_id: str) -> Provider | None: """通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.inst_map.get(provider_id) @@ -145,51 +123,49 @@ class Context: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts - def get_using_provider(self, umo: str = None) -> Provider: + def get_using_provider(self, umo: str | None = None) -> Provider | None: """ 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 Args: umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 """ - if umo and self._config["provider_settings"]["separate_provider"]: - perf = sp.get("session_provider_perf", {}) - prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None) - if inst := self.provider_manager.inst_map.get(prov_id, None): - return inst - return self.provider_manager.curr_provider_inst + return self.provider_manager.get_using_provider( + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) - def get_using_tts_provider(self, umo: str = None) -> TTSProvider: + def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider: """ 获取当前使用的用于 TTS 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - if umo and self._config["provider_settings"]["separate_provider"]: - perf = sp.get("session_provider_perf", {}) - prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None) - if inst := self.provider_manager.inst_map.get(prov_id, None): - return inst - return self.provider_manager.curr_tts_provider_inst + return self.provider_manager.get_using_provider( + provider_type=ProviderType.TEXT_TO_SPEECH, + umo=umo, + ) - def get_using_stt_provider(self, umo: str = None) -> STTProvider: + def get_using_stt_provider(self, umo: str | None = None) -> STTProvider: """ 获取当前使用的用于 STT 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - if umo and self._config["provider_settings"]["separate_provider"]: - perf = sp.get("session_provider_perf", {}) - prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None) - if inst := self.provider_manager.inst_map.get(prov_id, None): - return inst - return self.provider_manager.curr_stt_provider_inst + return self.provider_manager.get_using_provider( + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=umo, + ) - def get_config(self) -> AstrBotConfig: + def get_config(self, umo: str | None = None) -> AstrBotConfig: """获取 AstrBot 的配置。""" - return self._config + if not umo: + # using default config + return self._config + else: + return self.astrbot_config_mgr.get_conf(umo) def get_db(self) -> BaseDatabase: """获取 AstrBot 数据库。""" @@ -201,9 +177,14 @@ class Context: """ return self._event_queue - def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform: + @deprecated(version="4.0.0", reason="Use get_platform_inst instead") + def get_platform( + self, platform_type: Union[PlatformAdapterType, str] + ) -> Platform | None: """ 获取指定类型的平台适配器。 + + 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) """ for platform in self.platform_manager.platform_insts: name = platform.meta().name @@ -217,6 +198,20 @@ class Context: ): return platform + def get_platform_inst(self, platform_id: str) -> Platform | None: + """ + 获取指定 ID 的平台适配器实例。 + + Args: + platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 + + Returns: + Platform: 平台适配器实例,如果未找到则返回 None。 + """ + for platform in self.platform_manager.platform_insts: + if platform.meta().id == platform_id: + return platform + async def send_message( self, session: Union[str, MessageSesion], message_chain: MessageChain ) -> bool: @@ -240,7 +235,7 @@ class Context: raise ValueError("不合法的 session 字符串: " + str(e)) for platform in self.platform_manager.platform_insts: - if platform.meta().name == session.platform_name: + if platform.meta().id == session.platform_name: await platform.send_by_session(session, message_chain) return True return False diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index fa6a730ba..55a4393da 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -11,6 +11,7 @@ from .star_handler import ( register_on_llm_request, register_on_llm_response, register_llm_tool, + register_agent, register_on_decorating_result, register_after_message_sent, ) @@ -28,6 +29,7 @@ __all__ = [ "register_on_llm_request", "register_on_llm_response", "register_llm_tool", + "register_agent", "register_on_decorating_result", "register_after_message_sent", ] diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 0b9f7ad09..101f3a95f 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -15,6 +15,11 @@ from ..filter.regex import RegexFilter from typing import Awaitable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools +from astrbot.core.agent.agent import Agent +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.astr_agent_context import AstrAgentContext def get_handler_full_name(awaitable: Awaitable) -> str: @@ -306,7 +311,7 @@ def register_on_llm_response(**kwargs): return decorator -def register_llm_tool(name: str = None): +def register_llm_tool(name: str = None, **kwargs): """为函数调用(function-calling / tools-use)添加工具。 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) @@ -340,6 +345,9 @@ def register_llm_tool(name: str = None): """ name_ = name + registering_agent = None + if kwargs.get("registering_agent"): + registering_agent = kwargs["registering_agent"] def decorator(awaitable: Awaitable): llm_tool_name = name_ if name_ else awaitable.__name__ @@ -357,15 +365,69 @@ def register_llm_tool(name: str = None): "description": arg.description, } ) - md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) - llm_tools.add_func( - llm_tool_name, args, docstring.description.strip(), md.handler - ) + # print(llm_tool_name, registering_agent) + if not registering_agent: + md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) + llm_tools.add_func( + llm_tool_name, args, docstring.description.strip(), md.handler + ) + else: + assert isinstance(registering_agent, RegisteringAgent) + # print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name) + if registering_agent._agent.tools is None: + registering_agent._agent.tools = [] + registering_agent._agent.tools.append(llm_tools.spec_to_func( + llm_tool_name, args, docstring.description.strip(), awaitable + )) + return awaitable return decorator +class RegisteringAgent: + """用于 Agent 注册""" + + def llm_tool(self, *args, **kwargs): + kwargs["registering_agent"] = self + return register_llm_tool(*args, **kwargs) + + def __init__(self, agent: Agent[AstrAgentContext]): + self._agent = agent + + +def register_agent( + name: str, + instruction: str, + tools: list[str | FunctionTool] = None, + run_hooks: BaseAgentRunHooks[AstrAgentContext] = None, +): + """注册一个 Agent + + Args: + name: Agent 的名称 + instruction: Agent 的指令 + tools: Agent 使用的工具列表 + run_hooks: Agent 运行时的钩子函数 + """ + tools_ = tools or [] + + def decorator(awaitable: Awaitable): + AstrAgent = Agent[AstrAgentContext] + agent = AstrAgent( + name=name, + instructions=instruction, + tools=tools_, + run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](), + ) + handoff_tool = HandoffTool(agent=agent) + handoff_tool.handler=awaitable + llm_tools.func_list.append(handoff_tool) + return RegisteringAgent(agent) + + return decorator + + def register_on_decorating_result(**kwargs): """在发送消息前的事件""" diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py index 4bceb1109..6c5bc994d 100644 --- a/astrbot/core/star/session_llm_manager.py +++ b/astrbot/core/star/session_llm_manager.py @@ -2,8 +2,6 @@ 会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态 """ -from typing import Dict - from astrbot.core import logger, sp from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -26,8 +24,9 @@ class SessionServiceManager: bool: True表示启用,False表示禁用 """ # 获取会话服务配置 - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) # 如果配置了该会话的LLM状态,返回该状态 llm_enabled = session_services.get("llm_enabled") @@ -45,16 +44,13 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置LLM状态 - session_config[session_id]["llm_enabled"] = enabled - - # 保存配置 - sp.put("session_service_config", session_config) + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) + session_config["llm_enabled"] = enabled + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}" @@ -88,8 +84,9 @@ class SessionServiceManager: bool: True表示启用,False表示禁用 """ # 获取会话服务配置 - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) # 如果配置了该会话的TTS状态,返回该状态 tts_enabled = session_services.get("tts_enabled") @@ -107,16 +104,13 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置TTS状态 - session_config[session_id]["tts_enabled"] = enabled - - # 保存配置 - sp.put("session_service_config", session_config) + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) + session_config["tts_enabled"] = enabled + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}" @@ -150,8 +144,9 @@ class SessionServiceManager: bool: True表示启用,False表示禁用 """ # 获取会话服务配置 - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) # 如果配置了该会话的整体状态,返回该状态 session_enabled = session_services.get("session_enabled") @@ -169,16 +164,13 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置会话整体状态 - session_config[session_id]["session_enabled"] = enabled - - # 保存配置 - sp.put("session_service_config", session_config) + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) + session_config["session_enabled"] = enabled + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}" @@ -202,7 +194,7 @@ class SessionServiceManager: # ============================================================================= @staticmethod - def get_session_custom_name(session_id: str) -> str: + def get_session_custom_name(session_id: str) -> str | None: """获取会话的自定义名称 Args: @@ -211,8 +203,9 @@ class SessionServiceManager: Returns: str: 自定义名称,如果没有设置则返回None """ - session_config = sp.get("session_service_config", {}) or {} - session_services = session_config.get(session_id, {}) + session_services = sp.get( + "session_service_config", {}, scope="umo", scope_id=session_id + ) return session_services.get("custom_name") @staticmethod @@ -223,20 +216,17 @@ class SessionServiceManager: session_id: 会话ID (unified_msg_origin) custom_name: 自定义名称,可以为空字符串来清除名称 """ - # 获取当前配置 - session_config = sp.get("session_service_config", {}) or {} - if session_id not in session_config: - session_config[session_id] = {} - - # 设置自定义名称 + session_config = ( + sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} + ) if custom_name and custom_name.strip(): - session_config[session_id]["custom_name"] = custom_name.strip() + session_config["custom_name"] = custom_name.strip() else: # 如果传入空名称,则删除自定义名称 - session_config[session_id].pop("custom_name", None) - - # 保存配置 - sp.put("session_service_config", session_config) + session_config.pop("custom_name", None) + sp.put( + "session_service_config", session_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}" @@ -258,36 +248,3 @@ class SessionServiceManager: # 如果没有自定义名称,返回session_id的最后一段 return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id - - # ============================================================================= - # 通用配置方法 - # ============================================================================= - - @staticmethod - def get_session_service_config(session_id: str) -> Dict[str, bool]: - """获取指定会话的服务配置 - - Args: - session_id: 会话ID (unified_msg_origin) - - Returns: - Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典 - """ - session_config = sp.get("session_service_config", {}) or {} - return session_config.get( - session_id, - { - "session_enabled": True, # 默认启用 - "llm_enabled": True, # 默认启用 - "tts_enabled": True, # 默认启用 - }, - ) - - @staticmethod - def get_all_session_configs() -> Dict[str, Dict[str, bool]]: - """获取所有会话的服务配置 - - Returns: - Dict[str, Dict[str, bool]]: 所有会话的服务配置 - """ - return sp.get("session_service_config", {}) or {} diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index c0d1bbd73..5c7303e8d 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -22,7 +22,9 @@ class SessionPluginManager: bool: True表示启用,False表示禁用 """ # 获取会话插件配置 - session_plugin_config = sp.get("session_plugin_config", {}) or {} + session_plugin_config = sp.get( + "session_plugin_config", {}, scope="umo", scope_id=session_id + ) session_config = session_plugin_config.get(session_id, {}) enabled_plugins = session_config.get("enabled_plugins", []) @@ -51,7 +53,9 @@ class SessionPluginManager: enabled: True表示启用,False表示禁用 """ # 获取当前配置 - session_plugin_config = sp.get("session_plugin_config", {}) or {} + session_plugin_config = sp.get( + "session_plugin_config", {}, scope="umo", scope_id=session_id + ) if session_id not in session_plugin_config: session_plugin_config[session_id] = { "enabled_plugins": [], @@ -79,7 +83,9 @@ class SessionPluginManager: session_config["enabled_plugins"] = enabled_plugins session_config["disabled_plugins"] = disabled_plugins session_plugin_config[session_id] = session_config - sp.put("session_plugin_config", session_plugin_config) + sp.put( + "session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id + ) logger.info( f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}" @@ -95,7 +101,9 @@ class SessionPluginManager: Returns: Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典 """ - session_plugin_config = sp.get("session_plugin_config", {}) or {} + session_plugin_config = sp.get( + "session_plugin_config", {}, scope="umo", scope_id=session_id + ) return session_plugin_config.get( session_id, {"enabled_plugins": [], "disabled_plugins": []} ) diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index 2fe9dd7f3..0563e8cc8 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -56,32 +56,8 @@ class StarMetadata: star_handler_full_names: list[str] = field(default_factory=list) """注册的 Handler 的全名列表""" - supported_platforms: dict[str, bool] = field(default_factory=dict) - """插件支持的平台ID字典,key为平台ID,value为是否支持""" - def __str__(self) -> str: return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}" def __repr__(self) -> str: return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}" - - def update_platform_compatibility(self, plugin_enable_config: dict) -> None: - """更新插件支持的平台列表 - - Args: - plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项 - """ - if not plugin_enable_config: - return - - # 清空之前的配置 - self.supported_platforms.clear() - - # 遍历所有平台配置 - for platform_id, plugins in plugin_enable_config.items(): - # 检查该插件在当前平台的配置 - if self.name in plugins: - self.supported_platforms[platform_id] = plugins[self.name] - else: - # 如果没有明确配置,默认为启用 - self.supported_platforms[platform_id] = True diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index d375091e5..43a74396a 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -7,6 +7,7 @@ from .star import star_map T = TypeVar("T", bound="StarHandlerMetadata") + class StarHandlerRegistry(Generic[T]): def __init__(self): self.star_handlers_map: Dict[str, StarHandlerMetadata] = {} @@ -26,7 +27,10 @@ class StarHandlerRegistry(Generic[T]): print(handler.handler_full_name) def get_handlers_by_event_type( - self, event_type: EventType, only_activated=True, platform_id=None + self, + event_type: EventType, + only_activated=True, + plugins_name: list[str] | None = None, ) -> List[StarHandlerMetadata]: handlers = [] for handler in self._handlers: @@ -36,8 +40,15 @@ class StarHandlerRegistry(Generic[T]): plugin = star_map.get(handler.handler_module_path) if not (plugin and plugin.activated): continue - if platform_id and event_type != EventType.OnAstrBotLoadedEvent: - if not handler.is_enabled_for_platform(platform_id): + if plugins_name is not None and plugins_name != ["*"]: + plugin = star_map.get(handler.handler_module_path) + if not plugin: + continue + if ( + plugin.name not in plugins_name + and event_type != EventType.OnAstrBotLoadedEvent + and not plugin.reserved + ): continue handlers.append(handler) return handlers @@ -49,7 +60,8 @@ class StarHandlerRegistry(Generic[T]): self, module_name: str ) -> List[StarHandlerMetadata]: return [ - handler for handler in self._handlers + handler + for handler in self._handlers if handler.handler_module_path == module_name ] @@ -67,6 +79,7 @@ class StarHandlerRegistry(Generic[T]): def __len__(self): return len(self._handlers) + star_handlers_registry = StarHandlerRegistry() @@ -119,32 +132,3 @@ class StarHandlerMetadata: return self.extras_configs.get("priority", 0) < other.extras_configs.get( "priority", 0 ) - - def is_enabled_for_platform(self, platform_id: str) -> bool: - """检查插件是否在指定平台启用 - - Args: - platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例 - - Returns: - bool: 是否启用,True表示启用,False表示禁用 - """ - plugin = star_map.get(self.handler_module_path) - - # 如果插件元数据不存在,默认允许执行 - if not plugin or not plugin.name: - return True - - # 先检查插件是否被激活 - if not plugin.activated: - return False - - # 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性 - if ( - hasattr(plugin, "supported_platforms") - and platform_id in plugin.supported_platforms - ): - return plugin.supported_platforms[platform_id] - - # 如果没有缓存数据,默认允许执行 - return True diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index ab98b254e..5fb1b1dfa 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -22,6 +22,7 @@ from astrbot.core.utils.astrbot_path import ( get_astrbot_plugin_path, ) from astrbot.core.utils.io import remove_dir +from astrbot.core.agent.handoff import HandoffTool, FunctionTool from . import StarMetadata from .context import Context @@ -336,30 +337,8 @@ class PluginManager: result = await self.load(specified_module_path) - # 更新所有插件的平台兼容性 - await self.update_all_platform_compatibility() - return result - async def update_all_platform_compatibility(self): - """更新所有插件的平台兼容性设置""" - # 获取最新的平台插件启用配置 - plugin_enable_config = self.config.get("platform_settings", {}).get( - "plugin_enable", {} - ) - logger.debug( - f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}" - ) - - # 遍历所有插件,更新平台兼容性 - for plugin in self.context.get_all_stars(): - plugin.update_platform_compatibility(plugin_enable_config) - logger.debug( - f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}" - ) - - return True - async def load(self, specified_module_path=None, specified_dir_name=None): """载入插件。 当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。 @@ -373,10 +352,9 @@ class PluginManager: - success (bool): 是否全部加载成功 - error_message (str|None): 错误信息,成功时为 None """ - inactivated_plugins: list = sp.get("inactivated_plugins", []) - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - - alter_cmd = sp.get("alter_cmd", {}) + inactivated_plugins = await sp.global_get("inactivated_plugins", []) + inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", []) + alter_cmd = await sp.global_get("alter_cmd", {}) plugin_modules = self._get_plugin_modules() if plugin_modules is None: @@ -480,12 +458,6 @@ class PluginManager: metadata.root_dir_name = root_dir_name metadata.reserved = reserved - # 更新插件的平台兼容性 - plugin_enable_config = self.config.get("platform_settings", {}).get( - "plugin_enable", {} - ) - metadata.update_platform_compatibility(plugin_enable_config) - assert metadata.module_path is not None, ( f"插件 {metadata.name} 的模块路径为空。" ) @@ -503,17 +475,27 @@ class PluginManager: ) # 绑定 llm_tool handler for func_tool in llm_tools.func_list: - if ( - func_tool.handler - and func_tool.handler.__module__ == metadata.module_path - ): - func_tool.handler_module_path = metadata.module_path - func_tool.handler = functools.partial( - func_tool.handler, - metadata.star_cls, # type: ignore - ) - if func_tool.name in inactivated_llm_tools: - func_tool.active = False + if isinstance(func_tool, HandoffTool): + need_apply = [] + sub_tools = func_tool.agent.tools + for sub_tool in sub_tools: + if isinstance(sub_tool, FunctionTool): + need_apply.append(sub_tool) + else: + need_apply = [func_tool] + + for ft in need_apply: + if ( + ft.handler + and ft.handler.__module__ == metadata.module_path + ): + ft.handler_module_path = metadata.module_path + ft.handler = functools.partial( + ft.handler, + metadata.star_cls, # type: ignore + ) + if ft.name in inactivated_llm_tools: + ft.active = False else: # v3.4.0 以前的方式注册插件 @@ -776,12 +758,12 @@ class PluginManager: await self._terminate_plugin(plugin) # 加入到 shared_preferences 中 - inactivated_plugins: list = sp.get("inactivated_plugins", []) + inactivated_plugins: list = await sp.global_get("inactivated_plugins", []) if plugin.module_path not in inactivated_plugins: inactivated_plugins.append(plugin.module_path) inactivated_llm_tools: list = list( - set(sp.get("inactivated_llm_tools", [])) + set(await sp.global_get("inactivated_llm_tools", [])) ) # 后向兼容 # 禁用插件启用的 llm_tool @@ -791,8 +773,8 @@ class PluginManager: if func_tool.name not in inactivated_llm_tools: inactivated_llm_tools.append(func_tool.name) - sp.put("inactivated_plugins", inactivated_plugins) - sp.put("inactivated_llm_tools", inactivated_llm_tools) + await sp.global_put("inactivated_plugins", inactivated_plugins) + await sp.global_put("inactivated_llm_tools", inactivated_llm_tools) plugin.activated = False @@ -818,11 +800,11 @@ class PluginManager: async def turn_on_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) - inactivated_plugins: list = sp.get("inactivated_plugins", []) - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + inactivated_plugins: list = await sp.global_get("inactivated_plugins", []) + inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", []) if plugin.module_path in inactivated_plugins: inactivated_plugins.remove(plugin.module_path) - sp.put("inactivated_plugins", inactivated_plugins) + await sp.global_put("inactivated_plugins", inactivated_plugins) # 启用插件启用的 llm_tool for func_tool in llm_tools.func_list: @@ -832,7 +814,7 @@ class PluginManager: ): inactivated_llm_tools.remove(func_tool.name) func_tool.active = True - sp.put("inactivated_llm_tools", inactivated_llm_tools) + await sp.global_put("inactivated_llm_tools", inactivated_llm_tools) await self.reload(plugin_name) diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index a3a73fcc8..7fe9bde05 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -58,9 +58,10 @@ class Metric: pass try: if "adapter_name" in kwargs: - db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1}) - if "llm_name" in kwargs: - db_helper.insert_llm_metrics({kwargs["llm_name"]: 1}) + await db_helper.insert_platform_stats( + platform_id=kwargs["adapter_name"], + platform_type=kwargs.get("adapter_type", "unknown"), + ) except Exception as e: logger.error(f"保存指标到数据库失败: {e}") pass diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 42018d19e..c1368f186 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,43 +1,180 @@ -import json +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Preference +import threading +import asyncio import os -from typing import TypeVar +from typing import TypeVar, Any, overload from .astrbot_path import get_astrbot_data_path + _VT = TypeVar("_VT") + class SharedPreferences: - def __init__(self, path=None): - if path is None: - path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") - self.path = path - self._data = self._load_preferences() + def __init__(self, db_helper: BaseDatabase, json_storage_path=None): + if json_storage_path is None: + json_storage_path = os.path.join( + get_astrbot_data_path(), "shared_preferences.json" + ) + self.path = json_storage_path + self.db_helper = db_helper - def _load_preferences(self): - if os.path.exists(self.path): - try: - with open(self.path, "r") as f: - return json.load(f) - except json.JSONDecodeError: - os.remove(self.path) - return {} + self._sync_loop = asyncio.new_event_loop() + t = threading.Thread(target=self._sync_loop.run_forever, daemon=True) + t.start() - def _save_preferences(self): - with open(self.path, "w") as f: - json.dump(self._data, f, indent=4, ensure_ascii=False) - f.flush() + async def get_async( + self, + scope: str, + scope_id: str, + key: str, + default: _VT = None, + ) -> _VT: + """获取指定范围和键的偏好设置""" + if scope_id is not None and key is not None: + result = await self.db_helper.get_preference(scope, scope_id, key) + if result: + ret = result.value["val"] + else: + ret = default + return ret + else: + raise ValueError( + "scope_id and key cannot be None when getting a specific preference." + ) - def get(self, key, default: _VT = None) -> _VT: - return self._data.get(key, default) + async def range_get_async( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: + """获取指定范围的偏好设置 + Note: 返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。scope_id 和 key 可以为 None,这时返回该范围下所有的偏好设置。 + """ + ret = await self.db_helper.get_preferences(scope, scope_id, key) + return ret - def put(self, key, value): - self._data[key] = value - self._save_preferences() + @overload + async def session_get( + self, umo: None, key: str, default: Any = None + ) -> list[Preference]: ... - def remove(self, key): - if key in self._data: - del self._data[key] - self._save_preferences() + @overload + async def session_get( + self, umo: str, key: None, default: Any = None + ) -> list[Preference]: ... - def clear(self): - self._data.clear() - self._save_preferences() + @overload + async def session_get( + self, umo: None, key: None, default: Any = None + ) -> list[Preference]: ... + + async def session_get( + self, umo: str | None, key: str | None = None, default: _VT = None + ) -> _VT | list[Preference]: + """获取会话范围的偏好设置 + + Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + """ + if umo is None or key is None: + return await self.range_get_async("umo", umo, key) + return await self.get_async("umo", umo, key, default) + + @overload + async def global_get(self, key: None, default: Any = None) -> list[Preference]: ... + + @overload + async def global_get(self, key: str, default: _VT = None) -> _VT: ... + + async def global_get( + self, key: str | None, default: _VT = None + ) -> _VT | list[Preference]: + """获取全局范围的偏好设置 + + Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + """ + if key is None: + return await self.range_get_async("global", "global", key) + return await self.get_async("global", "global", key, default) + + async def put_async(self, scope: str, scope_id: str, key: str, value: Any): + """设置指定范围和键的偏好设置""" + await self.db_helper.insert_preference_or_update( + scope, scope_id, key, {"val": value} + ) + + async def session_put(self, umo: str, key: str, value: Any): + await self.put_async("umo", umo, key, value) + + async def global_put(self, key: str, value: Any): + await self.put_async("global", "global", key, value) + + async def remove_async(self, scope: str, scope_id: str, key: str): + """删除指定范围和键的偏好设置""" + await self.db_helper.remove_preference(scope, scope_id, key) + + async def session_remove(self, umo: str, key: str): + await self.remove_async("umo", umo, key) + + async def global_remove(self, key: str): + """删除全局偏好设置""" + await self.remove_async("global", "global", key) + + async def clear_async(self, scope: str, scope_id: str): + """清空指定范围的所有偏好设置""" + await self.db_helper.clear_preferences(scope, scope_id) + + # ==== + # DEPRECATED METHODS + # ==== + + def get( + self, + key: str, + default: _VT = None, + scope: str | None = None, + scope_id: str | None = "", + ) -> _VT: + """获取偏好设置(已弃用)""" + if scope_id == "": + scope_id = "unknown" + if scope_id is None or key is None: + # result = asyncio.run(self.range_get_async(scope, scope_id, key)) + raise ValueError( + "scope_id and key cannot be None when getting a specific preference." + ) + result = asyncio.run_coroutine_threadsafe( + self.get_async(scope or "unknown", scope_id or "unknown", key, default), + self._sync_loop, + ).result() + + return result if result is not None else default + + def range_get( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: + """获取指定范围的偏好设置(已弃用)""" + result = asyncio.run_coroutine_threadsafe( + self.range_get_async(scope, scope_id, key), self._sync_loop + ).result() + + return result + + def put(self, key, value, scope: str | None = None, scope_id: str | None = None): + """设置偏好设置(已弃用)""" + asyncio.run_coroutine_threadsafe( + self.put_async(scope or "unknown", scope_id or "unknown", key, value), + self._sync_loop, + ).result() + + def remove(self, key, scope: str | None = None, scope_id: str | None = None): + """删除偏好设置(已弃用)""" + asyncio.run_coroutine_threadsafe( + self.remove_async(scope or "unknown", scope_id or "unknown", key), + self._sync_loop, + ).result() + + def clear(self, scope: str | None = None, scope_id: str | None = None): + """清空偏好设置(已弃用)""" + asyncio.run_coroutine_threadsafe( + self.clear_async(scope or "unknown", scope_id or "unknown"), + self._sync_loop, + ).result() diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 03db6d5e7..2295f051b 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -1,37 +1,76 @@ import aiohttp +import asyncio import os import ssl import certifi - +import logging +import random from . import RenderStrategy from astrbot.core.config import VERSION from astrbot.core.utils.io import download_image_by_url +from astrbot.core.utils.astrbot_path import get_astrbot_data_path ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" +CUSTOM_T2I_TEMPLATE_PATH = os.path.join(get_astrbot_data_path(), "t2i_template.html") + +logger = logging.getLogger("astrbot") class NetworkRenderStrategy(RenderStrategy): def __init__(self, base_url: str | None = None) -> None: super().__init__() if not base_url: - base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT - self.BASE_RENDER_URL = base_url - self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template") + self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT + else: + self.BASE_RENDER_URL = self._clean_url(base_url) + self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template", "base.html") + with open(self.TEMPLATE_PATH, "r", encoding="utf-8") as f: + self.DEFAULT_TEMPLATE = f.read() - if self.BASE_RENDER_URL.endswith("/"): - self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1] - if not self.BASE_RENDER_URL.endswith("text2img"): - self.BASE_RENDER_URL += "/text2img" + self.endpoints = [self.BASE_RENDER_URL] - def set_endpoint(self, base_url: str): - if not base_url: - base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT - self.BASE_RENDER_URL = base_url + async def initialize(self): + if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT: + asyncio.create_task(self.get_official_endpoints()) - if self.BASE_RENDER_URL.endswith("/"): - self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1] - if not self.BASE_RENDER_URL.endswith("text2img"): - self.BASE_RENDER_URL += "/text2img" + async def get_template(self) -> str: + """获取文转图 HTML 模板 + + Returns: + str: 文转图 HTML 模板字符串 + """ + if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH): + with open(CUSTOM_T2I_TEMPLATE_PATH, "r", encoding="utf-8") as f: + return f.read() + return self.DEFAULT_TEMPLATE + + async def get_official_endpoints(self): + """获取官方的 t2i 端点列表。""" + try: + async with aiohttp.ClientSession() as session: + async with session.get( + "https://api.soulter.top/astrbot/t2i-endpoints" + ) as resp: + if resp.status == 200: + data = await resp.json() + all_endpoints: list[dict] = data.get("data", []) + self.endpoints = [ + ep.get("url") + for ep in all_endpoints + if ep.get("active") and ep.get("url") + ] + logger.info( + f"Successfully got {len(self.endpoints)} official T2I endpoints." + ) + except Exception as e: + logger.error(f"Failed to get official endpoints: {e}") + + def _clean_url(self, url: str): + if url.endswith("/"): + url = url[:-1] + if not url.endswith("text2img"): + url += "/text2img" + return url async def render_custom_template( self, @@ -41,6 +80,7 @@ class NetworkRenderStrategy(RenderStrategy): options: dict | None = None, ) -> str: """使用自定义文转图模板""" + default_options = {"full_page": True, "type": "jpeg", "quality": 40} if options: default_options |= options @@ -51,30 +91,44 @@ class NetworkRenderStrategy(RenderStrategy): "tmpldata": tmpl_data, "options": default_options, } - if return_url: - ssl_context = ssl.create_default_context(cafile=certifi.where()) - connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession( - trust_env=True, connector=connector - ) as session: - async with session.post( - f"{self.BASE_RENDER_URL}/generate", json=post_data - ) as resp: - ret = await resp.json() - return f"{self.BASE_RENDER_URL}/{ret['data']['id']}" - return await download_image_by_url( - f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data - ) + + endpoints = self.endpoints.copy() if self.endpoints else [self.BASE_RENDER_URL] + random.shuffle(endpoints) + last_exception = None + for endpoint in endpoints: + try: + if return_url: + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession( + trust_env=True, connector=connector + ) as session: + async with session.post( + f"{endpoint}/generate", json=post_data + ) as resp: + if resp.status == 200: + ret = await resp.json() + return f"{endpoint}/{ret['data']['id']}" + else: + raise Exception(f"HTTP {resp.status}") + else: + # download_image_by_url 失败时抛异常 + return await download_image_by_url( + f"{endpoint}/generate", post=True, post_data=post_data + ) + except Exception as e: + last_exception = e + logger.warning(f"Endpoint {endpoint} failed: {e}, trying next...") + continue + # 全部失败 + logger.error(f"All endpoints failed: {last_exception}") + raise RuntimeError(f"All endpoints failed: {last_exception}") async def render(self, text: str, return_url: bool = False) -> str: """ 返回图像的文件路径 """ - with open( - os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding="utf-8" - ) as f: - tmpl_str = f.read() - assert tmpl_str + tmpl_str = await self.get_template() text = text.replace("`", "\\`") return await self.render_custom_template( tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 9e423be15..a3ceec4ad 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -10,10 +10,8 @@ class HtmlRenderer: self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - def set_network_endpoint(self, endpoint_url: str): - """设置 t2i 的网络端点。""" - logger.info("文本转图像服务接口: " + endpoint_url) - self.network_strategy.set_endpoint(endpoint_url) + async def initialize(self): + await self.network_strategy.initialize() async def render_custom_template( self, diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 8d08b9d53..ef2fa3e86 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -6,11 +6,11 @@ from .stat import StatRoute from .log import LogRoute from .static_file import StaticFileRoute from .chat import ChatRoute -from .tools import ToolsRoute # 导入新的ToolsRoute +from .tools import ToolsRoute from .conversation import ConversationRoute from .file import FileRoute from .session_management import SessionManagementRoute - +from .persona import PersonaRoute __all__ = [ "AuthRoute", @@ -25,4 +25,5 @@ __all__ = [ "ConversationRoute", "FileRoute", "SessionManagementRoute", + "PersonaRoute", ] diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 651f1b65c..083647bc3 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -9,6 +9,7 @@ import asyncio from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.platform.astr_message_event import MessageSession class ChatRoute(Route): @@ -29,28 +30,15 @@ class ChatRoute(Route): "/chat/get_file": ("GET", self.get_file), "/chat/post_image": ("POST", self.post_image), "/chat/post_file": ("POST", self.post_file), - "/chat/status": ("GET", self.status), } - self.db = db self.core_lifecycle = core_lifecycle self.register_routes() self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") os.makedirs(self.imgs_dir, exist_ok=True) self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] - - async def status(self): - has_llm_enabled = ( - self.core_lifecycle.provider_manager.curr_provider_inst is not None - ) - has_stt_enabled = ( - self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None - ) - return ( - Response() - .ok(data={"llm_enabled": has_llm_enabled, "stt_enabled": has_stt_enabled}) - .__dict__ - ) + self.conv_mgr = core_lifecycle.conversation_manager + self.platform_history_mgr = core_lifecycle.platform_message_history_manager async def get_file(self): filename = request.args.get("filename") @@ -131,24 +119,23 @@ class ChatRoute(Route): if not conversation_id: return Response().error("conversation_id is empty").__dict__ - # Get conversation-specific queues - back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id) - # append user message - conversation = self.db.get_conversation_by_user_id(username, conversation_id) - try: - history = json.loads(conversation.history) - except BaseException as e: - logger.error(f"Failed to parse conversation history: {e}") - history = [] + webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + + # Get conversation-specific queues + back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) + new_his = {"type": "user", "message": message} if image_url: new_his["image_url"] = image_url if audio_url: new_his["audio_url"] = audio_url - history.append(new_his) - self.db.update_conversation( - username, conversation_id, history=json.dumps(history) + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id=username, + sender_name=username, ) async def stream(): @@ -164,7 +151,6 @@ class ChatRoute(Route): result_text = result["data"] type = result.get("type") - cid = result.get("cid") streaming = result.get("streaming", False) yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" await asyncio.sleep(0.05) @@ -173,17 +159,13 @@ class ChatRoute(Route): break elif (streaming and type == "complete") or not streaming: # append bot message - conversation = self.db.get_conversation_by_user_id( - username, cid - ) - try: - history = json.loads(conversation.history) - except BaseException as e: - logger.error(f"Failed to parse conversation history: {e}") - history = [] - history.append({"type": "bot", "message": result_text}) - self.db.update_conversation( - username, cid, history=json.dumps(history) + new_his = {"type": "bot", "message": result_text} + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", ) except BaseException as _: @@ -191,11 +173,11 @@ class ChatRoute(Route): return # Put message to conversation-specific queue - chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id) + chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) await chat_queue.put( ( username, - conversation_id, + webchat_conv_id, { "message": message, "image_url": image_url, # list @@ -217,25 +199,51 @@ class ChatRoute(Route): ) return response + async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str: + """从对话 ID 中提取 WebChat 会话 ID + + NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 + """ + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin="webchat", conversation_id=conversation_id + ) + if not conversation: + raise ValueError(f"Conversation with ID {conversation_id} not found.") + conv_user_id = conversation.user_id + webchat_session_id = MessageSession.from_str(conv_user_id).session_id + if "!" not in webchat_session_id: + raise ValueError(f"Invalid conv user ID: {conv_user_id}") + return webchat_session_id.split("!")[-1] + async def delete_conversation(self): - username = g.get("username", "guest") conversation_id = request.args.get("conversation_id") if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ + username = g.get("username", "guest") # Clean up queues when deleting conversation webchat_queue_mgr.remove_queues(conversation_id) - self.db.delete_conversation(username, conversation_id) + webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) + await self.conv_mgr.delete_conversation( + unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", + conversation_id=conversation_id, + ) + await self.platform_history_mgr.delete( + platform_id="webchat", user_id=webchat_conv_id, offset_sec=99999999 + ) return Response().ok().__dict__ async def new_conversation(self): username = g.get("username", "guest") - conversation_id = str(uuid.uuid4()) - self.db.new_conversation(username, conversation_id) - return Response().ok(data={"conversation_id": conversation_id}).__dict__ + webchat_conv_id = str(uuid.uuid4()) + conv_id = await self.conv_mgr.new_conversation( + unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}", + platform_id="webchat", + content=[], + ) + return Response().ok(data={"conversation_id": conv_id}).__dict__ async def rename_conversation(self): - username = g.get("username", "guest") post_data = await request.json if "conversation_id" not in post_data or "title" not in post_data: return Response().error("Missing key: conversation_id or title").__dict__ @@ -243,20 +251,42 @@ class ChatRoute(Route): conversation_id = post_data["conversation_id"] title = post_data["title"] - self.db.update_conversation_title(username, conversation_id, title=title) + await self.conv_mgr.update_conversation( + unified_msg_origin="webchat", # fake + conversation_id=conversation_id, + title=title, + ) return Response().ok(message="重命名成功!").__dict__ async def get_conversations(self): - username = g.get("username", "guest") - conversations = self.db.get_conversations(username) - return Response().ok(data=conversations).__dict__ + conversations = await self.conv_mgr.get_conversations(platform_id="webchat") + # remove content + conversations_ = [] + for conv in conversations: + conv.history = None + conversations_.append(conv) + return Response().ok(data=conversations_).__dict__ async def get_conversation(self): - username = g.get("username", "guest") conversation_id = request.args.get("conversation_id") if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ - conversation = self.db.get_conversation_by_user_id(username, conversation_id) + webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id) - return Response().ok(data=conversation).__dict__ + # Get platform message history + history_ls = await self.platform_history_mgr.get( + platform_id="webchat", user_id=webchat_conv_id, page=1, page_size=1000 + ) + + history_res = [history.model_dump() for history in history_ls] + + return ( + Response() + .ok( + data={ + "history": history_res, + } + ) + .__dict__ + ) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 7de720a38..8cb548c62 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -4,16 +4,22 @@ import os from .route import Route, Response, RouteContext from astrbot.core.provider.entities import ProviderType from quart import request -from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP +from astrbot.core.config.default import ( + CONFIG_METADATA_2, + DEFAULT_VALUE_MAP, + CONFIG_METADATA_3, + CONFIG_METADATA_3_SYSTEM, +) from astrbot.core.utils.astrbot_path import get_astrbot_path from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.platform.register import platform_registry from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry -from astrbot.core import logger +from astrbot.core import logger, html_renderer from astrbot.core.provider import Provider import asyncio +from astrbot.core.utils.t2i.network_strategy import CUSTOM_T2I_TEMPLATE_PATH def try_cast(value: str, type_: str): @@ -159,33 +165,166 @@ class ConfigRoute(Route): super().__init__(context) self.core_lifecycle = core_lifecycle self.config: AstrBotConfig = core_lifecycle.astrbot_config + self.acm = core_lifecycle.astrbot_config_mgr self.routes = { + "/config/abconf/new": ("POST", self.create_abconf), + "/config/abconf": ("GET", self.get_abconf), + "/config/abconfs": ("GET", self.get_abconf_list), + "/config/abconf/delete": ("POST", self.delete_abconf), + "/config/abconf/update": ("POST", self.update_abconf), "/config/get": ("GET", self.get_configs), "/config/astrbot/update": ("POST", self.post_astrbot_configs), "/config/plugin/update": ("POST", self.post_plugin_configs), "/config/platform/new": ("POST", self.post_new_platform), "/config/platform/update": ("POST", self.post_update_platform), "/config/platform/delete": ("POST", self.post_delete_platform), + "/config/platform/list": ("GET", self.get_platform_list), "/config/provider/new": ("POST", self.post_new_provider), "/config/provider/update": ("POST", self.post_update_provider), "/config/provider/delete": ("POST", self.post_delete_provider), - "/config/llmtools": ("GET", self.get_llm_tools), "/config/provider/check_one": ("GET", self.check_one_provider_status), "/config/provider/list": ("GET", self.get_provider_config_list), "/config/provider/model_list": ("GET", self.get_provider_model_list), - "/config/provider/get_session_seperate": ( - "GET", - lambda: Response() - .ok({"enable": self.config["provider_settings"]["separate_provider"]}) - .__dict__, - ), - "/config/provider/set_session_seperate": ( - "POST", - self.post_session_seperate, - ), + "/config/astrbot/t2i-template/get": ("GET", self.get_t2i_template), + "/config/astrbot/t2i-template/save": ("POST", self.post_t2i_template), + "/config/astrbot/t2i-template/delete": ("DELETE", self.delete_t2i_template), } self.register_routes() + async def get_t2i_template(self): + """获取 T2I 模板""" + try: + template = await html_renderer.network_strategy.get_template() + has_custom_template = os.path.exists(CUSTOM_T2I_TEMPLATE_PATH) + return ( + Response() + .ok({"template": template, "has_custom_template": has_custom_template}) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取模板失败: {str(e)}").__dict__ + + async def post_t2i_template(self): + """保存 T2I 模板""" + try: + post_data = await request.json + if not post_data or "template" not in post_data: + return Response().error("缺少模板内容").__dict__ + + template_content = post_data["template"] + + # 保存自定义模板到文件 + with open(CUSTOM_T2I_TEMPLATE_PATH, "w", encoding="utf-8") as f: + f.write(template_content) + + return Response().ok(message="模板保存成功").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"保存模板失败: {str(e)}").__dict__ + + async def delete_t2i_template(self): + """删除自定义 T2I 模板,恢复默认模板""" + try: + if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH): + os.remove(CUSTOM_T2I_TEMPLATE_PATH) + return Response().ok(message="已恢复默认模板").__dict__ + else: + return Response().ok(message="未找到自定义模板文件").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"删除模板失败: {str(e)}").__dict__ + + async def get_abconf_list(self): + """获取所有 AstrBot 配置文件的列表""" + abconf_list = self.acm.get_conf_list() + return Response().ok({"info_list": abconf_list}).__dict__ + + async def create_abconf(self): + """创建新的 AstrBot 配置文件""" + post_data = await request.json + if not post_data: + return Response().error("缺少配置数据").__dict__ + umo_parts = post_data["umo_parts"] + name = post_data.get("name", None) + + try: + conf_id = self.acm.create_conf(umo_parts=umo_parts, name=name) + return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + + async def get_abconf(self): + """获取指定 AstrBot 配置文件""" + abconf_id = request.args.get("id") + system_config = request.args.get("system_config", "0").lower() == "1" + if not abconf_id and not system_config: + return Response().error("缺少配置文件 ID").__dict__ + + try: + if system_config: + abconf = self.acm.confs["default"] + return ( + Response() + .ok({"config": abconf, "metadata": CONFIG_METADATA_3_SYSTEM}) + .__dict__ + ) + abconf = self.acm.confs[abconf_id] + return ( + Response() + .ok({"config": abconf, "metadata": CONFIG_METADATA_3}) + .__dict__ + ) + except ValueError as e: + return Response().error(str(e)).__dict__ + + async def delete_abconf(self): + """删除指定 AstrBot 配置文件""" + post_data = await request.json + if not post_data: + return Response().error("缺少配置数据").__dict__ + + conf_id = post_data.get("id") + if not conf_id: + return Response().error("缺少配置文件 ID").__dict__ + + try: + success = self.acm.delete_conf(conf_id) + if success: + return Response().ok(message="删除成功").__dict__ + else: + return Response().error("删除失败").__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"删除配置文件失败: {str(e)}").__dict__ + + async def update_abconf(self): + """更新指定 AstrBot 配置文件信息""" + post_data = await request.json + if not post_data: + return Response().error("缺少配置数据").__dict__ + + conf_id = post_data.get("id") + if not conf_id: + return Response().error("缺少配置文件 ID").__dict__ + + name = post_data.get("name") + umo_parts = post_data.get("umo_parts") + + try: + success = self.acm.update_conf_info(conf_id, name=name, umo_parts=umo_parts) + if success: + return Response().ok(message="更新成功").__dict__ + else: + return Response().error("更新失败").__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"更新配置文件失败: {str(e)}").__dict__ + async def _test_single_provider(self, provider): """辅助函数:测试单个 provider 的可用性""" meta = provider.meta() @@ -210,11 +349,16 @@ class ConfigRoute(Route): response = await asyncio.wait_for( provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 ) - logger.debug(f"Received response from {status_info['name']}: {response}") + logger.debug( + f"Received response from {status_info['name']}: {response}" + ) if response is not None: status_info["status"] = "available" response_text_snippet = "" - if hasattr(response, "completion_text") and response.completion_text: + if ( + hasattr(response, "completion_text") + and response.completion_text + ): response_text_snippet = ( response.completion_text[:70] + "..." if len(response.completion_text) > 70 @@ -233,29 +377,48 @@ class ConfigRoute(Route): f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" ) else: - status_info["error"] = "Test call returned None, but expected an LLMResponse object." - logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.") + status_info["error"] = ( + "Test call returned None, but expected an LLMResponse object." + ) + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None." + ) except asyncio.TimeoutError: - status_info["error"] = "Connection timed out after 45 seconds during test call." - logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.") + status_info["error"] = ( + "Connection timed out after 45 seconds during test call." + ) + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) timed out." + ) except Exception as e: error_message = str(e) status_info["error"] = error_message - logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}") - logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}") + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}" + ) + logger.debug( + f"Traceback for {status_info['name']}:\n{traceback.format_exc()}" + ) elif provider_capability_type == ProviderType.EMBEDDING: try: # For embedding, we can call the get_embedding method with a short prompt. embedding_result = await provider.get_embedding("health_check") - if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)): + if isinstance(embedding_result, list) and ( + not embedding_result or isinstance(embedding_result[0], float) + ): status_info["status"] = "available" else: status_info["status"] = "unavailable" - status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}" + status_info["error"] = ( + f"Embedding test failed: unexpected result type {type(embedding_result)}" + ) except Exception as e: - logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True) + logger.error( + f"Error testing embedding provider {provider_name}: {e}", + exc_info=True, + ) status_info["status"] = "unavailable" status_info["error"] = f"Embedding test failed: {str(e)}" @@ -267,41 +430,71 @@ class ConfigRoute(Route): status_info["status"] = "available" else: status_info["status"] = "unavailable" - status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}" + status_info["error"] = ( + f"TTS test failed: unexpected result type {type(audio_result)}" + ) except Exception as e: - logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True) + logger.error( + f"Error testing TTS provider {provider_name}: {e}", exc_info=True + ) status_info["status"] = "unavailable" status_info["error"] = f"TTS test failed: {str(e)}" elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: try: - logger.debug(f"Sending health check audio to provider: {status_info['name']}") - sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav") + logger.debug( + f"Sending health check audio to provider: {status_info['name']}" + ) + sample_audio_path = os.path.join( + get_astrbot_path(), "samples", "stt_health_check.wav" + ) if not os.path.exists(sample_audio_path): status_info["status"] = "unavailable" - status_info["error"] = "STT test failed: sample audio file not found." - logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}") + status_info["error"] = ( + "STT test failed: sample audio file not found." + ) + logger.warning( + f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}" + ) else: text_result = await provider.get_text(sample_audio_path) if isinstance(text_result, str) and text_result: status_info["status"] = "available" - snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result - logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'") + snippet = ( + text_result[:70] + "..." + if len(text_result) > 70 + else text_result + ) + logger.info( + f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'" + ) else: status_info["status"] = "unavailable" - status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}" - logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}") + status_info["error"] = ( + f"STT test failed: unexpected result type {type(text_result)}" + ) + logger.warning( + f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}" + ) except Exception as e: - logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True) + logger.error( + f"Error testing STT provider {provider_name}: {e}", exc_info=True + ) status_info["status"] = "unavailable" status_info["error"] = f"STT test failed: {str(e)}" else: - logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}") + logger.debug( + f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}" + ) status_info["status"] = "available" - status_info["error"] = "This provider type is not tested and is assumed to be available." + status_info["error"] = ( + "This provider type is not tested and is assumed to be available." + ) return status_info - def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error): + def _error_response( + self, message: str, status_code: int = 500, log_fn=logger.error + ): log_fn(message) # 记录更详细的traceback信息,但只在是严重错误时 if status_code == 500: @@ -312,7 +505,9 @@ class ConfigRoute(Route): """API: check a single LLM Provider's status by id""" provider_id = request.args.get("id") if not provider_id: - return self._error_response("Missing provider_id parameter", 400, logger.warning) + return self._error_response( + "Missing provider_id parameter", 400, logger.warning + ) logger.info(f"API call: /config/provider/check_one id={provider_id}") try: @@ -320,16 +515,21 @@ class ConfigRoute(Route): target = prov_mgr.inst_map.get(provider_id) if not target: - logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.") - return Response().error(f"Provider with id '{provider_id}' not found").__dict__ + logger.warning( + f"Provider with id '{provider_id}' not found in provider_manager." + ) + return ( + Response() + .error(f"Provider with id '{provider_id}' not found") + .__dict__ + ) result = await self._test_single_provider(target) return Response().ok(result).__dict__ except Exception as e: return self._error_response( - f"Critical error checking provider {provider_id}: {e}", - 500 + f"Critical error checking provider {provider_id}: {e}", 500 ) async def get_configs(self): @@ -340,29 +540,15 @@ class ConfigRoute(Route): return Response().ok(await self._get_astrbot_config()).__dict__ return Response().ok(await self._get_plugin_config(plugin_name)).__dict__ - async def post_session_seperate(self): - """设置提供商会话隔离""" - post_config = await request.json - enable = post_config.get("enable", None) - if enable is None: - return Response().error("缺少参数 enable").__dict__ - - astrbot_config = self.core_lifecycle.astrbot_config - astrbot_config["provider_settings"]["separate_provider"] = enable - try: - astrbot_config.save_config() - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "设置成功~").__dict__ - async def get_provider_config_list(self): provider_type = request.args.get("provider_type", None) if not provider_type: return Response().error("缺少参数 provider_type").__dict__ + provider_type_ls = provider_type.split(",") provider_list = [] astrbot_config = self.core_lifecycle.astrbot_config for provider in astrbot_config["provider"]: - if provider.get("provider_type", None) == provider_type: + if provider.get("provider_type", None) in provider_type_ls: provider_list.append(provider) return Response().ok(provider_list).__dict__ @@ -388,11 +574,21 @@ class ConfigRoute(Route): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ + async def get_platform_list(self): + """获取所有平台的列表""" + platform_list = [] + for platform in self.config["platform"]: + platform_list.append(platform) + return Response().ok({"platforms": platform_list}).__dict__ + async def post_astrbot_configs(self): - post_configs = await request.json + data = await request.json + config = data.get("config", None) + conf_id = data.get("conf_id", None) try: - await self._save_astrbot_configs(post_configs) - return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ + await self._save_astrbot_configs(config, conf_id) + await self.core_lifecycle.reload_pipeline_scheduler(conf_id) + return Response().ok(None, "保存成功~").__dict__ except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ @@ -509,12 +705,6 @@ class ConfigRoute(Route): return Response().error(str(e)).__dict__ return Response().ok(None, "删除成功,已经实时生效~").__dict__ - async def get_llm_tools(self): - """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" - tool_mgr = self.core_lifecycle.provider_manager.llm_tools - tools = tool_mgr.get_func_desc_openai_style() - return Response().ok(tools).__dict__ - async def _get_astrbot_config(self): config = self.config @@ -557,10 +747,12 @@ class ConfigRoute(Route): return ret - async def _save_astrbot_configs(self, post_configs: dict): + async def _save_astrbot_configs(self, post_configs: dict, conf_id: str = None): try: - save_config(post_configs, self.config, is_core=True) - await self.core_lifecycle.restart() + if conf_id not in self.acm.confs: + raise ValueError(f"配置文件 {conf_id} 不存在") + astrbot_config = self.acm.confs[conf_id] + save_config(post_configs, astrbot_config, is_core=True) except Exception as e: raise e diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py index dde6f9a5a..fb5d3e10e 100644 --- a/astrbot/dashboard/routes/conversation.py +++ b/astrbot/dashboard/routes/conversation.py @@ -29,6 +29,7 @@ class ConversationRoute(Route): ), } self.db_helper = db_helper + self.conv_mgr = core_lifecycle.conversation_manager self.core_lifecycle = core_lifecycle self.register_routes() @@ -54,7 +55,6 @@ class ConversationRoute(Route): exclude_platforms.split(",") if exclude_platforms else [] ) - # 限制页面大小,防止请求过大数据 if page < 1: page = 1 if page_size < 1: @@ -62,9 +62,11 @@ class ConversationRoute(Route): if page_size > 100: page_size = 100 - # 使用数据库的分页方法获取会话列表和总数,传入筛选条件 try: - conversations, total_count = self.db_helper.get_filtered_conversations( + ( + conversations, + total_count, + ) = await self.conv_mgr.get_filtered_conversations( page=page, page_size=page_size, platforms=platform_list, @@ -108,7 +110,9 @@ class ConversationRoute(Route): if not user_id or not cid: return Response().error("缺少必要参数: user_id 和 cid").__dict__ - conversation = self.db_helper.get_conversation_by_user_id(user_id, cid) + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, conversation_id=cid + ) if not conversation: return Response().error("对话不存在").__dict__ @@ -143,14 +147,18 @@ class ConversationRoute(Route): if not user_id or not cid: return Response().error("缺少必要参数: user_id 和 cid").__dict__ - conversation = self.db_helper.get_conversation_by_user_id(user_id, cid) + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, conversation_id=cid + ) if not conversation: return Response().error("对话不存在").__dict__ - if title is not None: - self.db_helper.update_conversation_title(user_id, cid, title) - if persona_id is not None: - self.db_helper.update_conversation_persona_id(user_id, cid, persona_id) - + if title is not None or persona_id is not None: + await self.conv_mgr.update_conversation( + unified_msg_origin=user_id, + conversation_id=cid, + title=title, + persona_id=persona_id, + ) return Response().ok({"message": "对话信息更新成功"}).__dict__ except Exception as e: @@ -201,11 +209,17 @@ class ConversationRoute(Route): Response().error("history 必须是有效的 JSON 字符串或数组").__dict__ ) - conversation = self.db_helper.get_conversation_by_user_id(user_id, cid) + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=user_id, conversation_id=cid + ) if not conversation: return Response().error("对话不存在").__dict__ - self.db_helper.update_conversation(user_id, cid, history) + history = json.loads(history) if isinstance(history, str) else history + + await self.conv_mgr.update_conversation( + unified_msg_origin=user_id, conversation_id=cid, history=history + ) return Response().ok({"message": "对话历史更新成功"}).__dict__ diff --git a/astrbot/dashboard/routes/persona.py b/astrbot/dashboard/routes/persona.py new file mode 100644 index 000000000..032471ee4 --- /dev/null +++ b/astrbot/dashboard/routes/persona.py @@ -0,0 +1,199 @@ +import traceback +from .route import Route, Response, RouteContext +from astrbot.core import logger +from quart import request +from astrbot.core.db import BaseDatabase +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + + +class PersonaRoute(Route): + def __init__( + self, + context: RouteContext, + db_helper: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.routes = { + "/persona/list": ("GET", self.list_personas), + "/persona/detail": ("POST", self.get_persona_detail), + "/persona/create": ("POST", self.create_persona), + "/persona/update": ("POST", self.update_persona), + "/persona/delete": ("POST", self.delete_persona), + } + self.db_helper = db_helper + self.persona_mgr = core_lifecycle.persona_mgr + self.register_routes() + + async def list_personas(self): + """获取所有人格列表""" + try: + personas = await self.persona_mgr.get_all_personas() + return ( + Response() + .ok( + [ + { + "persona_id": persona.persona_id, + "system_prompt": persona.system_prompt, + "begin_dialogs": persona.begin_dialogs or [], + "tools": persona.tools, + "created_at": persona.created_at.isoformat() + if persona.created_at + else None, + "updated_at": persona.updated_at.isoformat() + if persona.updated_at + else None, + } + for persona in personas + ] + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"获取人格列表失败: {str(e)}").__dict__ + + async def get_persona_detail(self): + """获取指定人格的详细信息""" + try: + data = await request.get_json() + persona_id = data.get("persona_id") + + if not persona_id: + return Response().error("缺少必要参数: persona_id").__dict__ + + persona = await self.persona_mgr.get_persona(persona_id) + if not persona: + return Response().error("人格不存在").__dict__ + + return ( + Response() + .ok( + { + "persona_id": persona.persona_id, + "system_prompt": persona.system_prompt, + "begin_dialogs": persona.begin_dialogs or [], + "tools": persona.tools, + "created_at": persona.created_at.isoformat() + if persona.created_at + else None, + "updated_at": persona.updated_at.isoformat() + if persona.updated_at + else None, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"获取人格详情失败: {str(e)}").__dict__ + + async def create_persona(self): + """创建新人格""" + try: + data = await request.get_json() + persona_id = data.get("persona_id", "").strip() + system_prompt = data.get("system_prompt", "").strip() + begin_dialogs = data.get("begin_dialogs", []) + tools = data.get("tools") + + if not persona_id: + return Response().error("人格ID不能为空").__dict__ + + if not system_prompt: + return Response().error("系统提示词不能为空").__dict__ + + # 验证 begin_dialogs 格式 + if begin_dialogs and len(begin_dialogs) % 2 != 0: + return ( + Response() + .error("预设对话数量必须为偶数(用户和助手轮流对话)") + .__dict__ + ) + + persona = await self.persona_mgr.create_persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs if begin_dialogs else None, + tools=tools if tools else None, + ) + + return ( + Response() + .ok( + { + "message": "人格创建成功", + "persona": { + "persona_id": persona.persona_id, + "system_prompt": persona.system_prompt, + "begin_dialogs": persona.begin_dialogs or [], + "tools": persona.tools or [], + "created_at": persona.created_at.isoformat() + if persona.created_at + else None, + "updated_at": persona.updated_at.isoformat() + if persona.updated_at + else None, + }, + } + ) + .__dict__ + ) + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"创建人格失败: {str(e)}").__dict__ + + async def update_persona(self): + """更新人格信息""" + try: + data = await request.get_json() + persona_id = data.get("persona_id") + system_prompt = data.get("system_prompt") + begin_dialogs = data.get("begin_dialogs") + tools = data.get("tools") + + if not persona_id: + return Response().error("缺少必要参数: persona_id").__dict__ + + # 验证 begin_dialogs 格式 + if begin_dialogs is not None and len(begin_dialogs) % 2 != 0: + return ( + Response() + .error("预设对话数量必须为偶数(用户和助手轮流对话)") + .__dict__ + ) + + await self.persona_mgr.update_persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs, + tools=tools, + ) + + return Response().ok({"message": "人格更新成功"}).__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"更新人格失败: {str(e)}").__dict__ + + async def delete_persona(self): + """删除人格""" + try: + data = await request.get_json() + persona_id = data.get("persona_id") + + if not persona_id: + return Response().error("缺少必要参数: persona_id").__dict__ + + await self.persona_mgr.delete_persona(persona_id) + + return Response().ok({"message": "人格删除成功"}).__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"删除人格失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 179b45428..849339698 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -40,8 +40,6 @@ class PluginRoute(Route): "/plugin/on": ("POST", self.on_plugin), "/plugin/reload": ("POST", self.reload_plugins), "/plugin/readme": ("GET", self.get_plugin_readme), - "/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable), - "/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable), } self.core_lifecycle = core_lifecycle self.plugin_manager = plugin_manager @@ -286,14 +284,6 @@ class PluginRoute(Route): f"{filter.parent_command_names[0]} {filter.command_name}" ) info["cmd"] = info["cmd"].strip() - if ( - self.core_lifecycle.astrbot_config["wake_prefix"] - and len(self.core_lifecycle.astrbot_config["wake_prefix"]) - > 0 - ): - info["cmd"] = ( - f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" - ) elif isinstance(filter, CommandGroupFilter): info["type"] = "指令组" info["cmd"] = filter.get_complete_command_names()[0] @@ -301,14 +291,6 @@ class PluginRoute(Route): info["sub_command"] = filter.print_cmd_tree( filter.sub_command_filters ) - if ( - self.core_lifecycle.astrbot_config["wake_prefix"] - and len(self.core_lifecycle.astrbot_config["wake_prefix"]) - > 0 - ): - info["cmd"] = ( - f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}" - ) elif isinstance(filter, RegexFilter): info["type"] = "正则匹配" info["cmd"] = filter.regex_str @@ -498,90 +480,3 @@ class PluginRoute(Route): except Exception as e: logger.error(f"/api/plugin/readme: {traceback.format_exc()}") return Response().error(f"读取README文件失败: {str(e)}").__dict__ - - async def get_plugin_platform_enable(self): - """获取插件在各平台的可用性配置""" - try: - platform_enable = self.core_lifecycle.astrbot_config.get( - "platform_settings", {} - ).get("plugin_enable", {}) - - # 获取所有可用平台 - platforms = [] - - for platform in self.core_lifecycle.astrbot_config.get("platform", []): - platform_type = platform.get("type", "") - platform_id = platform.get("id", "") - - platforms.append( - { - "name": platform_id, # 使用type作为name,这是系统内部使用的平台名称 - "id": platform_id, # 保留id字段以便前端可以显示 - "type": platform_type, - "display_name": f"{platform_type}({platform_id})", - } - ) - - adjusted_platform_enable = {} - for platform_id, plugins in platform_enable.items(): - adjusted_platform_enable[platform_id] = plugins - - # 获取所有插件,包括系统内部插件 - plugins = [] - for plugin in self.plugin_manager.context.get_all_stars(): - plugins.append( - { - "name": plugin.name, - "desc": plugin.desc, - "reserved": plugin.reserved, # 添加reserved标志 - } - ) - - logger.debug( - f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}" - ) - - return ( - Response() - .ok( - { - "platforms": platforms, - "plugins": plugins, - "platform_enable": adjusted_platform_enable, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - - async def set_plugin_platform_enable(self): - """设置插件在各平台的可用性配置""" - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - try: - data = await request.json - platform_enable = data.get("platform_enable", {}) - - # 更新配置 - config = self.core_lifecycle.astrbot_config - platform_settings = config.get("platform_settings", {}) - platform_settings["plugin_enable"] = platform_enable - config["platform_settings"] = platform_settings - config.save_config() - - # 更新插件的平台兼容性缓存 - await self.plugin_manager.update_all_platform_compatibility() - - logger.info(f"插件平台可用性配置已更新: {platform_enable}") - - return Response().ok(None, "插件平台可用性配置已更新").__dict__ - except Exception as e: - logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index bd94d9adf..a11fae252 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -21,17 +21,19 @@ class Route: @dataclass class Response: - status: str = None - message: str = None - data: dict = None + status: str | None = None + message: str | None = None + data: dict | list | None = None def error(self, message: str): self.status = "error" self.message = message return self - def ok(self, data: dict = {}, message: str = None): + def ok(self, data: dict | list | None = None, message: str | None = None): self.status = "ok" + if data is None: + data = {} self.data = data self.message = message return self diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index fdcbdbf73..1271a2493 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -24,7 +24,6 @@ class SessionManagementRoute(Route): "/session/list": ("GET", self.list_sessions), "/session/update_persona": ("POST", self.update_session_persona), "/session/update_provider": ("POST", self.update_session_provider), - "/session/get_session_info": ("POST", self.get_session_info), "/session/plugins": ("GET", self.get_session_plugins), "/session/update_plugin": ("POST", self.update_session_plugin), "/session/update_llm": ("POST", self.update_session_llm), @@ -32,24 +31,20 @@ class SessionManagementRoute(Route): "/session/update_name": ("POST", self.update_session_name), "/session/update_status": ("POST", self.update_session_status), } - self.db_helper = db_helper + self.conv_mgr = core_lifecycle.conversation_manager self.core_lifecycle = core_lifecycle self.register_routes() async def list_sessions(self): """获取所有会话的列表,包括 persona 和 provider 信息""" try: - # 获取会话对话映射 - session_conversations = sp.get("session_conversation", {}) or {} - - # 获取会话提供商偏好设置 - session_provider_perf = sp.get("session_provider_perf", {}) or {} - - # 获取可用的 personas - personas = self.core_lifecycle.star_context.provider_manager.personas - - # 获取可用的 providers - provider_manager = self.core_lifecycle.star_context.provider_manager + preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[]) + session_conversations = {} + for pref in preferences: + session_conversations[pref.scope_id] = pref.value["val"] + provider_manager = self.core_lifecycle.provider_manager + persona_mgr = self.core_lifecycle.persona_mgr + personas = persona_mgr.personas_v3 sessions = [] @@ -59,13 +54,9 @@ class SessionManagementRoute(Route): "session_id": session_id, "conversation_id": conversation_id, "persona_id": None, - "persona_name": None, "chat_provider_id": None, - "chat_provider_name": None, "stt_provider_id": None, - "stt_provider_name": None, "tts_provider_id": None, - "tts_provider_name": None, "session_enabled": SessionServiceManager.is_session_enabled( session_id ), @@ -90,74 +81,46 @@ class SessionManagementRoute(Route): } # 获取对话信息 - conversation = self.db_helper.get_conversation_by_user_id( - session_id, conversation_id + conversation = await self.conv_mgr.get_conversation( + unified_msg_origin=session_id, conversation_id=conversation_id ) if conversation: session_info["persona_id"] = conversation.persona_id + # 查找 persona 名称 if conversation.persona_id and conversation.persona_id != "[%None]": for persona in personas: if persona["name"] == conversation.persona_id: - session_info["persona_name"] = persona["name"] + session_info["persona_id"] = persona["name"] break elif conversation.persona_id == "[%None]": - session_info["persona_name"] = "无人格" + session_info["persona_id"] = "无人格" else: # 使用默认人格 - default_persona = provider_manager.selected_default_persona + default_persona = persona_mgr.selected_default_persona_v3 if default_persona: session_info["persona_id"] = default_persona["name"] - session_info["persona_name"] = default_persona["name"] - # 获取会话的 provider 偏好设置 - session_perf = session_provider_perf.get(session_id, {}) - - # Chat completion provider - chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value) - if chat_provider_id: - chat_provider = provider_manager.inst_map.get(chat_provider_id) - if chat_provider: - session_info["chat_provider_id"] = chat_provider_id - session_info["chat_provider_name"] = chat_provider.meta().id - else: - # 使用默认 provider - default_provider = provider_manager.curr_provider_inst - if default_provider: - session_info["chat_provider_id"] = default_provider.meta().id - session_info["chat_provider_name"] = default_provider.meta().id - - # STT provider - stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value) - if stt_provider_id: - stt_provider = provider_manager.inst_map.get(stt_provider_id) - if stt_provider: - session_info["stt_provider_id"] = stt_provider_id - session_info["stt_provider_name"] = stt_provider.meta().id - else: - # 使用默认 STT provider - default_stt_provider = provider_manager.curr_stt_provider_inst - if default_stt_provider: - session_info["stt_provider_id"] = default_stt_provider.meta().id - session_info["stt_provider_name"] = ( - default_stt_provider.meta().id - ) - - # TTS provider - tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value) - if tts_provider_id: - tts_provider = provider_manager.inst_map.get(tts_provider_id) - if tts_provider: - session_info["tts_provider_id"] = tts_provider_id - session_info["tts_provider_name"] = tts_provider.meta().id - else: - # 使用默认 TTS provider - default_tts_provider = provider_manager.curr_tts_provider_inst - if default_tts_provider: - session_info["tts_provider_id"] = default_tts_provider.meta().id - session_info["tts_provider_name"] = ( - default_tts_provider.meta().id - ) + # 获取 provider 信息 + provider_manager = self.core_lifecycle.provider_manager + chat_provider = provider_manager.get_using_provider( + provider_type=ProviderType.CHAT_COMPLETION, umo=session_id + ) + tts_provider = provider_manager.get_using_provider( + provider_type=ProviderType.TEXT_TO_SPEECH, umo=session_id + ) + stt_provider = provider_manager.get_using_provider( + provider_type=ProviderType.SPEECH_TO_TEXT, umo=session_id + ) + if chat_provider: + meta = chat_provider.meta() + session_info["chat_provider_id"] = meta.id + if tts_provider: + meta = tts_provider.meta() + session_info["tts_provider_id"] = meta.id + if stt_provider: + meta = stt_provider.meta() + session_info["stt_provider_id"] = meta.id sessions.append(session_info) @@ -311,133 +274,6 @@ class SessionManagementRoute(Route): logger.error(error_msg) return Response().error(f"更新会话提供商失败: {str(e)}").__dict__ - async def get_session_info(self): - """获取指定会话的详细信息""" - try: - data = await request.get_json() - session_id = data.get("session_id") - - if not session_id: - return Response().error("缺少必要参数: session_id").__dict__ - # 获取会话对话信息 - session_conversations = sp.get("session_conversation", {}) or {} - conversation_id = session_conversations.get(session_id) - - if not conversation_id: - return Response().error(f"会话 {session_id} 未找到对话").__dict__ - - session_info = { - "session_id": session_id, - "conversation_id": conversation_id, - "persona_id": None, - "persona_name": None, - "chat_provider_id": None, - "chat_provider_name": None, - "stt_provider_id": None, - "stt_provider_name": None, - "tts_provider_id": None, - "tts_provider_name": None, - "llm_enabled": SessionServiceManager.is_llm_enabled_for_session( - session_id - ), - "tts_enabled": None, # 将在下面设置 - "platform": session_id.split(":")[0] - if ":" in session_id - else "unknown", - "message_type": session_id.split(":")[1] - if session_id.count(":") >= 1 - else "unknown", - "session_name": session_id.split(":")[2] - if session_id.count(":") >= 2 - else session_id, - } - - # 获取TTS状态 - session_info["tts_enabled"] = ( - SessionServiceManager.is_tts_enabled_for_session(session_id) - ) - - # 获取对话信息 - conversation = self.db_helper.get_conversation_by_user_id( - session_id, conversation_id - ) - if conversation: - session_info["persona_id"] = conversation.persona_id - - # 查找 persona 名称 - provider_manager = self.core_lifecycle.star_context.provider_manager - personas = provider_manager.personas - - if conversation.persona_id and conversation.persona_id != "[%None]": - for persona in personas: - if persona["name"] == conversation.persona_id: - session_info["persona_name"] = persona["name"] - break - elif conversation.persona_id == "[%None]": - session_info["persona_name"] = "无人格" - else: - # 使用默认人格 - default_persona = provider_manager.selected_default_persona - if default_persona: - session_info["persona_id"] = default_persona["name"] - session_info["persona_name"] = default_persona["name"] - - # 获取会话的 provider 偏好设置 - session_provider_perf = sp.get("session_provider_perf", {}) or {} - session_perf = session_provider_perf.get(session_id, {}) - - # 获取 provider 信息 - provider_manager = self.core_lifecycle.star_context.provider_manager - - # Chat completion provider - chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value) - if chat_provider_id: - chat_provider = provider_manager.inst_map.get(chat_provider_id) - if chat_provider: - session_info["chat_provider_id"] = chat_provider_id - session_info["chat_provider_name"] = chat_provider.meta().id - else: - # 使用默认 provider - default_provider = provider_manager.curr_provider_inst - if default_provider: - session_info["chat_provider_id"] = default_provider.meta().id - session_info["chat_provider_name"] = default_provider.meta().id - - # STT provider - stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value) - if stt_provider_id: - stt_provider = provider_manager.inst_map.get(stt_provider_id) - if stt_provider: - session_info["stt_provider_id"] = stt_provider_id - session_info["stt_provider_name"] = stt_provider.meta().id - else: - # 使用默认 STT provider - default_stt_provider = provider_manager.curr_stt_provider_inst - if default_stt_provider: - session_info["stt_provider_id"] = default_stt_provider.meta().id - session_info["stt_provider_name"] = default_stt_provider.meta().id - - # TTS provider - tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value) - if tts_provider_id: - tts_provider = provider_manager.inst_map.get(tts_provider_id) - if tts_provider: - session_info["tts_provider_id"] = tts_provider_id - session_info["tts_provider_name"] = tts_provider.meta().id - else: - # 使用默认 TTS provider - default_tts_provider = provider_manager.curr_tts_provider_inst - if default_tts_provider: - session_info["tts_provider_id"] = default_tts_provider.meta().id - session_info["tts_provider_name"] = default_tts_provider.meta().id - - return Response().ok(session_info).__dict__ - - except Exception as e: - error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}" - logger.error(error_msg) - return Response().error(f"获取会话信息失败: {str(e)}").__dict__ - async def get_session_plugins(self): """获取指定会话的插件配置信息""" try: diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 2a8389396..d13eb802c 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -11,6 +11,7 @@ from astrbot.core.db import BaseDatabase from astrbot.core.config import VERSION from astrbot.core.utils.io import get_dashboard_version from astrbot.core import DEMO_MODE +from astrbot.core.db.migration.helper import check_migration_needed_v4 class StatRoute(Route): @@ -59,6 +60,8 @@ class StatRoute(Route): ) async def get_version(self): + need_migration = await check_migration_needed_v4(self.core_lifecycle.db) + return ( Response() .ok( @@ -66,6 +69,7 @@ class StatRoute(Route): "version": VERSION, "dashboard_version": await get_dashboard_version(), "change_pwd_hint": self.is_default_cred(), + "need_migration": need_migration, } ) .__dict__ @@ -84,7 +88,7 @@ class StatRoute(Route): message_time_based_stats = [] idx = 0 - for bucket_end in range(start_time, now, 1800): + for bucket_end in range(start_time, now, 3600): cnt = 0 while ( idx < len(stat.platform) diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 5dad2576b..79a601b25 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -1,5 +1,3 @@ -import json -import os import traceback import aiohttp @@ -7,7 +5,7 @@ from quart import request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.star import star_map from .route import Response, Route, RouteContext @@ -25,44 +23,17 @@ class ToolsRoute(Route): "/tools/mcp/add": ("POST", self.add_mcp_server), "/tools/mcp/update": ("POST", self.update_mcp_server), "/tools/mcp/delete": ("POST", self.delete_mcp_server), - "/tools/mcp/market": ("GET", self.get_mcp_markets), "/tools/mcp/test": ("POST", self.test_mcp_connection), + "/tools/list": ("GET", self.get_tool_list), + "/tools/toggle-tool": ("POST", self.toggle_tool), + "/tools/mcp/sync-provider": ("POST", self.sync_provider), } self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools - @property - def mcp_config_path(self): - data_dir = get_astrbot_data_path() - return os.path.join(data_dir, "mcp_server.json") - - def load_mcp_config(self): - if not os.path.exists(self.mcp_config_path): - # 配置文件不存在,创建默认配置 - os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) - with open(self.mcp_config_path, "w", encoding="utf-8") as f: - json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) - return DEFAULT_MCP_CONFIG - - try: - with open(self.mcp_config_path, "r", encoding="utf-8") as f: - return json.load(f) - except Exception as e: - logger.error(f"加载 MCP 配置失败: {e}") - return DEFAULT_MCP_CONFIG - - def save_mcp_config(self, config): - try: - with open(self.mcp_config_path, "w", encoding="utf-8") as f: - json.dump(config, f, ensure_ascii=False, indent=4) - return True - except Exception as e: - logger.error(f"保存 MCP 配置失败: {e}") - return False - async def get_mcp_servers(self): try: - config = self.load_mcp_config() + config = self.tool_mgr.load_mcp_config() servers = [] # 获取所有服务器并添加它们的工具列表 @@ -125,14 +96,14 @@ class ToolsRoute(Route): if not has_valid_config: return Response().error("必须提供有效的服务器配置").__dict__ - config = self.load_mcp_config() + config = self.tool_mgr.load_mcp_config() if name in config["mcpServers"]: return Response().error(f"服务器 {name} 已存在").__dict__ config["mcpServers"][name] = server_config - if self.save_mcp_config(config): + if self.tool_mgr.save_mcp_config(config): try: await self.tool_mgr.enable_mcp_server( name, server_config, timeout=30 @@ -162,7 +133,7 @@ class ToolsRoute(Route): if not name: return Response().error("服务器名称不能为空").__dict__ - config = self.load_mcp_config() + config = self.tool_mgr.load_mcp_config() if name not in config["mcpServers"]: return Response().error(f"服务器 {name} 不存在").__dict__ @@ -198,7 +169,7 @@ class ToolsRoute(Route): config["mcpServers"][name] = server_config - if self.save_mcp_config(config): + if self.tool_mgr.save_mcp_config(config): # 处理MCP客户端状态变化 if active: if name in self.tool_mgr.mcp_client_dict or not only_update_active: @@ -266,14 +237,14 @@ class ToolsRoute(Route): if not name: return Response().error("服务器名称不能为空").__dict__ - config = self.load_mcp_config() + config = self.tool_mgr.load_mcp_config() if name not in config["mcpServers"]: return Response().error(f"服务器 {name} 不存在").__dict__ del config["mcpServers"][name] - if self.save_mcp_config(config): + if self.tool_mgr.save_mcp_config(config): if name in self.tool_mgr.mcp_client_dict: try: await self.tool_mgr.disable_mcp_server(name, timeout=10) @@ -295,31 +266,6 @@ class ToolsRoute(Route): logger.error(traceback.format_exc()) return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__ - async def get_mcp_markets(self): - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 10, type=int) - BASE_URL = ( - "https://api.soulter.top/astrbot/mcpservers?page={}&page_size={}".format( - page, - page_size, - ) - ) - try: - async with aiohttp.ClientSession() as session: - async with session.get(f"{BASE_URL}") as response: - if response.status == 200: - data = await response.json() - return Response().ok(data["data"]).__dict__ - else: - return ( - Response() - .error(f"获取市场数据失败: HTTP {response.status}") - .__dict__ - ) - except Exception as _: - logger.error(traceback.format_exc()) - return Response().error("获取市场数据失败").__dict__ - async def test_mcp_connection(self): """ 测试 MCP 服务器连接 @@ -336,3 +282,57 @@ class ToolsRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__ + + async def get_tool_list(self): + """获取所有注册的工具列表""" + try: + tools = self.tool_mgr.func_list + tools_dict = [tool.__dict__() for tool in tools] + return Response().ok(data=tools_dict).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取工具列表失败: {str(e)}").__dict__ + + async def toggle_tool(self): + """启用或停用指定的工具""" + try: + data = await request.json + tool_name = data.get("name") + action = data.get("activate") # True or False + + if not tool_name or action is None: + return Response().error("缺少必要参数: name 或 action").__dict__ + + if action: + try: + ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) + except ValueError as e: + return Response().error(f"启用工具失败: {str(e)}").__dict__ + else: + ok = self.tool_mgr.deactivate_llm_tool(tool_name) + + if ok: + return Response().ok(None, "操作成功。").__dict__ + else: + return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__ + + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"操作工具失败: {str(e)}").__dict__ + + async def sync_provider(self): + """同步 MCP 提供者配置""" + try: + data = await request.json + provider_name = data.get("name") # modelscope, or others + match provider_name: + case "modelscope": + access_token = data.get("access_token", "") + await self.tool_mgr.sync_modelscope_mcp_servers(access_token) + case _: + return Response().error(f"未知: {provider_name}").__dict__ + + return Response().ok(message="同步成功").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"同步失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index 79aa56bc8..a7b9dc5e5 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -7,6 +7,7 @@ from astrbot.core import logger, pip_installer from astrbot.core.utils.io import download_dashboard, get_dashboard_version from astrbot.core.config.default import VERSION from astrbot.core import DEMO_MODE +from astrbot.core.db.migration.helper import do_migration_v4, check_migration_needed_v4 class UpdateRoute(Route): @@ -23,11 +24,27 @@ class UpdateRoute(Route): "/update/do": ("POST", self.update_project), "/update/dashboard": ("POST", self.update_dashboard), "/update/pip-install": ("POST", self.install_pip_package), + "/update/migration": ("POST", self.do_migration), } self.astrbot_updator = astrbot_updator self.core_lifecycle = core_lifecycle self.register_routes() + async def do_migration(self): + need_migration = await check_migration_needed_v4(self.core_lifecycle.db) + if not need_migration: + return Response().ok(None, "不需要进行迁移。").__dict__ + try: + data = await request.json + pim = data.get("platform_id_map", {}) + await do_migration_v4( + self.core_lifecycle.db, pim, self.core_lifecycle.astrbot_config + ) + return Response().ok(None, "迁移成功。").__dict__ + except Exception as e: + logger.error(f"迁移失败: {traceback.format_exc()}") + return Response().error(f"迁移失败: {str(e)}").__dict__ + async def check_update(self): type_ = request.args.get("type", None) diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 06f6f8e60..e22b20524 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -60,6 +60,9 @@ class AstrBotDashboard: self.session_management_route = SessionManagementRoute( self.context, db, core_lifecycle ) + self.persona_route = PersonaRoute( + self.context, db, core_lifecycle + ) self.app.add_url_rule( "/api/plug/", diff --git a/changelogs/v4.0.0-beta.1.md b/changelogs/v4.0.0-beta.1.md new file mode 100644 index 000000000..461605c84 --- /dev/null +++ b/changelogs/v4.0.0-beta.1.md @@ -0,0 +1,3 @@ +# What's Changed + +> **这是 v4.0.0 的测试版本(beta.1),功能尚未完全稳定和加入**。v4.0.0 被设计为向前兼容,如有任何插件兼容性问题或者其他异常请在 GitHub 提交 [Issue](https://github.com/AstrBotDevs/AstrBot/issues)。在测试版本期间,您可以无缝回退到旧版本的 AstrBot,并且数据不受影响。 diff --git a/dashboard/src/assets/images/astrbot_banner.png b/dashboard/src/assets/images/astrbot_banner.png new file mode 100644 index 000000000..f837d9fb4 Binary files /dev/null and b/dashboard/src/assets/images/astrbot_banner.png differ diff --git a/dashboard/src/assets/images/logo-normal.svg b/dashboard/src/assets/images/logo-normal.svg deleted file mode 100644 index eb8373044..000000000 --- a/dashboard/src/assets/images/logo-normal.svg +++ /dev/null @@ -1,40 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/dashboard/src/assets/images/logo-waifu.png b/dashboard/src/assets/images/logo-waifu.png deleted file mode 100644 index cc9375e42..000000000 Binary files a/dashboard/src/assets/images/logo-waifu.png and /dev/null differ diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 12de2561b..4374c99a0 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -2,6 +2,9 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor' import { ref, computed } from 'vue' import ListConfigItem from './ListConfigItem.vue' +import ProviderSelector from './ProviderSelector.vue' +import PersonaSelector from './PersonaSelector.vue' +import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue' import { useI18n } from '@/i18n/composables' const props = defineProps({ @@ -48,6 +51,47 @@ function openEditorDialog(key, value, theme, language) { function saveEditedContent() { dialog.value = false } + +function getValueBySelector(obj, selector) { + const keys = selector.split('.') + let current = obj + for (const key of keys) { + if (current && typeof current === 'object' && key in current) { + current = current[key] + } else { + return undefined + } + } + return current +} + +function shouldShowItem(itemMeta, itemKey) { + if (!itemMeta?.condition) { + return true + } + for (const [conditionKey, expectedValue] of Object.entries(itemMeta.condition)) { + const actualValue = getValueBySelector(props.iterable, conditionKey) + if (actualValue !== expectedValue) { + return false + } + } + return true +} + +function hasVisibleItemsAfter(items, currentIndex) { + const itemEntries = Object.entries(items) + + // 检查当前索引之后是否还有可见的配置项 + for (let i = currentIndex + 1; i < itemEntries.length; i++) { + const [itemKey, itemValue] = itemEntries[i] + const itemMeta = props.metadata[props.metadataKey].items[itemKey] + if (!itemMeta?.invisible && shouldShowItem(itemMeta, itemKey)) { + return true + } + } + + return false +}