diff --git a/README.md b/README.md
index ca7a45eb3..40a3e95f1 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
-
+

English |
@@ -111,10 +111,14 @@ uvx astrbot init
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
-#### Replit 部署
+#### 在 Replit 上部署
[](https://repl.it/github/Soulter/AstrBot)
+#### 在 雨云 上部署
+
+[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
+
## ⚡ 消息平台支持情况
| 平台 | 支持性 |
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 3d6926ebe..4a75a11bc 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -6,14 +6,14 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
-VERSION = "3.5.17"
+VERSION = "3.5.18"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
# 默认配置
DEFAULT_CONFIG = {
"config_version": 2,
"platform_settings": {
- "plugin_enable": [],
+ "plugin_enable": {},
"unique_session": False,
"rate_limit": {
"time": 60,
@@ -54,6 +54,7 @@ DEFAULT_CONFIG = {
"wake_prefix": "",
"web_search": False,
"web_search_link": False,
+ "display_reasoning_text": False,
"identifier": False,
"datetime_system_prompt": True,
"default_personality": "default",
@@ -61,6 +62,7 @@ DEFAULT_CONFIG = {
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
+ "show_tool_use_status": False,
"streaming_segmented": False,
"separate_provider": False,
},
@@ -441,7 +443,7 @@ CONFIG_METADATA_2 = {
"ignore_bot_self_message": {
"description": "是否忽略机器人自身的消息",
"type": "bool",
- "hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
+ "hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
},
"ignore_at_all": {
"description": "是否忽略 @ 全体成员",
@@ -770,17 +772,6 @@ CONFIG_METADATA_2 = {
"model": "deepseek/deepseek-r1",
},
},
- "LLMTuner": {
- "id": "llmtuner_default",
- "type": "llm_tuner",
- "provider_type": "chat_completion",
- "enable": True,
- "base_model_path": "",
- "adapter_model_path": "",
- "llmtuner_template": "",
- "finetuning_type": "lora",
- "quantization_bit": 4,
- },
"Dify": {
"id": "dify_app_default",
"type": "dify",
@@ -1662,6 +1653,11 @@ CONFIG_METADATA_2 = {
"obvious_hint": True,
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
},
+ "display_reasoning_text": {
+ "description": "显示思考内容",
+ "type": "bool",
+ "hint": "开启后,将在回复中显示模型的思考过程。",
+ },
"identifier": {
"description": "启动识别群员",
"type": "bool",
@@ -1699,10 +1695,15 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
},
+ "show_tool_use_status": {
+ "description": "函数调用状态输出",
+ "type": "bool",
+ "hint": "在触发函数调用时输出其函数名和内容。",
+ },
"streaming_segmented": {
"description": "不支持流式回复的平台分段输出",
"type": "bool",
- "hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
+ "hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
},
},
},
diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py
index 28c92fa89..7bfdd34c8 100644
--- a/astrbot/core/message/message_event_result.py
+++ b/astrbot/core/message/message_event_result.py
@@ -24,6 +24,8 @@ class MessageChain:
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
+ type: Optional[str] = None
+ """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。
@@ -98,6 +100,15 @@ class MessageChain:
self.chain.append(Image.fromFileSystem(path))
return self
+ def base64_image(self, base64_str: str):
+ """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。
+ Example:
+
+ CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
+ """
+ self.chain.append(Image.fromBase64(base64_str))
+ return self
+
def use_t2i(self, use_t2i: bool):
"""设置是否使用文本转图片服务。
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
- STREAMING_FINISH= enum.auto()
+ STREAMING_FINISH = enum.auto()
"""流式输出完成"""
diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py
index eb5ffb1cd..d98f7c341 100644
--- a/astrbot/core/pipeline/context.py
+++ b/astrbot/core/pipeline/context.py
@@ -1,6 +1,14 @@
+import inspect
+import traceback
+import typing as T
from dataclasses import dataclass
from astrbot.core.config.astrbot_config import AstrBotConfig
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star import PluginManager
+from astrbot.api import logger
+from astrbot.core.star.star_handler import star_handlers_registry, EventType
+from astrbot.core.star.star import star_map
+from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
@dataclass
@@ -9,3 +17,91 @@ class PipelineContext:
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
+
+ async def call_event_hook(
+ self,
+ event: AstrMessageEvent,
+ hook_type: EventType,
+ *args,
+ ):
+ platform_id = event.get_platform_id()
+ handlers = star_handlers_registry.get_handlers_by_event_type(
+ hook_type, platform_id=platform_id
+ )
+ for handler in handlers:
+ try:
+ logger.debug(
+ f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
+ )
+ await handler.handler(event, *args)
+ except BaseException:
+ logger.error(traceback.format_exc())
+
+ if event.is_stopped():
+ logger.info(
+ f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
+ )
+ return
+
+ async def call_handler(
+ self,
+ event: AstrMessageEvent,
+ handler: T.Awaitable,
+ *args,
+ **kwargs,
+ ) -> T.AsyncGenerator[None, None]:
+ """执行事件处理函数并处理其返回结果
+
+ 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
+ 1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
+ 2. 协程: 执行一次并处理返回值
+
+ Args:
+ ctx (PipelineContext): 消息管道上下文对象
+ event (AstrMessageEvent): 事件对象
+ handler (Awaitable): 事件处理函数
+
+ Returns:
+ AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
+ """
+ ready_to_call = None # 一个协程或者异步生成器
+
+ trace_ = None
+
+ try:
+ ready_to_call = handler(event, *args, **kwargs)
+ except TypeError as _:
+ # 向下兼容
+ trace_ = traceback.format_exc()
+ # 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
+ ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
+
+ if inspect.isasyncgen(ready_to_call):
+ _has_yielded = False
+ try:
+ async for ret in ready_to_call:
+ # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
+ # 返回值只能是 MessageEventResult 或者 None(无返回值)
+ _has_yielded = True
+ if isinstance(ret, (MessageEventResult, CommandResult)):
+ # 如果返回值是 MessageEventResult, 设置结果并继续
+ event.set_result(ret)
+ yield
+ else:
+ # 如果返回值是 None, 则不设置结果并继续
+ # 继续执行后续阶段
+ yield ret
+ if not _has_yielded:
+ # 如果这个异步生成器没有执行到 yield 分支
+ yield
+ except Exception as e:
+ logger.error(f"Previous Error: {trace_}")
+ raise e
+ elif inspect.iscoroutine(ready_to_call):
+ # 如果只是一个协程, 直接执行
+ ret = await ready_to_call
+ if isinstance(ret, (MessageEventResult, CommandResult)):
+ event.set_result(ret)
+ yield
+ else:
+ yield ret
diff --git a/astrbot/core/pipeline/process_stage/agent_runner/base.py b/astrbot/core/pipeline/process_stage/agent_runner/base.py
new file mode 100644
index 000000000..431a95ca6
--- /dev/null
+++ b/astrbot/core/pipeline/process_stage/agent_runner/base.py
@@ -0,0 +1,57 @@
+import abc
+import typing as T
+from dataclasses import dataclass
+from astrbot.core.provider.entities import LLMResponse
+from ....message.message_event_result import MessageChain
+from enum import Enum, auto
+
+
+class AgentState(Enum):
+ """Agent 状态枚举"""
+ IDLE = auto() # 初始状态
+ RUNNING = auto() # 运行中
+ DONE = auto() # 完成
+ ERROR = auto() # 错误状态
+
+
+class AgentResponseData(T.TypedDict):
+ chain: MessageChain
+
+
+@dataclass
+class AgentResponse:
+ type: str
+ data: AgentResponseData
+
+
+class BaseAgentRunner:
+ @abc.abstractmethod
+ async def reset(self) -> None:
+ """
+ Reset the agent to its initial state.
+ This method should be called before starting a new run.
+ """
+ ...
+
+ @abc.abstractmethod
+ async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
+ """
+ Process a single step of the agent.
+ """
+ ...
+
+ @abc.abstractmethod
+ def done(self) -> bool:
+ """
+ Check if the agent has completed its task.
+ Returns True if the agent is done, False otherwise.
+ """
+ ...
+
+ @abc.abstractmethod
+ def get_final_llm_resp(self) -> LLMResponse | None:
+ """
+ Get the final observation from the agent.
+ This method should be called after the agent is done.
+ """
+ ...
diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
new file mode 100644
index 000000000..3163e02e4
--- /dev/null
+++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
@@ -0,0 +1,300 @@
+import sys
+import traceback
+import typing as T
+from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
+from ...context import PipelineContext
+from astrbot.core.provider.provider import Provider
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.message.message_event_result import (
+ MessageChain,
+)
+from astrbot.core.provider.entities import (
+ ProviderRequest,
+ LLMResponse,
+ ToolCallMessageSegment,
+ AssistantMessageSegment,
+ ToolCallsResult,
+)
+from mcp.types import (
+ TextContent,
+ ImageContent,
+ EmbeddedResource,
+ TextResourceContents,
+ BlobResourceContents,
+)
+from astrbot.core.star.star_handler import EventType
+from astrbot import logger
+
+if sys.version_info >= (3, 12):
+ from typing import override
+else:
+ from typing_extensions import override
+
+
+# TODO:
+# 1. 处理平台不兼容的处理器
+
+
+class ToolLoopAgent(BaseAgentRunner):
+ def __init__(
+ self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
+ ) -> None:
+ self.provider = provider
+ self.req = None
+ self.event = event
+ self.pipeline_ctx = pipeline_ctx
+ self._state = AgentState.IDLE
+ self.final_llm_resp = None
+ self.streaming = False
+
+ @override
+ async def reset(self, req: ProviderRequest, streaming: bool) -> None:
+ self.req = req
+ self.streaming = streaming
+ self.final_llm_resp = None
+ self._state = AgentState.IDLE
+
+ def _transition_state(self, new_state: AgentState) -> None:
+ """转换 Agent 状态"""
+ if self._state != new_state:
+ logger.debug(f"Agent state transition: {self._state} -> {new_state}")
+ self._state = new_state
+
+ async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
+ """Yields chunks *and* a final LLMResponse."""
+ if self.streaming:
+ stream = self.provider.text_chat_stream(**self.req.__dict__)
+ async for resp in stream: # type: ignore
+ yield resp
+ else:
+ yield await self.provider.text_chat(**self.req.__dict__)
+
+ @override
+ async def step(self):
+ """
+ Process a single step of the agent.
+ This method should return the result of the step.
+ """
+ if not self.req:
+ raise ValueError("Request is not set. Please call reset() first.")
+
+ # 开始处理,转换到运行状态
+ self._transition_state(AgentState.RUNNING)
+ llm_resp_result = None
+
+ async for llm_response in self._iter_llm_responses():
+ assert isinstance(llm_response, LLMResponse)
+ if llm_response.is_chunk:
+ if llm_response.result_chain:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(chain=llm_response.result_chain),
+ )
+ else:
+ yield AgentResponse(
+ type="streaming_delta",
+ data=AgentResponseData(
+ chain=MessageChain().message(llm_response.completion_text)
+ ),
+ )
+ continue
+ llm_resp_result = llm_response
+ break # got final response
+
+ if not llm_resp_result:
+ return
+
+ # 处理 LLM 响应
+ llm_resp = llm_resp_result
+ logger.debug(f"LLMResp: {llm_resp}")
+
+ if llm_resp.role == "err":
+ # 如果 LLM 响应错误,转换到错误状态
+ self.final_llm_resp = llm_resp
+ self._transition_state(AgentState.ERROR)
+ yield AgentResponse(
+ type="err",
+ data=AgentResponseData(
+ chain=MessageChain().message(
+ f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
+ )
+ ),
+ )
+
+ if not llm_resp.tools_call_name:
+ # 如果没有工具调用,转换到完成状态
+ self.final_llm_resp = llm_resp
+ self._transition_state(AgentState.DONE)
+
+ # 执行事件钩子
+ await self.pipeline_ctx.call_event_hook(
+ self.event, EventType.OnLLMResponseEvent, llm_resp
+ )
+
+ # 返回 LLM 结果
+ if llm_resp.result_chain:
+ yield AgentResponse(
+ type="llm_result",
+ data=AgentResponseData(chain=llm_resp.result_chain),
+ )
+ elif llm_resp.completion_text:
+ yield AgentResponse(
+ type="llm_result",
+ data=AgentResponseData(
+ chain=MessageChain().message(llm_resp.completion_text)
+ ),
+ )
+
+ # 如果有工具调用,还需处理工具调用
+ if llm_resp.tools_call_name:
+ tool_call_result_blocks = []
+ for tool_call_name in llm_resp.tools_call_name:
+ yield AgentResponse(
+ type="tool_call",
+ data=AgentResponseData(
+ chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
+ ),
+ )
+ async for result in self._handle_function_tools(self.req, llm_resp):
+ if isinstance(result, list):
+ tool_call_result_blocks = result
+ elif isinstance(result, MessageChain):
+ yield AgentResponse(
+ type="tool_call_result",
+ data=AgentResponseData(chain=result),
+ )
+ # 将结果添加到上下文中
+ tool_calls_result = ToolCallsResult(
+ tool_calls_info=AssistantMessageSegment(
+ role="assistant",
+ tool_calls=llm_resp.to_openai_tool_calls(),
+ content=llm_resp.completion_text,
+ ),
+ tool_calls_result=tool_call_result_blocks,
+ )
+ self.req.append_tool_calls_result(tool_calls_result)
+
+ async def _handle_function_tools(
+ self,
+ req: ProviderRequest,
+ llm_response: LLMResponse,
+ ) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
+ """处理函数工具调用。"""
+ tool_call_result_blocks: list[ToolCallMessageSegment] = []
+ logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
+
+ # 执行函数调用
+ for func_tool_name, func_tool_args, func_tool_id in zip(
+ llm_response.tools_call_name,
+ llm_response.tools_call_args,
+ llm_response.tools_call_ids,
+ ):
+ try:
+ if not req.func_tool:
+ return
+ func_tool = req.func_tool.get_func(func_tool_name)
+ if func_tool.origin == "mcp":
+ logger.info(
+ f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
+ )
+ client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
+ res = await client.session.call_tool(func_tool.name, func_tool_args)
+ if not res:
+ continue
+ if isinstance(res.content[0], TextContent):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content=res.content[0].text,
+ )
+ )
+ yield MessageChain().message(res.content[0].text)
+ elif isinstance(res.content[0], ImageContent):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回了图片(已直接发送给用户)",
+ )
+ )
+ yield MessageChain().base64_image(res.content[0].data)
+ elif isinstance(res.content[0], EmbeddedResource):
+ resource = res.content[0].resource
+ if isinstance(resource, TextResourceContents):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content=resource.text,
+ )
+ )
+ yield MessageChain().message(resource.text)
+ elif (
+ isinstance(resource, BlobResourceContents)
+ and resource.mimeType
+ and resource.mimeType.startswith("image/")
+ ):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回了图片(已直接发送给用户)",
+ )
+ )
+ yield MessageChain().base64_image(res.content[0].data)
+ else:
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回的数据类型不受支持",
+ )
+ )
+ yield MessageChain().message("返回的数据类型不受支持。")
+ else:
+ logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
+ # 尝试调用工具函数
+ wrapper = self.pipeline_ctx.call_handler(
+ self.event, func_tool.handler, **func_tool_args
+ )
+ async for resp in wrapper:
+ if resp is not None:
+ # Tool 返回结果
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content=resp,
+ )
+ )
+ yield MessageChain().message(resp)
+ else:
+ # Tool 直接请求发送消息给用户
+ # 这里我们将直接结束 Agent Loop。
+ self._transition_state(AgentState.DONE)
+ if res := self.event.get_result():
+ if res.chain:
+ yield MessageChain(chain=res.chain)
+
+ self.event.clear_result()
+ except Exception as e:
+ logger.warning(traceback.format_exc())
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content=f"error: {str(e)}",
+ )
+ )
+
+ # 处理函数调用响应
+ if tool_call_result_blocks:
+ yield tool_call_result_blocks
+
+ def done(self) -> bool:
+ """检查 Agent 是否已完成工作"""
+ return self._state in (AgentState.DONE, AgentState.ERROR)
+
+ def get_final_llm_resp(self) -> LLMResponse | None:
+ return self.final_llm_resp
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index dd4349813..ae588133a 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -3,6 +3,7 @@
"""
import traceback
+import copy
import asyncio
import json
from typing import Union, AsyncGenerator
@@ -20,40 +21,27 @@ from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
- ToolCallMessageSegment,
- AssistantMessageSegment,
- ToolCallsResult,
-)
-from astrbot.core.star.star_handler import star_handlers_registry, EventType
-from astrbot.core.star.star import star_map
-from astrbot.core.star.session_llm_manager import SessionServiceManager
-from mcp.types import (
- TextContent,
- ImageContent,
- EmbeddedResource,
- TextResourceContents,
- BlobResourceContents,
)
+from astrbot.core.star.star_handler import EventType
from astrbot.core import web_chat_back_queue
+from ..agent_runner.tool_loop_agent import ToolLoopAgent
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
- self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list
- self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
- "wake_prefix"
- ] # str
- self.max_context_length = ctx.astrbot_config["provider_settings"][
- "max_context_length"
- ] # int
- self.dequeue_context_length = min(
- max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
+ conf = ctx.astrbot_config
+ settings = conf["provider_settings"]
+ self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
+ self.provider_wake_prefix: str = settings["wake_prefix"] # str
+ self.max_context_length = settings["max_context_length"] # int
+ self.dequeue_context_length: int = min(
+ max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
- ) # int
- self.streaming_response = ctx.astrbot_config["provider_settings"][
- "streaming_response"
- ] # bool
+ )
+ self.streaming_response: bool = settings["streaming_response"]
+ self.max_step: int = settings.get("max_agent_step", 10)
+ self.show_tool_use: bool = settings.get("show_tool_use_status", True)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -90,10 +78,7 @@ class LLMRequestSubStage(Stage):
)
if req.conversation:
- all_contexts = json.loads(req.conversation.history)
- req.contexts = self._process_tool_message_pairs(
- all_contexts, remove_tags=True
- )
+ req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest(prompt="", image_urls=[])
@@ -135,26 +120,7 @@ class LLMRequestSubStage(Stage):
return
# 执行请求 LLM 前事件钩子。
- # 装饰 system_prompt 等功能
- # 获取当前平台ID
- platform_id = event.get_platform_id()
- handlers = star_handlers_registry.get_handlers_by_event_type(
- EventType.OnLLMRequestEvent, platform_id=platform_id
- )
- for handler in handlers:
- try:
- logger.debug(
- f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
- )
- await handler.handler(event, req)
- except BaseException:
- logger.error(traceback.format_exc())
-
- if event.is_stopped():
- logger.info(
- f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
- )
- return
+ await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req)
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
@@ -184,83 +150,63 @@ class LLMRequestSubStage(Stage):
if not req.session_id:
req.session_id = event.unified_msg_origin
- async def requesting(req: ProviderRequest):
- try:
- need_loop = True
- while need_loop:
- need_loop = False
-
- # 在每次实际请求 LLM 前检查会话级别的启停状态,这可以防止插件或函数工具调用时绕过会话级别的限制
- if not SessionServiceManager.should_process_llm_request(event):
- logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。")
- return
-
- logger.debug(f"提供商请求 Payload: {req}")
- final_llm_response = None
+ # fix messages
+ req.contexts = self.fix_messages(req.contexts)
- if self.streaming_response:
- stream = provider.text_chat_stream(**req.__dict__)
- async for llm_response in stream:
- if llm_response.is_chunk:
- if llm_response.result_chain:
- yield llm_response.result_chain # MessageChain
- else:
- yield MessageChain().message(
- llm_response.completion_text
- )
- else:
- final_llm_response = llm_response
- else:
- final_llm_response = await provider.text_chat(
- **req.__dict__
- ) # 请求 LLM
+ # Call Agent
+ tool_loop_agent = ToolLoopAgent(
+ provider=provider,
+ event=event,
+ pipeline_ctx=self.ctx,
+ )
+ await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
- if not final_llm_response:
- raise Exception("LLM response is None.")
+ async def requesting():
+ step_idx = 0
+ while step_idx < self.max_step:
+ step_idx += 1
+ try:
+ async for resp in tool_loop_agent.step():
+ if resp.type == "tool_call_result":
+ continue # 跳过工具调用结果
+ if resp.type == "tool_call":
+ if self.streaming_response:
+ # 用来标记流式响应需要分节
+ yield MessageChain(chain=[], type="break")
+ if self.show_tool_use or event.get_platform_name() == "webchat":
+ resp.data["chain"].type = "tool_call"
+ await event.send(resp.data["chain"])
+ continue
- # 执行 LLM 响应后的事件钩子。
- handlers = star_handlers_registry.get_handlers_by_event_type(
- EventType.OnLLMResponseEvent
+ if not self.streaming_response:
+ content_typ = (
+ ResultContentType.LLM_RESULT
+ if resp.type == "llm_result"
+ else ResultContentType.GENERAL_RESULT
+ )
+ event.set_result(
+ MessageEventResult(
+ chain=resp.data["chain"].chain,
+ result_content_type=content_typ,
+ )
+ )
+ yield
+ event.clear_result()
+ else:
+ if resp.type == "streaming_delta":
+ yield resp.data["chain"] # MessageChain
+ if tool_loop_agent.done():
+ break
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ event.set_result(
+ MessageEventResult().message(
+ f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
+ )
)
- for handler in handlers:
- try:
- logger.debug(
- f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
- )
- await handler.handler(event, final_llm_response)
- except BaseException:
- logger.error(traceback.format_exc())
-
- if event.is_stopped():
- logger.info(
- f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
- )
- return
-
- if self.streaming_response:
- # 流式输出的处理
- async for result in self._handle_llm_stream_response(
- event, req, final_llm_response
- ):
- if isinstance(result, ProviderRequest):
- # 有函数工具调用并且返回了结果,我们需要再次请求 LLM
- req = result
- need_loop = True
- else:
- yield
- else:
- # 非流式输出的处理
- async for result in self._handle_llm_response(
- event, req, final_llm_response
- ):
- if isinstance(result, ProviderRequest):
- # 有函数工具调用并且返回了结果,我们需要再次请求 LLM
- req = result
- need_loop = True
- else:
- yield
-
+ return
asyncio.create_task(
Metric.upload(
llm_tick=1,
@@ -269,44 +215,38 @@ class LLMRequestSubStage(Stage):
)
)
- # 保存到历史记录
- await self._save_to_history(event, req, final_llm_response)
-
- except BaseException as e:
- logger.error(traceback.format_exc())
- event.set_result(
- MessageEventResult().message(
- f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
- )
- )
-
- if not self.streaming_response:
- event.set_extra("tool_call_result", None)
- async for _ in requesting(req):
- yield
- else:
+ if self.streaming_response:
+ # 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
- .set_async_stream(requesting(req))
+ .set_async_stream(requesting())
)
- # 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
yield
-
- if event.get_extra("tool_call_result"):
- event.set_result(event.get_extra("tool_call_result"))
- event.set_extra("tool_call_result", None)
+ if tool_loop_agent.done():
+ if final_llm_resp := tool_loop_agent.get_final_llm_resp():
+ if final_llm_resp.completion_text:
+ chain = (
+ MessageChain().message(final_llm_resp.completion_text).chain
+ )
+ else:
+ chain = final_llm_resp.result_chain.chain
+ event.set_result(
+ MessageEventResult(
+ chain=chain,
+ result_content_type=ResultContentType.STREAMING_FINISH,
+ )
+ )
+ else:
+ async for _ in requesting():
yield
- # 暂时直接发出去
- if img_b64 := event.get_extra("tool_call_img_respond"):
- await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
- event.set_extra("tool_call_img_respond", None)
-
+ # 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
- # 异步处理 WebChat 特殊情况
asyncio.create_task(self._handle_webchat(event, req))
+ await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
+
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
# 检查会话级别的LLM启停状态,防止标题生成功能绕过会话级别限制
@@ -324,10 +264,6 @@ class LLMRequestSubStage(Stage):
return
provider = self.ctx.plugin_manager.context.get_using_provider()
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
- # if len(latest_pair) > 1:
- # cleaned_text += (
- # "\nAssistant: " + latest_pair[1].get("content", "").strip()
- # )
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await provider.text_chat(
system_prompt="You are expert in summarizing user's query.",
@@ -368,322 +304,50 @@ class LLMRequestSubStage(Stage):
}
)
- async def _handle_llm_response(
- self,
- event: AstrMessageEvent,
- req: ProviderRequest,
- llm_response: LLMResponse,
- ) -> AsyncGenerator[Union[None, ProviderRequest], None]:
- """处理非流式 LLM 响应。
-
- Returns:
- AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
-
- Yields:
- Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
- """
- if llm_response.role == "assistant":
- # text completion
- if llm_response.result_chain:
- event.set_result(
- MessageEventResult(
- chain=llm_response.result_chain.chain
- ).set_result_content_type(ResultContentType.LLM_RESULT)
- )
- else:
- event.set_result(
- MessageEventResult()
- .message(llm_response.completion_text)
- .set_result_content_type(ResultContentType.LLM_RESULT)
- )
- elif llm_response.role == "err":
- event.set_result(
- MessageEventResult().message(
- f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
- )
- )
- elif llm_response.role == "tool":
- # 处理函数工具调用
- async for result in self._handle_function_tools(event, req, llm_response):
- yield result
-
- async def _handle_llm_stream_response(
- self,
- event: AstrMessageEvent,
- req: ProviderRequest,
- llm_response: LLMResponse,
- ) -> AsyncGenerator[Union[None, ProviderRequest], None]:
- """处理流式 LLM 响应。
-
- 专门用于处理流式输出完成后的响应,与非流式响应处理分离。
-
- Returns:
- AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
-
- Yields:
- Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
- """
- if llm_response.role == "assistant":
- # text completion
- if llm_response.result_chain:
- event.set_result(
- MessageEventResult(
- chain=llm_response.result_chain.chain
- ).set_result_content_type(ResultContentType.STREAMING_FINISH)
- )
- else:
- event.set_result(
- MessageEventResult()
- .message(llm_response.completion_text)
- .set_result_content_type(ResultContentType.STREAMING_FINISH)
- )
- elif llm_response.role == "err":
- event.set_result(
- MessageEventResult().message(
- f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
- )
- )
- elif llm_response.role == "tool":
- # 处理函数工具调用
- async for result in self._handle_function_tools(event, req, llm_response):
- yield result
-
- async def _handle_function_tools(
- self,
- event: AstrMessageEvent,
- req: ProviderRequest,
- llm_response: LLMResponse,
- ) -> AsyncGenerator[Union[None, ProviderRequest], None]:
- """处理函数工具调用。
-
- Returns:
- AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
- """
- # function calling
- tool_call_result: list[ToolCallMessageSegment] = []
- logger.info(
- f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
- )
- for func_tool_name, func_tool_args, func_tool_id in zip(
- llm_response.tools_call_name,
- llm_response.tools_call_args,
- llm_response.tools_call_ids,
- ):
- try:
- func_tool = req.func_tool.get_func(func_tool_name)
- if func_tool.origin == "mcp":
- logger.info(
- f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
- )
- client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
- res = await client.session.call_tool(func_tool.name, func_tool_args)
- if res:
- # TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
- if isinstance(res.content[0], TextContent):
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=res.content[0].text,
- )
- )
- elif isinstance(res.content[0], ImageContent):
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回了图片(已直接发送给用户)",
- )
- )
- event.set_extra(
- "tool_call_img_respond",
- res.content[0].data,
- )
- elif isinstance(res.content[0], EmbeddedResource):
- resource = res.content[0].resource
- if isinstance(resource, TextResourceContents):
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=resource.text,
- )
- )
- elif (
- isinstance(resource, BlobResourceContents)
- and resource.mimeType
- and resource.mimeType.startswith("image/")
- ):
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回了图片(已直接发送给用户)",
- )
- )
- event.set_extra(
- "tool_call_img_respond",
- res.content[0].data,
- )
- else:
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回的数据类型不受支持",
- )
- )
- else:
- # 获取处理器,过滤掉平台不兼容的处理器
- platform_id = event.get_platform_id()
- star_md = star_map.get(func_tool.handler_module_path)
- if (
- star_md
- and platform_id in star_md.supported_platforms
- and not star_md.supported_platforms[platform_id]
- ):
- logger.debug(
- f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
- )
- # 直接跳过,不添加任何消息到tool_call_result
- continue
-
- logger.info(
- f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
- )
- # 尝试调用工具函数
- wrapper = self._call_handler(
- self.ctx, event, func_tool.handler, **func_tool_args
- )
- async for resp in wrapper:
- if resp is not None: # 有 return 返回
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=resp,
- )
- )
- else:
- res = event.get_result()
- if res and res.chain:
- event.set_extra("tool_call_result", res)
- yield # 有生成器返回
- event.clear_result() # 清除上一个 handler 的结果
- except BaseException as e:
- logger.warning(traceback.format_exc())
- tool_call_result.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=f"error: {str(e)}",
- )
- )
- if tool_call_result:
- # 函数调用结果
- req.func_tool = None # 暂时不支持递归工具调用
- assistant_msg_seg = AssistantMessageSegment(
- role="assistant", tool_calls=llm_response.to_openai_tool_calls()
- )
- # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
- req.tool_calls_result = ToolCallsResult(
- tool_calls_info=assistant_msg_seg,
- tool_calls_result=tool_call_result,
- )
- yield req # 再次执行 LLM 请求
- else:
- if llm_response.completion_text:
- event.set_result(
- MessageEventResult().message(llm_response.completion_text)
- )
-
async def _save_to_history(
- self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
+ self,
+ event: AstrMessageEvent,
+ req: ProviderRequest,
+ llm_response: LLMResponse | None,
):
- if not req or not req.conversation or not llm_response:
+ if (
+ not req
+ or not req.conversation
+ or not llm_response
+ or llm_response.role != "assistant"
+ ):
return
- if llm_response.role == "assistant":
- # 文本回复
- contexts = req.contexts.copy()
- contexts.append(await req.assemble_context())
+ # 历史上下文
+ messages = copy.deepcopy(req.contexts)
+ # 这一轮对话请求的用户输入
+ messages.append(await req.assemble_context())
+ # 这一轮对话的 LLM 响应
+ if req.tool_calls_result:
+ if not isinstance(req.tool_calls_result, list):
+ messages.extend(req.tool_calls_result.to_openai_messages())
+ elif isinstance(req.tool_calls_result, list):
+ for tcr in req.tool_calls_result:
+ messages.extend(tcr.to_openai_messages())
+ messages.append({"role": "assistant", "content": llm_response.completion_text})
+ messages = list(filter(lambda item: "_no_save" not in item, messages))
+ await self.conv_manager.update_conversation(
+ event.unified_msg_origin, req.conversation.cid, history=messages
+ )
+ logger.debug(f"messages persisted: {messages}")
- # 记录并标记函数调用结果
- if req.tool_calls_result:
- tool_calls_messages = req.tool_calls_result.to_openai_messages()
-
- # 添加标记
- for message in tool_calls_messages:
- message["_tool_call_history"] = True
-
- processed_tool_messages = self._process_tool_message_pairs(
- tool_calls_messages, remove_tags=False
- )
-
- contexts.extend(processed_tool_messages)
-
- contexts.append(
- {"role": "assistant", "content": llm_response.completion_text}
- )
- contexts_to_save = list(
- filter(lambda item: "_no_save" not in item, contexts)
- )
- await self.conv_manager.update_conversation(
- event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
- )
-
- def _process_tool_message_pairs(self, messages, remove_tags=True):
- """处理工具调用消息,确保assistant和tool消息成对出现
-
- Args:
- messages (list): 消息列表
- remove_tags (bool): 是否移除_tool_call_history标记
-
- Returns:
- list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
- """
- result = []
- i = 0
-
- while i < len(messages):
- current_msg = messages[i]
-
- # 普通消息直接添加
- if "_tool_call_history" not in current_msg:
- result.append(current_msg.copy() if remove_tags else current_msg)
- i += 1
- continue
-
- # 工具调用消息成对处理
- if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
- assistant_msg = current_msg.copy()
-
- if remove_tags and "_tool_call_history" in assistant_msg:
- del assistant_msg["_tool_call_history"]
-
- related_tools = []
- j = i + 1
- while (
- j < len(messages)
- and messages[j].get("role") == "tool"
- and "_tool_call_history" in messages[j]
- ):
- tool_msg = messages[j].copy()
-
- if remove_tags:
- del tool_msg["_tool_call_history"]
-
- related_tools.append(tool_msg)
- j += 1
-
- # 成对的时候添加到结果
- if related_tools:
- result.append(assistant_msg)
- result.extend(related_tools)
-
- i = j # 跳过已处理
+ def fix_messages(self, messages: list[dict]) -> list[dict]:
+ """验证并且修复上下文"""
+ fixed_messages = []
+ for message in messages:
+ if message.get("role") == "tool":
+ # tool block 前面必须要有 user 和 assistant block
+ if len(fixed_messages) < 2:
+ # 这种情况可能是上下文被截断导致的
+ # 我们直接将之前的上下文都清空
+ fixed_messages = []
+ else:
+ fixed_messages.append(message)
else:
- # 单独的tool消息
- i += 1
-
- return result
+ fixed_messages.append(message)
+ return fixed_messages
diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py
index c7817e49c..00f58d55b 100644
--- a/astrbot/core/pipeline/process_stage/method/star_request.py
+++ b/astrbot/core/pipeline/process_stage/method/star_request.py
@@ -50,7 +50,7 @@ class StarRequestSubStage(Stage):
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
)
- wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
+ wrapper = self.ctx.call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py
index c7d4ff792..b41794733 100644
--- a/astrbot/core/pipeline/stage.py
+++ b/astrbot/core/pipeline/stage.py
@@ -1,12 +1,8 @@
from __future__ import annotations
import abc
-import inspect
-import traceback
-from astrbot.api import logger
-from typing import List, AsyncGenerator, Union, Awaitable
+from typing import List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
-from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
raise NotImplementedError
-
- async def _call_handler(
- self,
- ctx: PipelineContext,
- event: AstrMessageEvent,
- handler: Awaitable,
- *args,
- **kwargs,
- ) -> AsyncGenerator[None, None]:
- """执行事件处理函数并处理其返回结果
-
- 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
- 1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
- 2. 协程: 执行一次并处理返回值
-
- Args:
- ctx (PipelineContext): 消息管道上下文对象
- event (AstrMessageEvent): 待处理的事件对象
- handler (Awaitable): 事件处理函数
- *args: 传递给handler的位置参数
- **kwargs: 传递给handler的关键字参数
-
- Returns:
- AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
- """
- ready_to_call = None # 一个协程或者异步生成器(async def)
-
- trace_ = None
-
- try:
- ready_to_call = handler(event, *args, **kwargs)
- except TypeError as _:
- # 向下兼容
- trace_ = traceback.format_exc()
- # 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
- ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
-
- if isinstance(ready_to_call, AsyncGenerator):
- # 如果是一个异步生成器, 进入洋葱模型
- _has_yielded = False # 是否返回过值
- try:
- async for ret in ready_to_call:
- # 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
- # 返回值只能是 MessageEventResult 或者 None(无返回值)
- _has_yielded = True
- if isinstance(ret, (MessageEventResult, CommandResult)):
- # 如果返回值是 MessageEventResult, 设置结果并继续
- event.set_result(ret)
- yield # 传递控制权给上一层的process函数
- else:
- # 如果返回值是 None, 则不设置结果并继续
- # 继续执行后续阶段
- yield ret # 传递控制权给上一层的process函数
- if not _has_yielded:
- # 如果这个异步生成器没有执行到yield分支
- yield
- except Exception as e:
- logger.error(f"Previous Error: {trace_}")
- raise e
- elif inspect.iscoroutine(ready_to_call):
- # 如果只是一个协程, 直接执行
- ret = await ready_to_call
- if isinstance(ret, (MessageEventResult, CommandResult)):
- event.set_result(ret)
- yield # 传递控制权给上一层的process函数
- else:
- yield ret # 传递控制权给上一层的process函数
diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py
index 3636cd611..5b3a1d916 100644
--- a/astrbot/core/platform/sources/telegram/tg_event.py
+++ b/astrbot/core/platform/sources/telegram/tg_event.py
@@ -158,6 +158,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
async for chain in generator:
if isinstance(chain, MessageChain):
+ if chain.type == "break":
+ # 分割符
+ message_id = None # 重置消息 ID
+ delta = "" # 重置 delta
+ continue
+
# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py
index 76b5dc85d..111027a5c 100644
--- a/astrbot/core/platform/sources/webchat/webchat_event.py
+++ b/astrbot/core/platform/sources/webchat/webchat_event.py
@@ -35,6 +35,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"cid": cid,
"data": data,
"streaming": streaming,
+ "chain_type": message.type,
}
)
elif isinstance(comp, Image):
@@ -110,6 +111,18 @@ class WebChatMessageEvent(AstrMessageEvent):
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
async for chain in generator:
+ if chain.type == "break" and final_data:
+ # 分割符
+ await web_chat_back_queue.put(
+ {
+ "type": "end",
+ "data": final_data,
+ "streaming": True,
+ "cid": self.session_id.split("!")[-1],
+ }
+ )
+ final_data = ""
+ continue
final_data += await WebChatMessageEvent._send(
chain, session_id=self.session_id, streaming=True
)
diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py
index e01e46cf9..abb01960c 100644
--- a/astrbot/core/provider/entities.py
+++ b/astrbot/core/provider/entities.py
@@ -58,7 +58,7 @@ class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
- tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
+ tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
role: str = "assistant"
def to_dict(self):
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
}
if self.content:
ret["content"] = self.content
- elif self.tool_calls:
+ if self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@@ -95,19 +95,19 @@ class ProviderRequest:
"""提示词"""
session_id: str = ""
"""会话 ID"""
- image_urls: List[str] = None
+ image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
- func_tool: FuncCall = None
+ func_tool: FuncCall | None = None
"""可用的函数工具"""
- contexts: List = None
+ contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
- conversation: Conversation = None
+ conversation: Conversation | None = None
- tool_calls_result: ToolCallsResult = None
+ tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
def __repr__(self):
@@ -116,6 +116,14 @@ class ProviderRequest:
def __str__(self):
return self.__repr__()
+ def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
+ """添加工具调用结果到请求中"""
+ if not self.tool_calls_result:
+ self.tool_calls_result = []
+ if isinstance(self.tool_calls_result, ToolCallsResult):
+ self.tool_calls_result = [self.tool_calls_result]
+ self.tool_calls_result.append(tool_calls_result)
+
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 """
if not self.contexts:
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 5886b8083..2abe59d65 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -190,11 +190,6 @@ class ProviderManager:
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
- case "llm_tuner":
- logger.info("加载 LLM Tuner 工具 ...")
- from .sources.llmtuner_source import (
- LLMTunerModelLoader as LLMTunerModelLoader,
- )
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "dashscope":
@@ -330,8 +325,6 @@ class ProviderManager:
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
- self.db_helper,
- self.provider_settings.get("persistant_history", True),
self.selected_default_persona,
)
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index c285ebd42..1ecca3537 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -1,6 +1,5 @@
import abc
from typing import List
-from astrbot.core.db import BaseDatabase
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
@@ -53,15 +52,13 @@ class Provider(AbstractProvider):
self,
provider_config: dict,
provider_settings: dict,
- persistant_history: bool = True,
- db_helper: BaseDatabase = None,
- default_persona: Personality = None,
+ default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
- self.curr_personality: Personality = default_persona
+ self.curr_personality = default_persona
"""维护了当前的使用的 persona,即人格。可能为 None"""
@abc.abstractmethod
@@ -86,11 +83,11 @@ class Provider(AbstractProvider):
self,
prompt: str,
session_id: str = None,
- image_urls: List[str] = None,
+ image_urls: list[str] = None,
func_tool: FuncCall = None,
- contexts: List = None,
+ contexts: list = None,
system_prompt: str = None,
- tool_calls_result: ToolCallsResult = None,
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -114,11 +111,11 @@ class Provider(AbstractProvider):
self,
prompt: str,
session_id: str = None,
- image_urls: List[str] = None,
+ image_urls: list[str] = None,
func_tool: FuncCall = None,
- contexts: List = None,
+ contexts: list = None,
system_prompt: str = None,
- tool_calls_result: ToolCallsResult = None,
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py
index c3ad45868..a53250fb7 100644
--- a/astrbot/core/provider/sources/anthropic_source.py
+++ b/astrbot/core/provider/sources/anthropic_source.py
@@ -1,3 +1,6 @@
+import json
+import anthropic
+import base64
from typing import List
from mimetypes import guess_type
@@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic
from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
-from astrbot.core.db import BaseDatabase
-from astrbot.api.provider import Provider, Personality
+from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
-from astrbot.core.message.message_event_result import MessageChain
-from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
-from .openai_source import ProviderOpenAIOfficial
+from astrbot.core.provider.entities import LLMResponse
+from typing import AsyncGenerator
@register_provider_adapter(
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
)
-class ProviderAnthropic(ProviderOpenAIOfficial):
+class ProviderAnthropic(Provider):
def __init__(
self,
- provider_config: dict,
- provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=True,
- default_persona: Personality = None,
+ provider_config,
+ provider_settings,
+ default_persona=None,
) -> None:
- # Skip OpenAI's __init__ and call Provider's __init__ directly
- Provider.__init__(
- self,
+ super().__init__(
provider_config,
provider_settings,
- persistant_history,
- db_helper,
default_persona,
)
- self.chosen_api_key = None
+ self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", [])
- self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
+ self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
@@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
self.set_model(provider_config["model_config"]["model"])
+ def _prepare_payload(self, messages: list[dict]):
+ """准备 Anthropic API 的请求 payload
+
+ Args:
+ messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息
+ Returns:
+ system_prompt: 系统提示内容
+ new_messages: 处理后的消息列表,去除系统提示
+ """
+ system_prompt = ""
+ new_messages = []
+ for message in messages:
+ if message["role"] == "system":
+ system_prompt = message["content"]
+ elif message["role"] == "assistant":
+ blocks = []
+ if isinstance(message["content"], str):
+ blocks.append({"type": "text", "text": message["content"]})
+ if "tool_calls" in message:
+ for tool_call in message["tool_calls"]:
+ blocks.append( # noqa: PERF401
+ {
+ "type": "tool_use",
+ "name": tool_call["function"]["name"],
+ "input": json.loads(tool_call["function"]["arguments"])
+ if isinstance(tool_call["function"]["arguments"], str)
+ else tool_call["function"]["arguments"],
+ "id": tool_call["id"],
+ }
+ )
+ new_messages.append(
+ {
+ "role": "assistant",
+ "content": blocks,
+ }
+ )
+ elif message["role"] == "tool":
+ new_messages.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "tool_result",
+ "tool_use_id": message["tool_call_id"],
+ "content": message["content"],
+ }
+ ],
+ }
+ )
+ else:
+ new_messages.append(message)
+
+ return system_prompt, new_messages
+
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
- tool_list = tools.get_func_desc_anthropic_style()
- if tool_list:
+ if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
completion = await self.client.messages.create(**payloads, stream=False)
@@ -64,70 +112,157 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if len(completion.content) == 0:
raise Exception("API 返回的 completion 为空。")
- # TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
- # 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
- content = completion.content[-1]
- llm_response = LLMResponse("assistant")
+ llm_response = LLMResponse(role="assistant")
- if content.type == "text":
- # text completion
- completion_text = str(content.text).strip()
- # llm_response.completion_text = completion_text
- llm_response.result_chain = MessageChain().message(completion_text)
-
- # Anthropic每次只返回一个函数调用
- if completion.stop_reason == "tool_use":
- # tools call (function calling)
- args_ls = []
- func_name_ls = []
- tool_use_ids = []
- func_name_ls.append(content.name)
- args_ls.append(content.input)
- tool_use_ids.append(content.id)
- llm_response.role = "tool"
- llm_response.tools_call_args = args_ls
- llm_response.tools_call_name = func_name_ls
- llm_response.tools_call_ids = tool_use_ids
+ for content_block in completion.content:
+ if content_block.type == "text":
+ completion_text = str(content_block.text).strip()
+ llm_response.completion_text = completion_text
+ if content_block.type == "tool_use":
+ llm_response.tools_call_args.append(content_block.input)
+ llm_response.tools_call_name.append(content_block.name)
+ llm_response.tools_call_ids.append(content_block.id)
+ # TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args:
- logger.error(f"API 返回的 completion 无法解析:{completion}。")
- raise Exception(f"API 返回的 completion 无法解析:{completion}。")
-
- llm_response.raw_completion = completion
+ raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
return llm_response
+ async def _query_stream(
+ self, payloads: dict, tools: FuncCall
+ ) -> AsyncGenerator[LLMResponse, None]:
+ if tools:
+ if tool_list := tools.get_func_desc_anthropic_style():
+ payloads["tools"] = tool_list
+
+ # 用于累积工具调用信息
+ tool_use_buffer = {}
+ # 用于累积最终结果
+ final_text = ""
+ final_tool_calls = []
+
+ async with self.client.messages.stream(**payloads) as stream:
+ assert isinstance(stream, anthropic.AsyncMessageStream)
+ async for event in stream:
+ if event.type == "content_block_start":
+ if event.content_block.type == "text":
+ # 文本块开始
+ yield LLMResponse(
+ role="assistant", completion_text="", is_chunk=True
+ )
+ elif event.content_block.type == "tool_use":
+ # 工具使用块开始,初始化缓冲区
+ tool_use_buffer[event.index] = {
+ "id": event.content_block.id,
+ "name": event.content_block.name,
+ "input": {},
+ }
+
+ elif event.type == "content_block_delta":
+ if event.delta.type == "text_delta":
+ # 文本增量
+ final_text += event.delta.text
+ yield LLMResponse(
+ role="assistant",
+ completion_text=event.delta.text,
+ is_chunk=True,
+ )
+ elif event.delta.type == "input_json_delta":
+ # 工具调用参数增量
+ if event.index in tool_use_buffer:
+ # 累积 JSON 输入
+ if "input_json" not in tool_use_buffer[event.index]:
+ tool_use_buffer[event.index]["input_json"] = ""
+ tool_use_buffer[event.index]["input_json"] += (
+ event.delta.partial_json
+ )
+
+ elif event.type == "content_block_stop":
+ # 内容块结束
+ if event.index in tool_use_buffer:
+ # 解析完整的工具调用
+ tool_info = tool_use_buffer[event.index]
+ try:
+ if "input_json" in tool_info:
+ tool_info["input"] = json.loads(tool_info["input_json"])
+
+ # 添加到最终结果
+ final_tool_calls.append(
+ {
+ "id": tool_info["id"],
+ "name": tool_info["name"],
+ "input": tool_info["input"],
+ }
+ )
+
+ yield LLMResponse(
+ role="tool",
+ completion_text="",
+ tools_call_args=[tool_info["input"]],
+ tools_call_name=[tool_info["name"]],
+ tools_call_ids=[tool_info["id"]],
+ is_chunk=True,
+ )
+ except json.JSONDecodeError:
+ # JSON 解析失败,跳过这个工具调用
+ logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}")
+
+ # 清理缓冲区
+ del tool_use_buffer[event.index]
+
+ # 返回最终的完整结果
+ final_response = LLMResponse(
+ role="assistant", completion_text=final_text, is_chunk=False
+ )
+
+ if final_tool_calls:
+ final_response.tools_call_args = [
+ call["input"] for call in final_tool_calls
+ ]
+ final_response.tools_call_name = [call["name"] for call in final_tool_calls]
+ final_response.tools_call_ids = [call["id"] for call in final_tool_calls]
+
+ yield final_response
+
async def text_chat(
self,
- prompt: str,
- session_id: str = None,
- image_urls: List[str] = [],
- func_tool: FuncCall = None,
+ prompt,
+ session_id=None,
+ image_urls=None,
+ func_tool=None,
contexts=None,
system_prompt=None,
- tool_calls_result: ToolCallsResult = None,
+ tool_calls_result=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
- if not prompt:
- prompt = ""
-
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
+ if system_prompt:
+ context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
+ # tool calls result
if tool_calls_result:
- # 暂时这样写。
- prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
+ if not isinstance(tool_calls_result, list):
+ context_query.extend(tool_calls_result.to_openai_messages())
+ else:
+ for tcr in tool_calls_result:
+ context_query.extend(tcr.to_openai_messages())
+
+ system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
+ model_config["model"] = self.get_model()
+
+ payloads = {"messages": new_messages, **model_config}
- payloads = {"messages": context_query, **model_config}
# Anthropic has a different way of handling system prompts
if system_prompt:
payloads["system"] = system_prompt
@@ -135,32 +270,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
-
except Exception as e:
- if "maximum context length" in str(e):
- retry_cnt = 20
- while retry_cnt > 0:
- logger.warning(
- f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
- )
- try:
- await self.pop_record(context_query)
- response = await self.client.messages.create(
- messages=context_query, **model_config
- )
- llm_response = LLMResponse("assistant")
- llm_response.result_chain = MessageChain().message(response.content[0].text)
- llm_response.raw_completion = response
- return llm_response
- except Exception as e:
- if "maximum context length" in str(e):
- retry_cnt -= 1
- else:
- raise e
- return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
- else:
- logger.error(f"发生了错误。Provider 配置如下: {model_config}")
- raise e
+ logger.error(f"发生了错误。Provider 配置如下: {model_config}")
+ raise e
return llm_response
@@ -175,21 +287,34 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
tool_calls_result=None,
**kwargs,
):
- # raise NotImplementedError("This method is not implemented yet.")
- # 调用 text_chat 模拟流式
- llm_response = await self.text_chat(
- prompt=prompt,
- session_id=session_id,
- image_urls=image_urls,
- func_tool=func_tool,
- contexts=contexts,
- system_prompt=system_prompt,
- tool_calls_result=tool_calls_result,
- )
- llm_response.is_chunk = True
- yield llm_response
- llm_response.is_chunk = False
- yield llm_response
+ if contexts is None:
+ contexts = []
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = [*contexts, new_record]
+ if system_prompt:
+ context_query.insert(0, {"role": "system", "content": system_prompt})
+
+ for part in context_query:
+ if "_no_save" in part:
+ del part["_no_save"]
+
+ # tool calls result
+ if tool_calls_result:
+ context_query.extend(tool_calls_result.to_openai_messages())
+
+ system_prompt, new_messages = self._prepare_payload(context_query)
+
+ model_config = self.provider_config.get("model_config", {})
+ model_config["model"] = self.get_model()
+
+ payloads = {"messages": new_messages, **model_config}
+
+ # Anthropic has a different way of handling system prompts
+ if system_prompt:
+ payloads["system"] = system_prompt
+
+ async for llm_response in self._query_stream(payloads, func_tool):
+ yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""组装上下文,支持文本和图片"""
@@ -232,3 +357,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
)
return {"role": "user", "content": content}
+
+ async def encode_image_bs64(self, image_url: str) -> str:
+ """
+ 将图片转换为 base64
+ """
+ if image_url.startswith("base64://"):
+ return image_url.replace("base64://", "data:image/jpeg;base64,")
+ with open(image_url, "rb") as f:
+ image_bs64 = base64.b64encode(f.read()).decode("utf-8")
+ return "data:image/jpeg;base64," + image_bs64
+ return ""
+
+ def get_current_key(self) -> str:
+ return self.chosen_api_key
+
+ async def get_models(self) -> List[str]:
+ models_str = []
+ models = await self.client.models.list()
+ models = sorted(models.data, key=lambda x: x.id)
+ for model in models:
+ models_str.append(model.id)
+ return models_str
+
+ def set_key(self, key: str):
+ self.chosen_api_key = key
diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py
index f719190a1..3498f8346 100644
--- a/astrbot/core/provider/sources/dashscope_source.py
+++ b/astrbot/core/provider/sources/dashscope_source.py
@@ -5,7 +5,6 @@ from typing import List
from .. import Provider, Personality
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
-from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
@@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=False,
- default_persona: Personality = None,
+ default_persona: Personality | None = None,
) -> None:
Provider.__init__(
self,
provider_config,
provider_settings,
- persistant_history,
- db_helper,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py
index 348c3e72c..81c910d66 100644
--- a/astrbot/core/provider/sources/dify_source.py
+++ b/astrbot/core/provider/sources/dify_source.py
@@ -1,10 +1,9 @@
import astrbot.core.message.components as Comp
import os
from typing import List
-from .. import Provider, Personality
+from .. import Provider
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
-from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url, download_file
@@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
class ProviderDify(Provider):
def __init__(
self,
- provider_config: dict,
- provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=False,
- default_persona: Personality = None,
+ provider_config,
+ provider_settings,
+ default_persona = None,
) -> None:
super().__init__(
provider_config,
provider_settings,
- persistant_history,
- db_helper,
default_persona,
)
self.api_key = provider_config.get("dify_api_key", "")
diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py
index c16b39415..e1d1f11bd 100644
--- a/astrbot/core/provider/sources/gemini_source.py
+++ b/astrbot/core/provider/sources/gemini_source.py
@@ -12,8 +12,7 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
-from astrbot.api.provider import Personality, Provider
-from astrbot.core.db import BaseDatabase
+from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.func_tool_manager import FuncCall
@@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider):
def __init__(
self,
- provider_config: dict,
- provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=True,
- default_persona: Personality = None,
+ provider_config,
+ provider_settings,
+ default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
- persistant_history,
- db_helper,
default_persona,
)
self.api_keys: list = provider_config.get("key", [])
@@ -264,12 +259,10 @@ class ProviderGoogleGenAI(Provider):
contents.append(content_cls(parts=part))
gemini_contents: list[types.Content] = []
- native_tool_enabled = any(
- [
- self.provider_config.get("gm_native_coderunner", False),
- self.provider_config.get("gm_native_search", False),
- ]
- )
+ native_tool_enabled = any([
+ self.provider_config.get("gm_native_coderunner", False),
+ self.provider_config.get("gm_native_search", False),
+ ])
for message in payloads["messages"]:
role, content = message["role"], message.get("content")
@@ -506,12 +499,12 @@ class ProviderGoogleGenAI(Provider):
async def text_chat(
self,
prompt: str,
- session_id: str = None,
- image_urls: list[str] = None,
- func_tool: FuncCall = None,
- contexts: list = None,
- system_prompt: str = None,
- tool_calls_result: ToolCallsResult = None,
+ session_id=None,
+ image_urls=None,
+ func_tool=None,
+ contexts=None,
+ system_prompt=None,
+ tool_calls_result=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -527,7 +520,11 @@ class ProviderGoogleGenAI(Provider):
# tool calls result
if tool_calls_result:
- context_query.extend(tool_calls_result.to_openai_messages())
+ if not isinstance(tool_calls_result, list):
+ context_query.extend(tool_calls_result.to_openai_messages())
+ else:
+ for tcr in tool_calls_result:
+ context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
@@ -631,9 +628,10 @@ class ProviderGoogleGenAI(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
- user_content["content"].append(
- {"type": "image_url", "image_url": {"url": image_data}}
- )
+ user_content["content"].append({
+ "type": "image_url",
+ "image_url": {"url": image_data},
+ })
return user_content
else:
return {"role": "user", "content": text}
diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py
deleted file mode 100644
index 8648512d0..000000000
--- a/astrbot/core/provider/sources/llmtuner_source.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import os
-from llmtuner.chat import ChatModel
-from typing import List
-from .. import Provider
-from ..entities import LLMResponse
-from ..func_tool_manager import FuncCall
-from astrbot.core.db import BaseDatabase
-from ..register import register_provider_adapter
-
-
-@register_provider_adapter(
- "llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
-)
-class LLMTunerModelLoader(Provider):
- def __init__(
- self,
- provider_config: dict,
- provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=True,
- default_persona=None,
- ) -> None:
- super().__init__(
- provider_config,
- provider_settings,
- persistant_history,
- db_helper,
- default_persona,
- )
- if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
- provider_config["adapter_model_path"]
- ):
- raise FileNotFoundError("模型文件路径不存在。")
- self.base_model_path = provider_config["base_model_path"]
- self.adapter_model_path = provider_config["adapter_model_path"]
- self.model = ChatModel(
- {
- "model_name_or_path": self.base_model_path,
- "adapter_name_or_path": self.adapter_model_path,
- "template": provider_config["llmtuner_template"],
- "finetuning_type": provider_config["finetuning_type"],
- "quantization_bit": provider_config["quantization_bit"],
- }
- )
- self.set_model(
- os.path.basename(self.base_model_path)
- + "_"
- + os.path.basename(self.adapter_model_path)
- )
-
- async def assemble_context(self, text: str, image_urls: List[str] = None):
- """
- 组装上下文。
- """
- return {"role": "user", "content": text}
-
- async def text_chat(
- self,
- prompt: str,
- session_id: str = None,
- image_urls: List[str] = None,
- func_tool: FuncCall = None,
- contexts: List = None,
- system_prompt: str = None,
- **kwargs,
- ) -> LLMResponse:
- if contexts is None:
- contexts = []
- system_prompt = ""
- new_record = {"role": "user", "content": prompt}
- query_context = [*contexts, new_record]
-
- # 提取出系统提示
- system_idxs = []
- for idx, context in enumerate(query_context):
- if context["role"] == "system":
- system_idxs.append(idx)
-
- if "_no_save" in context:
- del context["_no_save"]
-
- for idx in reversed(system_idxs):
- system_prompt += " " + query_context.pop(idx)["content"]
-
- conf = {
- "messages": query_context,
- "system": system_prompt,
- }
- if func_tool:
- tool_list = func_tool.get_func_desc_openai_style()
- if tool_list:
- conf["tools"] = tool_list
-
- responses = await self.model.achat(**conf)
-
- llm_response = LLMResponse("assistant", responses[-1].response_text)
-
- return llm_response
-
- async def text_chat_stream(
- self,
- prompt,
- session_id=None,
- image_urls=...,
- func_tool=None,
- contexts=...,
- system_prompt=None,
- tool_calls_result=None,
- **kwargs,
- ):
- # raise NotImplementedError("This method is not implemented yet.")
- # 调用 text_chat 模拟流式
- llm_response = await self.text_chat(
- prompt=prompt,
- session_id=session_id,
- image_urls=image_urls,
- func_tool=func_tool,
- contexts=contexts,
- system_prompt=system_prompt,
- tool_calls_result=tool_calls_result,
- )
- llm_response.is_chunk = True
- yield llm_response
- llm_response.is_chunk = False
- yield llm_response
-
- async def get_current_key(self):
- return "none"
-
- async def set_key(self, key):
- pass
-
- async def get_models(self):
- return [self.get_model()]
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 15104db3c..ef6131d8c 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
-# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.message.message_event_result import MessageChain
-from astrbot.core.db import BaseDatabase
-from astrbot.api.provider import Provider, Personality
+from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List, AsyncGenerator
@@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
class ProviderOpenAIOfficial(Provider):
def __init__(
self,
- provider_config: dict,
- provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=True,
- default_persona: Personality = None,
+ provider_config,
+ provider_settings,
+ default_persona = None,
) -> None:
super().__init__(
provider_config,
provider_settings,
- persistant_history,
- db_helper,
default_persona,
)
self.chosen_api_key = None
@@ -224,12 +218,10 @@ class ProviderOpenAIOfficial(Provider):
async def _prepare_chat_payload(
self,
prompt: str,
- session_id: str = None,
- image_urls: list[str] = None,
- func_tool: FuncCall = None,
- contexts: list = None,
- system_prompt: str = None,
- tool_calls_result: ToolCallsResult = None,
+ image_urls: list[str] | None = None,
+ contexts: list | None = None,
+ system_prompt: str | None = None,
+ tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -246,14 +238,18 @@ class ProviderOpenAIOfficial(Provider):
# tool calls result
if tool_calls_result:
- context_query.extend(tool_calls_result.to_openai_messages())
+ if isinstance(tool_calls_result, ToolCallsResult):
+ context_query.extend(tool_calls_result.to_openai_messages())
+ else:
+ for tcr in tool_calls_result:
+ context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config}
- return payloads, context_query, func_tool
+ return payloads, context_query
async def _handle_api_error(
self,
@@ -352,11 +348,9 @@ class ProviderOpenAIOfficial(Provider):
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
- payloads, context_query, func_tool = await self._prepare_chat_payload(
+ payloads, context_query = await self._prepare_chat_payload(
prompt,
- session_id,
image_urls,
- func_tool,
contexts,
system_prompt,
tool_calls_result,
@@ -422,11 +416,9 @@ class ProviderOpenAIOfficial(Provider):
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
- payloads, context_query, func_tool = await self._prepare_chat_payload(
+ payloads, context_query = await self._prepare_chat_payload(
prompt,
- session_id,
image_urls,
- func_tool,
contexts,
system_prompt,
tool_calls_result,
diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py
index e7e9d4a14..428dee8f4 100644
--- a/astrbot/core/provider/sources/zhipu_source.py
+++ b/astrbot/core/provider/sources/zhipu_source.py
@@ -1,4 +1,3 @@
-from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
@@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial):
self,
provider_config: dict,
provider_settings: dict,
- db_helper: BaseDatabase,
- persistant_history=True,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
- db_helper,
- persistant_history,
default_persona,
)
diff --git a/changelogs/v3.5.18.md b/changelogs/v3.5.18.md
new file mode 100644
index 000000000..d9eaa6de8
--- /dev/null
+++ b/changelogs/v3.5.18.md
@@ -0,0 +1,18 @@
+# What's Changed
+
+> 重构了大模型请求部分,如果发现此部分使用时有问题请提交 issue
+
+1. 修复: 安装插件按钮被删除、无法自定义安装插件
+2. 修复: 环境变量中的代理地址无法生效
+1. 修复: randomize jwt secret
+2. 修复: 在 Node 消息段发送简单文本信息的问题
+1. 修复: QQ 官方机器人适配器使用 SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台
+4. 修复: Discord 适配器无法优雅重载
+1. 修复: Telegram 适配器无法主动回复
+1. 修复: 仪表盘的『插件配置』中不显示代码编辑器
+3. 新增: Gemini TTS API
+1. 新增: 允许 html_render 方法传入 Playwright.screenshot 配置参数
+1. 优化: 修复 CommandFilter 支持对布尔类型进行解析
+4. 新增: WechatPadPro 发送 TTS 时 添加对 MP3 格式音频支持
+1. 重构: 将大模型请求部分抽象成 AgentRunner,提高可读性和可扩展性,工具调用结果支持持久化保存到数据库,完善 Agent 的多轮工具调用能力。
+1. 移除: LLMTuner 模型提供商适配器。请使用 Ollama 来加载微调模型
\ No newline at end of file
diff --git a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue
index 22ac90006..1612afcb8 100644
--- a/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue
+++ b/dashboard/src/layouts/full/vertical-header/VerticalHeader.vue
@@ -29,7 +29,7 @@ let dashboardCurrentVersion = ref('');
let version = ref('');
let releases = ref([]);
let devCommits = ref([]);
-
+let updatingDashboardLoading = ref(false);
let installLoading = ref(false);
let tab = ref(0);
@@ -217,6 +217,7 @@ function switchVersion(version: string) {
}
function updateDashboard() {
+ updatingDashboardLoading.value = true;
updateStatus.value = t('core.header.updateDialog.status.updating');
axios.post('/api/update/dashboard')
.then((res) => {
@@ -230,7 +231,9 @@ function updateDashboard() {
.catch((err) => {
console.log(err);
updateStatus.value = err
- });
+ }).finally(() => {
+ updatingDashboardLoading.value = false;
+ });
}
function toggleDarkMode() {
@@ -416,7 +419,7 @@ commonStore.getStartTime();
+ :disabled="!dashboardHasNewVersion" :loading="updatingDashboardLoading">
{{ t('core.header.updateDialog.dashboardUpdate.downloadAndUpdate') }}
diff --git a/dashboard/src/theme/DarkTheme.ts b/dashboard/src/theme/DarkTheme.ts
index 9276c8f98..5906eca32 100644
--- a/dashboard/src/theme/DarkTheme.ts
+++ b/dashboard/src/theme/DarkTheme.ts
@@ -39,7 +39,8 @@ const PurpleThemeDark: ThemeTypes = {
background: '#111111',
overlay: '#111111aa',
codeBg: '#282833',
- code: '#ffffffdd'
+ code: '#ffffffdd',
+ chatMessageBubble: '#2d2e30',
}
};
diff --git a/dashboard/src/theme/LightTheme.ts b/dashboard/src/theme/LightTheme.ts
index 35aa1339a..94e011867 100644
--- a/dashboard/src/theme/LightTheme.ts
+++ b/dashboard/src/theme/LightTheme.ts
@@ -20,7 +20,7 @@ const PurpleTheme: ThemeTypes = {
lightsuccess: '#b9f6ca',
lighterror: '#f9d8d8',
lightwarning: '#fff8e1',
- primaryText: '#000000dd',
+ primaryText: '#1b1c1d',
secondaryText: '#000000aa',
darkprimary: '#1565c0',
darksecondary: '#4527a0',
@@ -39,7 +39,8 @@ const PurpleTheme: ThemeTypes = {
background: '#f9fafcf4',
overlay: '#ffffffaa',
codeBg: '#f5f0ff',
- code: '#673ab7'
+ code: '#673ab7',
+ chatMessageBubble: '#e7ebf4',
}
};
diff --git a/dashboard/src/types/themeTypes/ThemeType.ts b/dashboard/src/types/themeTypes/ThemeType.ts
index 69b00a1ab..f5e2e5491 100644
--- a/dashboard/src/types/themeTypes/ThemeType.ts
+++ b/dashboard/src/types/themeTypes/ThemeType.ts
@@ -35,5 +35,6 @@ export type ThemeTypes = {
secondary200?: string;
codeBg?: string;
code?: string;
+ chatMessageBubble?: string;
};
};
diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue
index fadba8aac..2954153ae 100644
--- a/dashboard/src/views/ChatPage.vue
+++ b/dashboard/src/views/ChatPage.vue
@@ -3,13 +3,16 @@