Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 38486bc5aa | |||
| 7b8800c4eb | |||
| 8f4625f53b | |||
| d2d5ef1c5c | |||
| 5073f21002 | |||
| 69aaf09ac8 | |||
| 6e61ee81d8 | |||
| cfd05a8d17 | |||
| e204b180a8 | |||
| 563972fd29 | |||
| aa6f73574d | |||
| cefd2d7f49 | |||
| 81e1e545fb | |||
| d516920e72 | |||
| 2171372246 | |||
| d2df4d0cce | |||
| 6ab90fc123 | |||
| 1a84ebbb1e | |||
| c9c0352369 | |||
| 9903b028a3 | |||
| fbc4f8527b | |||
| ac71d9f034 | |||
| 64bcbc9fc0 | |||
| 9e7d46f956 | |||
| e911896cfb | |||
| 9c6d66093f | |||
| b2e39b9701 | |||
| e95ad4049b | |||
| 1df49d1d6f | |||
| 47e6ed455e | |||
| 92592fb9d9 | |||
| be8a0991ed | |||
| 61aac9c80c | |||
| 60af83cfee | |||
| cf64e6c231 | |||
| b711140f26 | |||
| 1d766001bb | |||
| 0759a11a85 | |||
| cb749a38ab | |||
| 369eab18ab | |||
| 6c1f540170 | |||
| d026a9f009 | |||
| a8e7dadd39 | |||
| 2f8d921adf | |||
| 0c6e526f94 | |||
| b1e3018b6b | |||
| 87f05fce66 | |||
| 1b37530c96 | |||
| 842c3c8ea9 |
@@ -27,17 +27,6 @@ jobs:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }}
|
||||
|
||||
- name: Build Dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
mkdir -p dist/assets
|
||||
echo $(git rev-parse HEAD) > dist/assets/version
|
||||
cd ..
|
||||
mkdir -p data
|
||||
cp -r dashboard/dist data/
|
||||
|
||||
- name: Set QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
|
||||
+14
-1
@@ -3,5 +3,18 @@ from astrbot import logger
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_agent as agent
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
|
||||
__all__ = ["AstrBotConfig", "logger", "html_renderer", "llm_tool", "sp"]
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"agent",
|
||||
"sp",
|
||||
"ToolSet",
|
||||
"FunctionTool",
|
||||
"BaseFunctionToolExecutor",
|
||||
]
|
||||
|
||||
@@ -37,10 +37,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
@@ -53,10 +50,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip",
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
@@ -65,10 +59,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
|
||||
click.echo("初始化管理面板目录...")
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"),
|
||||
extract_path=str(astrbot_root),
|
||||
version=f"v{VERSION}",
|
||||
latest=False,
|
||||
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import asyncio
|
||||
from .log import LogManager, LogBroker # noqa
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
@@ -21,7 +20,7 @@ html_renderer = HtmlRenderer(t2i_base_url)
|
||||
logger = LogManager.GetLogger(log_name="astrbot")
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
sp = SharedPreferences()
|
||||
sp = SharedPreferences(db_helper=db_helper)
|
||||
# 文件令牌服务
|
||||
file_token_service = FileTokenService()
|
||||
pip_installer = PipInstaller(
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from .tool import FunctionTool
|
||||
from typing import Generic
|
||||
from .run_context import TContext
|
||||
from .hooks import BaseAgentRunHooks
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent(Generic[TContext]):
|
||||
name: str
|
||||
instructions: str | None = None
|
||||
tools: list[str, FunctionTool] | None = None
|
||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import Generic
|
||||
from .tool import FunctionTool
|
||||
from .agent import Agent
|
||||
from .run_context import TContext
|
||||
|
||||
|
||||
class HandoffTool(FunctionTool, Generic[TContext]):
|
||||
"""Handoff tool for delegating tasks to another agent."""
|
||||
|
||||
def __init__(
|
||||
self, agent: Agent[TContext], parameters: dict | None = None, **kwargs
|
||||
):
|
||||
self.agent = agent
|
||||
super().__init__(
|
||||
name=f"transfer_to_{agent.name}",
|
||||
parameters=parameters or self.default_parameters(),
|
||||
description=agent.instructions or self.default_description(agent.name),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def default_parameters(self) -> dict:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": "The input to be handed off to another agent. This should be a clear and concise request or task.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def default_description(self, agent_name: str | None) -> str:
|
||||
agent_name = agent_name or "another"
|
||||
return f"Delegate tasks to {self.name} agent to handle the request."
|
||||
@@ -0,0 +1,27 @@
|
||||
import mcp
|
||||
from dataclasses import dataclass
|
||||
from .run_context import ContextWrapper, TContext
|
||||
from typing import Generic
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseAgentRunHooks(Generic[TContext]):
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
||||
async def on_tool_start(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
): ...
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
tool_result: mcp.types.CallToolResult | None,
|
||||
): ...
|
||||
async def on_agent_done(
|
||||
self, run_context: ContextWrapper[TContext], llm_response: LLMResponse
|
||||
): ...
|
||||
@@ -0,0 +1,208 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
first_key = next(iter(config["mcpServers"]))
|
||||
config = config["mcpServers"][first_key]
|
||||
config.pop("active", None)
|
||||
return config
|
||||
|
||||
|
||||
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
"""快速测试 MCP 服务器可达性"""
|
||||
import aiohttp
|
||||
|
||||
cfg = _prepare_config(config.copy())
|
||||
|
||||
url = cfg["url"]
|
||||
headers = cfg.get("headers", {})
|
||||
timeout = cfg.get("timeout", 10)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if cfg.get("transport") == "streamable_http":
|
||||
test_payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"id": 0,
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
||||
},
|
||||
}
|
||||
async with session.post(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
json=test_payload,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
else:
|
||||
async with session.get(
|
||||
url,
|
||||
headers={
|
||||
**headers,
|
||||
"Accept": "application/json, text/event-stream",
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, f"HTTP {response.status}: {response.reason}"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, f"连接超时: {timeout}秒"
|
||||
except Exception as e:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: list[mcp.Tool] = []
|
||||
self.server_errlogs: list[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
@@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
import typing as T
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
@@ -0,0 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
context: TContext
|
||||
event: AstrMessageEvent
|
||||
|
||||
NoContext = ContextWrapper[None]
|
||||
@@ -0,0 +1,3 @@
|
||||
from .base import BaseAgentRunner
|
||||
|
||||
__all__ = ["BaseAgentRunner"]
|
||||
+21
-20
@@ -1,32 +1,33 @@
|
||||
import abc
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ....message.message_event_result import MessageChain
|
||||
from enum import Enum, auto
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponse
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Agent 状态枚举"""
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
"""Defines the state of the agent."""
|
||||
|
||||
IDLE = auto() # Initial state
|
||||
RUNNING = auto() # Currently processing
|
||||
DONE = auto() # Completed
|
||||
ERROR = auto() # Error state
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
class BaseAgentRunner:
|
||||
class BaseAgentRunner(T.Generic[TContext]):
|
||||
@abc.abstractmethod
|
||||
async def reset(self) -> None:
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
"""
|
||||
Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
+129
-101
@@ -1,10 +1,12 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
|
||||
from ...context import PipelineContext
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentState
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..response import AgentResponseData
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -21,8 +23,8 @@ from mcp.types import (
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
@@ -31,28 +33,25 @@ else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# TODO:
|
||||
# 1. 处理平台不兼容的处理器
|
||||
|
||||
|
||||
class ToolLoopAgent(BaseAgentRunner):
|
||||
def __init__(
|
||||
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.req = None
|
||||
self.event = event
|
||||
self.pipeline_ctx = pipeline_ctx
|
||||
self._state = AgentState.IDLE
|
||||
self.final_llm_resp = None
|
||||
self.streaming = False
|
||||
|
||||
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
@override
|
||||
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
|
||||
self.req = req
|
||||
self.streaming = streaming
|
||||
async def reset(
|
||||
self,
|
||||
provider: Provider,
|
||||
request: ProviderRequest,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
self.tool_executor = tool_executor
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
@@ -78,6 +77,12 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
if self._state == AgentState.IDLE:
|
||||
try:
|
||||
await self.agent_hooks.on_agent_begin(self.run_context)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
@@ -124,12 +129,10 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
# 执行事件钩子
|
||||
if await self.pipeline_ctx.call_event_hook(
|
||||
self.event, EventType.OnLLMResponseEvent, llm_resp
|
||||
):
|
||||
return
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
# 返回 LLM 结果
|
||||
if llm_resp.result_chain:
|
||||
@@ -193,50 +196,33 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if not res:
|
||||
continue
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||
|
||||
executor = self.tool_executor.execute(
|
||||
tool=func_tool,
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
)
|
||||
async for resp in executor:
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
@@ -247,43 +233,85 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
else:
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
# 尝试调用工具函数
|
||||
wrapper = self.pipeline_ctx.call_handler(
|
||||
self.event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
# Tool 返回结果
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resp)
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
).base64_image(res.content[0].data)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
self.event.clear_result()
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context,
|
||||
func_tool_name,
|
||||
func_tool_args,
|
||||
resp,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.run_context.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_end(
|
||||
self.run_context, func_tool_name, func_tool_args, None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
||||
)
|
||||
|
||||
self.run_context.event.clear_result()
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
@@ -0,0 +1,256 @@
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
from typing import Awaitable, Literal, Any, Optional
|
||||
from .mcp_client import MCPClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
|
||||
name: str | None = None
|
||||
parameters: dict | None = None
|
||||
description: str | None = None
|
||||
handler: Awaitable | None = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str | None = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str | None = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient | None = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
def __dict__(self) -> dict[str, Any]:
|
||||
"""将 FunctionTool 转换为字典格式"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"parameters": self.parameters,
|
||||
"description": self.description,
|
||||
"active": self.active,
|
||||
"origin": self.origin,
|
||||
"mcp_server_name": self.mcp_server_name,
|
||||
}
|
||||
|
||||
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
|
||||
def __init__(self, tools: list[FunctionTool] = None):
|
||||
self.tools: list[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self.tools) == 0
|
||||
|
||||
def add_tool(self, tool: FunctionTool):
|
||||
"""Add a tool to the set."""
|
||||
# 检查是否已存在同名工具
|
||||
for i, existing_tool in enumerate(self.tools):
|
||||
if existing_tool.name == tool.name:
|
||||
self.tools[i] = tool
|
||||
return
|
||||
self.tools.append(tool)
|
||||
|
||||
def remove_tool(self, name: str):
|
||||
"""Remove a tool by its name."""
|
||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||
|
||||
def get_tool(self, name: str) -> Optional[FunctionTool]:
|
||||
"""Get a tool by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FunctionTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
self.add_tool(_func)
|
||||
|
||||
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
|
||||
def remove_func(self, name: str):
|
||||
"""Remove a function tool by its name."""
|
||||
self.remove_tool(name)
|
||||
|
||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||
def get_func(self, name: str) -> list[FunctionTool]:
|
||||
"""Get all function tools."""
|
||||
return self.get_tool(name)
|
||||
|
||||
@property
|
||||
def func_list(self) -> list[FunctionTool]:
|
||||
"""Get the list of function tools."""
|
||||
return self.tools
|
||||
|
||||
def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
|
||||
"""Convert tools to OpenAI API function calling schema format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
func_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
},
|
||||
}
|
||||
|
||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
return result
|
||||
|
||||
def anthropic_schema(self) -> list[dict]:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
tool_def = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
def google_schema(self) -> dict:
|
||||
"""Convert tools to Google GenAI API format."""
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""Convert schema to Gemini API format."""
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties:
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": convert_schema(tool.parameters),
|
||||
}
|
||||
for tool in self.tools
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False):
|
||||
return self.openai_schema(omit_empty_parameter_field)
|
||||
|
||||
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
|
||||
def get_func_desc_anthropic_style(self):
|
||||
return self.anthropic_schema()
|
||||
|
||||
@deprecated(reason="Use google_schema() instead", version="4.0.0")
|
||||
def get_func_desc_google_genai_style(self):
|
||||
return self.google_schema()
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""获取所有工具的名称列表"""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tools)
|
||||
|
||||
def __bool__(self):
|
||||
return len(self.tools) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
|
||||
def __str__(self):
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
@@ -0,0 +1,11 @@
|
||||
import mcp
|
||||
from typing import Any, Generic, AsyncGenerator
|
||||
from .run_context import TContext, ContextWrapper
|
||||
from .tool import FunctionTool
|
||||
|
||||
|
||||
class BaseFunctionToolExecutor(Generic[TContext]):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args
|
||||
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ...
|
||||
@@ -0,0 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
provider: Provider
|
||||
first_provider_request: ProviderRequest
|
||||
curr_provider_request: ProviderRequest
|
||||
streaming: bool
|
||||
@@ -0,0 +1,276 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.core import AstrBotConfig, logger
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||
from typing import TypeVar, TypedDict
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class ConfInfo(TypedDict):
|
||||
"""Configuration information for a specific session or platform."""
|
||||
|
||||
id: str # UUID of the configuration or "default"
|
||||
umop: list[str] # Unified Message Origin Pattern
|
||||
name: str
|
||||
path: str # File name to the configuration file
|
||||
|
||||
|
||||
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
id="default",
|
||||
umop=["::"],
|
||||
name="default",
|
||||
path=ASTRBOT_CONFIG_PATH,
|
||||
)
|
||||
|
||||
|
||||
class AstrBotConfigManager:
|
||||
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
||||
|
||||
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
|
||||
self.sp = sp
|
||||
self.confs: dict[str, AstrBotConfig] = {}
|
||||
"""uuid / "default" -> AstrBotConfig"""
|
||||
self.confs["default"] = default_config
|
||||
self._load_all_configs()
|
||||
|
||||
def _load_all_configs(self):
|
||||
"""Load all configurations from the shared preferences."""
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
for uuid_, meta in abconf_data.items():
|
||||
filename = meta["path"]
|
||||
conf_path = os.path.join(get_astrbot_config_path(), filename)
|
||||
if os.path.exists(conf_path):
|
||||
conf = AstrBotConfig(config_path=conf_path)
|
||||
self.confs[uuid_] = conf
|
||||
else:
|
||||
logger.warning(
|
||||
f"Config file {conf_path} for UUID {uuid_} does not exist, skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||
|
||||
Returns:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
"""
|
||||
# uuid -> { "umop": list, "path": str, "name": str }
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
try:
|
||||
umo = str(MessageSession.from_str(umo)) # validate
|
||||
except Exception:
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
for uuid_, meta in abconf_data.items():
|
||||
for pattern in meta["umop"]:
|
||||
if self._is_umo_match(pattern, umo):
|
||||
return ConfInfo(**meta, id=uuid_)
|
||||
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
def _save_conf_mapping(
|
||||
self,
|
||||
abconf_path: str,
|
||||
abconf_id: str,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
abconf_name: str | None = None,
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
"umop": umo_parts,
|
||||
"path": abconf_path,
|
||||
"name": random_word,
|
||||
}
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
|
||||
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
||||
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
||||
if not umo:
|
||||
return self.confs["default"]
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
uuid_ = self._load_conf_mapping(umo)["id"]
|
||||
|
||||
conf = self.confs.get(uuid_)
|
||||
if not conf:
|
||||
conf = self.confs["default"] # default MUST exists
|
||||
|
||||
return conf
|
||||
|
||||
@property
|
||||
def default_conf(self) -> AstrBotConfig:
|
||||
"""获取默认配置文件"""
|
||||
return self.confs["default"]
|
||||
|
||||
def get_conf_info(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件元数据"""
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
return self._load_conf_mapping(umo)
|
||||
|
||||
def get_conf_list(self) -> list[ConfInfo]:
|
||||
"""获取所有配置文件的元数据列表"""
|
||||
conf_list = []
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
abconf_mapping = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
for uuid_, meta in abconf_mapping.items():
|
||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||
return conf_list
|
||||
|
||||
def create_conf(
|
||||
self,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
config: dict = DEFAULT_CONFIG,
|
||||
name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
|
||||
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
||||
"""
|
||||
conf_uuid = str(uuid.uuid4())
|
||||
conf_file_name = f"abconf_{conf_uuid}.json"
|
||||
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
||||
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
||||
conf.save_config()
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
|
||||
self.confs[conf_uuid] = conf
|
||||
return conf_uuid
|
||||
|
||||
def delete_conf(self, conf_id: str) -> bool:
|
||||
"""删除指定配置文件
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
|
||||
Raises:
|
||||
ValueError: 如果试图删除默认配置文件
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能删除默认配置文件")
|
||||
|
||||
# 从映射中移除
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
|
||||
# 获取配置文件路径
|
||||
conf_path = os.path.join(
|
||||
get_astrbot_config_path(), abconf_data[conf_id]["path"]
|
||||
)
|
||||
|
||||
# 删除配置文件
|
||||
try:
|
||||
if os.path.exists(conf_path):
|
||||
os.remove(conf_path)
|
||||
logger.info(f"已删除配置文件: {conf_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除配置文件 {conf_path} 失败: {e}")
|
||||
return False
|
||||
|
||||
# 从内存中移除
|
||||
if conf_id in self.confs:
|
||||
del self.confs[conf_id]
|
||||
|
||||
# 从映射中移除
|
||||
del abconf_data[conf_id]
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
||||
) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
name: 新的配置文件名称 (可选)
|
||||
umo_parts: 新的 UMO 部分列表 (可选)
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
if conf_id == "default":
|
||||
raise ValueError("不能更新默认配置文件的信息")
|
||||
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
if conf_id not in abconf_data:
|
||||
logger.warning(f"配置文件 {conf_id} 不存在于映射中")
|
||||
return False
|
||||
|
||||
# 更新名称
|
||||
if name is not None:
|
||||
abconf_data[conf_id]["name"] = name
|
||||
|
||||
# 更新 UMO 部分
|
||||
if umo_parts is not None:
|
||||
# 验证 UMO 部分格式
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
logger.info(f"成功更新配置文件 {conf_id} 的信息")
|
||||
return True
|
||||
|
||||
def g(
|
||||
self, umo: str | None = None, key: str | None = None, default: _VT = None
|
||||
) -> _VT:
|
||||
"""获取配置项。umo 为 None 时使用默认配置"""
|
||||
if umo is None:
|
||||
return self.confs["default"].get(key, default)
|
||||
conf = self.get_conf(umo)
|
||||
return conf.get(key, default)
|
||||
+609
-200
File diff suppressed because it is too large
Load Diff
@@ -5,40 +5,44 @@ AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json
|
||||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.db.po import Conversation, ConversationV2
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.session_conversations: Dict[str, str] = {}
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
"""启动定时保存任务"""
|
||||
asyncio.create_task(self._periodic_save())
|
||||
def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
|
||||
"""将 ConversationV2 对象转换为 Conversation 对象"""
|
||||
created_at = int(conv_v2.created_at.timestamp())
|
||||
updated_at = int(conv_v2.updated_at.timestamp())
|
||||
return Conversation(
|
||||
platform_id=conv_v2.platform_id,
|
||||
user_id=conv_v2.user_id,
|
||||
cid=conv_v2.conversation_id,
|
||||
history=json.dumps(conv_v2.content or []),
|
||||
title=conv_v2.title,
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
async def _periodic_save(self):
|
||||
"""定时保存会话对话映射关系到存储中"""
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
"""保存会话对话映射关系到存储中"""
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
async def new_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
platform_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
) -> str:
|
||||
"""新建对话,并将当前会话的对话转移到新对话
|
||||
|
||||
Args:
|
||||
@@ -46,11 +50,23 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return conversation_id
|
||||
if not platform_id:
|
||||
# 如果没有提供 platform_id,则从 unified_msg_origin 中解析
|
||||
parts = unified_msg_origin.split(":")
|
||||
if len(parts) >= 3:
|
||||
platform_id = parts[0]
|
||||
if not platform_id:
|
||||
platform_id = "unknown"
|
||||
conv = await self.db.create_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
platform_id=platform_id,
|
||||
content=content,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conv.conversation_id
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
||||
return conv.conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
"""切换会话的对话
|
||||
@@ -60,10 +76,10 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)
|
||||
|
||||
async def delete_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str = None
|
||||
self, unified_msg_origin: str, conversation_id: str | None = None
|
||||
):
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
@@ -71,13 +87,18 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
f = False
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
f = True
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
del self.session_conversations[unified_msg_origin]
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
Args:
|
||||
@@ -85,14 +106,19 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
ret = self.session_conversations.get(unified_msg_origin, None)
|
||||
if not ret:
|
||||
ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
|
||||
if ret:
|
||||
self.session_conversations[unified_msg_origin] = ret
|
||||
return ret
|
||||
|
||||
async def get_conversation(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str,
|
||||
create_if_not_exists: bool = False,
|
||||
) -> Conversation:
|
||||
) -> Conversation | None:
|
||||
"""获取会话的对话
|
||||
|
||||
Args:
|
||||
@@ -101,27 +127,74 @@ class ConversationManager:
|
||||
Returns:
|
||||
conversation (Conversation): 对话对象
|
||||
"""
|
||||
conv = self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
if not conv and create_if_not_exists:
|
||||
# 如果对话不存在且需要创建,则新建一个对话
|
||||
conversation_id = await self.new_conversation(unified_msg_origin)
|
||||
return self.db.get_conversation_by_user_id(
|
||||
unified_msg_origin, conversation_id
|
||||
)
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
conv = await self.db.get_conversation_by_id(cid=conversation_id)
|
||||
conv_res = None
|
||||
if conv:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
return conv_res
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
"""获取会话的所有对话
|
||||
async def get_conversations(
|
||||
self, unified_msg_origin: str | None = None, platform_id: str | None = None
|
||||
) -> List[Conversation]:
|
||||
"""获取对话列表
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
|
||||
platform_id (str): 平台 ID, 可选参数, 用于过滤对话
|
||||
Returns:
|
||||
conversations (List[Conversation]): 对话对象列表
|
||||
"""
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
convs = await self.db.get_conversations(
|
||||
user_id=unified_msg_origin, platform_id=platform_id
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
convs_res.append(conv_res)
|
||||
return convs_res
|
||||
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platform_ids: list[str] | None = None,
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""获取过滤后的对话列表
|
||||
|
||||
Args:
|
||||
page (int): 页码, 默认为 1
|
||||
page_size (int): 每页大小, 默认为 20
|
||||
platform_ids (list[str]): 平台 ID 列表, 可选
|
||||
search_query (str): 搜索查询字符串, 可选
|
||||
Returns:
|
||||
conversations (list[Conversation]): 对话对象列表
|
||||
"""
|
||||
convs, cnt = await self.db.get_filtered_conversations(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
platform_ids=platform_ids,
|
||||
search_query=search_query,
|
||||
**kwargs,
|
||||
)
|
||||
convs_res = []
|
||||
for conv in convs:
|
||||
conv_res = self._convert_conv_from_v2_to_v1(conv)
|
||||
convs_res.append(conv_res)
|
||||
return convs_res, cnt
|
||||
|
||||
async def update_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str | None = None,
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话
|
||||
|
||||
@@ -130,40 +203,55 @@ class ConversationManager:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
"""
|
||||
if not conversation_id:
|
||||
# 如果没有提供 conversation_id,则获取当前的
|
||||
conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
await self.db.update_conversation(
|
||||
cid=conversation_id,
|
||||
history=json.dumps(history),
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
async def update_conversation_title(
|
||||
self, unified_msg_origin: str, title: str, conversation_id: str | None = None
|
||||
):
|
||||
"""更新会话的对话标题
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
title (str): 对话标题
|
||||
|
||||
Deprecated:
|
||||
Use `update_conversation` with `title` parameter instead.
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
user_id=unified_msg_origin, cid=conversation_id, title=title
|
||||
)
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(
|
||||
self, unified_msg_origin: str, persona_id: str
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
persona_id: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
"""更新会话的对话 Persona ID
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
persona_id (str): 对话 Persona ID
|
||||
|
||||
Deprecated:
|
||||
Use `update_conversation` with `persona_id` parameter instead.
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
user_id=unified_msg_origin, cid=conversation_id, persona_id=persona_id
|
||||
)
|
||||
await self.update_conversation(
|
||||
unified_msg_origin=unified_msg_origin,
|
||||
conversation_id=conversation_id,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||
|
||||
@@ -15,20 +15,23 @@ import time
|
||||
import threading
|
||||
import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from . import astrbot_config, html_renderer
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
|
||||
@@ -77,11 +80,26 @@ class AstrBotCoreLifecycle:
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||
|
||||
await self.db.initialize()
|
||||
|
||||
await html_renderer.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
default_config=self.astrbot_config, sp=sp
|
||||
)
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
# 初始化人格管理器
|
||||
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
|
||||
await self.persona_mgr.initialize()
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||
self.provider_manager = ProviderManager(
|
||||
self.astrbot_config_mgr, self.db, self.persona_mgr
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
@@ -89,6 +107,9 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
# 初始化平台消息历史管理器
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
@@ -97,6 +118,9 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.platform_message_history_manager,
|
||||
self.persona_mgr,
|
||||
self.astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -109,16 +133,16 @@ class AstrBotCoreLifecycle:
|
||||
await self.provider_manager.initialize()
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler = PipelineScheduler(
|
||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||
)
|
||||
await self.pipeline_scheduler.initialize()
|
||||
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
self.event_bus = EventBus(
|
||||
self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr
|
||||
)
|
||||
|
||||
# 记录启动时间
|
||||
self.start_time = int(time.time())
|
||||
@@ -235,6 +259,39 @@ class AstrBotCoreLifecycle:
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
tasks.append(
|
||||
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
|
||||
asyncio.create_task(
|
||||
platform_inst.run(),
|
||||
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]:
|
||||
"""加载消息事件流水线调度器
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
"""
|
||||
mapping = {}
|
||||
for conf_id, ab_config in self.astrbot_config_mgr.confs.items():
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
)
|
||||
await scheduler.initialize()
|
||||
mapping[conf_id] = scheduler
|
||||
return mapping
|
||||
|
||||
async def reload_pipeline_scheduler(self, conf_id: str):
|
||||
"""重新加载消息事件流水线调度器
|
||||
|
||||
Returns:
|
||||
dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射
|
||||
"""
|
||||
ab_config = self.astrbot_config_mgr.confs.get(conf_id)
|
||||
if not ab_config:
|
||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||
scheduler = PipelineScheduler(
|
||||
PipelineContext(ab_config, self.plugin_manager, conf_id)
|
||||
)
|
||||
await scheduler.initialize()
|
||||
self.pipeline_scheduler_mapping[conf_id] = scheduler
|
||||
|
||||
+238
-115
@@ -1,7 +1,20 @@
|
||||
import abc
|
||||
import datetime
|
||||
import typing as T
|
||||
from deprecated import deprecated
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||
from astrbot.core.db.po import (
|
||||
Stats,
|
||||
PlatformStat,
|
||||
ConversationV2,
|
||||
PlatformMessageHistory,
|
||||
Attachment,
|
||||
Persona,
|
||||
Preference,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -10,152 +23,262 @@ class BaseDatabase(abc.ABC):
|
||||
数据库基类
|
||||
"""
|
||||
|
||||
DATABASE_URL = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化数据库连接"""
|
||||
pass
|
||||
|
||||
def insert_base_metrics(self, metrics: dict):
|
||||
"""插入基础指标数据"""
|
||||
self.insert_platform_metrics(metrics["platform_stats"])
|
||||
self.insert_plugin_metrics(metrics["plugin_stats"])
|
||||
self.insert_command_metrics(metrics["command_stats"])
|
||||
self.insert_llm_metrics(metrics["llm_stats"])
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
"""插入平台指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_plugin_metrics(self, metrics: dict):
|
||||
"""插入插件指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_command_metrics(self, metrics: dict):
|
||||
"""插入指令指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
"""插入 LLM 指标数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||
"""更新 LLM 历史记录。当不存在 session_id 时插入"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_llm_history(
|
||||
self, session_id: str = None, provider_type: str = None
|
||||
) -> List[LLMHistory]:
|
||||
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
|
||||
raise NotImplementedError
|
||||
@asynccontextmanager
|
||||
async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session."""
|
||||
if not self.inited:
|
||||
await self.initialize()
|
||||
self.inited = True
|
||||
async with self.AsyncSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取基础统计数据"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_total_message_count(self) -> int:
|
||||
"""获取总消息数"""
|
||||
raise NotImplementedError
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_stats instead")
|
||||
@abc.abstractmethod
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取基础统计数据(合并)"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_atri_vision_data(self, vision_data: ATRIVision):
|
||||
"""插入 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
# New methods in v4.0.0
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data(self) -> List[ATRIVision]:
|
||||
"""获取 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
async def insert_platform_stats(
|
||||
self,
|
||||
platform_id: str,
|
||||
platform_type: str,
|
||||
count: int = 1,
|
||||
timestamp: datetime.datetime | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform statistic record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data_by_path_or_id(
|
||||
self, url_or_path: str, id: str
|
||||
) -> ATRIVision:
|
||||
"""通过 url 或 path 获取 ATRI 视觉数据"""
|
||||
raise NotImplementedError
|
||||
async def count_platform_stats(self) -> int:
|
||||
"""Count the number of platform statistics records."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
"""通过 user_id 和 cid 获取 Conversation"""
|
||||
raise NotImplementedError
|
||||
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
||||
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
"""新建 Conversation"""
|
||||
raise NotImplementedError
|
||||
async def get_conversations(
|
||||
self, user_id: str | None = None, platform_id: str | None = None
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations for a specific user and platform_id(optional).
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
"""更新 Conversation"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
"""删除 Conversation"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
"""更新 Conversation 标题"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
"""更新 Conversation Persona ID"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||
content is not included in the result.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_filtered_conversations(
|
||||
async def get_conversation_by_id(self, cid: str) -> ConversationV2:
|
||||
"""Get a specific conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> list[ConversationV2]:
|
||||
"""Get all conversations with pagination."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表
|
||||
platform_ids: list[str] | None = None,
|
||||
search_query: str = "",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationV2], int]:
|
||||
"""Get conversations filtered by platform IDs and search query."""
|
||||
...
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
platforms: 平台筛选列表
|
||||
message_types: 消息类型筛选列表
|
||||
search_query: 搜索关键词
|
||||
exclude_ids: 排除的用户ID列表
|
||||
exclude_platforms: 排除的平台列表
|
||||
@abc.abstractmethod
|
||||
async def create_conversation(
|
||||
self,
|
||||
user_id: str,
|
||||
platform_id: str,
|
||||
content: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
cid: str | None = None,
|
||||
created_at: datetime.datetime | None = None,
|
||||
updated_at: datetime.datetime | None = None,
|
||||
) -> ConversationV2:
|
||||
"""Create a new conversation."""
|
||||
...
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@abc.abstractmethod
|
||||
async def update_conversation(
|
||||
self,
|
||||
cid: str,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_conversation(self, cid: str) -> None:
|
||||
"""Delete a conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict],
|
||||
sender_id: str | None = None,
|
||||
sender_name: str | None = None,
|
||||
) -> None:
|
||||
"""Insert a new platform message history record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_platform_message_offset(
|
||||
self, platform_id: str, user_id: str, offset_sec: int = 86400
|
||||
) -> None:
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_platform_message_history(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> list[PlatformMessageHistory]:
|
||||
"""Get platform message history for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_attachment(
|
||||
self,
|
||||
path: str,
|
||||
type: str,
|
||||
mime_type: str,
|
||||
):
|
||||
"""Insert a new attachment record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_attachment_by_id(self, attachment_id: str) -> Attachment:
|
||||
"""Get an attachment by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""Insert a new persona record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_persona_by_id(self, persona_id: str) -> Persona:
|
||||
"""Get a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_personas(self) -> list[Persona]:
|
||||
"""Get all personas for a specific bot."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str | None = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona | None:
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_persona(self, persona_id: str) -> None:
|
||||
"""Delete a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_preference_or_update(
|
||||
self, scope: str, scope_id: str, key: str, value: dict
|
||||
) -> Preference:
|
||||
"""Insert a new preference record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference:
|
||||
"""Get a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_preferences(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""Get all preferences for a specific scope ID or key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def remove_preference(self, scope: str, scope_id: str, key: str) -> None:
|
||||
"""Remove a preference by scope ID and key."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear_preferences(self, scope: str, scope_id: str) -> None:
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
# cid: str,
|
||||
# role: str,
|
||||
# content: list,
|
||||
# tool_calls: list = None,
|
||||
# tool_call_id: str = None,
|
||||
# parent_id: str = None,
|
||||
# ) -> LLMMessage:
|
||||
# """Insert a new LLM message into the conversation."""
|
||||
# ...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def get_llm_messages(self, cid: str) -> list[LLMMessage]:
|
||||
# """Get all LLM messages for a specific conversation."""
|
||||
# ...
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.api import logger, sp
|
||||
from .migra_3_to_4 import (
|
||||
migration_conversation_table,
|
||||
migration_platform_table,
|
||||
migration_webchat_data,
|
||||
migration_persona_data,
|
||||
migration_preferences,
|
||||
)
|
||||
|
||||
|
||||
async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
"""
|
||||
检查是否需要进行数据库迁移
|
||||
如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
|
||||
"""
|
||||
data_v3_exists = os.path.exists(get_astrbot_data_path())
|
||||
if not data_v3_exists:
|
||||
return False
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_v4"
|
||||
)
|
||||
if migration_done:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def do_migration_v4(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
astrbot_config: AstrBotConfig,
|
||||
):
|
||||
"""
|
||||
执行数据库迁移
|
||||
迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
if not await check_migration_needed_v4(db_helper):
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移...")
|
||||
|
||||
# 执行会话表迁移
|
||||
await migration_conversation_table(db_helper, platform_id_map)
|
||||
|
||||
# 执行人格数据迁移
|
||||
await migration_persona_data(db_helper, astrbot_config)
|
||||
|
||||
# 执行 WebChat 数据迁移
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行偏好设置迁移
|
||||
await migration_preferences(db_helper,platform_id_map)
|
||||
|
||||
# 执行平台统计表迁移
|
||||
await migration_platform_table(db_helper, platform_id_map)
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_v4", True)
|
||||
|
||||
logger.info("数据库迁移完成。")
|
||||
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
import datetime
|
||||
from .. import BaseDatabase
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
from .shared_preferences_v3 import sp as sp_v3
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
||||
from sqlalchemy import text
|
||||
|
||||
"""
|
||||
1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
|
||||
2. 迁移旧的 platform 到新的 platform_stats 表。
|
||||
"""
|
||||
|
||||
|
||||
def get_platform_id(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
{"platform_id": old_platform_name, "platform_type": old_platform_name},
|
||||
).get("platform_id", old_platform_name)
|
||||
|
||||
|
||||
def get_platform_type(
|
||||
platform_id_map: dict[str, dict[str, str]], old_platform_name: str
|
||||
) -> str:
|
||||
return platform_id_map.get(
|
||||
old_platform_name,
|
||||
{"platform_id": old_platform_name, "platform_type": old_platform_name},
|
||||
).get("platform_type", old_platform_name)
|
||||
|
||||
|
||||
async def migration_conversation_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, session.platform_name
|
||||
)
|
||||
session.platform_id = platform_id # 更新平台名称为新的 ID
|
||||
conv_v2 = ConversationV2(
|
||||
user_id=str(session),
|
||||
content=json.loads(conv.history) if conv.history else [],
|
||||
platform_id=platform_id,
|
||||
title=conv.title,
|
||||
persona_id=conv.persona_id,
|
||||
conversation_id=conv.cid,
|
||||
created_at=datetime.datetime.fromtimestamp(conv.created_at),
|
||||
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
|
||||
)
|
||||
dbsession.add(conv_v2)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_platform_table(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
secs_from_2023_4_10_to_now = (
|
||||
datetime.datetime.now(datetime.timezone.utc)
|
||||
- datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc)
|
||||
).total_seconds()
|
||||
offset_sec = int(secs_from_2023_4_10_to_now)
|
||||
logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。")
|
||||
stats = db_helper_v3.get_base_stats(offset_sec=offset_sec)
|
||||
logger.info(f"迁移 {len(stats.platform)} 条旧的平台数据到新的表中...")
|
||||
platform_stats_v3 = stats.platform
|
||||
|
||||
if not platform_stats_v3:
|
||||
logger.info("没有找到旧平台数据,跳过迁移。")
|
||||
return
|
||||
|
||||
first_time_stamp = platform_stats_v3[0].timestamp
|
||||
end_time_stamp = platform_stats_v3[-1].timestamp
|
||||
start_time = first_time_stamp - (first_time_stamp % 3600) # 向下取整到小时
|
||||
end_time = end_time_stamp + (3600 - (end_time_stamp % 3600)) # 向上取整到小时
|
||||
|
||||
idx = 0
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
total_buckets = (end_time - start_time) // 3600
|
||||
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
|
||||
if bucket_idx % 500 == 0:
|
||||
progress = int((bucket_idx + 1) / total_buckets * 100)
|
||||
logger.info(f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})")
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(platform_stats_v3)
|
||||
and platform_stats_v3[idx].timestamp < bucket_end
|
||||
):
|
||||
cnt += platform_stats_v3[idx].count
|
||||
idx += 1
|
||||
if cnt == 0:
|
||||
continue
|
||||
platform_id = get_platform_id(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
platform_type = get_platform_type(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
try:
|
||||
await dbsession.execute(
|
||||
text("""
|
||||
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
||||
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
||||
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
|
||||
count = platform_stats.count + EXCLUDED.count
|
||||
"""),
|
||||
{
|
||||
"timestamp": datetime.datetime.fromtimestamp(
|
||||
bucket_end, tz=datetime.timezone.utc
|
||||
),
|
||||
"platform_id": platform_id,
|
||||
"platform_type": platform_type,
|
||||
"count": cnt,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
|
||||
|
||||
|
||||
async def migration_webchat_data(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
|
||||
)
|
||||
conversations, total_cnt = db_helper_v3.get_all_conversations(
|
||||
page=1, page_size=10000000
|
||||
)
|
||||
logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
|
||||
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
history = json.loads(conv.history) if conv.history else []
|
||||
for msg in history:
|
||||
type_ = msg.get("type") # user type, "bot" or "user"
|
||||
new_history = PlatformMessageHistory(
|
||||
platform_id=platform_id,
|
||||
user_id=conv.cid, # we use conv.cid as user_id for webchat
|
||||
content=msg,
|
||||
sender_id=type_,
|
||||
sender_name=type_,
|
||||
)
|
||||
dbsession.add(new_history)
|
||||
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_persona_data(
|
||||
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
|
||||
):
|
||||
"""
|
||||
迁移 Persona 数据到新的表中。
|
||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||
"""
|
||||
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
||||
total_personas = len(v3_persona_config)
|
||||
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
|
||||
|
||||
for idx, persona in enumerate(v3_persona_config):
|
||||
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
|
||||
progress = int((idx + 1) / total_personas * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
|
||||
try:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
mood_prompt = ""
|
||||
user_turn = True
|
||||
for mood_dialog in mood_imitation_dialogs:
|
||||
if user_turn:
|
||||
mood_prompt += f"A: {mood_dialog}\n"
|
||||
else:
|
||||
mood_prompt += f"B: {mood_dialog}\n"
|
||||
user_turn = not user_turn
|
||||
system_prompt = persona.get("prompt", "")
|
||||
if mood_prompt:
|
||||
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
||||
persona_new = await db_helper.insert_persona(
|
||||
persona_id=persona["name"],
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs,
|
||||
)
|
||||
logger.info(
|
||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
|
||||
async def migration_preferences(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
):
|
||||
# 1. global scope migration
|
||||
keys = [
|
||||
"inactivated_llm_tools",
|
||||
"inactivated_plugins",
|
||||
"curr_provider",
|
||||
"curr_provider_tts",
|
||||
"curr_provider_stt",
|
||||
"alter_cmd",
|
||||
]
|
||||
for key in keys:
|
||||
value = sp_v3.get(key)
|
||||
if value is not None:
|
||||
await sp.put_async("global", "global", key, value)
|
||||
logger.info(f"迁移全局偏好设置 {key} 成功,值: {value}")
|
||||
|
||||
# 2. umo scope migration
|
||||
session_conversation = sp_v3.get("session_conversation", default={})
|
||||
for umo, conversation_id in session_conversation.items():
|
||||
if not umo or not conversation_id:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "sel_conv_id", conversation_id)
|
||||
logger.info(f"迁移会话 {umo} 的对话数据到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的对话数据失败: {e}", exc_info=True)
|
||||
|
||||
session_service_config = sp_v3.get("session_service_config", default={})
|
||||
for umo, config in session_service_config.items():
|
||||
if not umo or not config:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
await sp.put_async("umo", str(session), "session_service_config", config)
|
||||
|
||||
logger.info(f"迁移会话 {umo} 的服务配置到新表成功,平台 ID: {platform_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的服务配置失败: {e}", exc_info=True)
|
||||
|
||||
session_variables = sp_v3.get("session_variables", default={})
|
||||
for umo, variables in session_variables.items():
|
||||
if not umo or not variables:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
await sp.put_async("umo", str(session), "session_variables", variables)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的变量失败: {e}", exc_info=True)
|
||||
|
||||
session_provider_perf = sp_v3.get("session_provider_perf", default={})
|
||||
for umo, perf in session_provider_perf.items():
|
||||
if not umo or not perf:
|
||||
continue
|
||||
try:
|
||||
session = MessageSesion.from_str(session_str=umo)
|
||||
platform_id = get_platform_id(platform_id_map, session.platform_name)
|
||||
session.platform_id = platform_id
|
||||
|
||||
for provider_type, provider_id in perf.items():
|
||||
await sp.put_async(
|
||||
"umo", str(session), f"provider_perf_{provider_type}", provider_id
|
||||
)
|
||||
logger.info(
|
||||
f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
|
||||
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
sp = SharedPreferences()
|
||||
@@ -0,0 +1,493 @@
|
||||
import sqlite3
|
||||
import time
|
||||
from astrbot.core.db.po import Platform, Stats
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
"""字符串格式的列表。"""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
|
||||
|
||||
INIT_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT, -- 会话 id
|
||||
cid TEXT, -- 对话 id
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
|
||||
PRAGMA encoding = 'UTF-8';
|
||||
"""
|
||||
|
||||
|
||||
class SQLiteDatabase():
|
||||
def __init__(self, db_path: str) -> None:
|
||||
super().__init__()
|
||||
self.db_path = db_path
|
||||
|
||||
sql = INIT_SQL
|
||||
|
||||
# 初始化数据库
|
||||
self.conn = self._get_conn(self.db_path)
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
"""
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
"""
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
has_persona_id = False
|
||||
for row in res:
|
||||
if row[1] == "title":
|
||||
has_title = True
|
||||
if row[1] == "persona_id":
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
"""
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
c.close()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: Tuple = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
conn = self._get_conn(self.db_path)
|
||||
c = conn.cursor()
|
||||
|
||||
if params:
|
||||
c.execute(sql, params)
|
||||
c.close()
|
||||
else:
|
||||
c.execute(sql)
|
||||
c.close()
|
||||
|
||||
conn.commit()
|
||||
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
|
||||
""",
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT * FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform=platform)
|
||||
|
||||
def get_total_message_count(self) -> int:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT SUM(count) FROM platform
|
||||
"""
|
||||
)
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return res[0]
|
||||
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
|
||||
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
||||
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT name, SUM(count), timestamp FROM platform
|
||||
"""
|
||||
+ where_clause
|
||||
+ " GROUP BY name"
|
||||
)
|
||||
|
||||
platform = []
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
|
||||
if not res:
|
||||
return
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
"""
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
"""
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
for row in res:
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(
|
||||
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
||||
)
|
||||
return conversations
|
||||
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
"""更新对话,并且同时更新时间"""
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(history, updated_at, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(title, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(persona_id, user_id, cid),
|
||||
)
|
||||
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
"""
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
""",
|
||||
(user_id, cid),
|
||||
)
|
||||
|
||||
def get_all_conversations(
|
||||
self, page: int = 1, page_size: int = 20
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取所有对话,支持分页,按更新时间降序排序"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
try:
|
||||
# 获取总记录数
|
||||
c.execute("""
|
||||
SELECT COUNT(*) FROM webchat_conversation
|
||||
""")
|
||||
total_count = c.fetchone()[0]
|
||||
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取分页数据,按更新时间降序排序
|
||||
c.execute(
|
||||
"""
|
||||
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||
FROM webchat_conversation
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(page_size, offset),
|
||||
)
|
||||
|
||||
rows = c.fetchall()
|
||||
|
||||
conversations = []
|
||||
|
||||
for row in rows:
|
||||
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
|
||||
safe_cid = str(cid) if cid else "unknown"
|
||||
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||
|
||||
conversations.append(
|
||||
{
|
||||
"user_id": user_id or "",
|
||||
"cid": safe_cid,
|
||||
"title": title or f"对话 {display_cid}",
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
except Exception as _:
|
||||
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||
return [], 0
|
||||
finally:
|
||||
c.close()
|
||||
|
||||
def get_filtered_conversations(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
platforms: List[str] = None,
|
||||
message_types: List[str] = None,
|
||||
search_query: str = None,
|
||||
exclude_ids: List[str] = None,
|
||||
exclude_platforms: List[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""获取筛选后的对话列表"""
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
where_clauses = []
|
||||
params = []
|
||||
|
||||
# 平台筛选
|
||||
if platforms and len(platforms) > 0:
|
||||
platform_conditions = []
|
||||
for platform in platforms:
|
||||
platform_conditions.append("user_id LIKE ?")
|
||||
params.append(f"{platform}:%")
|
||||
|
||||
if platform_conditions:
|
||||
where_clauses.append(f"({' OR '.join(platform_conditions)})")
|
||||
|
||||
# 消息类型筛选
|
||||
if message_types and len(message_types) > 0:
|
||||
message_type_conditions = []
|
||||
for msg_type in message_types:
|
||||
message_type_conditions.append("user_id LIKE ?")
|
||||
params.append(f"%:{msg_type}:%")
|
||||
|
||||
if message_type_conditions:
|
||||
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
||||
|
||||
# 搜索关键词
|
||||
if search_query:
|
||||
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||
where_clauses.append(
|
||||
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
||||
)
|
||||
search_param = f"%{search_query}%"
|
||||
params.extend([search_param, search_param, search_param, search_param])
|
||||
|
||||
# 排除特定用户ID
|
||||
if exclude_ids and len(exclude_ids) > 0:
|
||||
for exclude_id in exclude_ids:
|
||||
where_clauses.append("user_id NOT LIKE ?")
|
||||
params.append(f"{exclude_id}%")
|
||||
|
||||
# 排除特定平台
|
||||
if exclude_platforms and len(exclude_platforms) > 0:
|
||||
for exclude_platform in exclude_platforms:
|
||||
where_clauses.append("user_id NOT LIKE ?")
|
||||
params.append(f"{exclude_platform}:%")
|
||||
|
||||
# 构建完整的 WHERE 子句
|
||||
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
|
||||
|
||||
# 构建计数查询
|
||||
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
|
||||
|
||||
# 获取总记录数
|
||||
c.execute(count_sql, params)
|
||||
total_count = c.fetchone()[0]
|
||||
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 构建分页数据查询
|
||||
data_sql = f"""
|
||||
SELECT user_id, cid, created_at, updated_at, title, persona_id
|
||||
FROM webchat_conversation
|
||||
{where_sql}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_params = params + [page_size, offset]
|
||||
|
||||
# 获取分页数据
|
||||
c.execute(data_sql, query_params)
|
||||
rows = c.fetchall()
|
||||
|
||||
conversations = []
|
||||
|
||||
for row in rows:
|
||||
user_id, cid, created_at, updated_at, title, persona_id = row
|
||||
# 确保 cid 是字符串类型,否则使用一个默认值
|
||||
safe_cid = str(cid) if cid else "unknown"
|
||||
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
||||
|
||||
conversations.append(
|
||||
{
|
||||
"user_id": user_id or "",
|
||||
"cid": safe_cid,
|
||||
"title": title or f"对话 {display_cid}",
|
||||
"persona_id": persona_id or "",
|
||||
"created_at": created_at or 0,
|
||||
"updated_at": updated_at or 0,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total_count
|
||||
|
||||
except Exception as _:
|
||||
# 返回空列表和0,确保即使出错也有有效的返回值
|
||||
return [], 0
|
||||
finally:
|
||||
c.close()
|
||||
+229
-74
@@ -1,7 +1,233 @@
|
||||
"""指标数据"""
|
||||
import uuid
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
from sqlmodel import (
|
||||
SQLModel,
|
||||
Text,
|
||||
JSON,
|
||||
UniqueConstraint,
|
||||
Field,
|
||||
)
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
|
||||
class PlatformStat(SQLModel, table=True):
|
||||
"""This class represents the statistics of bot usage across different platforms.
|
||||
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
platform_id: str = Field(nullable=False)
|
||||
platform_type: str = Field(nullable=False) # such as "aiocqhttp", "slack", etc.
|
||||
count: int = Field(default=0, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"timestamp",
|
||||
"platform_id",
|
||||
"platform_type",
|
||||
name="uix_platform_stats",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conversation_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False)
|
||||
content: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
title: Optional[str] = Field(default=None, max_length=255)
|
||||
persona_id: Optional[str] = Field(default=None)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conversation_id",
|
||||
name="uix_conversation_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Persona(SQLModel, table=True):
|
||||
"""Persona is a set of instructions for LLMs to follow.
|
||||
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
persona_id: str = Field(max_length=255, nullable=False)
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""a list of strings, each representing a dialog to start with"""
|
||||
tools: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"persona_id",
|
||||
name="uix_persona_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
scope: str = Field(nullable=False)
|
||||
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
||||
scope_id: str = Field(nullable=False)
|
||||
"""ID of the scope, such as 'global', 'umo', 'plugin_name'."""
|
||||
key: str = Field(nullable=False)
|
||||
value: dict = Field(sa_type=JSON, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"scope",
|
||||
"scope_id",
|
||||
"key",
|
||||
name="uix_preference_scope_scope_id_key",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PlatformMessageHistory(SQLModel, table=True):
|
||||
"""This class represents the message history for a specific platform.
|
||||
|
||||
It is used to store messages that are not LLM-generated, such as user messages
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
sender_name: Optional[str] = Field(
|
||||
default=None
|
||||
) # Name of the sender in the platform
|
||||
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class Attachment(SQLModel, table=True):
|
||||
"""This class represents attachments for messages in AstrBot.
|
||||
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
path: str = Field(nullable=False) # Path to the file on disk
|
||||
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
|
||||
mime_type: str = Field(nullable=False) # MIME type of the file
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"attachment_id",
|
||||
name="uix_attachment_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
|
||||
对于 WebChat,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
|
||||
在 v4.0.0 版本及之后,WebChat 的历史记录被迁移至 `PlatformMessageHistory` 表中,
|
||||
"""
|
||||
|
||||
platform_id: str
|
||||
user_id: str
|
||||
cid: str
|
||||
"""对话 ID, 是 uuid 格式的字符串"""
|
||||
history: str = ""
|
||||
"""字符串格式的对话列表。"""
|
||||
title: str | None = ""
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
"""LLM 人格类。
|
||||
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
# ====
|
||||
# Deprecated, and will be removed in future versions.
|
||||
# ====
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -13,77 +239,6 @@ class Platform:
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""供应商使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plugin:
|
||||
"""插件使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Command:
|
||||
"""命令使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
platform: List[Platform] = field(default_factory=list)
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMHistory:
|
||||
"""LLM 聊天时持久化的信息"""
|
||||
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ATRIVision:
|
||||
"""Deprecated"""
|
||||
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
is_meme: bool
|
||||
keywords: List[str]
|
||||
platform_name: str
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
"""
|
||||
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
"""字符串格式的列表。"""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
platform: list[Platform] = field(default_factory=list)
|
||||
|
||||
+529
-554
File diff suppressed because it is too large
Load Diff
@@ -1,50 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS platform(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS plugin(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS command(
|
||||
name VARCHAR(32),
|
||||
count INTEGER,
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT, -- 会话 id
|
||||
cid TEXT, -- 对话 id
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
|
||||
PRAGMA encoding = 'UTF-8';
|
||||
@@ -5,6 +5,7 @@ from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
@@ -17,6 +18,7 @@ class FaissVecDB(BaseVecDB):
|
||||
doc_store_path: str,
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
rerank_provider: RerankProvider | None = None,
|
||||
):
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
@@ -26,11 +28,14 @@ class FaissVecDB(BaseVecDB):
|
||||
embedding_provider.get_dim(), index_store_path
|
||||
)
|
||||
self.embedding_provider = embedding_provider
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def initialize(self):
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
@@ -53,7 +58,12 @@ class FaissVecDB(BaseVecDB):
|
||||
return int_id
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, k: int = 5, fetch_k: int = 20, metadata_filters: dict = None
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
@@ -62,6 +72,7 @@ class FaissVecDB(BaseVecDB):
|
||||
query (str): 查询文本
|
||||
k (int): 返回的最相似文档的数量
|
||||
fetch_k (int): 在根据 metadata 过滤前从 FAISS 中获取的数量
|
||||
rerank (bool): 是否使用重排序。这需要在实例化时提供 rerank_provider, 如果未提供并且 rerank 为 True, 不会抛出异常。
|
||||
metadata_filters (dict): 元数据过滤器
|
||||
|
||||
Returns:
|
||||
@@ -72,7 +83,6 @@ class FaissVecDB(BaseVecDB):
|
||||
vector=np.array([embedding]).astype("float32"),
|
||||
k=fetch_k if metadata_filters else k,
|
||||
)
|
||||
# TODO: rerank
|
||||
if len(indices[0]) == 0 or indices[0][0] == -1:
|
||||
return []
|
||||
# normalize scores
|
||||
@@ -83,7 +93,7 @@ class FaissVecDB(BaseVecDB):
|
||||
)
|
||||
if not fetched_docs:
|
||||
return []
|
||||
result_docs = []
|
||||
result_docs: list[Result] = []
|
||||
|
||||
idx_pos = {fetch_doc["id"]: idx for idx, fetch_doc in enumerate(fetched_docs)}
|
||||
for i, indice_idx in enumerate(indices[0]):
|
||||
@@ -93,7 +103,20 @@ class FaissVecDB(BaseVecDB):
|
||||
fetch_doc = fetched_docs[pos]
|
||||
score = scores[0][i]
|
||||
result_docs.append(Result(similarity=float(score), data=fetch_doc))
|
||||
return result_docs[:k]
|
||||
|
||||
top_k_results = result_docs[:k]
|
||||
|
||||
if rerank and self.rerank_provider:
|
||||
documents = [doc.data["text"] for doc in top_k_results]
|
||||
reranked_results = await self.rerank_provider.rerank(query, documents)
|
||||
reranked_results = sorted(
|
||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||
)
|
||||
top_k_results = [
|
||||
top_k_results[reranked_result.index] for reranked_result in reranked_results
|
||||
]
|
||||
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
"""
|
||||
|
||||
+19
-17
@@ -16,30 +16,32 @@ from asyncio import Queue
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
from astrbot.core import logger
|
||||
from .platform import AstrMessageEvent
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""事件总线: 用于处理事件的分发和处理
|
||||
"""用于处理事件的分发和处理"""
|
||||
|
||||
维护一个异步队列, 来接受各种消息事件
|
||||
"""
|
||||
|
||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
async def dispatch(self):
|
||||
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||
while True:
|
||||
event: AstrMessageEvent = (
|
||||
await self.event_queue.get()
|
||||
) # 从事件队列中获取新的事件
|
||||
self._print_event(event) # 打印日志
|
||||
asyncio.create_task(
|
||||
self.pipeline_scheduler.execute(event)
|
||||
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent):
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
@@ -48,10 +50,10 @@ class EventBus:
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||
else:
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Persona, Personality
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot import logger
|
||||
|
||||
DEFAULT_PERSONALITY = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
begin_dialogs=[],
|
||||
mood_imitation_dialogs=[],
|
||||
tools=None,
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed="",
|
||||
)
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
|
||||
self.db = db_helper
|
||||
self.acm = acm
|
||||
default_ps = acm.default_conf.get("provider_settings", {})
|
||||
self.default_persona: str = default_ps.get("default_personality", "default")
|
||||
self.personas: list[Persona] = []
|
||||
self.selected_default_persona: Persona | None = None
|
||||
|
||||
self.personas_v3: list[Personality] = []
|
||||
self.selected_default_persona_v3: Personality | None = None
|
||||
self.persona_v3_config: list[dict] = []
|
||||
|
||||
async def initialize(self):
|
||||
self.personas = await self.get_all_personas()
|
||||
self.get_v3_persona_data()
|
||||
logger.info(f"已加载 {len(self.personas)} 个人格。")
|
||||
|
||||
async def get_persona(self, persona_id: str):
|
||||
"""获取指定 persona 的信息"""
|
||||
persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
return persona
|
||||
|
||||
async def get_default_persona_v3(
|
||||
self, umo: str | MessageSession | None = None
|
||||
) -> Personality:
|
||||
"""获取默认 persona"""
|
||||
cfg = self.acm.get_conf(umo)
|
||||
default_persona_id = cfg.get("provider_settings", {}).get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
if not default_persona_id or default_persona_id == "default":
|
||||
return DEFAULT_PERSONALITY
|
||||
try:
|
||||
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
|
||||
except Exception:
|
||||
return DEFAULT_PERSONALITY
|
||||
|
||||
async def delete_persona(self, persona_id: str):
|
||||
"""删除指定 persona"""
|
||||
if not await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
await self.db.delete_persona(persona_id)
|
||||
self.personas = [p for p in self.personas if p.persona_id != persona_id]
|
||||
self.get_v3_persona_data()
|
||||
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
):
|
||||
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
existing_persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not existing_persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
persona = await self.db.update_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
if persona:
|
||||
for i, p in enumerate(self.personas):
|
||||
if p.persona_id == persona_id:
|
||||
self.personas[i] = persona
|
||||
break
|
||||
self.get_v3_persona_data()
|
||||
return persona
|
||||
|
||||
async def get_all_personas(self) -> list[Persona]:
|
||||
"""获取所有 personas"""
|
||||
return await self.db.get_personas()
|
||||
|
||||
async def create_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
||||
new_persona = await self.db.insert_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
self.personas.append(new_persona)
|
||||
self.get_v3_persona_data()
|
||||
return new_persona
|
||||
|
||||
def get_v3_persona_data(
|
||||
self,
|
||||
) -> tuple[list[dict], list[Personality], Personality]:
|
||||
"""获取 AstrBot <4.0.0 版本的 persona 数据。
|
||||
|
||||
Returns:
|
||||
- list[dict]: 包含 persona 配置的字典列表。
|
||||
- list[Personality]: 包含 Personality 对象的列表。
|
||||
- Personality: 默认选择的 Personality 对象。
|
||||
"""
|
||||
v3_persona_config = [
|
||||
{
|
||||
"prompt": persona.system_prompt,
|
||||
"name": persona.persona_id,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"mood_imitation_dialogs": [], # deprecated
|
||||
"tools": persona.tools,
|
||||
}
|
||||
for persona in self.personas
|
||||
]
|
||||
|
||||
personas_v3: list[Personality] = []
|
||||
selected_default_persona: Personality | None = None
|
||||
|
||||
for persona_cfg in v3_persona_config:
|
||||
begin_dialogs = persona_cfg.get("begin_dialogs", [])
|
||||
bd_processed = []
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona_cfg,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed="", # deprecated
|
||||
)
|
||||
if persona["name"] == self.default_persona:
|
||||
selected_default_persona = persona
|
||||
personas_v3.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not selected_default_persona and len(personas_v3) > 0:
|
||||
# 默认选择第一个
|
||||
selected_default_persona = personas_v3[0]
|
||||
|
||||
if not selected_default_persona:
|
||||
selected_default_persona = DEFAULT_PERSONALITY
|
||||
personas_v3.append(selected_default_persona)
|
||||
|
||||
self.personas_v3 = personas_v3
|
||||
self.selected_default_persona_v3 = selected_default_persona
|
||||
self.persona_v3_config = v3_persona_config
|
||||
self.selected_default_persona = Persona(
|
||||
persona_id=selected_default_persona["name"],
|
||||
system_prompt=selected_default_persona["prompt"],
|
||||
begin_dialogs=selected_default_persona["begin_dialogs"],
|
||||
tools=selected_default_persona["tools"] or None,
|
||||
)
|
||||
|
||||
return v3_persona_config, personas_v3, selected_default_persona
|
||||
@@ -4,7 +4,6 @@ from astrbot.core.message.message_event_result import (
|
||||
)
|
||||
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
@@ -21,7 +20,6 @@ STAGES_ORDER = [
|
||||
"SessionStatusCheckStage", # 检查会话是否整体启用
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
@@ -34,7 +32,6 @@ __all__ = [
|
||||
"SessionStatusCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
|
||||
@@ -1,14 +1,7 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from .context_utils import call_handler, call_event_hook
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,97 +10,6 @@ class PipelineContext:
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
|
||||
async def call_event_hook(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
) -> bool:
|
||||
"""调用事件钩子函数
|
||||
|
||||
Returns:
|
||||
bool: 如果事件被终止,返回 True
|
||||
"""
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
|
||||
return event.is_stopped()
|
||||
|
||||
async def call_handler(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
astrbot_config_id: str
|
||||
call_handler = call_handler
|
||||
call_event_hook = call_event_hook
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from astrbot import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
|
||||
async def call_event_hook(
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""调用事件钩子函数
|
||||
|
||||
Returns:
|
||||
bool: 如果事件被终止,返回 True
|
||||
# """
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args, **kwargs)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
|
||||
return event.is_stopped()
|
||||
@@ -1,56 +0,0 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class PlatformCompatibilityStage(Stage):
|
||||
"""检查所有处理器的平台兼容性。
|
||||
|
||||
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化平台兼容性检查阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
|
||||
# 获取已激活的处理器
|
||||
activated_handlers = event.get_extra("activated_handlers")
|
||||
if activated_handlers is None:
|
||||
activated_handlers = []
|
||||
|
||||
# 标记不兼容的处理器
|
||||
for handler in activated_handlers:
|
||||
if not isinstance(handler, StarHandlerMetadata):
|
||||
continue
|
||||
# 检查处理器是否在当前平台启用
|
||||
enabled = handler.is_enabled_for_platform(platform_id)
|
||||
if not enabled:
|
||||
if handler.handler_module_path in star_map:
|
||||
plugin_name = star_map[handler.handler_module_path].name
|
||||
logger.debug(
|
||||
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||
)
|
||||
# 设置处理器为平台不兼容状态
|
||||
# TODO: 更好的标记方式
|
||||
handler.platform_compatible = False
|
||||
else:
|
||||
# 确保处理器为平台兼容状态
|
||||
handler.platform_compatible = True
|
||||
|
||||
# 更新已激活的处理器列表
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
@@ -20,12 +20,270 @@ from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from ...context import PipelineContext, call_event_hook, call_handler
|
||||
from ..stage import Stage
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
try:
|
||||
import mcp
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@classmethod
|
||||
async def execute(cls, tool, run_context, **tool_args):
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
"""
|
||||
if isinstance(tool, HandoffTool):
|
||||
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
if tool.origin == "local":
|
||||
async for r in cls._execute_local(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
elif tool.origin == "mcp":
|
||||
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
||||
yield r
|
||||
return
|
||||
|
||||
raise Exception(f"Unknown function origin: {tool.origin}")
|
||||
|
||||
@classmethod
|
||||
async def _execute_handoff(
|
||||
cls,
|
||||
tool: HandoffTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
input_ = tool_args.get("input", "agent")
|
||||
agent_runner = AgentRunner()
|
||||
|
||||
# make toolset for the agent
|
||||
tools = tool.agent.tools
|
||||
if tools:
|
||||
toolset = ToolSet()
|
||||
for t in tools:
|
||||
if isinstance(t, str):
|
||||
_t = llm_tools.get_func(t)
|
||||
if _t:
|
||||
toolset.add_tool(_t)
|
||||
elif isinstance(t, FunctionTool):
|
||||
toolset.add_tool(t)
|
||||
else:
|
||||
toolset = None
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description,
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=run_context.context.provider,
|
||||
first_provider_request=run_context.context.first_provider_request,
|
||||
curr_provider_request=request,
|
||||
streaming=run_context.context.streaming,
|
||||
)
|
||||
|
||||
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
||||
await run_context.event.send(
|
||||
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
|
||||
)
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=run_context.context.provider,
|
||||
request=request,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx, event=run_context.event
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
streaming=run_context.context.streaming,
|
||||
)
|
||||
|
||||
async for _ in run_agent(agent_runner, 15, True):
|
||||
pass
|
||||
|
||||
if agent_runner.done():
|
||||
llm_response = agent_runner.get_final_llm_resp()
|
||||
logger.debug(
|
||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||
)
|
||||
|
||||
result = (
|
||||
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
|
||||
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
|
||||
)
|
||||
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=result,
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
yield mcp.types.TextContent(
|
||||
type="text",
|
||||
text=f"error when deligate task to {tool.agent.name}",
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
if not run_context.event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
|
||||
# 检查 tool 下有没有 run 方法
|
||||
if not tool.handler and not hasattr(tool, "run"):
|
||||
raise ValueError("Tool must have a valid handler or 'run' method.")
|
||||
awaitable = tool.handler or getattr(tool, "run")
|
||||
|
||||
wrapper = call_handler(
|
||||
event=run_context.event,
|
||||
handler=awaitable,
|
||||
**tool_args,
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
if isinstance(resp, mcp.types.CallToolResult):
|
||||
yield resp
|
||||
else:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
|
||||
@classmethod
|
||||
async def _execute_mcp(
|
||||
cls,
|
||||
tool: FunctionTool,
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
**tool_args,
|
||||
):
|
||||
if not tool.mcp_client:
|
||||
raise ValueError("MCP client is not available for MCP function tools.")
|
||||
res = await tool.mcp_client.session.call_tool(
|
||||
name=tool.name,
|
||||
arguments=tool_args,
|
||||
)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
run_context.event, EventType.OnLLMResponseEvent, llm_response
|
||||
)
|
||||
|
||||
|
||||
MAIN_AGENT_HOOKS = MainAgentHooks()
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
|
||||
) -> AsyncGenerator[MessageChain, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.event
|
||||
while step_idx < max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use or astr_event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await astr_event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not agent_runner.streaming:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
astr_event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
astr_event.clear_result()
|
||||
else:
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
astr_event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -65,6 +323,20 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent):
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
# 获取对话上下文
|
||||
cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
if not cid:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
return conversation
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, _nested: bool = False
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -100,30 +372,14 @@ class LLMRequestSubStage(Stage):
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
# 获取对话上下文
|
||||
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if not conversation_id:
|
||||
conversation_id = await self.conv_manager.new_conversation(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
conversation_id = await self.conv_manager.new_conversation(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, conversation_id
|
||||
)
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
@@ -133,7 +389,7 @@ class LLMRequestSubStage(Stage):
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
if await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
@@ -167,92 +423,62 @@ class LLMRequestSubStage(Stage):
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
# Call Agent
|
||||
tool_loop_agent = ToolLoopAgent(
|
||||
provider=provider,
|
||||
event=event,
|
||||
pipeline_ctx=self.ctx,
|
||||
)
|
||||
# check provider modalities
|
||||
# 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持图像,清空图像列表。")
|
||||
req.image_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。")
|
||||
req.func_tool = None
|
||||
# 插件可用性设置
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
plugin = star_map.get(tool.handler_module_path)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
||||
)
|
||||
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||
|
||||
async def requesting():
|
||||
step_idx = 0
|
||||
while step_idx < self.max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in tool_loop_agent.step():
|
||||
if event.is_stopped():
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
resp.data["chain"].type = "tool_call_result"
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if self.streaming_response:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if (
|
||||
self.show_tool_use
|
||||
or event.get_platform_name() == "webchat"
|
||||
):
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not self.streaming_response:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
event.clear_result()
|
||||
else:
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if tool_loop_agent.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||
)
|
||||
)
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=provider.get_model(),
|
||||
provider_type=provider.meta().type,
|
||||
)
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
provider=provider,
|
||||
first_provider_request=req,
|
||||
curr_provider_request=req,
|
||||
streaming=self.streaming_response,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=self.streaming_response,
|
||||
)
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting())
|
||||
.set_async_stream(
|
||||
run_agent(agent_runner, self.max_step, self.show_tool_use)
|
||||
)
|
||||
)
|
||||
yield
|
||||
if tool_loop_agent.done():
|
||||
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
@@ -266,15 +492,15 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
)
|
||||
else:
|
||||
async for _ in requesting():
|
||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||
yield
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
|
||||
|
||||
async def _handle_webchat(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
@@ -307,19 +533,10 @@ class LLMRequestSubStage(Stage):
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
event.unified_msg_origin, title=title
|
||||
unified_msg_origin=event.unified_msg_origin,
|
||||
title=title,
|
||||
conversation_id=req.conversation.cid,
|
||||
)
|
||||
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
|
||||
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
|
||||
# TODO: 优化 WebChat 适配器的对话管理
|
||||
if event.session_id:
|
||||
username, cid = event.session_id.split("!")[1:3]
|
||||
db_helper = self.ctx.plugin_manager.context._db
|
||||
db_helper.update_conversation_title(
|
||||
user_id=username,
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||
"""
|
||||
|
||||
from ...context import PipelineContext
|
||||
from ...context import PipelineContext, call_handler
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -33,16 +33,6 @@ class StarRequestSubStage(Stage):
|
||||
handlers_parsed_params = {}
|
||||
|
||||
for handler in activated_handlers:
|
||||
# 检查处理器是否在当前平台兼容
|
||||
if (
|
||||
hasattr(handler, "platform_compatible")
|
||||
and handler.platform_compatible is False
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||
)
|
||||
continue
|
||||
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
@@ -50,7 +40,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
)
|
||||
wrapper = self.ctx.call_handler(event, handler.handler, **params)
|
||||
wrapper = call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
@@ -128,7 +128,7 @@ class RespondStage(Stage):
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
logger.info(f"应用流式输出({event.get_platform_id()})")
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
@@ -214,7 +214,7 @@ class RespondStage(Stage):
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -64,9 +64,10 @@ class ResultDecorateStage(Stage):
|
||||
]
|
||||
self.content_safe_check_stage = None
|
||||
if self.content_safe_check_reply:
|
||||
for stage in registered_stages:
|
||||
if stage.__class__.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage
|
||||
for stage_cls in registered_stages:
|
||||
if stage_cls.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -98,7 +99,7 @@ class ResultDecorateStage(Stage):
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||
EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -11,16 +11,17 @@ class PipelineScheduler:
|
||||
|
||||
def __init__(self, context: PipelineContext):
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||
key=lambda x: STAGES_ORDER.index(x.__name__)
|
||||
) # 按照顺序排序
|
||||
self.ctx = context # 上下文对象
|
||||
self.stages = [] # 存储阶段实例
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管道调度器时, 初始化所有阶段"""
|
||||
for stage in registered_stages:
|
||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
for stage_cls in registered_stages:
|
||||
stage_instance = stage_cls() # 创建实例
|
||||
await stage_instance.initialize(self.ctx)
|
||||
self.stages.append(stage_instance)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
"""依次执行各个阶段
|
||||
@@ -29,9 +30,9 @@ class PipelineScheduler:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||
"""
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
for i in range(from_stage, len(self.stages)):
|
||||
stage = self.stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
|
||||
coroutine = stage.process(
|
||||
event
|
||||
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import List, AsyncGenerator, Union
|
||||
from typing import List, AsyncGenerator, Union, Type
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
|
||||
|
||||
|
||||
def register_stage(cls):
|
||||
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
|
||||
registered_stages.append(cls())
|
||||
registered_stages.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
|
||||
@@ -112,8 +112,17 @@ class WakingCheckStage(Stage):
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
|
||||
# 将 plugins_name 设置到 event 中
|
||||
enabled_plugins_name = self.ctx.astrbot_config.get("plugin_set", ["*"])
|
||||
if enabled_plugins_name == ["*"]:
|
||||
# 如果是 *,则表示所有插件都启用
|
||||
event.plugins_name = None
|
||||
else:
|
||||
event.plugins_name = enabled_plugins_name
|
||||
logger.debug(f"enabled_plugins_name: {enabled_plugins_name}")
|
||||
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.AdapterMessageEvent
|
||||
EventType.AdapterMessageEvent, plugins_name=event.plugins_name
|
||||
):
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
|
||||
@@ -3,9 +3,10 @@ import asyncio
|
||||
import re
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
Plain,
|
||||
@@ -23,21 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSesion:
|
||||
platform_name: str
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
@staticmethod
|
||||
def from_str(session_str: str):
|
||||
platform_name, message_type, session_id = session_str.split(":")
|
||||
return MessageSesion(platform_name, MessageType(message_type), session_id)
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
@@ -64,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -78,13 +65,23 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.call_llm = False
|
||||
"""是否在此消息事件中禁止默认的 LLM 请求"""
|
||||
|
||||
self.plugins_name: list[str] | None = None
|
||||
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
|
||||
|
||||
# back_compability
|
||||
self.platform = platform_meta
|
||||
|
||||
def get_platform_name(self):
|
||||
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
|
||||
|
||||
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
|
||||
return self.platform_meta.name
|
||||
|
||||
def get_platform_id(self):
|
||||
"""获取这个事件所属的平台的 ID。
|
||||
|
||||
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
|
||||
"""
|
||||
return self.platform_meta.id
|
||||
|
||||
def get_message_str(self) -> str:
|
||||
@@ -188,6 +185,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
清除额外的信息。
|
||||
"""
|
||||
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
|
||||
self._extras.clear()
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
|
||||
@@ -18,6 +18,9 @@ class PlatformManager:
|
||||
|
||||
self.platforms_config = config["platform"]
|
||||
self.settings = config["platform_settings"]
|
||||
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
|
||||
这个配置中的 unique_session 需要特殊处理,
|
||||
约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
|
||||
self.event_queue = event_queue
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSession:
|
||||
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
|
||||
如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
|
||||
|
||||
platform_name: str
|
||||
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
platform_id: str = None
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
def __post_init__(self):
|
||||
self.platform_id = self.platform_name
|
||||
|
||||
@staticmethod
|
||||
def from_str(session_str: str):
|
||||
platform_id, message_type, session_id = session_str.split(":")
|
||||
return MessageSession(platform_id, MessageType(message_type), session_id)
|
||||
|
||||
|
||||
MessageSesion = MessageSession # back compatibility
|
||||
@@ -5,7 +5,7 @@ from asyncio import Queue
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .astr_message_event import MessageSesion
|
||||
from .message_session import MessageSesion
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
@dataclass
|
||||
class PlatformMetadata:
|
||||
name: str
|
||||
"""平台的名称"""
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str = None
|
||||
|
||||
@@ -77,7 +77,7 @@ class WebChatAdapter(Platform):
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat", description="webchat", id=self.config.get("id", "")
|
||||
name="webchat", description="webchat", id="webchat"
|
||||
)
|
||||
|
||||
async def send_by_session(
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import PlatformMessageHistory
|
||||
|
||||
|
||||
class PlatformMessageHistoryManager:
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
self.db = db_helper
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
content: list[dict], # TODO: parse from message chain
|
||||
sender_id: str = None,
|
||||
sender_name: str = None,
|
||||
):
|
||||
"""Insert a new platform message history record."""
|
||||
await self.db.insert_platform_message_history(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
sender_id=sender_id,
|
||||
sender_name=sender_name,
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
platform_id: str,
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 200,
|
||||
) -> list[PlatformMessageHistory]:
|
||||
"""Get platform message history for a specific user."""
|
||||
history = await self.db.get_platform_message_history(
|
||||
platform_id=platform_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
history.reverse()
|
||||
return history
|
||||
|
||||
async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400):
|
||||
"""Delete platform message history records older than the specified offset."""
|
||||
await self.db.delete_platform_message_offset(
|
||||
platform_id=platform_id, user_id=user_id, offset_sec=offset_sec
|
||||
)
|
||||
@@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot import logger
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
@@ -20,6 +20,7 @@ class ProviderType(enum.Enum):
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,7 +98,7 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall | None = None
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
@@ -293,3 +294,10 @@ class LLMResponse:
|
||||
}
|
||||
)
|
||||
return ret
|
||||
|
||||
@dataclass
|
||||
class RerankResult:
|
||||
index: int
|
||||
"""在候选列表中的索引位置"""
|
||||
relevance_score: float
|
||||
"""相关性分数"""
|
||||
|
||||
@@ -1,32 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
import aiohttp
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, List, Awaitable
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
from astrbot.core import sp
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.agent.mcp_client import MCPClient
|
||||
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
try:
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning(
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
@@ -39,6 +24,10 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
# alias
|
||||
FuncTool = FunctionTool
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
@@ -105,181 +94,9 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
用于描述一个函数调用工具。
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
handler: Awaitable = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
async def execute(self, **args) -> Any:
|
||||
"""执行函数调用"""
|
||||
if self.origin == "local":
|
||||
if not self.handler:
|
||||
raise Exception(f"Local function {self.name} has no handler")
|
||||
return await self.handler(**args)
|
||||
elif self.origin == "mcp":
|
||||
if not self.mcp_client or not self.mcp_client.session:
|
||||
raise Exception(f"MCP client for {self.name} is not available")
|
||||
# 使用name属性而不是额外的mcp_tool_name
|
||||
actual_tool_name = (
|
||||
self.name.split(":")[-1] if ":" in self.name else self.name
|
||||
)
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
self.session: Optional[mcp.ClientSession] = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
self.running_event = asyncio.Event()
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在:
|
||||
1. 当 transport 指定为 `streamable_http` 时,使用 Streamable HTTP 连接方式。
|
||||
1. 当 transport 指定为 `sse` 时,使用 SSE 连接方式。
|
||||
2. 如果没有指定,默认使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
if "url" in cfg:
|
||||
success, error_msg = await _quick_test_mcp_connection(cfg)
|
||||
if not success:
|
||||
raise Exception(error_msg)
|
||||
|
||||
if cfg.get("transport") != "streamable_http":
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=cfg.get("timeout", 5),
|
||||
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
||||
)
|
||||
streams = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
*streams,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
||||
sse_read_timeout = timedelta(
|
||||
seconds=cfg.get("sse_read_timeout", 60 * 5)
|
||||
)
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=cfg["url"],
|
||||
headers=cfg.get("headers", {}),
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
terminate_on_close=cfg.get("terminate_on_close", True),
|
||||
)
|
||||
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
||||
self._streams_context
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read_stream=read_s,
|
||||
write_stream=write_s,
|
||||
read_timeout_seconds=read_timeout,
|
||||
logging_callback=logging_callback, # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
await self.session.initialize()
|
||||
|
||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||
"""List all tools from the server and save them to self.tools"""
|
||||
response = await self.session.list_tools()
|
||||
self.tools = response.tools
|
||||
return response
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
|
||||
|
||||
class FuncCall:
|
||||
class FunctionToolManager:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: List[FuncTool] = []
|
||||
"""内部加载的 func tools"""
|
||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||
@@ -287,6 +104,29 @@ class FuncCall:
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
|
||||
def spec_to_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Awaitable,
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
return FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
@@ -304,22 +144,14 @@ class FuncCall:
|
||||
# check if the tool has been added before
|
||||
self.remove_func(name)
|
||||
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
self.func_list.append(
|
||||
self.spec_to_func(
|
||||
name=name,
|
||||
func_args=func_args,
|
||||
desc=desc,
|
||||
handler=handler,
|
||||
)
|
||||
)
|
||||
self.func_list.append(_func)
|
||||
logger.info(f"添加函数调用工具: {name}")
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
@@ -331,11 +163,15 @@ class FuncCall:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def get_func(self, name) -> FuncTool:
|
||||
def get_func(self, name) -> FuncTool | None:
|
||||
for f in self.func_list:
|
||||
if f.name == name:
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""获取完整工具集"""
|
||||
tool_set = ToolSet(self.func_list.copy())
|
||||
return tool_set
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
@@ -556,203 +392,179 @@ class FuncCall:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
_l = []
|
||||
# 处理所有工具(包括本地和MCP工具)
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
func_ = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
# "parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
func_["function"]["parameters"] = f.parameters
|
||||
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||
del func_["function"]["parameters"]
|
||||
_l.append(func_)
|
||||
return _l
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.openai_schema(
|
||||
omit_empty_parameter_field=omit_empty_parameter_field
|
||||
)
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
"""
|
||||
获得 Anthropic API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
# Convert internal format to Anthropic style
|
||||
tool = {
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": f.parameters.get("properties", {}),
|
||||
# Keep the required field from the original parameters if it exists
|
||||
"required": f.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
tools.append(tool)
|
||||
return tools
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.anthropic_schema()
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> dict:
|
||||
"""
|
||||
获得 Google GenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.google_schema()
|
||||
|
||||
# Gemini API 支持的数据类型和格式
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""停用一个已经注册的函数调用工具。
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""转换 schema 为 Gemini API 格式"""
|
||||
Returns:
|
||||
如果没找到,会返回 False"""
|
||||
func_tool = self.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
# 暂时指定默认为null
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties: # 只在有非空属性时添加
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
**({"parameters": convert_schema(f.parameters)}),
|
||||
}
|
||||
for f in self.func_list
|
||||
if f.active
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
inactivated_llm_tools: list = sp.get(
|
||||
"inactivated_llm_tools", [], scope="global", scope_id="global"
|
||||
)
|
||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
inactivated_llm_tools,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
|
||||
return True
|
||||
return False
|
||||
|
||||
TOOLS:
|
||||
可用的函数列表:
|
||||
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
|
||||
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
|
||||
func_tool = self.get_func(name)
|
||||
if func_tool is not None:
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(
|
||||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
||||
)
|
||||
|
||||
{func_definition}
|
||||
func_tool.active = True
|
||||
|
||||
LIMIT:
|
||||
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
|
||||
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
|
||||
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
|
||||
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
|
||||
inactivated_llm_tools: list = sp.get(
|
||||
"inactivated_llm_tools", [], scope="global", scope_id="global"
|
||||
)
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put(
|
||||
"inactivated_llm_tools",
|
||||
inactivated_llm_tools,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
EXAMPLE:
|
||||
1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
|
||||
return True
|
||||
return False
|
||||
|
||||
用户的提问是:{question}
|
||||
""")
|
||||
@property
|
||||
def mcp_config_path(self):
|
||||
data_dir = get_astrbot_data_path()
|
||||
return os.path.join(data_dir, "mcp_server.json")
|
||||
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
if res.find("```") != -1:
|
||||
res = res[res.find("```json") + 7 : res.rfind("```")]
|
||||
res = json.loads(res)
|
||||
break
|
||||
except Exception as e:
|
||||
_c += 1
|
||||
if _c == 3:
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise e
|
||||
def load_mcp_config(self):
|
||||
if not os.path.exists(self.mcp_config_path):
|
||||
# 配置文件不存在,创建默认配置
|
||||
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
|
||||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
if "res" in res and not res["res"]:
|
||||
return "", False
|
||||
try:
|
||||
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 MCP 配置失败: {e}")
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
tool_call_result = []
|
||||
for tool in res:
|
||||
# 说明有函数调用
|
||||
func_name = tool["name"]
|
||||
args = tool["args"]
|
||||
# 调用函数
|
||||
func_tool = self.get_func(func_name)
|
||||
if not func_tool:
|
||||
raise Exception(f"Request function {func_name} not found.")
|
||||
def save_mcp_config(self, config: dict):
|
||||
try:
|
||||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存 MCP 配置失败: {e}")
|
||||
return False
|
||||
|
||||
ret = await func_tool.execute(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
async def sync_modelscope_mcp_servers(self, access_token: str) -> None:
|
||||
"""从 ModelScope 平台同步 MCP 服务器配置"""
|
||||
base_url = "https://www.modelscope.cn/openapi/v1"
|
||||
url = f"{base_url}/mcp/servers/operational"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token.strip()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
mcp_server_list = data.get("data", {}).get(
|
||||
"mcp_server_list", []
|
||||
)
|
||||
local_mcp_config = self.load_mcp_config()
|
||||
|
||||
synced_count = 0
|
||||
for server in mcp_server_list:
|
||||
server_name = server["name"]
|
||||
operational_urls = server.get("operational_urls", [])
|
||||
if not operational_urls:
|
||||
continue
|
||||
url_info = operational_urls[0]
|
||||
server_url = url_info.get("url")
|
||||
if not server_url:
|
||||
continue
|
||||
# 添加到配置中(同名会覆盖)
|
||||
local_mcp_config["mcpServers"][server_name] = {
|
||||
"url": server_url,
|
||||
"transport": "sse",
|
||||
"active": True,
|
||||
"provider": "modelscope",
|
||||
}
|
||||
synced_count += 1
|
||||
|
||||
if synced_count > 0:
|
||||
self.save_mcp_config(local_mcp_config)
|
||||
tasks = []
|
||||
for server in mcp_server_list:
|
||||
name = server["name"]
|
||||
tasks.append(
|
||||
self.enable_mcp_server(
|
||||
name=name,
|
||||
config=local_mcp_config["mcpServers"][name],
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(
|
||||
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器"
|
||||
)
|
||||
else:
|
||||
logger.warning("没有找到可用的 ModelScope MCP 服务器")
|
||||
else:
|
||||
raise Exception(
|
||||
f"ModelScope API 请求失败: HTTP {response.status}"
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"网络连接错误: {str(e)}")
|
||||
except Exception as e:
|
||||
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {str(e)}")
|
||||
|
||||
def __str__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
|
||||
# alias
|
||||
FuncCall = FunctionToolManager
|
||||
|
||||
@@ -3,89 +3,32 @@ import traceback
|
||||
from typing import List
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
from ..persona_mgr import PersonaManager
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
acm: AstrBotConfigManager,
|
||||
db_helper: BaseDatabase,
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: List = config["provider"]
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
for persona in self.persona_configs:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
bd_processed = []
|
||||
mid_processed = ""
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
mood_imitation_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in mood_imitation_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
mid_processed += "\n"
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed=mid_processed,
|
||||
)
|
||||
if persona["name"] == self.default_persona_name:
|
||||
self.selected_default_persona = persona
|
||||
self.personas.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
if not self.selected_default_persona:
|
||||
self.selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed="",
|
||||
)
|
||||
self.personas.append(self.selected_default_persona)
|
||||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||||
self.default_persona_name = persona_mgr.default_persona
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
"""加载的 Provider 的实例"""
|
||||
@@ -100,46 +43,111 @@ class ProviderManager:
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
self.curr_provider_inst: Provider | None = None
|
||||
"""默认的 Provider 实例"""
|
||||
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.curr_stt_provider_inst: STTProvider | None = None
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.curr_tts_provider_inst: TTSProvider | None = None
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
# kdb(experimental)
|
||||
self.curr_kdb_name = ""
|
||||
kdb_cfg = config.get("knowledge_db", {})
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
@property
|
||||
def persona_configs(self) -> list:
|
||||
"""动态获取最新的 persona 配置"""
|
||||
return self.persona_mgr.persona_v3_config
|
||||
|
||||
@property
|
||||
def personas(self) -> list:
|
||||
"""动态获取最新的 personas 列表"""
|
||||
return self.persona_mgr.personas_v3
|
||||
|
||||
@property
|
||||
def selected_default_persona(self):
|
||||
"""动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()"""
|
||||
return self.persona_mgr.selected_default_persona_v3
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str | None = None
|
||||
):
|
||||
"""设置提供商。
|
||||
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||||
|
||||
Version 4.0.0: 这个版本下已经默认隔离提供商
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
if umo and self.provider_settings["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
session_perf[provider_type.value] = provider_id
|
||||
perf[umo] = session_perf
|
||||
sp.put("session_provider_perf", perf)
|
||||
if umo:
|
||||
await sp.session_put(
|
||||
umo,
|
||||
f"provider_perf_{provider_type.value}",
|
||||
provider_id,
|
||||
)
|
||||
return
|
||||
# 不启用提供商会话隔离模式的情况
|
||||
self.curr_provider_inst = self.inst_map[provider_id]
|
||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
sp.put("curr_provider_tts", provider_id)
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
sp.put("curr_provider_stt", provider_id)
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||||
|
||||
Returns:
|
||||
Provider: 正在使用的提供商实例。
|
||||
"""
|
||||
provider = None
|
||||
if umo:
|
||||
provider_id = sp.get(
|
||||
f"provider_perf_{provider_type.value}",
|
||||
None,
|
||||
scope="umo",
|
||||
scope_id=umo,
|
||||
)
|
||||
if provider_id:
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
# default setting
|
||||
config = self.acm.get_conf(umo)
|
||||
if provider_type == ProviderType.CHAT_COMPLETION:
|
||||
provider_id = config["provider_settings"].get("default_provider_id")
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = self.provider_insts[0] if self.provider_insts else None
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
provider_id = config["provider_stt_settings"].get("provider_id")
|
||||
if not provider_id:
|
||||
return None
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = (
|
||||
self.stt_provider_insts[0] if self.stt_provider_insts else None
|
||||
)
|
||||
elif provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
provider_id = config["provider_tts_settings"].get("provider_id")
|
||||
if not provider_id:
|
||||
return None
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = (
|
||||
self.tts_provider_insts[0] if self.tts_provider_insts else None
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
return provider
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
@@ -148,13 +156,22 @@ class ProviderManager:
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider", self.provider_settings.get("default_provider_id")
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_stt_provider_id = sp.get(
|
||||
"curr_provider_stt", self.provider_stt_settings.get("provider_id")
|
||||
"curr_provider_stt",
|
||||
self.provider_stt_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_tts_provider_id = sp.get(
|
||||
"curr_provider_tts", self.provider_tts_settings.get("provider_id")
|
||||
"curr_provider_tts",
|
||||
self.provider_tts_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
@@ -262,6 +279,10 @@ class ProviderManager:
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -345,7 +366,7 @@ class ProviderManager:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config, self.provider_settings
|
||||
)
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ToolCallsResult,
|
||||
ProviderType,
|
||||
RerankResult,
|
||||
)
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.db.po import Personality
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: List[str] = []
|
||||
mood_imitation_dialogs: List[str] = []
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: List[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta:
|
||||
id: str
|
||||
@@ -90,7 +85,7 @@ class Provider(AbstractProvider):
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
func_tool: ToolSet = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
@@ -119,7 +114,7 @@ class Provider(AbstractProvider):
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
func_tool: ToolSet = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
@@ -206,3 +201,17 @@ class EmbeddingProvider(AbstractProvider):
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
|
||||
class RerankProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def rerank(
|
||||
self, query: str, documents: list[str], top_n: int | None = None
|
||||
) -> list[RerankResult]:
|
||||
"""获取查询和文档的重排序分数"""
|
||||
...
|
||||
|
||||
@@ -75,8 +75,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
|
||||
if (
|
||||
|
||||
@@ -97,8 +97,7 @@ class ProviderDify(Provider):
|
||||
# 获得会话变量
|
||||
payload_vars = self.variables.copy()
|
||||
# 动态变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var = await sp.session_get(session_id, "session_variables", default={})
|
||||
payload_vars.update(session_var)
|
||||
payload_vars["system_prompt"] = system_prompt
|
||||
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import aiohttp
|
||||
from ..provider import RerankProvider
|
||||
from ..register import register_provider_adapter
|
||||
from ..entities import ProviderType, RerankResult
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"vllm_rerank",
|
||||
"VLLM Rerank 适配器",
|
||||
provider_type=ProviderType.RERANK,
|
||||
)
|
||||
class VLLMRerankProvider(RerankProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
self.auth_key = provider_config.get("rerank_api_key", "")
|
||||
self.base_url = provider_config.get("rerank_api_base", "http://127.0.0.1:8000")
|
||||
self.base_url = self.base_url.rstrip("/")
|
||||
self.timeout = provider_config.get("timeout", 20)
|
||||
self.model = provider_config.get("rerank_model", "BAAI/bge-reranker-base")
|
||||
|
||||
h = {}
|
||||
if self.auth_key:
|
||||
h["Authorization"] = f"Bearer {self.auth_key}"
|
||||
self.client = aiohttp.ClientSession(
|
||||
headers=h,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
)
|
||||
|
||||
async def rerank(
|
||||
self, query: str, documents: list[str], top_n: int | None = None
|
||||
) -> list[RerankResult]:
|
||||
payload = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"model": self.model,
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
async with self.client.post(
|
||||
f"{self.base_url}/v1/rerank", json=payload
|
||||
) as response:
|
||||
response_data = await response.json()
|
||||
results = response_data.get("results", [])
|
||||
|
||||
return [
|
||||
RerankResult(
|
||||
index=result["index"],
|
||||
relevance_score=result["relevance_score"],
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
async def terminate(self) -> None:
|
||||
"""关闭客户端会话"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
@@ -34,7 +34,7 @@ class Star(CommandParserMixin):
|
||||
|
||||
@staticmethod
|
||||
async def html_render(
|
||||
tmpl: str, data: dict, return_url=True, options: dict = None
|
||||
tmpl: str, data: dict, return_url=True, options: dict | None = None
|
||||
) -> str:
|
||||
"""渲染 HTML"""
|
||||
return await html_renderer.render_custom_template(
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
|
||||
from astrbot.core.provider.provider import (
|
||||
Provider,
|
||||
TTSProvider,
|
||||
STTProvider,
|
||||
EmbeddingProvider,
|
||||
)
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
@@ -22,6 +29,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
PlatformAdapterType,
|
||||
ADAPTER_NAME_2_TYPE,
|
||||
)
|
||||
from deprecated import deprecated
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -29,19 +37,6 @@ class Context:
|
||||
暴露给插件的接口上下文。
|
||||
"""
|
||||
|
||||
_event_queue: Queue = None
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
|
||||
_config: AstrBotConfig = None
|
||||
"""AstrBot 配置信息"""
|
||||
|
||||
_db: BaseDatabase = None
|
||||
"""AstrBot 数据库"""
|
||||
|
||||
provider_manager: ProviderManager = None
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
# back compatibility
|
||||
@@ -53,18 +48,27 @@ class Context:
|
||||
event_queue: Queue,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
provider_manager: ProviderManager,
|
||||
platform_manager: PlatformManager,
|
||||
conversation_manager: ConversationManager,
|
||||
message_history_manager: PlatformMessageHistoryManager,
|
||||
persona_manager: PersonaManager,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
"""事件队列。消息平台通过事件队列传递消息事件。"""
|
||||
self._config = config
|
||||
"""AstrBot 默认配置"""
|
||||
self._db = db
|
||||
"""AstrBot 数据库"""
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
self.message_history_manager = message_history_manager
|
||||
self.persona_manager = persona_manager
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata | None:
|
||||
"""根据插件名获取插件的 Metadata"""
|
||||
for star in star_registry:
|
||||
if star.name == star_name:
|
||||
@@ -74,7 +78,7 @@ class Context:
|
||||
"""获取当前载入的所有插件 Metadata 的列表"""
|
||||
return star_registry
|
||||
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
def get_llm_tool_manager(self) -> FunctionToolManager:
|
||||
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
@@ -84,40 +88,14 @@ class Context:
|
||||
Returns:
|
||||
如果没找到,会返回 False
|
||||
"""
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(
|
||||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
||||
)
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""停用一个已经注册的函数调用工具。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False"""
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""
|
||||
@@ -125,7 +103,7 @@ class Context:
|
||||
"""
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
||||
return self.provider_manager.inst_map.get(provider_id)
|
||||
|
||||
@@ -145,51 +123,49 @@ class Context:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
def get_using_provider(self, umo: str | None = None) -> Provider | None:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
|
||||
"""
|
||||
获取当前使用的用于 TTS 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
|
||||
"""
|
||||
获取当前使用的用于 STT 任务的 Provider。
|
||||
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo and self._config["provider_settings"]["separate_provider"]:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
||||
"""获取 AstrBot 的配置。"""
|
||||
return self._config
|
||||
if not umo:
|
||||
# using default config
|
||||
return self._config
|
||||
else:
|
||||
return self.astrbot_config_mgr.get_conf(umo)
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
@@ -201,9 +177,14 @@ class Context:
|
||||
"""
|
||||
return self._event_queue
|
||||
|
||||
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform:
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(
|
||||
self, platform_type: Union[PlatformAdapterType, str]
|
||||
) -> Platform | None:
|
||||
"""
|
||||
获取指定类型的平台适配器。
|
||||
|
||||
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
name = platform.meta().name
|
||||
@@ -217,6 +198,20 @@ class Context:
|
||||
):
|
||||
return platform
|
||||
|
||||
def get_platform_inst(self, platform_id: str) -> Platform | None:
|
||||
"""
|
||||
获取指定 ID 的平台适配器实例。
|
||||
|
||||
Args:
|
||||
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
|
||||
|
||||
Returns:
|
||||
Platform: 平台适配器实例,如果未找到则返回 None。
|
||||
"""
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().id == platform_id:
|
||||
return platform
|
||||
|
||||
async def send_message(
|
||||
self, session: Union[str, MessageSesion], message_chain: MessageChain
|
||||
) -> bool:
|
||||
@@ -240,7 +235,7 @@ class Context:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
if platform.meta().id == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -11,6 +11,7 @@ from .star_handler import (
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
register_agent,
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent,
|
||||
)
|
||||
@@ -28,6 +29,7 @@ __all__ = [
|
||||
"register_on_llm_request",
|
||||
"register_on_llm_response",
|
||||
"register_llm_tool",
|
||||
"register_agent",
|
||||
"register_on_decorating_result",
|
||||
"register_after_message_sent",
|
||||
]
|
||||
|
||||
@@ -15,6 +15,11 @@ from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.agent.agent import Agent
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
|
||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||
@@ -306,7 +311,7 @@ def register_on_llm_response(**kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
def register_llm_tool(name: str = None, **kwargs):
|
||||
"""为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
@@ -340,6 +345,9 @@ def register_llm_tool(name: str = None):
|
||||
"""
|
||||
|
||||
name_ = name
|
||||
registering_agent = None
|
||||
if kwargs.get("registering_agent"):
|
||||
registering_agent = kwargs["registering_agent"]
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
@@ -357,15 +365,69 @@ def register_llm_tool(name: str = None):
|
||||
"description": arg.description,
|
||||
}
|
||||
)
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(
|
||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||
)
|
||||
# print(llm_tool_name, registering_agent)
|
||||
if not registering_agent:
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(
|
||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
||||
)
|
||||
else:
|
||||
assert isinstance(registering_agent, RegisteringAgent)
|
||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||
if registering_agent._agent.tools is None:
|
||||
registering_agent._agent.tools = []
|
||||
registering_agent._agent.tools.append(llm_tools.spec_to_func(
|
||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
||||
))
|
||||
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RegisteringAgent:
|
||||
"""用于 Agent 注册"""
|
||||
|
||||
def llm_tool(self, *args, **kwargs):
|
||||
kwargs["registering_agent"] = self
|
||||
return register_llm_tool(*args, **kwargs)
|
||||
|
||||
def __init__(self, agent: Agent[AstrAgentContext]):
|
||||
self._agent = agent
|
||||
|
||||
|
||||
def register_agent(
|
||||
name: str,
|
||||
instruction: str,
|
||||
tools: list[str | FunctionTool] = None,
|
||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
|
||||
):
|
||||
"""注册一个 Agent
|
||||
|
||||
Args:
|
||||
name: Agent 的名称
|
||||
instruction: Agent 的指令
|
||||
tools: Agent 使用的工具列表
|
||||
run_hooks: Agent 运行时的钩子函数
|
||||
"""
|
||||
tools_ = tools or []
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
AstrAgent = Agent[AstrAgentContext]
|
||||
agent = AstrAgent(
|
||||
name=name,
|
||||
instructions=instruction,
|
||||
tools=tools_,
|
||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||
)
|
||||
handoff_tool = HandoffTool(agent=agent)
|
||||
handoff_tool.handler=awaitable
|
||||
llm_tools.func_list.append(handoff_tool)
|
||||
return RegisteringAgent(agent)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_decorating_result(**kwargs):
|
||||
"""在发送消息前的事件"""
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
@@ -26,8 +24,9 @@ class SessionServiceManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_services.get("llm_enabled")
|
||||
@@ -45,16 +44,13 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置LLM状态
|
||||
session_config[session_id]["llm_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["llm_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -88,8 +84,9 @@ class SessionServiceManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
tts_enabled = session_services.get("tts_enabled")
|
||||
@@ -107,16 +104,13 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置TTS状态
|
||||
session_config[session_id]["tts_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["tts_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -150,8 +144,9 @@ class SessionServiceManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
session_enabled = session_services.get("session_enabled")
|
||||
@@ -169,16 +164,13 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置会话整体状态
|
||||
session_config[session_id]["session_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
session_config["session_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -202,7 +194,7 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_custom_name(session_id: str) -> str:
|
||||
def get_session_custom_name(session_id: str) -> str | None:
|
||||
"""获取会话的自定义名称
|
||||
|
||||
Args:
|
||||
@@ -211,8 +203,9 @@ class SessionServiceManager:
|
||||
Returns:
|
||||
str: 自定义名称,如果没有设置则返回None
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
session_services = sp.get(
|
||||
"session_service_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
return session_services.get("custom_name")
|
||||
|
||||
@staticmethod
|
||||
@@ -223,20 +216,17 @@ class SessionServiceManager:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
custom_name: 自定义名称,可以为空字符串来清除名称
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
# 设置自定义名称
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
)
|
||||
if custom_name and custom_name.strip():
|
||||
session_config[session_id]["custom_name"] = custom_name.strip()
|
||||
session_config["custom_name"] = custom_name.strip()
|
||||
else:
|
||||
# 如果传入空名称,则删除自定义名称
|
||||
session_config[session_id].pop("custom_name", None)
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
session_config.pop("custom_name", None)
|
||||
sp.put(
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}"
|
||||
@@ -258,36 +248,3 @@ class SessionServiceManager:
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
# =============================================================================
|
||||
# 通用配置方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def get_session_service_config(session_id: str) -> Dict[str, bool]:
|
||||
"""获取指定会话的服务配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含session_enabled、llm_enabled、tts_enabled的字典
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
return session_config.get(
|
||||
session_id,
|
||||
{
|
||||
"session_enabled": True, # 默认启用
|
||||
"llm_enabled": True, # 默认启用
|
||||
"tts_enabled": True, # 默认启用
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
|
||||
"""获取所有会话的服务配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, bool]]: 所有会话的服务配置
|
||||
"""
|
||||
return sp.get("session_service_config", {}) or {}
|
||||
|
||||
@@ -22,7 +22,9 @@ class SessionPluginManager:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
@@ -51,7 +53,9 @@ class SessionPluginManager:
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
@@ -79,7 +83,9 @@ class SessionPluginManager:
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put("session_plugin_config", session_plugin_config)
|
||||
sp.put(
|
||||
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
@@ -95,7 +101,9 @@ class SessionPluginManager:
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
"""
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config", {}, scope="umo", scope_id=session_id
|
||||
)
|
||||
return session_plugin_config.get(
|
||||
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||
)
|
||||
|
||||
@@ -56,32 +56,8 @@ class StarMetadata:
|
||||
star_handler_full_names: list[str] = field(default_factory=list)
|
||||
"""注册的 Handler 的全名列表"""
|
||||
|
||||
supported_platforms: dict[str, bool] = field(default_factory=dict)
|
||||
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Plugin {self.name} ({self.version}) by {self.author}: {self.desc}"
|
||||
|
||||
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||
"""更新插件支持的平台列表
|
||||
|
||||
Args:
|
||||
plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项
|
||||
"""
|
||||
if not plugin_enable_config:
|
||||
return
|
||||
|
||||
# 清空之前的配置
|
||||
self.supported_platforms.clear()
|
||||
|
||||
# 遍历所有平台配置
|
||||
for platform_id, plugins in plugin_enable_config.items():
|
||||
# 检查该插件在当前平台的配置
|
||||
if self.name in plugins:
|
||||
self.supported_platforms[platform_id] = plugins[self.name]
|
||||
else:
|
||||
# 如果没有明确配置,默认为启用
|
||||
self.supported_platforms[platform_id] = True
|
||||
|
||||
@@ -7,6 +7,7 @@ from .star import star_map
|
||||
|
||||
T = TypeVar("T", bound="StarHandlerMetadata")
|
||||
|
||||
|
||||
class StarHandlerRegistry(Generic[T]):
|
||||
def __init__(self):
|
||||
self.star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -26,7 +27,10 @@ class StarHandlerRegistry(Generic[T]):
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
self,
|
||||
event_type: EventType,
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> List[StarHandlerMetadata]:
|
||||
handlers = []
|
||||
for handler in self._handlers:
|
||||
@@ -36,8 +40,15 @@ class StarHandlerRegistry(Generic[T]):
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
if plugins_name is not None and plugins_name != ["*"]:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not plugin:
|
||||
continue
|
||||
if (
|
||||
plugin.name not in plugins_name
|
||||
and event_type != EventType.OnAstrBotLoadedEvent
|
||||
and not plugin.reserved
|
||||
):
|
||||
continue
|
||||
handlers.append(handler)
|
||||
return handlers
|
||||
@@ -49,7 +60,8 @@ class StarHandlerRegistry(Generic[T]):
|
||||
self, module_name: str
|
||||
) -> List[StarHandlerMetadata]:
|
||||
return [
|
||||
handler for handler in self._handlers
|
||||
handler
|
||||
for handler in self._handlers
|
||||
if handler.handler_module_path == module_name
|
||||
]
|
||||
|
||||
@@ -67,6 +79,7 @@ class StarHandlerRegistry(Generic[T]):
|
||||
def __len__(self):
|
||||
return len(self._handlers)
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
|
||||
@@ -119,32 +132,3 @@ class StarHandlerMetadata:
|
||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||
"priority", 0
|
||||
)
|
||||
|
||||
def is_enabled_for_platform(self, platform_id: str) -> bool:
|
||||
"""检查插件是否在指定平台启用
|
||||
|
||||
Args:
|
||||
platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例
|
||||
|
||||
Returns:
|
||||
bool: 是否启用,True表示启用,False表示禁用
|
||||
"""
|
||||
plugin = star_map.get(self.handler_module_path)
|
||||
|
||||
# 如果插件元数据不存在,默认允许执行
|
||||
if not plugin or not plugin.name:
|
||||
return True
|
||||
|
||||
# 先检查插件是否被激活
|
||||
if not plugin.activated:
|
||||
return False
|
||||
|
||||
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
|
||||
if (
|
||||
hasattr(plugin, "supported_platforms")
|
||||
and platform_id in plugin.supported_platforms
|
||||
):
|
||||
return plugin.supported_platforms[platform_id]
|
||||
|
||||
# 如果没有缓存数据,默认允许执行
|
||||
return True
|
||||
|
||||
@@ -22,6 +22,7 @@ from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from astrbot.core.agent.handoff import HandoffTool, FunctionTool
|
||||
|
||||
from . import StarMetadata
|
||||
from .context import Context
|
||||
@@ -336,30 +337,8 @@ class PluginManager:
|
||||
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
# 更新所有插件的平台兼容性
|
||||
await self.update_all_platform_compatibility()
|
||||
|
||||
return result
|
||||
|
||||
async def update_all_platform_compatibility(self):
|
||||
"""更新所有插件的平台兼容性设置"""
|
||||
# 获取最新的平台插件启用配置
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
logger.debug(
|
||||
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
|
||||
)
|
||||
|
||||
# 遍历所有插件,更新平台兼容性
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin.update_platform_compatibility(plugin_enable_config)
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def load(self, specified_module_path=None, specified_dir_name=None):
|
||||
"""载入插件。
|
||||
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
|
||||
@@ -373,10 +352,9 @@ class PluginManager:
|
||||
- success (bool): 是否全部加载成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
alter_cmd = sp.get("alter_cmd", {})
|
||||
inactivated_plugins = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", [])
|
||||
alter_cmd = await sp.global_get("alter_cmd", {})
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
@@ -480,12 +458,6 @@ class PluginManager:
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
# 更新插件的平台兼容性
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
assert metadata.module_path is not None, (
|
||||
f"插件 {metadata.name} 的模块路径为空。"
|
||||
)
|
||||
@@ -503,17 +475,27 @@ class PluginManager:
|
||||
)
|
||||
# 绑定 llm_tool handler
|
||||
for func_tool in llm_tools.func_list:
|
||||
if (
|
||||
func_tool.handler
|
||||
and func_tool.handler.__module__ == metadata.module_path
|
||||
):
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(
|
||||
func_tool.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
if isinstance(func_tool, HandoffTool):
|
||||
need_apply = []
|
||||
sub_tools = func_tool.agent.tools
|
||||
for sub_tool in sub_tools:
|
||||
if isinstance(sub_tool, FunctionTool):
|
||||
need_apply.append(sub_tool)
|
||||
else:
|
||||
need_apply = [func_tool]
|
||||
|
||||
for ft in need_apply:
|
||||
if (
|
||||
ft.handler
|
||||
and ft.handler.__module__ == metadata.module_path
|
||||
):
|
||||
ft.handler_module_path = metadata.module_path
|
||||
ft.handler = functools.partial(
|
||||
ft.handler,
|
||||
metadata.star_cls, # type: ignore
|
||||
)
|
||||
if ft.name in inactivated_llm_tools:
|
||||
ft.active = False
|
||||
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
@@ -776,12 +758,12 @@ class PluginManager:
|
||||
await self._terminate_plugin(plugin)
|
||||
|
||||
# 加入到 shared_preferences 中
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = list(
|
||||
set(sp.get("inactivated_llm_tools", []))
|
||||
set(await sp.global_get("inactivated_llm_tools", []))
|
||||
) # 后向兼容
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
@@ -791,8 +773,8 @@ class PluginManager:
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
await sp.global_put("inactivated_plugins", inactivated_plugins)
|
||||
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = False
|
||||
|
||||
@@ -818,11 +800,11 @@ class PluginManager:
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||
if plugin.module_path in inactivated_plugins:
|
||||
inactivated_plugins.remove(plugin.module_path)
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
await sp.global_put("inactivated_plugins", inactivated_plugins)
|
||||
|
||||
# 启用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
@@ -832,7 +814,7 @@ class PluginManager:
|
||||
):
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
func_tool.active = True
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
await sp.global_put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
await self.reload(plugin_name)
|
||||
|
||||
|
||||
@@ -56,7 +56,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
try:
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
if os.name == "nt":
|
||||
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
|
||||
args = [
|
||||
f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]
|
||||
]
|
||||
else:
|
||||
args = sys.argv[1:]
|
||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||
@@ -66,13 +68,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
|
||||
async def check_update(
|
||||
self, url: str, current_version: str, consider_prerelease: bool = True
|
||||
) -> ReleaseInfo:
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
"""检查更新"""
|
||||
return await super().check_update(
|
||||
self.ASTRBOT_RELEASE_API, VERSION, consider_prerelease
|
||||
)
|
||||
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
|
||||
async def get_releases(self) -> list:
|
||||
return await self.fetch_release_info(self.ASTRBOT_RELEASE_API)
|
||||
|
||||
@@ -227,7 +227,7 @@ async def download_dashboard(
|
||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||
|
||||
if latest or len(str(version)) != 40:
|
||||
logger.info(f"准备下载 {version} 发行版本的 AstrBot WebUI 文件")
|
||||
logger.info("准备下载最新发行版本的 AstrBot WebUI")
|
||||
ver_name = "latest" if latest else version
|
||||
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
||||
try:
|
||||
|
||||
@@ -58,9 +58,10 @@ class Metric:
|
||||
pass
|
||||
try:
|
||||
if "adapter_name" in kwargs:
|
||||
db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1})
|
||||
if "llm_name" in kwargs:
|
||||
db_helper.insert_llm_metrics({kwargs["llm_name"]: 1})
|
||||
await db_helper.insert_platform_stats(
|
||||
platform_id=kwargs["adapter_name"],
|
||||
platform_type=kwargs.get("adapter_type", "unknown"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"保存指标到数据库失败: {e}")
|
||||
pass
|
||||
|
||||
@@ -1,43 +1,180 @@
|
||||
import json
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Preference
|
||||
import threading
|
||||
import asyncio
|
||||
import os
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar, Any, overload
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
def __init__(self, db_helper: BaseDatabase, json_storage_path=None):
|
||||
if json_storage_path is None:
|
||||
json_storage_path = os.path.join(
|
||||
get_astrbot_data_path(), "shared_preferences.json"
|
||||
)
|
||||
self.path = json_storage_path
|
||||
self.db_helper = db_helper
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=self._sync_loop.run_forever, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
async def get_async(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
) -> _VT:
|
||||
"""获取指定范围和键的偏好设置"""
|
||||
if scope_id is not None and key is not None:
|
||||
result = await self.db_helper.get_preference(scope, scope_id, key)
|
||||
if result:
|
||||
ret = result.value["val"]
|
||||
else:
|
||||
ret = default
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference."
|
||||
)
|
||||
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
async def range_get_async(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""获取指定范围的偏好设置
|
||||
Note: 返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。scope_id 和 key 可以为 None,这时返回该范围下所有的偏好设置。
|
||||
"""
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: None, key: str, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: str, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
@overload
|
||||
async def session_get(
|
||||
self, umo: None, key: None, default: Any = None
|
||||
) -> list[Preference]: ...
|
||||
|
||||
async def session_get(
|
||||
self, umo: str | None, key: str | None = None, default: _VT = None
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取会话范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if umo is None or key is None:
|
||||
return await self.range_get_async("umo", umo, key)
|
||||
return await self.get_async("umo", umo, key, default)
|
||||
|
||||
@overload
|
||||
async def global_get(self, key: None, default: Any = None) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def global_get(self, key: str, default: _VT = None) -> _VT: ...
|
||||
|
||||
async def global_get(
|
||||
self, key: str | None, default: _VT = None
|
||||
) -> _VT | list[Preference]:
|
||||
"""获取全局范围的偏好设置
|
||||
|
||||
Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。
|
||||
"""
|
||||
if key is None:
|
||||
return await self.range_get_async("global", "global", key)
|
||||
return await self.get_async("global", "global", key, default)
|
||||
|
||||
async def put_async(self, scope: str, scope_id: str, key: str, value: Any):
|
||||
"""设置指定范围和键的偏好设置"""
|
||||
await self.db_helper.insert_preference_or_update(
|
||||
scope, scope_id, key, {"val": value}
|
||||
)
|
||||
|
||||
async def session_put(self, umo: str, key: str, value: Any):
|
||||
await self.put_async("umo", umo, key, value)
|
||||
|
||||
async def global_put(self, key: str, value: Any):
|
||||
await self.put_async("global", "global", key, value)
|
||||
|
||||
async def remove_async(self, scope: str, scope_id: str, key: str):
|
||||
"""删除指定范围和键的偏好设置"""
|
||||
await self.db_helper.remove_preference(scope, scope_id, key)
|
||||
|
||||
async def session_remove(self, umo: str, key: str):
|
||||
await self.remove_async("umo", umo, key)
|
||||
|
||||
async def global_remove(self, key: str):
|
||||
"""删除全局偏好设置"""
|
||||
await self.remove_async("global", "global", key)
|
||||
|
||||
async def clear_async(self, scope: str, scope_id: str):
|
||||
"""清空指定范围的所有偏好设置"""
|
||||
await self.db_helper.clear_preferences(scope, scope_id)
|
||||
|
||||
# ====
|
||||
# DEPRECATED METHODS
|
||||
# ====
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
scope: str | None = None,
|
||||
scope_id: str | None = "",
|
||||
) -> _VT:
|
||||
"""获取偏好设置(已弃用)"""
|
||||
if scope_id == "":
|
||||
scope_id = "unknown"
|
||||
if scope_id is None or key is None:
|
||||
# result = asyncio.run(self.range_get_async(scope, scope_id, key))
|
||||
raise ValueError(
|
||||
"scope_id and key cannot be None when getting a specific preference."
|
||||
)
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
self.get_async(scope or "unknown", scope_id or "unknown", key, default),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
return result if result is not None else default
|
||||
|
||||
def range_get(
|
||||
self, scope: str, scope_id: str | None = None, key: str | None = None
|
||||
) -> list[Preference]:
|
||||
"""获取指定范围的偏好设置(已弃用)"""
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
self.range_get_async(scope, scope_id, key), self._sync_loop
|
||||
).result()
|
||||
|
||||
return result
|
||||
|
||||
def put(self, key, value, scope: str | None = None, scope_id: str | None = None):
|
||||
"""设置偏好设置(已弃用)"""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.put_async(scope or "unknown", scope_id or "unknown", key, value),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
def remove(self, key, scope: str | None = None, scope_id: str | None = None):
|
||||
"""删除偏好设置(已弃用)"""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_async(scope or "unknown", scope_id or "unknown", key),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
def clear(self, scope: str | None = None, scope_id: str | None = None):
|
||||
"""清空偏好设置(已弃用)"""
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.clear_async(scope or "unknown", scope_id or "unknown"),
|
||||
self._sync_loop,
|
||||
).result()
|
||||
|
||||
@@ -1,37 +1,76 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import certifi
|
||||
|
||||
import logging
|
||||
import random
|
||||
from . import RenderStrategy
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
CUSTOM_T2I_TEMPLATE_PATH = os.path.join(get_astrbot_data_path(), "t2i_template.html")
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
class NetworkRenderStrategy(RenderStrategy):
|
||||
def __init__(self, base_url: str | None = None) -> None:
|
||||
super().__init__()
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template")
|
||||
self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
else:
|
||||
self.BASE_RENDER_URL = self._clean_url(base_url)
|
||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template", "base.html")
|
||||
with open(self.TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
self.DEFAULT_TEMPLATE = f.read()
|
||||
|
||||
if self.BASE_RENDER_URL.endswith("/"):
|
||||
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
|
||||
if not self.BASE_RENDER_URL.endswith("text2img"):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
self.endpoints = [self.BASE_RENDER_URL]
|
||||
|
||||
def set_endpoint(self, base_url: str):
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
self.BASE_RENDER_URL = base_url
|
||||
async def initialize(self):
|
||||
if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT:
|
||||
asyncio.create_task(self.get_official_endpoints())
|
||||
|
||||
if self.BASE_RENDER_URL.endswith("/"):
|
||||
self.BASE_RENDER_URL = self.BASE_RENDER_URL[:-1]
|
||||
if not self.BASE_RENDER_URL.endswith("text2img"):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
async def get_template(self) -> str:
|
||||
"""获取文转图 HTML 模板
|
||||
|
||||
Returns:
|
||||
str: 文转图 HTML 模板字符串
|
||||
"""
|
||||
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
|
||||
with open(CUSTOM_T2I_TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
return self.DEFAULT_TEMPLATE
|
||||
|
||||
async def get_official_endpoints(self):
|
||||
"""获取官方的 t2i 端点列表。"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
"https://api.soulter.top/astrbot/t2i-endpoints"
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
all_endpoints: list[dict] = data.get("data", [])
|
||||
self.endpoints = [
|
||||
ep.get("url")
|
||||
for ep in all_endpoints
|
||||
if ep.get("active") and ep.get("url")
|
||||
]
|
||||
logger.info(
|
||||
f"Successfully got {len(self.endpoints)} official T2I endpoints."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get official endpoints: {e}")
|
||||
|
||||
def _clean_url(self, url: str):
|
||||
if url.endswith("/"):
|
||||
url = url[:-1]
|
||||
if not url.endswith("text2img"):
|
||||
url += "/text2img"
|
||||
return url
|
||||
|
||||
async def render_custom_template(
|
||||
self,
|
||||
@@ -41,6 +80,7 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
"""使用自定义文转图模板"""
|
||||
|
||||
default_options = {"full_page": True, "type": "jpeg", "quality": 40}
|
||||
if options:
|
||||
default_options |= options
|
||||
@@ -51,30 +91,44 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
"tmpldata": tmpl_data,
|
||||
"options": default_options,
|
||||
}
|
||||
if return_url:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, connector=connector
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.BASE_RENDER_URL}/generate", json=post_data
|
||||
) as resp:
|
||||
ret = await resp.json()
|
||||
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
|
||||
return await download_image_by_url(
|
||||
f"{self.BASE_RENDER_URL}/generate", post=True, post_data=post_data
|
||||
)
|
||||
|
||||
endpoints = self.endpoints.copy() if self.endpoints else [self.BASE_RENDER_URL]
|
||||
random.shuffle(endpoints)
|
||||
last_exception = None
|
||||
for endpoint in endpoints:
|
||||
try:
|
||||
if return_url:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True, connector=connector
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{endpoint}/generate", json=post_data
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
ret = await resp.json()
|
||||
return f"{endpoint}/{ret['data']['id']}"
|
||||
else:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
else:
|
||||
# download_image_by_url 失败时抛异常
|
||||
return await download_image_by_url(
|
||||
f"{endpoint}/generate", post=True, post_data=post_data
|
||||
)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(f"Endpoint {endpoint} failed: {e}, trying next...")
|
||||
continue
|
||||
# 全部失败
|
||||
logger.error(f"All endpoints failed: {last_exception}")
|
||||
raise RuntimeError(f"All endpoints failed: {last_exception}")
|
||||
|
||||
async def render(self, text: str, return_url: bool = False) -> str:
|
||||
"""
|
||||
返回图像的文件路径
|
||||
"""
|
||||
with open(
|
||||
os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding="utf-8"
|
||||
) as f:
|
||||
tmpl_str = f.read()
|
||||
assert tmpl_str
|
||||
tmpl_str = await self.get_template()
|
||||
text = text.replace("`", "\\`")
|
||||
return await self.render_custom_template(
|
||||
tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url
|
||||
|
||||
@@ -10,10 +10,8 @@ class HtmlRenderer:
|
||||
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
||||
self.local_strategy = LocalRenderStrategy()
|
||||
|
||||
def set_network_endpoint(self, endpoint_url: str):
|
||||
"""设置 t2i 的网络端点。"""
|
||||
logger.info("文本转图像服务接口: " + endpoint_url)
|
||||
self.network_strategy.set_endpoint(endpoint_url)
|
||||
async def initialize(self):
|
||||
await self.network_strategy.initialize()
|
||||
|
||||
async def render_custom_template(
|
||||
self,
|
||||
|
||||
@@ -107,38 +107,16 @@ class RepoZipUpdator:
|
||||
"""Semver 版本比较"""
|
||||
return VersionComparator.compare_version(v1, v2)
|
||||
|
||||
async def check_update(
|
||||
self, url: str, current_version: str, consider_prerelease: bool = True
|
||||
) -> ReleaseInfo | None:
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo | None:
|
||||
update_data = await self.fetch_release_info(url)
|
||||
|
||||
sel_release_data = None
|
||||
if consider_prerelease:
|
||||
tag_name = update_data[0]["tag_name"]
|
||||
sel_release_data = update_data[0]
|
||||
else:
|
||||
for data in update_data:
|
||||
# 跳过带有 alpha、beta 等预发布标签的版本
|
||||
if re.search(
|
||||
r"[\-_.]?(alpha|beta|rc|dev)[\-_.]?\d*$",
|
||||
data["tag_name"],
|
||||
re.IGNORECASE,
|
||||
):
|
||||
continue
|
||||
tag_name = data["tag_name"]
|
||||
sel_release_data = data
|
||||
break
|
||||
|
||||
if not sel_release_data or not tag_name:
|
||||
logger.error("未找到合适的发布版本")
|
||||
return None
|
||||
tag_name = update_data[0]["tag_name"]
|
||||
|
||||
if self.compare_version(current_version, tag_name) >= 0:
|
||||
return None
|
||||
return ReleaseInfo(
|
||||
version=tag_name,
|
||||
published_at=sel_release_data["published_at"],
|
||||
body=f"{tag_name}\n\n{sel_release_data['body']}",
|
||||
published_at=update_data[0]["published_at"],
|
||||
body=update_data[0]["body"],
|
||||
)
|
||||
|
||||
async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""):
|
||||
|
||||
@@ -6,11 +6,11 @@ from .stat import StatRoute
|
||||
from .log import LogRoute
|
||||
from .static_file import StaticFileRoute
|
||||
from .chat import ChatRoute
|
||||
from .tools import ToolsRoute # 导入新的ToolsRoute
|
||||
from .tools import ToolsRoute
|
||||
from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
from .session_management import SessionManagementRoute
|
||||
|
||||
from .persona import PersonaRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
@@ -25,4 +25,5 @@ __all__ = [
|
||||
"ConversationRoute",
|
||||
"FileRoute",
|
||||
"SessionManagementRoute",
|
||||
"PersonaRoute",
|
||||
]
|
||||
|
||||
@@ -9,6 +9,7 @@ import asyncio
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
|
||||
|
||||
class ChatRoute(Route):
|
||||
@@ -29,28 +30,15 @@ class ChatRoute(Route):
|
||||
"/chat/get_file": ("GET", self.get_file),
|
||||
"/chat/post_image": ("POST", self.post_image),
|
||||
"/chat/post_file": ("POST", self.post_file),
|
||||
"/chat/status": ("GET", self.status),
|
||||
}
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
||||
|
||||
async def status(self):
|
||||
has_llm_enabled = (
|
||||
self.core_lifecycle.provider_manager.curr_provider_inst is not None
|
||||
)
|
||||
has_stt_enabled = (
|
||||
self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.ok(data={"llm_enabled": has_llm_enabled, "stt_enabled": has_stt_enabled})
|
||||
.__dict__
|
||||
)
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
|
||||
async def get_file(self):
|
||||
filename = request.args.get("filename")
|
||||
@@ -131,24 +119,23 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("conversation_id is empty").__dict__
|
||||
|
||||
# Get conversation-specific queues
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id)
|
||||
|
||||
# append user message
|
||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
logger.error(f"Failed to parse conversation history: {e}")
|
||||
history = []
|
||||
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
||||
|
||||
# Get conversation-specific queues
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
||||
|
||||
new_his = {"type": "user", "message": message}
|
||||
if image_url:
|
||||
new_his["image_url"] = image_url
|
||||
if audio_url:
|
||||
new_his["audio_url"] = audio_url
|
||||
history.append(new_his)
|
||||
self.db.update_conversation(
|
||||
username, conversation_id, history=json.dumps(history)
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id=username,
|
||||
sender_name=username,
|
||||
)
|
||||
|
||||
async def stream():
|
||||
@@ -164,7 +151,6 @@ class ChatRoute(Route):
|
||||
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
cid = result.get("cid")
|
||||
streaming = result.get("streaming", False)
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.05)
|
||||
@@ -173,17 +159,13 @@ class ChatRoute(Route):
|
||||
break
|
||||
elif (streaming and type == "complete") or not streaming:
|
||||
# append bot message
|
||||
conversation = self.db.get_conversation_by_user_id(
|
||||
username, cid
|
||||
)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
logger.error(f"Failed to parse conversation history: {e}")
|
||||
history = []
|
||||
history.append({"type": "bot", "message": result_text})
|
||||
self.db.update_conversation(
|
||||
username, cid, history=json.dumps(history)
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
|
||||
except BaseException as _:
|
||||
@@ -191,11 +173,11 @@ class ChatRoute(Route):
|
||||
return
|
||||
|
||||
# Put message to conversation-specific queue
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
username,
|
||||
conversation_id,
|
||||
webchat_conv_id,
|
||||
{
|
||||
"message": message,
|
||||
"image_url": image_url, # list
|
||||
@@ -217,25 +199,51 @@ class ChatRoute(Route):
|
||||
)
|
||||
return response
|
||||
|
||||
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
|
||||
"""从对话 ID 中提取 WebChat 会话 ID
|
||||
|
||||
NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。
|
||||
"""
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin="webchat", conversation_id=conversation_id
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation with ID {conversation_id} not found.")
|
||||
conv_user_id = conversation.user_id
|
||||
webchat_session_id = MessageSession.from_str(conv_user_id).session_id
|
||||
if "!" not in webchat_session_id:
|
||||
raise ValueError(f"Invalid conv user ID: {conv_user_id}")
|
||||
return webchat_session_id.split("!")[-1]
|
||||
|
||||
async def delete_conversation(self):
|
||||
username = g.get("username", "guest")
|
||||
conversation_id = request.args.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
username = g.get("username", "guest")
|
||||
|
||||
# Clean up queues when deleting conversation
|
||||
webchat_queue_mgr.remove_queues(conversation_id)
|
||||
self.db.delete_conversation(username, conversation_id)
|
||||
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
||||
await self.conv_mgr.delete_conversation(
|
||||
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
await self.platform_history_mgr.delete(
|
||||
platform_id="webchat", user_id=webchat_conv_id, offset_sec=99999999
|
||||
)
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def new_conversation(self):
|
||||
username = g.get("username", "guest")
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(username, conversation_id)
|
||||
return Response().ok(data={"conversation_id": conversation_id}).__dict__
|
||||
webchat_conv_id = str(uuid.uuid4())
|
||||
conv_id = await self.conv_mgr.new_conversation(
|
||||
unified_msg_origin=f"webchat:FriendMessage:webchat!{username}!{webchat_conv_id}",
|
||||
platform_id="webchat",
|
||||
content=[],
|
||||
)
|
||||
return Response().ok(data={"conversation_id": conv_id}).__dict__
|
||||
|
||||
async def rename_conversation(self):
|
||||
username = g.get("username", "guest")
|
||||
post_data = await request.json
|
||||
if "conversation_id" not in post_data or "title" not in post_data:
|
||||
return Response().error("Missing key: conversation_id or title").__dict__
|
||||
@@ -243,20 +251,42 @@ class ChatRoute(Route):
|
||||
conversation_id = post_data["conversation_id"]
|
||||
title = post_data["title"]
|
||||
|
||||
self.db.update_conversation_title(username, conversation_id, title=title)
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin="webchat", # fake
|
||||
conversation_id=conversation_id,
|
||||
title=title,
|
||||
)
|
||||
return Response().ok(message="重命名成功!").__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
username = g.get("username", "guest")
|
||||
conversations = self.db.get_conversations(username)
|
||||
return Response().ok(data=conversations).__dict__
|
||||
conversations = await self.conv_mgr.get_conversations(platform_id="webchat")
|
||||
# remove content
|
||||
conversations_ = []
|
||||
for conv in conversations:
|
||||
conv.history = None
|
||||
conversations_.append(conv)
|
||||
return Response().ok(data=conversations_).__dict__
|
||||
|
||||
async def get_conversation(self):
|
||||
username = g.get("username", "guest")
|
||||
conversation_id = request.args.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||
webchat_conv_id = await self._get_webchat_conv_id_from_conv_id(conversation_id)
|
||||
|
||||
return Response().ok(data=conversation).__dict__
|
||||
# Get platform message history
|
||||
history_ls = await self.platform_history_mgr.get(
|
||||
platform_id="webchat", user_id=webchat_conv_id, page=1, page_size=1000
|
||||
)
|
||||
|
||||
history_res = [history.model_dump() for history in history_ls]
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
data={
|
||||
"history": history_res,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
@@ -4,16 +4,22 @@ import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.config.default import (
|
||||
CONFIG_METADATA_2,
|
||||
DEFAULT_VALUE_MAP,
|
||||
CONFIG_METADATA_3,
|
||||
CONFIG_METADATA_3_SYSTEM,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, html_renderer
|
||||
from astrbot.core.provider import Provider
|
||||
import asyncio
|
||||
from astrbot.core.utils.t2i.network_strategy import CUSTOM_T2I_TEMPLATE_PATH
|
||||
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
@@ -159,33 +165,166 @@ class ConfigRoute(Route):
|
||||
super().__init__(context)
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
||||
self.acm = core_lifecycle.astrbot_config_mgr
|
||||
self.routes = {
|
||||
"/config/abconf/new": ("POST", self.create_abconf),
|
||||
"/config/abconf": ("GET", self.get_abconf),
|
||||
"/config/abconfs": ("GET", self.get_abconf_list),
|
||||
"/config/abconf/delete": ("POST", self.delete_abconf),
|
||||
"/config/abconf/update": ("POST", self.update_abconf),
|
||||
"/config/get": ("GET", self.get_configs),
|
||||
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
|
||||
"/config/plugin/update": ("POST", self.post_plugin_configs),
|
||||
"/config/platform/new": ("POST", self.post_new_platform),
|
||||
"/config/platform/update": ("POST", self.post_update_platform),
|
||||
"/config/platform/delete": ("POST", self.post_delete_platform),
|
||||
"/config/platform/list": ("GET", self.get_platform_list),
|
||||
"/config/provider/new": ("POST", self.post_new_provider),
|
||||
"/config/provider/update": ("POST", self.post_update_provider),
|
||||
"/config/provider/delete": ("POST", self.post_delete_provider),
|
||||
"/config/llmtools": ("GET", self.get_llm_tools),
|
||||
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
||||
"/config/provider/get_session_seperate": (
|
||||
"GET",
|
||||
lambda: Response()
|
||||
.ok({"enable": self.config["provider_settings"]["separate_provider"]})
|
||||
.__dict__,
|
||||
),
|
||||
"/config/provider/set_session_seperate": (
|
||||
"POST",
|
||||
self.post_session_seperate,
|
||||
),
|
||||
"/config/astrbot/t2i-template/get": ("GET", self.get_t2i_template),
|
||||
"/config/astrbot/t2i-template/save": ("POST", self.post_t2i_template),
|
||||
"/config/astrbot/t2i-template/delete": ("DELETE", self.delete_t2i_template),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def get_t2i_template(self):
|
||||
"""获取 T2I 模板"""
|
||||
try:
|
||||
template = await html_renderer.network_strategy.get_template()
|
||||
has_custom_template = os.path.exists(CUSTOM_T2I_TEMPLATE_PATH)
|
||||
return (
|
||||
Response()
|
||||
.ok({"template": template, "has_custom_template": has_custom_template})
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取模板失败: {str(e)}").__dict__
|
||||
|
||||
async def post_t2i_template(self):
|
||||
"""保存 T2I 模板"""
|
||||
try:
|
||||
post_data = await request.json
|
||||
if not post_data or "template" not in post_data:
|
||||
return Response().error("缺少模板内容").__dict__
|
||||
|
||||
template_content = post_data["template"]
|
||||
|
||||
# 保存自定义模板到文件
|
||||
with open(CUSTOM_T2I_TEMPLATE_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(template_content)
|
||||
|
||||
return Response().ok(message="模板保存成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"保存模板失败: {str(e)}").__dict__
|
||||
|
||||
async def delete_t2i_template(self):
|
||||
"""删除自定义 T2I 模板,恢复默认模板"""
|
||||
try:
|
||||
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
|
||||
os.remove(CUSTOM_T2I_TEMPLATE_PATH)
|
||||
return Response().ok(message="已恢复默认模板").__dict__
|
||||
else:
|
||||
return Response().ok(message="未找到自定义模板文件").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除模板失败: {str(e)}").__dict__
|
||||
|
||||
async def get_abconf_list(self):
|
||||
"""获取所有 AstrBot 配置文件的列表"""
|
||||
abconf_list = self.acm.get_conf_list()
|
||||
return Response().ok({"info_list": abconf_list}).__dict__
|
||||
|
||||
async def create_abconf(self):
|
||||
"""创建新的 AstrBot 配置文件"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
umo_parts = post_data["umo_parts"]
|
||||
name = post_data.get("name", None)
|
||||
|
||||
try:
|
||||
conf_id = self.acm.create_conf(umo_parts=umo_parts, name=name)
|
||||
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_abconf(self):
|
||||
"""获取指定 AstrBot 配置文件"""
|
||||
abconf_id = request.args.get("id")
|
||||
system_config = request.args.get("system_config", "0").lower() == "1"
|
||||
if not abconf_id and not system_config:
|
||||
return Response().error("缺少配置文件 ID").__dict__
|
||||
|
||||
try:
|
||||
if system_config:
|
||||
abconf = self.acm.confs["default"]
|
||||
return (
|
||||
Response()
|
||||
.ok({"config": abconf, "metadata": CONFIG_METADATA_3_SYSTEM})
|
||||
.__dict__
|
||||
)
|
||||
abconf = self.acm.confs[abconf_id]
|
||||
return (
|
||||
Response()
|
||||
.ok({"config": abconf, "metadata": CONFIG_METADATA_3})
|
||||
.__dict__
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def delete_abconf(self):
|
||||
"""删除指定 AstrBot 配置文件"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
conf_id = post_data.get("id")
|
||||
if not conf_id:
|
||||
return Response().error("缺少配置文件 ID").__dict__
|
||||
|
||||
try:
|
||||
success = self.acm.delete_conf(conf_id)
|
||||
if success:
|
||||
return Response().ok(message="删除成功").__dict__
|
||||
else:
|
||||
return Response().error("删除失败").__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除配置文件失败: {str(e)}").__dict__
|
||||
|
||||
async def update_abconf(self):
|
||||
"""更新指定 AstrBot 配置文件信息"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
conf_id = post_data.get("id")
|
||||
if not conf_id:
|
||||
return Response().error("缺少配置文件 ID").__dict__
|
||||
|
||||
name = post_data.get("name")
|
||||
umo_parts = post_data.get("umo_parts")
|
||||
|
||||
try:
|
||||
success = self.acm.update_conf_info(conf_id, name=name, umo_parts=umo_parts)
|
||||
if success:
|
||||
return Response().ok(message="更新成功").__dict__
|
||||
else:
|
||||
return Response().error("更新失败").__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"更新配置文件失败: {str(e)}").__dict__
|
||||
|
||||
async def _test_single_provider(self, provider):
|
||||
"""辅助函数:测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
@@ -210,11 +349,16 @@ class ConfigRoute(Route):
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
logger.debug(
|
||||
f"Received response from {status_info['name']}: {response}"
|
||||
)
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if hasattr(response, "completion_text") and response.completion_text:
|
||||
if (
|
||||
hasattr(response, "completion_text")
|
||||
and response.completion_text
|
||||
):
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
@@ -233,29 +377,48 @@ class ConfigRoute(Route):
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
)
|
||||
else:
|
||||
status_info["error"] = "Test call returned None, but expected an LLMResponse object."
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.")
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
status_info["error"] = "Connection timed out after 45 seconds during test call."
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.")
|
||||
status_info["error"] = (
|
||||
"Connection timed out after 45 seconds during test call."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) timed out."
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
status_info["error"] = error_message
|
||||
logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}")
|
||||
logger.debug(f"Traceback for {status_info['name']}:\n{traceback.format_exc()}")
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Traceback for {status_info['name']}:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
elif provider_capability_type == ProviderType.EMBEDDING:
|
||||
try:
|
||||
# For embedding, we can call the get_embedding method with a short prompt.
|
||||
embedding_result = await provider.get_embedding("health_check")
|
||||
if isinstance(embedding_result, list) and (not embedding_result or isinstance(embedding_result[0], float)):
|
||||
if isinstance(embedding_result, list) and (
|
||||
not embedding_result or isinstance(embedding_result[0], float)
|
||||
):
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
||||
status_info["error"] = (
|
||||
f"Embedding test failed: unexpected result type {type(embedding_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing embedding provider {provider_name}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error testing embedding provider {provider_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"Embedding test failed: {str(e)}"
|
||||
|
||||
@@ -267,41 +430,71 @@ class ConfigRoute(Route):
|
||||
status_info["status"] = "available"
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: unexpected result type {type(audio_result)}"
|
||||
status_info["error"] = (
|
||||
f"TTS test failed: unexpected result type {type(audio_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing TTS provider {provider_name}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error testing TTS provider {provider_name}: {e}", exc_info=True
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"TTS test failed: {str(e)}"
|
||||
elif provider_capability_type == ProviderType.SPEECH_TO_TEXT:
|
||||
try:
|
||||
logger.debug(f"Sending health check audio to provider: {status_info['name']}")
|
||||
sample_audio_path = os.path.join(get_astrbot_path(), "samples", "stt_health_check.wav")
|
||||
logger.debug(
|
||||
f"Sending health check audio to provider: {status_info['name']}"
|
||||
)
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(), "samples", "stt_health_check.wav"
|
||||
)
|
||||
if not os.path.exists(sample_audio_path):
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = "STT test failed: sample audio file not found."
|
||||
logger.warning(f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}")
|
||||
status_info["error"] = (
|
||||
"STT test failed: sample audio file not found."
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}"
|
||||
)
|
||||
else:
|
||||
text_result = await provider.get_text(sample_audio_path)
|
||||
if isinstance(text_result, str) and text_result:
|
||||
status_info["status"] = "available"
|
||||
snippet = text_result[:70] + "..." if len(text_result) > 70 else text_result
|
||||
logger.info(f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'")
|
||||
snippet = (
|
||||
text_result[:70] + "..."
|
||||
if len(text_result) > 70
|
||||
else text_result
|
||||
)
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'"
|
||||
)
|
||||
else:
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: unexpected result type {type(text_result)}"
|
||||
logger.warning(f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}")
|
||||
status_info["error"] = (
|
||||
f"STT test failed: unexpected result type {type(text_result)}"
|
||||
)
|
||||
logger.warning(
|
||||
f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing STT provider {provider_name}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error testing STT provider {provider_name}: {e}", exc_info=True
|
||||
)
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = f"STT test failed: {str(e)}"
|
||||
else:
|
||||
logger.debug(f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}")
|
||||
logger.debug(
|
||||
f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}"
|
||||
)
|
||||
status_info["status"] = "available"
|
||||
status_info["error"] = "This provider type is not tested and is assumed to be available."
|
||||
status_info["error"] = (
|
||||
"This provider type is not tested and is assumed to be available."
|
||||
)
|
||||
|
||||
return status_info
|
||||
|
||||
def _error_response(self, message: str, status_code: int = 500, log_fn=logger.error):
|
||||
def _error_response(
|
||||
self, message: str, status_code: int = 500, log_fn=logger.error
|
||||
):
|
||||
log_fn(message)
|
||||
# 记录更详细的traceback信息,但只在是严重错误时
|
||||
if status_code == 500:
|
||||
@@ -312,7 +505,9 @@ class ConfigRoute(Route):
|
||||
"""API: check a single LLM Provider's status by id"""
|
||||
provider_id = request.args.get("id")
|
||||
if not provider_id:
|
||||
return self._error_response("Missing provider_id parameter", 400, logger.warning)
|
||||
return self._error_response(
|
||||
"Missing provider_id parameter", 400, logger.warning
|
||||
)
|
||||
|
||||
logger.info(f"API call: /config/provider/check_one id={provider_id}")
|
||||
try:
|
||||
@@ -320,16 +515,21 @@ class ConfigRoute(Route):
|
||||
target = prov_mgr.inst_map.get(provider_id)
|
||||
|
||||
if not target:
|
||||
logger.warning(f"Provider with id '{provider_id}' not found in provider_manager.")
|
||||
return Response().error(f"Provider with id '{provider_id}' not found").__dict__
|
||||
logger.warning(
|
||||
f"Provider with id '{provider_id}' not found in provider_manager."
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.error(f"Provider with id '{provider_id}' not found")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
result = await self._test_single_provider(target)
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
except Exception as e:
|
||||
return self._error_response(
|
||||
f"Critical error checking provider {provider_id}: {e}",
|
||||
500
|
||||
f"Critical error checking provider {provider_id}: {e}", 500
|
||||
)
|
||||
|
||||
async def get_configs(self):
|
||||
@@ -340,29 +540,15 @@ class ConfigRoute(Route):
|
||||
return Response().ok(await self._get_astrbot_config()).__dict__
|
||||
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
|
||||
|
||||
async def post_session_seperate(self):
|
||||
"""设置提供商会话隔离"""
|
||||
post_config = await request.json
|
||||
enable = post_config.get("enable", None)
|
||||
if enable is None:
|
||||
return Response().error("缺少参数 enable").__dict__
|
||||
|
||||
astrbot_config = self.core_lifecycle.astrbot_config
|
||||
astrbot_config["provider_settings"]["separate_provider"] = enable
|
||||
try:
|
||||
astrbot_config.save_config()
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "设置成功~").__dict__
|
||||
|
||||
async def get_provider_config_list(self):
|
||||
provider_type = request.args.get("provider_type", None)
|
||||
if not provider_type:
|
||||
return Response().error("缺少参数 provider_type").__dict__
|
||||
provider_type_ls = provider_type.split(",")
|
||||
provider_list = []
|
||||
astrbot_config = self.core_lifecycle.astrbot_config
|
||||
for provider in astrbot_config["provider"]:
|
||||
if provider.get("provider_type", None) == provider_type:
|
||||
if provider.get("provider_type", None) in provider_type_ls:
|
||||
provider_list.append(provider)
|
||||
return Response().ok(provider_list).__dict__
|
||||
|
||||
@@ -388,11 +574,21 @@ class ConfigRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_platform_list(self):
|
||||
"""获取所有平台的列表"""
|
||||
platform_list = []
|
||||
for platform in self.config["platform"]:
|
||||
platform_list.append(platform)
|
||||
return Response().ok({"platforms": platform_list}).__dict__
|
||||
|
||||
async def post_astrbot_configs(self):
|
||||
post_configs = await request.json
|
||||
data = await request.json
|
||||
config = data.get("config", None)
|
||||
conf_id = data.get("conf_id", None)
|
||||
try:
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
await self._save_astrbot_configs(config, conf_id)
|
||||
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
||||
return Response().ok(None, "保存成功~").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -509,12 +705,6 @@ class ConfigRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效~").__dict__
|
||||
|
||||
async def get_llm_tools(self):
|
||||
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools = tool_mgr.get_func_desc_openai_style()
|
||||
return Response().ok(tools).__dict__
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
@@ -557,10 +747,12 @@ class ConfigRoute(Route):
|
||||
|
||||
return ret
|
||||
|
||||
async def _save_astrbot_configs(self, post_configs: dict):
|
||||
async def _save_astrbot_configs(self, post_configs: dict, conf_id: str = None):
|
||||
try:
|
||||
save_config(post_configs, self.config, is_core=True)
|
||||
await self.core_lifecycle.restart()
|
||||
if conf_id not in self.acm.confs:
|
||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||
astrbot_config = self.acm.confs[conf_id]
|
||||
save_config(post_configs, astrbot_config, is_core=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class ConversationRoute(Route):
|
||||
),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
@@ -54,7 +55,6 @@ class ConversationRoute(Route):
|
||||
exclude_platforms.split(",") if exclude_platforms else []
|
||||
)
|
||||
|
||||
# 限制页面大小,防止请求过大数据
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
@@ -62,9 +62,11 @@ class ConversationRoute(Route):
|
||||
if page_size > 100:
|
||||
page_size = 100
|
||||
|
||||
# 使用数据库的分页方法获取会话列表和总数,传入筛选条件
|
||||
try:
|
||||
conversations, total_count = self.db_helper.get_filtered_conversations(
|
||||
(
|
||||
conversations,
|
||||
total_count,
|
||||
) = await self.conv_mgr.get_filtered_conversations(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
platforms=platform_list,
|
||||
@@ -108,7 +110,9 @@ class ConversationRoute(Route):
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
|
||||
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
if not conversation:
|
||||
return Response().error("对话不存在").__dict__
|
||||
|
||||
@@ -143,14 +147,18 @@ class ConversationRoute(Route):
|
||||
|
||||
if not user_id or not cid:
|
||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
if not conversation:
|
||||
return Response().error("对话不存在").__dict__
|
||||
if title is not None:
|
||||
self.db_helper.update_conversation_title(user_id, cid, title)
|
||||
if persona_id is not None:
|
||||
self.db_helper.update_conversation_persona_id(user_id, cid, persona_id)
|
||||
|
||||
if title is not None or persona_id is not None:
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
return Response().ok({"message": "对话信息更新成功"}).__dict__
|
||||
|
||||
except Exception as e:
|
||||
@@ -201,11 +209,17 @@ class ConversationRoute(Route):
|
||||
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
|
||||
)
|
||||
|
||||
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid
|
||||
)
|
||||
if not conversation:
|
||||
return Response().error("对话不存在").__dict__
|
||||
|
||||
self.db_helper.update_conversation(user_id, cid, history)
|
||||
history = json.loads(history) if isinstance(history, str) else history
|
||||
|
||||
await self.conv_mgr.update_conversation(
|
||||
unified_msg_origin=user_id, conversation_id=cid, history=history
|
||||
)
|
||||
|
||||
return Response().ok({"message": "对话历史更新成功"}).__dict__
|
||||
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
class PersonaRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db_helper: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/persona/list": ("GET", self.list_personas),
|
||||
"/persona/detail": ("POST", self.get_persona_detail),
|
||||
"/persona/create": ("POST", self.create_persona),
|
||||
"/persona/update": ("POST", self.update_persona),
|
||||
"/persona/delete": ("POST", self.delete_persona),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.persona_mgr = core_lifecycle.persona_mgr
|
||||
self.register_routes()
|
||||
|
||||
async def list_personas(self):
|
||||
"""获取所有人格列表"""
|
||||
try:
|
||||
personas = await self.persona_mgr.get_all_personas()
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
[
|
||||
{
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
}
|
||||
for persona in personas
|
||||
]
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取人格列表失败: {str(e)}").__dict__
|
||||
|
||||
async def get_persona_detail(self):
|
||||
"""获取指定人格的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
persona = await self.persona_mgr.get_persona(persona_id)
|
||||
if not persona:
|
||||
return Response().error("人格不存在").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取人格详情失败: {str(e)}").__dict__
|
||||
|
||||
async def create_persona(self):
|
||||
"""创建新人格"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id", "").strip()
|
||||
system_prompt = data.get("system_prompt", "").strip()
|
||||
begin_dialogs = data.get("begin_dialogs", [])
|
||||
tools = data.get("tools")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("人格ID不能为空").__dict__
|
||||
|
||||
if not system_prompt:
|
||||
return Response().error("系统提示词不能为空").__dict__
|
||||
|
||||
# 验证 begin_dialogs 格式
|
||||
if begin_dialogs and len(begin_dialogs) % 2 != 0:
|
||||
return (
|
||||
Response()
|
||||
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
persona = await self.persona_mgr.create_persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs if begin_dialogs else None,
|
||||
tools=tools if tools else None,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": "人格创建成功",
|
||||
"persona": {
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools or [],
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
},
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"创建人格失败: {str(e)}").__dict__
|
||||
|
||||
async def update_persona(self):
|
||||
"""更新人格信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
system_prompt = data.get("system_prompt")
|
||||
begin_dialogs = data.get("begin_dialogs")
|
||||
tools = data.get("tools")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
# 验证 begin_dialogs 格式
|
||||
if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
|
||||
return (
|
||||
Response()
|
||||
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
await self.persona_mgr.update_persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return Response().ok({"message": "人格更新成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新人格失败: {str(e)}").__dict__
|
||||
|
||||
async def delete_persona(self):
|
||||
"""删除人格"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
await self.persona_mgr.delete_persona(persona_id)
|
||||
|
||||
return Response().ok({"message": "人格删除成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"删除人格失败: {str(e)}").__dict__
|
||||
@@ -40,8 +40,6 @@ class PluginRoute(Route):
|
||||
"/plugin/on": ("POST", self.on_plugin),
|
||||
"/plugin/reload": ("POST", self.reload_plugins),
|
||||
"/plugin/readme": ("GET", self.get_plugin_readme),
|
||||
"/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable),
|
||||
"/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
@@ -286,14 +284,6 @@ class PluginRoute(Route):
|
||||
f"{filter.parent_command_names[0]} {filter.command_name}"
|
||||
)
|
||||
info["cmd"] = info["cmd"].strip()
|
||||
if (
|
||||
self.core_lifecycle.astrbot_config["wake_prefix"]
|
||||
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
|
||||
> 0
|
||||
):
|
||||
info["cmd"] = (
|
||||
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
|
||||
)
|
||||
elif isinstance(filter, CommandGroupFilter):
|
||||
info["type"] = "指令组"
|
||||
info["cmd"] = filter.get_complete_command_names()[0]
|
||||
@@ -301,14 +291,6 @@ class PluginRoute(Route):
|
||||
info["sub_command"] = filter.print_cmd_tree(
|
||||
filter.sub_command_filters
|
||||
)
|
||||
if (
|
||||
self.core_lifecycle.astrbot_config["wake_prefix"]
|
||||
and len(self.core_lifecycle.astrbot_config["wake_prefix"])
|
||||
> 0
|
||||
):
|
||||
info["cmd"] = (
|
||||
f"{self.core_lifecycle.astrbot_config['wake_prefix'][0]}{info['cmd']}"
|
||||
)
|
||||
elif isinstance(filter, RegexFilter):
|
||||
info["type"] = "正则匹配"
|
||||
info["cmd"] = filter.regex_str
|
||||
@@ -498,90 +480,3 @@ class PluginRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
|
||||
return Response().error(f"读取README文件失败: {str(e)}").__dict__
|
||||
|
||||
async def get_plugin_platform_enable(self):
|
||||
"""获取插件在各平台的可用性配置"""
|
||||
try:
|
||||
platform_enable = self.core_lifecycle.astrbot_config.get(
|
||||
"platform_settings", {}
|
||||
).get("plugin_enable", {})
|
||||
|
||||
# 获取所有可用平台
|
||||
platforms = []
|
||||
|
||||
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
|
||||
platform_type = platform.get("type", "")
|
||||
platform_id = platform.get("id", "")
|
||||
|
||||
platforms.append(
|
||||
{
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
}
|
||||
)
|
||||
|
||||
adjusted_platform_enable = {}
|
||||
for platform_id, plugins in platform_enable.items():
|
||||
adjusted_platform_enable[platform_id] = plugins
|
||||
|
||||
# 获取所有插件,包括系统内部插件
|
||||
plugins = []
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
plugins.append(
|
||||
{
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def set_plugin_platform_enable(self):
|
||||
"""设置插件在各平台的可用性配置"""
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
data = await request.json
|
||||
platform_enable = data.get("platform_enable", {})
|
||||
|
||||
# 更新配置
|
||||
config = self.core_lifecycle.astrbot_config
|
||||
platform_settings = config.get("platform_settings", {})
|
||||
platform_settings["plugin_enable"] = platform_enable
|
||||
config["platform_settings"] = platform_settings
|
||||
config.save_config()
|
||||
|
||||
# 更新插件的平台兼容性缓存
|
||||
await self.plugin_manager.update_all_platform_compatibility()
|
||||
|
||||
logger.info(f"插件平台可用性配置已更新: {platform_enable}")
|
||||
|
||||
return Response().ok(None, "插件平台可用性配置已更新").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -21,17 +21,19 @@ class Route:
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
status: str = None
|
||||
message: str = None
|
||||
data: dict = None
|
||||
status: str | None = None
|
||||
message: str | None = None
|
||||
data: dict | list | None = None
|
||||
|
||||
def error(self, message: str):
|
||||
self.status = "error"
|
||||
self.message = message
|
||||
return self
|
||||
|
||||
def ok(self, data: dict = {}, message: str = None):
|
||||
def ok(self, data: dict | list | None = None, message: str | None = None):
|
||||
self.status = "ok"
|
||||
if data is None:
|
||||
data = {}
|
||||
self.data = data
|
||||
self.message = message
|
||||
return self
|
||||
|
||||
@@ -24,7 +24,6 @@ class SessionManagementRoute(Route):
|
||||
"/session/list": ("GET", self.list_sessions),
|
||||
"/session/update_persona": ("POST", self.update_session_persona),
|
||||
"/session/update_provider": ("POST", self.update_session_provider),
|
||||
"/session/get_session_info": ("POST", self.get_session_info),
|
||||
"/session/plugins": ("GET", self.get_session_plugins),
|
||||
"/session/update_plugin": ("POST", self.update_session_plugin),
|
||||
"/session/update_llm": ("POST", self.update_session_llm),
|
||||
@@ -32,24 +31,20 @@ class SessionManagementRoute(Route):
|
||||
"/session/update_name": ("POST", self.update_session_name),
|
||||
"/session/update_status": ("POST", self.update_session_status),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
# 获取会话对话映射
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
|
||||
# 获取会话提供商偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
|
||||
# 获取可用的 personas
|
||||
personas = self.core_lifecycle.star_context.provider_manager.personas
|
||||
|
||||
# 获取可用的 providers
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[])
|
||||
session_conversations = {}
|
||||
for pref in preferences:
|
||||
session_conversations[pref.scope_id] = pref.value["val"]
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
persona_mgr = self.core_lifecycle.persona_mgr
|
||||
personas = persona_mgr.personas_v3
|
||||
|
||||
sessions = []
|
||||
|
||||
@@ -59,13 +54,9 @@ class SessionManagementRoute(Route):
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"session_enabled": SessionServiceManager.is_session_enabled(
|
||||
session_id
|
||||
),
|
||||
@@ -90,74 +81,46 @@ class SessionManagementRoute(Route):
|
||||
}
|
||||
|
||||
# 获取对话信息
|
||||
conversation = self.db_helper.get_conversation_by_user_id(
|
||||
session_id, conversation_id
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=session_id, conversation_id=conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
session_info["persona_id"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
session_info["persona_id"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
default_persona = persona_mgr.selected_default_persona_v3
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = (
|
||||
default_stt_provider.meta().id
|
||||
)
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = (
|
||||
default_tts_provider.meta().id
|
||||
)
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
chat_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION, umo=session_id
|
||||
)
|
||||
tts_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH, umo=session_id
|
||||
)
|
||||
stt_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT, umo=session_id
|
||||
)
|
||||
if chat_provider:
|
||||
meta = chat_provider.meta()
|
||||
session_info["chat_provider_id"] = meta.id
|
||||
if tts_provider:
|
||||
meta = tts_provider.meta()
|
||||
session_info["tts_provider_id"] = meta.id
|
||||
if stt_provider:
|
||||
meta = stt_provider.meta()
|
||||
session_info["stt_provider_id"] = meta.id
|
||||
|
||||
sessions.append(session_info)
|
||||
|
||||
@@ -311,133 +274,6 @@ class SessionManagementRoute(Route):
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话提供商失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_info(self):
|
||||
"""获取指定会话的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
# 获取会话对话信息
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
conversation_id = session_conversations.get(session_id)
|
||||
|
||||
if not conversation_id:
|
||||
return Response().error(f"会话 {session_id} 未找到对话").__dict__
|
||||
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"tts_enabled": None, # 将在下面设置
|
||||
"platform": session_id.split(":")[0]
|
||||
if ":" in session_id
|
||||
else "unknown",
|
||||
"message_type": session_id.split(":")[1]
|
||||
if session_id.count(":") >= 1
|
||||
else "unknown",
|
||||
"session_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
}
|
||||
|
||||
# 获取TTS状态
|
||||
session_info["tts_enabled"] = (
|
||||
SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
)
|
||||
|
||||
# 获取对话信息
|
||||
conversation = self.db_helper.get_conversation_by_user_id(
|
||||
session_id, conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
personas = provider_manager.personas
|
||||
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = default_stt_provider.meta().id
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = default_tts_provider.meta().id
|
||||
|
||||
return Response().ok(session_info).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话信息失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_plugins(self):
|
||||
"""获取指定会话的插件配置信息"""
|
||||
try:
|
||||
|
||||
@@ -11,6 +11,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot.core.db.migration.helper import check_migration_needed_v4
|
||||
|
||||
|
||||
class StatRoute(Route):
|
||||
@@ -59,6 +60,8 @@ class StatRoute(Route):
|
||||
)
|
||||
|
||||
async def get_version(self):
|
||||
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
@@ -66,6 +69,7 @@ class StatRoute(Route):
|
||||
"version": VERSION,
|
||||
"dashboard_version": await get_dashboard_version(),
|
||||
"change_pwd_hint": self.is_default_cred(),
|
||||
"need_migration": need_migration,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
@@ -84,7 +88,7 @@ class StatRoute(Route):
|
||||
message_time_based_stats = []
|
||||
|
||||
idx = 0
|
||||
for bucket_end in range(start_time, now, 1800):
|
||||
for bucket_end in range(start_time, now, 3600):
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(stat.platform)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import aiohttp
|
||||
@@ -7,7 +5,7 @@ from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.star import star_map
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -25,44 +23,17 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/add": ("POST", self.add_mcp_server),
|
||||
"/tools/mcp/update": ("POST", self.update_mcp_server),
|
||||
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
|
||||
"/tools/mcp/market": ("GET", self.get_mcp_markets),
|
||||
"/tools/mcp/test": ("POST", self.test_mcp_connection),
|
||||
"/tools/list": ("GET", self.get_tool_list),
|
||||
"/tools/toggle-tool": ("POST", self.toggle_tool),
|
||||
"/tools/mcp/sync-provider": ("POST", self.sync_provider),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
|
||||
@property
|
||||
def mcp_config_path(self):
|
||||
data_dir = get_astrbot_data_path()
|
||||
return os.path.join(data_dir, "mcp_server.json")
|
||||
|
||||
def load_mcp_config(self):
|
||||
if not os.path.exists(self.mcp_config_path):
|
||||
# 配置文件不存在,创建默认配置
|
||||
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
|
||||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
try:
|
||||
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 MCP 配置失败: {e}")
|
||||
return DEFAULT_MCP_CONFIG
|
||||
|
||||
def save_mcp_config(self, config):
|
||||
try:
|
||||
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存 MCP 配置失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_mcp_servers(self):
|
||||
try:
|
||||
config = self.load_mcp_config()
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
servers = []
|
||||
|
||||
# 获取所有服务器并添加它们的工具列表
|
||||
@@ -125,14 +96,14 @@ class ToolsRoute(Route):
|
||||
if not has_valid_config:
|
||||
return Response().error("必须提供有效的服务器配置").__dict__
|
||||
|
||||
config = self.load_mcp_config()
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if name in config["mcpServers"]:
|
||||
return Response().error(f"服务器 {name} 已存在").__dict__
|
||||
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
try:
|
||||
await self.tool_mgr.enable_mcp_server(
|
||||
name, server_config, timeout=30
|
||||
@@ -162,7 +133,7 @@ class ToolsRoute(Route):
|
||||
if not name:
|
||||
return Response().error("服务器名称不能为空").__dict__
|
||||
|
||||
config = self.load_mcp_config()
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if name not in config["mcpServers"]:
|
||||
return Response().error(f"服务器 {name} 不存在").__dict__
|
||||
@@ -198,7 +169,7 @@ class ToolsRoute(Route):
|
||||
|
||||
config["mcpServers"][name] = server_config
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
# 处理MCP客户端状态变化
|
||||
if active:
|
||||
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
|
||||
@@ -266,14 +237,14 @@ class ToolsRoute(Route):
|
||||
if not name:
|
||||
return Response().error("服务器名称不能为空").__dict__
|
||||
|
||||
config = self.load_mcp_config()
|
||||
config = self.tool_mgr.load_mcp_config()
|
||||
|
||||
if name not in config["mcpServers"]:
|
||||
return Response().error(f"服务器 {name} 不存在").__dict__
|
||||
|
||||
del config["mcpServers"][name]
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
if self.tool_mgr.save_mcp_config(config):
|
||||
if name in self.tool_mgr.mcp_client_dict:
|
||||
try:
|
||||
await self.tool_mgr.disable_mcp_server(name, timeout=10)
|
||||
@@ -295,31 +266,6 @@ class ToolsRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__
|
||||
|
||||
async def get_mcp_markets(self):
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 10, type=int)
|
||||
BASE_URL = (
|
||||
"https://api.soulter.top/astrbot/mcpservers?page={}&page_size={}".format(
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{BASE_URL}") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return Response().ok(data["data"]).__dict__
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"获取市场数据失败: HTTP {response.status}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as _:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error("获取市场数据失败").__dict__
|
||||
|
||||
async def test_mcp_connection(self):
|
||||
"""
|
||||
测试 MCP 服务器连接
|
||||
@@ -336,3 +282,57 @@ class ToolsRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
|
||||
|
||||
async def get_tool_list(self):
|
||||
"""获取所有注册的工具列表"""
|
||||
try:
|
||||
tools = self.tool_mgr.func_list
|
||||
tools_dict = [tool.__dict__() for tool in tools]
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取工具列表失败: {str(e)}").__dict__
|
||||
|
||||
async def toggle_tool(self):
|
||||
"""启用或停用指定的工具"""
|
||||
try:
|
||||
data = await request.json
|
||||
tool_name = data.get("name")
|
||||
action = data.get("activate") # True or False
|
||||
|
||||
if not tool_name or action is None:
|
||||
return Response().error("缺少必要参数: name 或 action").__dict__
|
||||
|
||||
if action:
|
||||
try:
|
||||
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
||||
except ValueError as e:
|
||||
return Response().error(f"启用工具失败: {str(e)}").__dict__
|
||||
else:
|
||||
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
||||
|
||||
if ok:
|
||||
return Response().ok(None, "操作成功。").__dict__
|
||||
else:
|
||||
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"操作工具失败: {str(e)}").__dict__
|
||||
|
||||
async def sync_provider(self):
|
||||
"""同步 MCP 提供者配置"""
|
||||
try:
|
||||
data = await request.json
|
||||
provider_name = data.get("name") # modelscope, or others
|
||||
match provider_name:
|
||||
case "modelscope":
|
||||
access_token = data.get("access_token", "")
|
||||
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
|
||||
case _:
|
||||
return Response().error(f"未知: {provider_name}").__dict__
|
||||
|
||||
return Response().ok(message="同步成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"同步失败: {str(e)}").__dict__
|
||||
|
||||
@@ -7,6 +7,7 @@ from astrbot.core import logger, pip_installer
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot.core.db.migration.helper import do_migration_v4, check_migration_needed_v4
|
||||
|
||||
|
||||
class UpdateRoute(Route):
|
||||
@@ -23,11 +24,27 @@ class UpdateRoute(Route):
|
||||
"/update/do": ("POST", self.update_project),
|
||||
"/update/dashboard": ("POST", self.update_dashboard),
|
||||
"/update/pip-install": ("POST", self.install_pip_package),
|
||||
"/update/migration": ("POST", self.do_migration),
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
async def do_migration(self):
|
||||
need_migration = await check_migration_needed_v4(self.core_lifecycle.db)
|
||||
if not need_migration:
|
||||
return Response().ok(None, "不需要进行迁移。").__dict__
|
||||
try:
|
||||
data = await request.json
|
||||
pim = data.get("platform_id_map", {})
|
||||
await do_migration_v4(
|
||||
self.core_lifecycle.db, pim, self.core_lifecycle.astrbot_config
|
||||
)
|
||||
return Response().ok(None, "迁移成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"迁移失败: {traceback.format_exc()}")
|
||||
return Response().error(f"迁移失败: {str(e)}").__dict__
|
||||
|
||||
async def check_update(self):
|
||||
type_ = request.args.get("type", None)
|
||||
|
||||
@@ -40,7 +57,7 @@ class UpdateRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
ret = await self.astrbot_updator.check_update(None, None, False)
|
||||
ret = await self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
@@ -83,7 +100,9 @@ class UpdateRoute(Route):
|
||||
)
|
||||
|
||||
try:
|
||||
await download_dashboard(latest=latest, version=version, proxy=proxy)
|
||||
await download_dashboard(
|
||||
latest=latest, version=version, proxy=proxy
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
|
||||
@@ -114,7 +133,7 @@ class UpdateRoute(Route):
|
||||
async def update_dashboard(self):
|
||||
try:
|
||||
try:
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
return Response().error(f"下载管理面板文件失败: {e}").__dict__
|
||||
|
||||
@@ -60,6 +60,9 @@ class AstrBotDashboard:
|
||||
self.session_management_route = SessionManagementRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
self.persona_route = PersonaRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复:构建 docker 镜像时同时构建 webui,并放入镜像中。
|
||||
2. 修复:下载 WebUI 文件时,明确版本号,以防止 latest 不一致导致下载的 WebUI 文件版本号与实际所需不符的问题。
|
||||
3. 优化:优化版本检测,考虑预发布版本,移除 `更新到最新版本` 按钮
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
@@ -1,40 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 24.1.2, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<svg version="1.1" id="Layer_3" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 128 128" style="enable-background:new 0 0 128 128;" xml:space="preserve">
|
||||
<g>
|
||||
<linearGradient id="SVGID_1_" gradientUnits="userSpaceOnUse" x1="93.7287" y1="106.6446" x2="52.9011" y2="81.6944">
|
||||
<stop offset="0.0969" style="stop-color:#FFB300"/>
|
||||
<stop offset="1" style="stop-color:#FFB300;stop-opacity:0"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_1_);" d="M123.04,107.67c-4.08-4.12-9.38-9.48-14.92-15.06c-0.34,1.29-0.93,2.39-1.79,3.26
|
||||
c-6.43,6.43-25.6-1.99-45.31-19.1c-2.46-2.13-16.74,20.28-14.1,22.87c3.27,3.2,26,17.86,33.78,20.73
|
||||
c22.66,8.35,34.3,0.22,38.24-3.59C121.16,114.61,122.51,111.5,123.04,107.67z"/>
|
||||
<linearGradient id="SVGID_2_" gradientUnits="userSpaceOnUse" x1="115.2813" y1="82.3624" x2="14.863" y2="0.8196">
|
||||
<stop offset="0" style="stop-color:#FFB300"/>
|
||||
<stop offset="0.7062" style="stop-color:#FDD835"/>
|
||||
<stop offset="0.8408" style="stop-color:#FDDC36"/>
|
||||
<stop offset="0.9842" style="stop-color:#FFE93A"/>
|
||||
<stop offset="1" style="stop-color:#FFEB3B"/>
|
||||
</linearGradient>
|
||||
<path style="fill:url(#SVGID_2_);" d="M25.05,27.7c-1.54-4.81-2.88-11.1-0.4-13.5c7.51-7.3,31.69,4.88,54.25,27.43
|
||||
c22.55,22.55,34.84,46.84,27.43,54.25c-0.07,0.07-0.16,0.13-0.23,0.2c6.13,5.82,12.2,11.6,16.1,15.31
|
||||
c4.87-14.43-6.45-44.11-31.5-69.96c-4.07-4.2-16.12-16.56-26.55-23.56C54.61,11.47,44.19,5.59,32.57,4.2
|
||||
C25,3.29,11.45,5.24,14.25,15.98c0.55,2.12,2.31,7.22,8.15,13.3C23.56,30.49,25.56,29.3,25.05,27.7z"/>
|
||||
<g>
|
||||
<path style="fill:#FDD835;" d="M55.98,42.1l-0.75,20c-0.06,1.53,0.72,2.98,2.04,3.77l16.86,10.11c1.85,1.25,1.46,4.09-0.66,4.79
|
||||
L54.79,85.5c-1.51,0.38-2.69,1.57-3.06,3.08l-4.89,19.93c-0.62,2.15-3.43,2.65-4.76,0.85L31.06,92.91
|
||||
c-0.85-1.26-2.31-1.97-3.83-1.85L7.49,92.61c-2.23,0.07-3.58-2.45-2.28-4.27l12.6-16.19c0.96-1.23,1.16-2.89,0.52-4.31
|
||||
l-7.88-17.57c-0.76-2.1,1.22-4.17,3.35-3.49l18.39,6.95c1.44,0.54,3.05,0.26,4.22-0.74l15.22-13
|
||||
C53.39,38.62,55.96,39.87,55.98,42.1z"/>
|
||||
<g>
|
||||
<path style="fill:#FFFF8D;" d="M46.99,59.33l4.66-12.75c0.28-0.7,0.7-1.93,1.79-1.4c0.86,0.42,0.46,2.43,0.46,2.43l-1.05,11.54
|
||||
c-0.41,4.39-1.6,5.38-3.3,5.49C47.6,64.75,45.65,62.98,46.99,59.33z"/>
|
||||
</g>
|
||||
<g>
|
||||
<path style="fill:#F4B400;" d="M53.89,83.73l14.53-3.13c0.73-0.18,2.01-0.42,1.64-1.58c-0.29-0.91-2.34-0.8-2.34-0.8l-10.97-0.86
|
||||
c-3.21-0.38-5.72,0.14-6.74,1.84C48.65,81.48,49.89,84.32,53.89,83.73z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.6 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 54 KiB |
@@ -2,6 +2,9 @@
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
const props = defineProps({
|
||||
@@ -48,6 +51,47 @@ function openEditorDialog(key, value, theme, language) {
|
||||
function saveEditedContent() {
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function getValueBySelector(obj, selector) {
|
||||
const keys = selector.split('.')
|
||||
let current = obj
|
||||
for (const key of keys) {
|
||||
if (current && typeof current === 'object' && key in current) {
|
||||
current = current[key]
|
||||
} else {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
function shouldShowItem(itemMeta, itemKey) {
|
||||
if (!itemMeta?.condition) {
|
||||
return true
|
||||
}
|
||||
for (const [conditionKey, expectedValue] of Object.entries(itemMeta.condition)) {
|
||||
const actualValue = getValueBySelector(props.iterable, conditionKey)
|
||||
if (actualValue !== expectedValue) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
function hasVisibleItemsAfter(items, currentIndex) {
|
||||
const itemEntries = Object.entries(items)
|
||||
|
||||
// 检查当前索引之后是否还有可见的配置项
|
||||
for (let i = currentIndex + 1; i < itemEntries.length; i++) {
|
||||
const [itemKey, itemValue] = itemEntries[i]
|
||||
const itemMeta = props.metadata[props.metadataKey].items[itemKey]
|
||||
if (!itemMeta?.invisible && shouldShowItem(itemMeta, itemKey)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -79,7 +123,7 @@ function saveEditedContent() {
|
||||
<div v-for="(val, key, index) in filteredIterable" :key="key" class="config-item">
|
||||
<!-- Nested Object -->
|
||||
<div v-if="metadata[metadataKey].items[key]?.type === 'object'" class="nested-object">
|
||||
<div v-if="metadata[metadataKey].items[key] && !metadata[metadataKey].items[key]?.invisible" class="nested-container">
|
||||
<div v-if="metadata[metadataKey].items[key] && !metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="nested-container">
|
||||
<v-expand-transition>
|
||||
<AstrBotConfig :metadata="metadata[metadataKey].items" :iterable="iterable[key]" :metadataKey="key">
|
||||
</AstrBotConfig>
|
||||
@@ -89,7 +133,7 @@ function saveEditedContent() {
|
||||
|
||||
<!-- Regular Property -->
|
||||
<template v-else>
|
||||
<v-row v-if="!metadata[metadataKey].items[key]?.invisible" class="config-row">
|
||||
<v-row v-if="!metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="config-row">
|
||||
<v-col cols="12" sm="6" class="property-info">
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
@@ -120,9 +164,64 @@ function saveEditedContent() {
|
||||
|
||||
<v-col cols="12" sm="5" class="config-input">
|
||||
<div v-if="metadata[metadataKey].items[key]" class="w-100">
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-if="metadata[metadataKey].items[key]?._special === 'select_provider'">
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'chat_completion'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_provider_stt'">
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'speech_to_text'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_provider_tts'">
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'text_to_speech'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_persona'">
|
||||
<PersonaSelector
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<!-- List item with options-->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible && metadata[metadataKey].items[key]?.render_type === 'checkbox'"
|
||||
class="d-flex flex-wrap gap-20">
|
||||
<v-checkbox
|
||||
v-for="(option, index) in metadata[metadataKey].items[key]?.options"
|
||||
v-model="iterable[key]"
|
||||
:label="metadata[metadataKey].items[key]?.labels ? metadata[metadataKey].items[key].labels[index] : option"
|
||||
:value="option"
|
||||
class="mr-2"
|
||||
color="primary"
|
||||
hide-details
|
||||
></v-checkbox>
|
||||
</div>
|
||||
<!-- List item with options-->
|
||||
<v-combobox
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
:items="metadata[metadataKey].items[key]?.options"
|
||||
:disabled="metadata[metadataKey].items[key]?.readonly"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
chips
|
||||
multiple
|
||||
></v-combobox>
|
||||
<!-- Select input -->
|
||||
<v-select
|
||||
v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-else-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
:items="metadata[metadataKey].items[key]?.options"
|
||||
:disabled="metadata[metadataKey].items[key]?.readonly"
|
||||
@@ -198,7 +297,7 @@ function saveEditedContent() {
|
||||
<!-- List item -->
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
|
||||
:value="iterable[key]"
|
||||
v-model="iterable[key]"
|
||||
class="config-field"
|
||||
/>
|
||||
</div>
|
||||
@@ -218,7 +317,7 @@ function saveEditedContent() {
|
||||
</v-row>
|
||||
|
||||
<v-divider
|
||||
v-if="index !== Object.keys(iterable).length - 1 && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-if="hasVisibleItemsAfter(filteredIterable, index) && !metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)"
|
||||
class="config-divider"
|
||||
></v-divider>
|
||||
</template>
|
||||
@@ -309,9 +408,9 @@ function saveEditedContent() {
|
||||
></v-switch>
|
||||
|
||||
<!-- List item -->
|
||||
<ListConfigItem
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
|
||||
:value="iterable[metadataKey]"
|
||||
v-model="iterable[metadataKey]"
|
||||
class="config-field"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,456 @@
|
||||
<script setup>
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
import PluginSetSelector from './PluginSetSelector.vue'
|
||||
import T2ITemplateEditor from './T2ITemplateEditor.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
|
||||
const props = defineProps({
|
||||
metadata: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
iterable: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
metadataKey: {
|
||||
type: String,
|
||||
required: true
|
||||
}
|
||||
})
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const dialog = ref(false)
|
||||
const currentEditingKey = ref('')
|
||||
const currentEditingLanguage = ref('json')
|
||||
const currentEditingTheme = ref('vs-light')
|
||||
let currentEditingKeyIterable = null
|
||||
|
||||
function getValueBySelector(obj, selector) {
|
||||
const keys = selector.split('.')
|
||||
let current = obj
|
||||
for (const key of keys) {
|
||||
if (current && typeof current === 'object' && key in current) {
|
||||
current = current[key]
|
||||
} else {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
function setValueBySelector(obj, selector, value) {
|
||||
const keys = selector.split('.')
|
||||
let current = obj
|
||||
|
||||
// 创建嵌套对象路径
|
||||
for (let i = 0; i < keys.length - 1; i++) {
|
||||
const key = keys[i]
|
||||
if (!current[key] || typeof current[key] !== 'object') {
|
||||
current[key] = {}
|
||||
}
|
||||
current = current[key]
|
||||
}
|
||||
|
||||
// 设置最终值
|
||||
current[keys[keys.length - 1]] = value
|
||||
}
|
||||
|
||||
// 创建一个计算属性来处理 JSON selector 的获取和设置
|
||||
function createSelectorModel(selector) {
|
||||
return computed({
|
||||
get() {
|
||||
return getValueBySelector(props.iterable, selector)
|
||||
},
|
||||
set(value) {
|
||||
setValueBySelector(props.iterable, selector, value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function openEditorDialog(key, value, theme, language) {
|
||||
currentEditingKey.value = key
|
||||
currentEditingLanguage.value = language || 'json'
|
||||
currentEditingTheme.value = theme || 'vs-light'
|
||||
currentEditingKeyIterable = value
|
||||
dialog.value = true
|
||||
}
|
||||
|
||||
function saveEditedContent() {
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function shouldShowItem(itemMeta, itemKey) {
|
||||
if (!itemMeta?.condition) {
|
||||
return true
|
||||
}
|
||||
for (const [conditionKey, expectedValue] of Object.entries(itemMeta.condition)) {
|
||||
const actualValue = getValueBySelector(props.iterable, conditionKey)
|
||||
if (actualValue !== expectedValue) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
function hasVisibleItemsAfter(items, currentIndex) {
|
||||
const itemEntries = Object.entries(items)
|
||||
|
||||
// 检查当前索引之后是否还有可见的配置项
|
||||
for (let i = currentIndex + 1; i < itemEntries.length; i++) {
|
||||
const [itemKey, itemMeta] = itemEntries[i]
|
||||
if (shouldShowItem(itemMeta, itemKey)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
|
||||
<v-card style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));" rounded="md" variant="outlined">
|
||||
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'">
|
||||
<v-list-item-title class="config-title">
|
||||
{{ metadata[metadataKey]?.description }}
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="config-hint">
|
||||
<span v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint" class="important-hint">‼️</span>
|
||||
{{ metadata[metadataKey]?.hint }}
|
||||
</v-list-item-subtitle>
|
||||
</v-card-text>
|
||||
|
||||
<!-- Object Type Configuration with JSON Selector Support -->
|
||||
<div v-if="metadata[metadataKey]?.type === 'object'" class="object-config">
|
||||
<div v-for="(itemMeta, itemKey, index) in metadata[metadataKey].items" :key="itemKey" class="config-item">
|
||||
<!-- Check if itemKey is a JSON selector -->
|
||||
<template v-if="shouldShowItem(itemMeta, itemKey)">
|
||||
<!-- JSON Selector Property -->
|
||||
<v-row v-if="!itemMeta?.invisible" class="config-row">
|
||||
<v-col cols="12" sm="6" class="property-info">
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
{{ itemMeta?.description || itemKey }}
|
||||
<span class="property-key">({{ itemKey }})</span>
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle class="property-hint">
|
||||
<span v-if="itemMeta?.obvious_hint && itemMeta?.hint" class="important-hint">‼️</span>
|
||||
{{ itemMeta?.hint }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
<v-col cols="12" sm="6" class="config-input">
|
||||
<div class="w-100" v-if="!itemMeta?._special">
|
||||
<!-- Select input for JSON selector -->
|
||||
<v-select v-if="itemMeta?.options" v-model="createSelectorModel(itemKey).value"
|
||||
:items="itemMeta?.options" :disabled="itemMeta?.readonly" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-select>
|
||||
|
||||
<!-- Code Editor for JSON selector -->
|
||||
<div v-else-if="itemMeta?.editor_mode" class="editor-container">
|
||||
<VueMonacoEditor :theme="itemMeta?.editor_theme || 'vs-light'"
|
||||
:language="itemMeta?.editor_language || 'json'"
|
||||
style="min-height: 100px; flex-grow: 1; border: 1px solid rgba(0, 0, 0, 0.1);"
|
||||
v-model:value="createSelectorModel(itemKey).value">
|
||||
</VueMonacoEditor>
|
||||
<v-btn icon size="small" variant="text" color="primary" class="editor-fullscreen-btn"
|
||||
@click="openEditorDialog(itemKey, iterable, itemMeta?.editor_theme, itemMeta?.editor_language)"
|
||||
:title="t('core.common.editor.fullscreen')">
|
||||
<v-icon>mdi-fullscreen</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- String input for JSON selector -->
|
||||
<v-text-field v-else-if="itemMeta?.type === 'string'" v-model="createSelectorModel(itemKey).value"
|
||||
density="compact" variant="outlined" class="config-field" hide-details></v-text-field>
|
||||
|
||||
<!-- Numeric input for JSON selector -->
|
||||
<v-text-field v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'"
|
||||
v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined" class="config-field"
|
||||
type="number" hide-details></v-text-field>
|
||||
|
||||
<!-- Text area for JSON selector -->
|
||||
<v-textarea v-else-if="itemMeta?.type === 'text'" v-model="createSelectorModel(itemKey).value"
|
||||
variant="outlined" rows="3" class="config-field" hide-details></v-textarea>
|
||||
|
||||
<!-- Boolean switch for JSON selector -->
|
||||
<v-switch v-else-if="itemMeta?.type === 'bool'" v-model="createSelectorModel(itemKey).value"
|
||||
color="primary" inset density="compact" hide-details style="display: flex; justify-content: end;"></v-switch>
|
||||
|
||||
<!-- List item for JSON selector -->
|
||||
<ListConfigItem
|
||||
v-else-if="itemMeta?.type === 'list'"
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
button-text="修改"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<!-- Fallback for JSON selector -->
|
||||
<v-text-field v-else v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-text-field>
|
||||
</div>
|
||||
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-else-if="itemMeta?._special === 'select_provider'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'chat_completion'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_stt'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'speech_to_text'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_tts'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'text_to_speech'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'provider_pool'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..."
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_persona'">
|
||||
<PersonaSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'persona_pool'">
|
||||
<PersonaSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
button-text="选择人格池..."
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_plugin_set'">
|
||||
<PluginSetSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 't2i_template'">
|
||||
<T2ITemplateEditor />
|
||||
</div>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<!-- Plugin Set Selector 全宽显示区域 -->
|
||||
<v-row v-if="!itemMeta?.invisible && itemMeta?._special === 'select_plugin_set'" class="plugin-set-display-row">
|
||||
<v-col cols="12" class="plugin-set-display">
|
||||
<div v-if="createSelectorModel(itemKey).value && createSelectorModel(itemKey).value.length > 0" class="selected-plugins-full-width">
|
||||
<div class="plugins-header">
|
||||
<small class="text-grey">已选择的插件:</small>
|
||||
</div>
|
||||
<div class="d-flex flex-wrap ga-2 mt-2">
|
||||
<v-chip
|
||||
v-for="plugin in (createSelectorModel(itemKey).value || [])"
|
||||
:key="plugin"
|
||||
size="small"
|
||||
label
|
||||
color="primary"
|
||||
variant="outlined"
|
||||
>
|
||||
{{ plugin === '*' ? '所有插件' : plugin }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</template>
|
||||
<v-divider class="config-divider" v-if="shouldShowItem(itemMeta, itemKey) && hasVisibleItemsAfter(metadata[metadataKey].items, index)"></v-divider>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</v-card>
|
||||
|
||||
<!-- Full Screen Editor Dialog -->
|
||||
<v-dialog v-model="dialog" fullscreen transition="dialog-bottom-transition" scrollable>
|
||||
<v-card>
|
||||
<v-toolbar color="primary" dark>
|
||||
<v-btn icon @click="dialog = false">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
<v-toolbar-title>{{ t('core.common.editor.editingTitle') }} - {{ currentEditingKey }}</v-toolbar-title>
|
||||
<v-spacer></v-spacer>
|
||||
<v-toolbar-items>
|
||||
<v-btn variant="text" @click="saveEditedContent">{{ t('core.common.save') }}</v-btn>
|
||||
</v-toolbar-items>
|
||||
</v-toolbar>
|
||||
<v-card-text class="pa-0">
|
||||
<VueMonacoEditor :theme="currentEditingTheme" :language="currentEditingLanguage"
|
||||
style="height: calc(100vh - 64px);" v-model:value="currentEditingKeyIterable[currentEditingKey]">
|
||||
</VueMonacoEditor>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
|
||||
|
||||
<style scoped>
|
||||
.config-section {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.config-title {
|
||||
/* font-weight: 600; */
|
||||
font-size: 1.3rem;
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.config-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.metadata-key,
|
||||
.property-key {
|
||||
font-size: 0.85em;
|
||||
opacity: 0.7;
|
||||
font-weight: normal;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.important-hint {
|
||||
opacity: 1;
|
||||
margin-right: 4px;
|
||||
}
|
||||
|
||||
.object-config,
|
||||
.simple-config {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.nested-object {
|
||||
padding-left: 16px;
|
||||
}
|
||||
|
||||
.nested-container {
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
margin: 12px 0;
|
||||
background-color: rgba(0, 0, 0, 0.02);
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.config-row {
|
||||
margin: 0;
|
||||
align-items: center;
|
||||
padding: 10px 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.config-row:hover {
|
||||
background-color: rgba(0, 0, 0, 0.03);
|
||||
}
|
||||
|
||||
.property-info {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.property-name {
|
||||
font-size: 0.875rem;
|
||||
/* font-weight: 600; */
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.property-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.type-indicator {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.config-input {
|
||||
padding: 4px 8px;
|
||||
}
|
||||
|
||||
.config-field {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.config-divider {
|
||||
border-color: rgba(0, 0, 0, 0.1);
|
||||
margin-left: 24px;
|
||||
}
|
||||
|
||||
.editor-container {
|
||||
position: relative;
|
||||
display: flex;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.editor-fullscreen-btn {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
z-index: 10;
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.editor-fullscreen-btn:hover {
|
||||
background-color: rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
.plugin-set-display-row {
|
||||
margin: 16px;
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.plugin-set-display {
|
||||
padding: 0 8px;
|
||||
}
|
||||
|
||||
.selected-plugins-full-width {
|
||||
background-color: rgba(var(--v-theme-primary), 0.05);
|
||||
border: 1px solid rgba(var(--v-theme-primary), 0.1);
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.plugins-header {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.nested-object {
|
||||
padding-left: 8px;
|
||||
}
|
||||
|
||||
.config-row {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.property-info,
|
||||
.type-indicator,
|
||||
.config-input {
|
||||
padding: 4px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,225 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
</span>
|
||||
<span v-else>
|
||||
{{ modelValue }}
|
||||
</span>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- Knowledge Base Selection Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
选择知识库
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<!-- 插件未安装提示 -->
|
||||
<div v-if="!loading && !pluginInstalled" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-puzzle-outline</v-icon>
|
||||
<p class="text-grey mt-4 mb-4">知识库插件未安装</p>
|
||||
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
|
||||
前往知识库页面
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- 知识库列表 -->
|
||||
<v-list v-else-if="!loading && pluginInstalled" density="compact">
|
||||
<!-- 不使用选项 -->
|
||||
<v-list-item
|
||||
:value="''"
|
||||
@click="selectKnowledgeBase({ collection_name: '' })"
|
||||
:active="selectedKnowledgeBase === ''"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<template v-slot:prepend>
|
||||
<v-icon color="grey-lighten-1">mdi-close-circle-outline</v-icon>
|
||||
</template>
|
||||
<v-list-item-title>不使用</v-list-item-title>
|
||||
<v-list-item-subtitle>不使用任何知识库</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedKnowledgeBase === ''" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
|
||||
<v-divider v-if="knowledgeBaseList.length > 0" class="my-2"></v-divider>
|
||||
|
||||
<!-- 知识库选项 -->
|
||||
<v-list-item
|
||||
v-for="kb in knowledgeBaseList"
|
||||
:key="kb.collection_name"
|
||||
:value="kb.collection_name"
|
||||
@click="selectKnowledgeBase(kb)"
|
||||
:active="selectedKnowledgeBase === kb.collection_name"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<template v-slot:prepend>
|
||||
<span class="emoji-icon">{{ kb.emoji || '🙂' }}</span>
|
||||
</template>
|
||||
<v-list-item-title>{{ kb.collection_name }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ kb.description || '无描述' }}
|
||||
<span v-if="kb.count !== undefined"> - {{ kb.count }} 项知识</span>
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedKnowledgeBase === kb.collection_name" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
|
||||
<!-- 当没有知识库时显示创建提示 -->
|
||||
<div v-if="knowledgeBaseList.length === 0" class="text-center py-4">
|
||||
<p class="text-grey mb-4">暂无知识库</p>
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="goToKnowledgeBasePage">
|
||||
创建知识库
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-list>
|
||||
|
||||
<!-- 空状态(插件未安装时保留原有逻辑) -->
|
||||
<div v-else-if="!loading && !pluginInstalled && knowledgeBaseList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-database-off</v-icon>
|
||||
<p class="text-grey mt-4 mb-4">暂无知识库</p>
|
||||
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
|
||||
创建知识库
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection"
|
||||
:disabled="selectedKnowledgeBase === null || selectedKnowledgeBase === undefined">
|
||||
确认选择
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { useRouter } from 'vue-router'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择知识库...'
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
const router = useRouter()
|
||||
|
||||
const dialog = ref(false)
|
||||
const knowledgeBaseList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectedKnowledgeBase = ref('')
|
||||
const pluginInstalled = ref(false)
|
||||
|
||||
// 监听 modelValue 变化,同步到 selectedKnowledgeBase
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
selectedKnowledgeBase.value = newValue || ''
|
||||
}, { immediate: true })
|
||||
|
||||
async function openDialog() {
|
||||
selectedKnowledgeBase.value = props.modelValue || ''
|
||||
dialog.value = true
|
||||
await checkPluginAndLoadKnowledgeBases()
|
||||
}
|
||||
|
||||
async function checkPluginAndLoadKnowledgeBases() {
|
||||
loading.value = true
|
||||
try {
|
||||
// 首先检查插件是否安装
|
||||
const pluginResponse = await axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base')
|
||||
|
||||
if (pluginResponse.data.status === 'ok' && pluginResponse.data.data.length > 0) {
|
||||
pluginInstalled.value = true
|
||||
// 插件已安装,获取知识库列表
|
||||
await loadKnowledgeBases()
|
||||
} else {
|
||||
pluginInstalled.value = false
|
||||
knowledgeBaseList.value = []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查知识库插件失败:', error)
|
||||
pluginInstalled.value = false
|
||||
knowledgeBaseList.value = []
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadKnowledgeBases() {
|
||||
try {
|
||||
const response = await axios.get('/api/plug/alkaid/kb/collections')
|
||||
if (response.data.status === 'ok') {
|
||||
knowledgeBaseList.value = response.data.data || []
|
||||
} else {
|
||||
knowledgeBaseList.value = []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载知识库列表失败:', error)
|
||||
knowledgeBaseList.value = []
|
||||
}
|
||||
}
|
||||
|
||||
function selectKnowledgeBase(kb) {
|
||||
selectedKnowledgeBase.value = kb.collection_name
|
||||
}
|
||||
|
||||
function confirmSelection() {
|
||||
emit('update:modelValue', selectedKnowledgeBase.value)
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelSelection() {
|
||||
selectedKnowledgeBase.value = props.modelValue || ''
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function goToKnowledgeBasePage() {
|
||||
dialog.value = false
|
||||
router.push('/alkaid/knowledge-base')
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-list-item {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.v-list-item:hover {
|
||||
background-color: rgba(var(--v-theme-primary), 0.04);
|
||||
}
|
||||
|
||||
.v-list-item.v-list-item--active {
|
||||
background-color: rgba(var(--v-theme-primary), 0.08);
|
||||
}
|
||||
|
||||
.emoji-icon {
|
||||
font-size: 20px;
|
||||
margin-right: 8px;
|
||||
min-width: 28px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
</style>
|
||||
@@ -1,135 +1,216 @@
|
||||
<template>
|
||||
<div class="list-config-item">
|
||||
<v-list dense style="background-color: transparent;max-height: 300px; overflow-y: auto;">
|
||||
<v-list-item v-for="(item, index) in items" :key="index">
|
||||
<v-list-item-content style="display: flex; justify-content: space-between;">
|
||||
<v-list-item-title v-if="editIndex !== index">
|
||||
<v-chip size="small" label color="primary">{{ item }}</v-chip>
|
||||
</v-list-item-title>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<div>
|
||||
<span v-if="!modelValue || modelValue.length === 0" style="color: rgb(var(--v-theme-primaryText));">
|
||||
暂无项目
|
||||
</span>
|
||||
<div v-else class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-for="item in displayItems" :key="item" size="x-small" label color="primary">
|
||||
{{ item.length > 20 ? item.slice(0, 20) + '...' : item }}
|
||||
</v-chip>
|
||||
<v-chip v-if="modelValue.length > maxDisplayItems" size="x-small" label color="grey-lighten-1">
|
||||
+{{ modelValue.length - maxDisplayItems }}
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- List Management Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
{{ dialogTitle }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-list v-if="localItems.length > 0" density="compact">
|
||||
<v-list-item
|
||||
v-for="(item, index) in localItems"
|
||||
:key="index"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<v-list-item-title v-if="editIndex !== index">
|
||||
{{ item }}
|
||||
</v-list-item-title>
|
||||
<v-text-field
|
||||
v-else
|
||||
v-model="editItem"
|
||||
hide-details
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
@keyup.enter="saveEdit"
|
||||
@keyup.esc="cancelEdit"
|
||||
autofocus
|
||||
></v-text-field>
|
||||
|
||||
<template v-slot:append>
|
||||
<div v-if="editIndex !== index" class="d-flex">
|
||||
<v-btn @click="startEdit(index, item)" variant="plain" icon size="small">
|
||||
<v-icon>mdi-pencil</v-icon>
|
||||
</v-btn>
|
||||
<v-btn @click="removeItem(index)" variant="plain" icon size="small">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
<div v-else class="d-flex">
|
||||
<v-btn @click="saveEdit" variant="plain" color="success" icon size="small">
|
||||
<v-icon>mdi-check</v-icon>
|
||||
</v-btn>
|
||||
<v-btn @click="cancelEdit" variant="plain" color="error" icon size="small">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
|
||||
<div v-else class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-format-list-bulleted</v-icon>
|
||||
<p class="text-grey mt-4">暂无项目</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<!-- Add new item section -->
|
||||
<v-card-text class="pa-4">
|
||||
<div class="d-flex align-center ga-2">
|
||||
<v-text-field
|
||||
v-else
|
||||
v-model="editItem"
|
||||
dense
|
||||
hide-details
|
||||
v-model="newItem"
|
||||
:label="t('core.common.list.addItemPlaceholder')"
|
||||
@keyup.enter="addItem"
|
||||
clearable
|
||||
hide-details
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
@keyup.enter="saveEdit"
|
||||
@keyup.esc="cancelEdit"
|
||||
autofocus
|
||||
></v-text-field>
|
||||
<div v-if="editIndex !== index">
|
||||
<v-btn @click="startEdit(index, item)" variant="plain" class="edit-btn" icon size="small">
|
||||
<v-icon>mdi-pencil</v-icon>
|
||||
</v-btn>
|
||||
<v-btn @click="removeItem(index)" variant="plain" icon size="small">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
<div v-else>
|
||||
<v-btn @click="saveEdit" variant="plain" color="success" icon size="small">
|
||||
<v-icon>mdi-check</v-icon>
|
||||
</v-btn>
|
||||
<v-btn @click="cancelEdit" variant="plain" color="error" icon size="small">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-list-item-content>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
<div style="display: flex; align-items: center;">
|
||||
<v-text-field v-model="newItem" :label="t('core.common.list.addItemPlaceholder')" @keyup.enter="addItem" clearable dense hide-details
|
||||
variant="outlined" density="compact"></v-text-field>
|
||||
<v-btn @click="addItem" text variant="tonal">
|
||||
<v-icon>mdi-plus</v-icon>
|
||||
{{ t('core.common.list.addButton') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
class="flex-grow-1">
|
||||
</v-text-field>
|
||||
<v-btn @click="addItem" variant="tonal" color="primary">
|
||||
<v-icon>mdi-plus</v-icon>
|
||||
{{ t('core.common.list.addButton') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
</div>
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelDialog">取消</v-btn>
|
||||
<v-btn color="primary" @click="confirmDialog">确认</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import { useI18n } from '@/i18n/composables';
|
||||
<script setup>
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
export default {
|
||||
name: 'ListConfigItem',
|
||||
setup() {
|
||||
const { t } = useI18n();
|
||||
return { t };
|
||||
const { t } = useI18n()
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: Array,
|
||||
default: () => []
|
||||
},
|
||||
props: {
|
||||
value: {
|
||||
type: Array,
|
||||
default: () => [],
|
||||
},
|
||||
label: {
|
||||
type: String,
|
||||
default: '',
|
||||
},
|
||||
label: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
newItem: '',
|
||||
items: this.value,
|
||||
editIndex: -1,
|
||||
editItem: '',
|
||||
};
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '修改'
|
||||
},
|
||||
watch: {
|
||||
items(newVal) {
|
||||
this.$emit('input', newVal);
|
||||
},
|
||||
dialogTitle: {
|
||||
type: String,
|
||||
default: '修改列表项'
|
||||
},
|
||||
methods: {
|
||||
addItem() {
|
||||
if (this.newItem.trim() !== '') {
|
||||
this.items.push(this.newItem.trim());
|
||||
this.newItem = '';
|
||||
}
|
||||
},
|
||||
removeItem(index) {
|
||||
this.items.splice(index, 1);
|
||||
},
|
||||
startEdit(index, item) {
|
||||
this.editIndex = index;
|
||||
this.editItem = item;
|
||||
},
|
||||
saveEdit() {
|
||||
if (this.editItem.trim() !== '') {
|
||||
this.items[this.editIndex] = this.editItem.trim();
|
||||
this.cancelEdit();
|
||||
}
|
||||
},
|
||||
cancelEdit() {
|
||||
this.editIndex = -1;
|
||||
this.editItem = '';
|
||||
},
|
||||
},
|
||||
};
|
||||
maxDisplayItems: {
|
||||
type: Number,
|
||||
default: 1
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
|
||||
const dialog = ref(false)
|
||||
const localItems = ref([])
|
||||
const originalItems = ref([])
|
||||
const newItem = ref('')
|
||||
const editIndex = ref(-1)
|
||||
const editItem = ref('')
|
||||
|
||||
// 计算要显示的项目
|
||||
const displayItems = computed(() => {
|
||||
return props.modelValue.slice(0, props.maxDisplayItems)
|
||||
})
|
||||
|
||||
// 监听 modelValue 变化,同步到 localItems
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
localItems.value = [...(newValue || [])]
|
||||
}, { immediate: true })
|
||||
|
||||
function openDialog() {
|
||||
localItems.value = [...(props.modelValue || [])]
|
||||
originalItems.value = [...(props.modelValue || [])]
|
||||
dialog.value = true
|
||||
editIndex.value = -1
|
||||
editItem.value = ''
|
||||
newItem.value = ''
|
||||
}
|
||||
|
||||
function addItem() {
|
||||
if (newItem.value.trim() !== '') {
|
||||
localItems.value.push(newItem.value.trim())
|
||||
newItem.value = ''
|
||||
}
|
||||
}
|
||||
|
||||
function removeItem(index) {
|
||||
localItems.value.splice(index, 1)
|
||||
}
|
||||
|
||||
function startEdit(index, item) {
|
||||
editIndex.value = index
|
||||
editItem.value = item
|
||||
}
|
||||
|
||||
function saveEdit() {
|
||||
if (editItem.value.trim() !== '') {
|
||||
localItems.value[editIndex.value] = editItem.value.trim()
|
||||
cancelEdit()
|
||||
}
|
||||
}
|
||||
|
||||
function cancelEdit() {
|
||||
editIndex.value = -1
|
||||
editItem.value = ''
|
||||
}
|
||||
|
||||
function confirmDialog() {
|
||||
emit('update:modelValue', [...localItems.value])
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelDialog() {
|
||||
localItems.value = [...originalItems.value]
|
||||
editIndex.value = -1
|
||||
editItem.value = ''
|
||||
newItem.value = ''
|
||||
dialog.value = false
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.list-config-item {
|
||||
border: 1px solid var(--v-theme-border);
|
||||
padding: 16px;
|
||||
margin-bottom: 8px;
|
||||
border-radius: 10px;
|
||||
background-color: var(--v-theme-background);
|
||||
}
|
||||
|
||||
.v-list-item {
|
||||
padding: 0;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.v-list-item-title {
|
||||
font-size: 14px;
|
||||
.v-list-item:hover {
|
||||
background-color: rgba(var(--v-theme-primary), 0.04);
|
||||
}
|
||||
|
||||
.v-btn {
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
.edit-btn {
|
||||
margin-right: -8px;
|
||||
.v-chip {
|
||||
margin: 2px;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,275 @@
|
||||
<template>
|
||||
<v-dialog v-model="isOpen" persistent max-width="600" max-height="80vh" scrollable>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
{{ t('features.migration.dialog.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-6">
|
||||
<p class="mb-4">{{ t('features.migration.dialog.warning') }}</p>
|
||||
|
||||
|
||||
<div v-if="migrationCompleted" class="text-center py-8">
|
||||
<v-icon size="64" color="success" class="mb-4">mdi-check-circle</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.migration.dialog.completed') }}</h3>
|
||||
{{ migrationResult?.message || t('features.migration.dialog.success') }}
|
||||
</div>
|
||||
|
||||
<div v-else-if="migrating" class="migration-in-progress">
|
||||
<div class="text-center py-4">
|
||||
<v-progress-circular indeterminate color="primary" class="mb-4"></v-progress-circular>
|
||||
<h3 class="mb-4">{{ t('features.migration.dialog.migrating') }}</h3>
|
||||
<p class="mb-4">{{ t('features.migration.dialog.migratingSubtitle') }}</p>
|
||||
</div>
|
||||
<div class="console-container">
|
||||
<ConsoleDisplayer ref="consoleDisplayer" :showLevelBtns="false" style="height: 300px;" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-else-if="loading" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" class="mb-4"></v-progress-circular>
|
||||
<p>{{ t('features.migration.dialog.loading') }}</p>
|
||||
</div>
|
||||
|
||||
<div v-else-if="error" class="text-center py-4">
|
||||
<v-alert type="error" variant="tonal" class="mb-4">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-alert</v-icon>
|
||||
</template>
|
||||
{{ error }}
|
||||
</v-alert>
|
||||
<v-btn color="primary" @click="loadPlatforms">
|
||||
{{ t('features.migration.dialog.retry') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<div v-else>
|
||||
<div v-if="platformGroups.length === 0" class="text-center py-4">
|
||||
<v-alert type="info" variant="tonal">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-information</v-icon>
|
||||
</template>
|
||||
{{ t('features.migration.dialog.noPlatforms') }}
|
||||
</v-alert>
|
||||
</div>
|
||||
|
||||
<div v-else>
|
||||
<div v-for="group in platformGroups" :key="group.type" class="mb-6">
|
||||
<v-card variant="outlined" v-if="group.platforms.length > 1">
|
||||
<v-card-subtitle class="py-2">
|
||||
{{ group.type }}
|
||||
</v-card-subtitle>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-text style="padding: 16px;">
|
||||
<small>请选择该平台类型下您主要使用的平台适配器。</small>
|
||||
<v-radio-group v-model="selectedPlatforms[group.type]" :key="group.type"
|
||||
hide-details>
|
||||
<v-radio v-for="platform in group.platforms" :key="platform.id"
|
||||
:value="platform.id" :label="getPlatformLabel(platform)" color="primary"
|
||||
class="mb-1"></v-radio>
|
||||
</v-radio-group>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="px-6 py-4">
|
||||
<v-spacer></v-spacer>
|
||||
<template v-if="migrationCompleted">
|
||||
<v-btn color="primary" variant="elevated" @click="handleClose">
|
||||
{{ t('core.common.confirm') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<template v-else>
|
||||
<v-btn color="grey" variant="text" @click="handleCancel" :disabled="migrating">
|
||||
{{ t('core.common.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" variant="elevated" @click="handleMigration" :disabled="!canMigrate || migrating"
|
||||
:loading="migrating">
|
||||
{{ t('features.migration.dialog.startMigration') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import ConsoleDisplayer from './ConsoleDisplayer.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const isOpen = ref(false)
|
||||
const loading = ref(false)
|
||||
const error = ref('')
|
||||
const migrating = ref(false)
|
||||
const migrationCompleted = ref(false)
|
||||
const migrationResult = ref(null)
|
||||
const platforms = ref([])
|
||||
const selectedPlatforms = ref({})
|
||||
|
||||
let resolvePromise = null
|
||||
|
||||
// 计算属性:将平台按类型分组
|
||||
const platformGroups = computed(() => {
|
||||
const groups = {}
|
||||
platforms.value.forEach(platform => {
|
||||
const type = platform.platform_type || platform.type
|
||||
if (!groups[type]) {
|
||||
groups[type] = {
|
||||
type,
|
||||
platforms: []
|
||||
}
|
||||
}
|
||||
groups[type].platforms.push(platform)
|
||||
})
|
||||
return Object.values(groups)
|
||||
})
|
||||
|
||||
// 计算属性:检查是否可以开始迁移
|
||||
const canMigrate = computed(() => {
|
||||
return platformGroups.value.every(group => selectedPlatforms.value[group.type])
|
||||
})
|
||||
|
||||
// 监听 isOpen 变化,当对话框打开时加载平台列表
|
||||
watch(isOpen, (newVal) => {
|
||||
if (newVal) {
|
||||
loadPlatforms()
|
||||
} else {
|
||||
// 重置状态
|
||||
platforms.value = []
|
||||
selectedPlatforms.value = {}
|
||||
error.value = ''
|
||||
migrating.value = false
|
||||
migrationCompleted.value = false
|
||||
migrationResult.value = null
|
||||
}
|
||||
})
|
||||
|
||||
// 加载平台列表
|
||||
const loadPlatforms = async () => {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
|
||||
try {
|
||||
const response = await axios.get('/api/config/platform/list')
|
||||
if (response.data.status === 'ok') {
|
||||
platforms.value = response.data.data.platforms || []
|
||||
|
||||
// 为每个平台类型初始化默认选择(选择第一个)
|
||||
platformGroups.value.forEach(group => {
|
||||
if (group.platforms.length > 0) {
|
||||
selectedPlatforms.value[group.type] = group.platforms[0].id
|
||||
}
|
||||
})
|
||||
} else {
|
||||
error.value = response.data.message || t('features.migration.dialog.loadError')
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to load platforms:', err)
|
||||
error.value = t('features.migration.dialog.loadError')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 执行迁移
|
||||
const handleMigration = async () => {
|
||||
migrating.value = true
|
||||
|
||||
try {
|
||||
// 构建 platform_id_map
|
||||
const platformIdMap = {}
|
||||
|
||||
Object.entries(selectedPlatforms.value).forEach(([type, platformId]) => {
|
||||
const selectedPlatform = platforms.value.find(p => p.id === platformId)
|
||||
if (selectedPlatform) {
|
||||
platformIdMap[type] = {
|
||||
platform_id: platformId,
|
||||
platform_type: type
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('Migration platform_id_map:', platformIdMap)
|
||||
|
||||
const response = await axios.post('/api/update/migration', {
|
||||
platform_id_map: platformIdMap
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
migrationCompleted.value = true
|
||||
migrationResult.value = {
|
||||
success: true,
|
||||
message: response.data.message || t('features.migration.dialog.success')
|
||||
}
|
||||
} else {
|
||||
throw new Error(response.data.message || t('features.migration.dialog.migrationError'))
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Migration failed:', err)
|
||||
error.value = err.message || t('features.migration.dialog.migrationError')
|
||||
} finally {
|
||||
migrating.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 取消操作
|
||||
const handleCancel = () => {
|
||||
isOpen.value = false
|
||||
if (resolvePromise) {
|
||||
resolvePromise({ success: false, cancelled: true })
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭已完成的迁移对话框
|
||||
const handleClose = () => {
|
||||
isOpen.value = false
|
||||
if (resolvePromise) {
|
||||
resolvePromise(migrationResult.value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 获取平台显示标签
|
||||
const getPlatformLabel = (platform) => {
|
||||
const name = platform.name || platform.id || 'Unknown'
|
||||
return `${name}`
|
||||
}
|
||||
|
||||
// 打开对话框的方法
|
||||
const open = () => {
|
||||
isOpen.value = true
|
||||
|
||||
return new Promise((resolve) => {
|
||||
resolvePromise = resolve
|
||||
})
|
||||
}
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-radio-group {
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.migration-in-progress {
|
||||
min-height: 400px;
|
||||
}
|
||||
|
||||
.console-container {
|
||||
border: 1px solid var(--v-theme-border);
|
||||
border-radius: 8px;
|
||||
margin-top: 16px;
|
||||
overflow: hidden;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,152 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
</span>
|
||||
<span v-else>
|
||||
{{ modelValue === 'default' ? '默认人格' : modelValue }}
|
||||
</span>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- Persona Selection Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
选择人格
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<v-list v-if="!loading && personaList.length > 0" density="compact">
|
||||
<v-list-item
|
||||
v-for="persona in personaList"
|
||||
:key="persona.persona_id"
|
||||
:value="persona.persona_id"
|
||||
@click="selectPersona(persona)"
|
||||
:active="selectedPersona === persona.persona_id"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<v-list-item-title>{{ persona.persona_id === 'default' ? '默认人格' : persona.persona_id }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ persona.system_prompt ? persona.system_prompt.substring(0, 50) + '...' : '无描述' }}
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedPersona === persona.persona_id" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
|
||||
<div v-else-if="!loading && personaList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-account-off</v-icon>
|
||||
<p class="text-grey mt-4">暂无可用的人格</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection"
|
||||
:disabled="!selectedPersona">
|
||||
确认选择
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择人格...'
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
|
||||
const dialog = ref(false)
|
||||
const personaList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectedPersona = ref('')
|
||||
|
||||
// 监听 modelValue 变化,同步到 selectedPersona
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
selectedPersona.value = newValue || ''
|
||||
}, { immediate: true })
|
||||
|
||||
async function openDialog() {
|
||||
selectedPersona.value = props.modelValue || ''
|
||||
dialog.value = true
|
||||
await loadPersonas()
|
||||
}
|
||||
|
||||
async function loadPersonas() {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/persona/list')
|
||||
if (response.data.status === 'ok') {
|
||||
const personas = response.data.data || []
|
||||
// 添加默认人格选项
|
||||
personaList.value = [
|
||||
{
|
||||
persona_id: 'default',
|
||||
system_prompt: 'You are a helpful and friendly assistant.'
|
||||
},
|
||||
...personas
|
||||
]
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载人格列表失败:', error)
|
||||
personaList.value = [
|
||||
{
|
||||
persona_id: 'default',
|
||||
system_prompt: 'You are a helpful and friendly assistant.'
|
||||
}
|
||||
]
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function selectPersona(persona) {
|
||||
selectedPersona.value = persona.persona_id
|
||||
}
|
||||
|
||||
function confirmSelection() {
|
||||
emit('update:modelValue', selectedPersona.value)
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelSelection() {
|
||||
selectedPersona.value = props.modelValue || ''
|
||||
dialog.value = false
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-list-item {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.v-list-item:hover {
|
||||
background-color: rgba(var(--v-theme-primary), 0.04);
|
||||
}
|
||||
|
||||
.v-list-item.v-list-item--active {
|
||||
background-color: rgba(var(--v-theme-primary), 0.08);
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,226 @@
|
||||
<template>
|
||||
<div>
|
||||
<!-- 顶部操作区域 -->
|
||||
<div class="d-flex align-center justify-space-between mb-2">
|
||||
<div class="flex-grow-1">
|
||||
<span v-if="!modelValue || modelValue.length === 0" style="color: rgb(var(--v-theme-primaryText));">
|
||||
未启用任何插件
|
||||
</span>
|
||||
<span v-else-if="isAllPlugins" style="color: rgb(var(--v-theme-primaryText));">
|
||||
启用所有插件 (*)
|
||||
</span>
|
||||
<span v-else style="color: rgb(var(--v-theme-primaryText));">
|
||||
已选择 {{ modelValue.length }} 个插件
|
||||
</span>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Plugin Set Selection Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="700px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
选择插件集合
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-4">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<div v-if="!loading">
|
||||
<!-- 预设选项 -->
|
||||
<v-radio-group v-model="selectionMode" class="mb-4" hide-details>
|
||||
<v-radio
|
||||
value="all"
|
||||
label="启用所有插件"
|
||||
color="primary"
|
||||
></v-radio>
|
||||
<v-radio
|
||||
value="none"
|
||||
label="不启用任何插件"
|
||||
color="primary"
|
||||
></v-radio>
|
||||
<v-radio
|
||||
value="custom"
|
||||
label="自定义选择"
|
||||
color="primary"
|
||||
></v-radio>
|
||||
</v-radio-group>
|
||||
|
||||
<!-- 自定义选择时显示插件列表 -->
|
||||
<div v-if="selectionMode === 'custom'" style="max-height: 300px; overflow-y: auto;">
|
||||
<v-list v-if="pluginList.length > 0" density="compact">
|
||||
<v-list-item
|
||||
v-for="plugin in pluginList"
|
||||
:key="plugin.name"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<template v-slot:prepend>
|
||||
<v-checkbox
|
||||
v-model="selectedPlugins"
|
||||
:value="plugin.name"
|
||||
color="primary"
|
||||
hide-details
|
||||
></v-checkbox>
|
||||
</template>
|
||||
|
||||
<v-list-item-title>{{ plugin.name }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ plugin.desc || '无描述' }}
|
||||
<v-chip v-if="!plugin.activated" size="x-small" color="grey" class="ml-1">
|
||||
未激活
|
||||
</v-chip>
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
|
||||
<div class="pl-8 pt-2">
|
||||
<small>*不显示系统插件和已经在插件页禁用的插件。</small>
|
||||
</div>
|
||||
</v-list>
|
||||
|
||||
<div v-else class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-puzzle-outline</v-icon>
|
||||
<p class="text-grey mt-4">暂无可用的插件</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection">
|
||||
确认选择
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: Array,
|
||||
default: () => []
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择插件集合...'
|
||||
},
|
||||
maxDisplayItems: {
|
||||
type: Number,
|
||||
default: 3
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
|
||||
const dialog = ref(false)
|
||||
const pluginList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectionMode = ref('custom') // 'all', 'none', 'custom'
|
||||
const selectedPlugins = ref([])
|
||||
|
||||
// 判断是否为"所有插件"模式
|
||||
const isAllPlugins = computed(() => {
|
||||
return props.modelValue && props.modelValue.length === 1 && props.modelValue[0] === '*'
|
||||
})
|
||||
|
||||
// 移除插件
|
||||
function removePlugin(pluginName) {
|
||||
if (props.modelValue && props.modelValue.length > 0) {
|
||||
const newValue = props.modelValue.filter(name => name !== pluginName)
|
||||
emit('update:modelValue', newValue)
|
||||
}
|
||||
}
|
||||
|
||||
// 监听 modelValue 变化,同步内部状态
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
if (!newValue || newValue.length === 0) {
|
||||
selectionMode.value = 'none'
|
||||
selectedPlugins.value = []
|
||||
} else if (newValue.length === 1 && newValue[0] === '*') {
|
||||
selectionMode.value = 'all'
|
||||
selectedPlugins.value = []
|
||||
} else {
|
||||
selectionMode.value = 'custom'
|
||||
selectedPlugins.value = [...newValue]
|
||||
}
|
||||
}, { immediate: true })
|
||||
|
||||
async function openDialog() {
|
||||
dialog.value = true
|
||||
await loadPlugins()
|
||||
}
|
||||
|
||||
async function loadPlugins() {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/plugin/get')
|
||||
if (response.data.status === 'ok') {
|
||||
// 只显示已激活且非系统的插件,并按名称排序
|
||||
pluginList.value = (response.data.data || [])
|
||||
.filter(plugin => plugin.activated && !plugin.reserved)
|
||||
.sort((a, b) => a.name.localeCompare(b.name))
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载插件列表失败:', error)
|
||||
pluginList.value = []
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function confirmSelection() {
|
||||
let newValue = []
|
||||
|
||||
switch (selectionMode.value) {
|
||||
case 'all':
|
||||
newValue = ['*']
|
||||
break
|
||||
case 'none':
|
||||
newValue = []
|
||||
break
|
||||
case 'custom':
|
||||
newValue = [...selectedPlugins.value]
|
||||
break
|
||||
}
|
||||
|
||||
emit('update:modelValue', newValue)
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelSelection() {
|
||||
// 恢复到原始状态
|
||||
const currentValue = props.modelValue || []
|
||||
if (currentValue.length === 0) {
|
||||
selectionMode.value = 'none'
|
||||
selectedPlugins.value = []
|
||||
} else if (currentValue.length === 1 && currentValue[0] === '*') {
|
||||
selectionMode.value = 'all'
|
||||
selectedPlugins.value = []
|
||||
} else {
|
||||
selectionMode.value = 'custom'
|
||||
selectedPlugins.value = [...currentValue]
|
||||
}
|
||||
|
||||
dialog.value = false
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-list-item {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.v-list-item:hover {
|
||||
background-color: rgba(var(--v-theme-primary), 0.04);
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,167 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
</span>
|
||||
<span v-else>
|
||||
{{ modelValue }}
|
||||
</span>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- Provider Selection Dialog -->
|
||||
<v-dialog v-model="dialog" max-width="600px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 py-4" style="font-weight: normal;">
|
||||
选择提供商
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<v-list v-if="!loading && providerList.length > 0" density="compact">
|
||||
<!-- 不选择选项 -->
|
||||
<v-list-item
|
||||
key="none"
|
||||
value=""
|
||||
@click="selectProvider({ id: '' })"
|
||||
:active="selectedProvider === ''"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<v-list-item-title>不选择</v-list-item-title>
|
||||
<v-list-item-subtitle>清除当前选择</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedProvider === ''" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
|
||||
<v-divider class="ma-1"></v-divider>
|
||||
|
||||
<v-list-item
|
||||
v-for="provider in providerList"
|
||||
:key="provider.id"
|
||||
:value="provider.id"
|
||||
@click="selectProvider(provider)"
|
||||
:active="selectedProvider === provider.id"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<v-list-item-title>{{ provider.id }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ provider.type || provider.provider_type || '未知类型' }}
|
||||
<span v-if="provider.model_config?.model">- {{ provider.model_config.model }}</span>
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedProvider === provider.id" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
|
||||
<div v-else-if="!loading && providerList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-api-off</v-icon>
|
||||
<p class="text-grey mt-4">暂无可用的提供商</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection">
|
||||
确认选择
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
providerType: {
|
||||
type: String,
|
||||
default: 'chat_completion'
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '选择提供商...'
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
|
||||
const dialog = ref(false)
|
||||
const providerList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectedProvider = ref('')
|
||||
|
||||
// 监听 modelValue 变化,同步到 selectedProvider
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
selectedProvider.value = newValue || ''
|
||||
}, { immediate: true })
|
||||
|
||||
async function openDialog() {
|
||||
selectedProvider.value = props.modelValue || ''
|
||||
dialog.value = true
|
||||
await loadProviders()
|
||||
}
|
||||
|
||||
async function loadProviders() {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/config/provider/list', {
|
||||
params: {
|
||||
provider_type: props.providerType
|
||||
}
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
providerList.value = response.data.data || []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载提供商列表失败:', error)
|
||||
providerList.value = []
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function selectProvider(provider) {
|
||||
selectedProvider.value = provider.id
|
||||
}
|
||||
|
||||
function confirmSelection() {
|
||||
emit('update:modelValue', selectedProvider.value)
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelSelection() {
|
||||
selectedProvider.value = props.modelValue || ''
|
||||
dialog.value = false
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.v-list-item {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.v-list-item:hover {
|
||||
background-color: rgba(var(--v-theme-primary), 0.04);
|
||||
}
|
||||
|
||||
.v-list-item.v-list-item--active {
|
||||
background-color: rgba(var(--v-theme-primary), 0.08);
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,307 @@
|
||||
<template>
|
||||
<v-dialog v-model="dialog" max-width="1400px" persistent scrollable>
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn
|
||||
v-bind="props"
|
||||
variant="outlined"
|
||||
color="primary"
|
||||
size="small"
|
||||
:loading="loading"
|
||||
>
|
||||
自定义 T2I 模板
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<v-card>
|
||||
<v-card-title class="d-flex align-center justify-space-between">
|
||||
<span>自定义文转图 HTML 模板</span>
|
||||
<div class="d-flex gap-2">
|
||||
<v-btn
|
||||
v-if="hasCustomTemplate"
|
||||
variant="outlined"
|
||||
color="warning"
|
||||
size="small"
|
||||
@click="resetToDefault"
|
||||
:loading="resetLoading"
|
||||
>
|
||||
恢复默认
|
||||
</v-btn>
|
||||
<v-btn
|
||||
variant="text"
|
||||
icon
|
||||
@click="closeDialog"
|
||||
>
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-0">
|
||||
<v-row no-gutters style="height: 70vh;">
|
||||
<!-- 左侧编辑器 -->
|
||||
<v-col cols="6" class="d-flex flex-column">
|
||||
<v-toolbar density="compact" color="surface-variant">
|
||||
<v-toolbar-title class="text-subtitle-2">HTML 模板编辑器</v-toolbar-title>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn
|
||||
variant="text"
|
||||
size="small"
|
||||
@click="saveTemplate"
|
||||
:loading="saveLoading"
|
||||
color="primary"
|
||||
>
|
||||
保存模板
|
||||
</v-btn>
|
||||
</v-toolbar>
|
||||
<div class="flex-grow-1" style="border-right: 1px solid rgba(0,0,0,0.1);">
|
||||
<VueMonacoEditor
|
||||
v-model:value="templateContent"
|
||||
:theme="editorTheme"
|
||||
language="html"
|
||||
:options="editorOptions"
|
||||
style="height: 100%;"
|
||||
/>
|
||||
</div>
|
||||
</v-col>
|
||||
|
||||
<!-- 右侧预览 -->
|
||||
<v-col cols="6" class="d-flex flex-column">
|
||||
<v-toolbar density="compact" color="surface-variant">
|
||||
<v-toolbar-title class="text-subtitle-2">实时预览(可能有差异)</v-toolbar-title>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn
|
||||
variant="text"
|
||||
size="small"
|
||||
@click="refreshPreview"
|
||||
:loading="previewLoading"
|
||||
>
|
||||
刷新预览
|
||||
</v-btn>
|
||||
</v-toolbar>
|
||||
<div class="flex-grow-1 preview-container">
|
||||
<iframe
|
||||
ref="previewFrame"
|
||||
:srcdoc="previewContent"
|
||||
style="width: 100%; height: 100%; border: none; zoom: 0.6;"
|
||||
/>
|
||||
</div>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="px-6 py-4">
|
||||
<v-row no-gutters class="align-center">
|
||||
<v-col>
|
||||
<div class="text-caption text-grey">
|
||||
<v-icon size="16" class="mr-1">mdi-information</v-icon>
|
||||
支持 jinja2 语法。可用变量:<code> text | safe </code>(要渲染的文本), <code> version </code>(AstrBot 版本)
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="auto">
|
||||
<v-btn
|
||||
variant="text"
|
||||
@click="closeDialog"
|
||||
>
|
||||
取消
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="saveTemplate"
|
||||
:loading="saveLoading"
|
||||
>
|
||||
保存并应用
|
||||
</v-btn>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
|
||||
<!-- 确认重置对话框 -->
|
||||
<v-dialog v-model="resetDialog" max-width="400px">
|
||||
<v-card>
|
||||
<v-card-title>确认重置</v-card-title>
|
||||
<v-card-text>
|
||||
确定要恢复默认模板吗?这将删除您的自定义模板,此操作无法撤销。
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="resetDialog = false">取消</v-btn>
|
||||
<v-btn color="warning" @click="confirmReset" :loading="resetLoading">确认重置</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, nextTick, watch } from 'vue'
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import axios from 'axios'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
// 响应式数据
|
||||
const dialog = ref(false)
|
||||
const resetDialog = ref(false)
|
||||
const loading = ref(false)
|
||||
const saveLoading = ref(false)
|
||||
const resetLoading = ref(false)
|
||||
const previewLoading = ref(false)
|
||||
const templateContent = ref('')
|
||||
const hasCustomTemplate = ref(false)
|
||||
const previewFrame = ref(null)
|
||||
|
||||
// 编辑器配置
|
||||
const editorTheme = computed(() => 'vs-light')
|
||||
const editorOptions = {
|
||||
automaticLayout: true,
|
||||
fontSize: 12,
|
||||
lineNumbers: 'on',
|
||||
wordWrap: 'on',
|
||||
minimap: { enabled: false },
|
||||
scrollBeyondLastLine: false,
|
||||
}
|
||||
|
||||
// 示例数据用于预览
|
||||
const previewData = {
|
||||
text: '这是一个示例文本,用于预览模板效果。\n\n这里可以包含多行文本,支持换行和各种格式。',
|
||||
version: 'v4.0.0'
|
||||
}
|
||||
|
||||
// 生成预览内容
|
||||
const previewContent = computed(() => {
|
||||
try {
|
||||
// 简单的模板替换,模拟 Jinja2 渲染
|
||||
let content = templateContent.value
|
||||
content = content.replace(/\{\{\s*text\s*\|\s*safe\s*\}\}/g, previewData.text)
|
||||
content = content.replace(/\{\{\s*version\s*\}\}/g, previewData.version)
|
||||
return content
|
||||
} catch (error) {
|
||||
return `<div style="color: red; padding: 20px;">模板渲染错误: ${error.message}</div>`
|
||||
}
|
||||
})
|
||||
|
||||
// 方法
|
||||
const loadTemplate = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/config/astrbot/t2i-template/get')
|
||||
if (response.data.status === 'ok') {
|
||||
templateContent.value = response.data.data.template
|
||||
hasCustomTemplate.value = response.data.data.has_custom_template
|
||||
} else {
|
||||
console.error('加载模板失败:', response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载模板失败:', error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const saveTemplate = async () => {
|
||||
saveLoading.value = true
|
||||
try {
|
||||
const response = await axios.post('/api/config/astrbot/t2i-template/save', {
|
||||
template: templateContent.value
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
hasCustomTemplate.value = true
|
||||
closeDialog()
|
||||
} else {
|
||||
console.error('保存模板失败:', response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('保存模板失败:', error)
|
||||
} finally {
|
||||
saveLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const resetToDefault = () => {
|
||||
resetDialog.value = true
|
||||
}
|
||||
|
||||
const confirmReset = async () => {
|
||||
resetLoading.value = true
|
||||
try {
|
||||
const response = await axios.delete('/api/config/astrbot/t2i-template/delete')
|
||||
if (response.data.status === 'ok') {
|
||||
hasCustomTemplate.value = false
|
||||
resetDialog.value = false
|
||||
// 重新加载默认模板
|
||||
await loadTemplate()
|
||||
} else {
|
||||
console.error('重置模板失败:', response.data.message)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('重置模板失败:', error)
|
||||
} finally {
|
||||
resetLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const refreshPreview = () => {
|
||||
previewLoading.value = true
|
||||
nextTick(() => {
|
||||
if (previewFrame.value) {
|
||||
previewFrame.value.contentWindow.location.reload()
|
||||
}
|
||||
setTimeout(() => {
|
||||
previewLoading.value = false
|
||||
}, 500)
|
||||
})
|
||||
}
|
||||
|
||||
const closeDialog = () => {
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
watch(dialog, (newVal) => {
|
||||
if (newVal && !templateContent.value) {
|
||||
loadTemplate()
|
||||
}
|
||||
})
|
||||
|
||||
defineExpose({
|
||||
openDialog: () => {
|
||||
dialog.value = true
|
||||
if (!templateContent.value) {
|
||||
loadTemplate()
|
||||
}
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.preview-container {
|
||||
background-color: #f5f5f5;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.preview-container::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-image:
|
||||
linear-gradient(45deg, #ccc 25%, transparent 25%),
|
||||
linear-gradient(-45deg, #ccc 25%, transparent 25%),
|
||||
linear-gradient(45deg, transparent 75%, #ccc 75%),
|
||||
linear-gradient(-45deg, transparent 75%, #ccc 75%);
|
||||
background-size: 20px 20px;
|
||||
background-position: 0 0, 0 10px, 10px -10px, -10px 0px;
|
||||
opacity: 0.1;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
code {
|
||||
background-color: rgba(0,0,0,0.05);
|
||||
padding: 2px 4px;
|
||||
border-radius: 3px;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
</style>
|
||||
@@ -52,6 +52,8 @@ export class I18nLoader {
|
||||
{ name: 'features/alkaid/index', path: 'features/alkaid/index.json' },
|
||||
{ name: 'features/alkaid/knowledge-base', path: 'features/alkaid/knowledge-base.json' },
|
||||
{ name: 'features/alkaid/memory', path: 'features/alkaid/memory.json' },
|
||||
{ name: 'features/persona', path: 'features/persona.json' },
|
||||
{ name: 'features/migration', path: 'features/migration.json' },
|
||||
|
||||
// 消息模块
|
||||
{ name: 'messages/errors', path: 'messages/errors.json' },
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
"tip": "💡 TIP:",
|
||||
"tipLink": "",
|
||||
"tipContinue": "By default, the corresponding version of the WebUI files will be downloaded when switching versions. The WebUI code is located in the dashboard directory of the project, and you can use npm to build it yourself.",
|
||||
"dockerTip": "When switching versions, it will try to update both the bot main program and the dashboard. If you are using Docker deployment, you can also re-pull the image or use",
|
||||
"dockerTip": "The `Update to Latest Version` button will try to update both the bot main program and the dashboard. If you are using Docker deployment, you can also re-pull the image or use",
|
||||
"dockerTipLink": "watchtower",
|
||||
"dockerTipContinue": "to automatically monitor and pull.",
|
||||
"table": {
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
"dashboard": "Dashboard",
|
||||
"platforms": "Platforms",
|
||||
"providers": "Providers",
|
||||
"persona": "Persona",
|
||||
"toolUse": "MCP Tools",
|
||||
"config": "Config",
|
||||
"extension": "Extensions",
|
||||
"extensionMarketplace": "Extension Market",
|
||||
"chat": "Chat",
|
||||
"conversation": "Conversations",
|
||||
"sessionManagement": "Session Management",
|
||||
"console": "Console",
|
||||
"alkaid": "Alkaid Lab",
|
||||
"knowledgeBase": "Knowledge Base",
|
||||
"about": "About",
|
||||
"settings": "Settings",
|
||||
"documentation": "Documentation",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user