diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 3d6926ebe..222fa64d6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -61,6 +61,7 @@ DEFAULT_CONFIG = { "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, + "show_tool_use_status": False, "streaming_segmented": False, "separate_provider": False, }, @@ -441,7 +442,7 @@ CONFIG_METADATA_2 = { "ignore_bot_self_message": { "description": "是否忽略机器人自身的消息", "type": "bool", - "hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人", + "hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人", }, "ignore_at_all": { "description": "是否忽略 @ 全体成员", @@ -770,17 +771,6 @@ CONFIG_METADATA_2 = { "model": "deepseek/deepseek-r1", }, }, - "LLMTuner": { - "id": "llmtuner_default", - "type": "llm_tuner", - "provider_type": "chat_completion", - "enable": True, - "base_model_path": "", - "adapter_model_path": "", - "llmtuner_template": "", - "finetuning_type": "lora", - "quantization_bit": 4, - }, "Dify": { "id": "dify_app_default", "type": "dify", @@ -1699,10 +1689,15 @@ CONFIG_METADATA_2 = { "type": "bool", "hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台", }, + "show_tool_use_status": { + "description": "函数调用状态输出", + "type": "bool", + "hint": "在触发函数调用时输出其函数名和内容。", + }, "streaming_segmented": { "description": "不支持流式回复的平台分段输出", "type": "bool", - "hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项", + "hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项", }, }, }, diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 28c92fa89..7bfdd34c8 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -24,6 +24,8 @@ class MessageChain: chain: List[BaseMessageComponent] = field(default_factory=list) use_t2i_: Optional[bool] = None # None 为跟随用户设置 + type: Optional[str] = None + """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" def message(self, message: str): """添加一条文本消息到消息链 `chain` 中。 @@ -98,6 +100,15 @@ class MessageChain: self.chain.append(Image.fromFileSystem(path)) return self + def base64_image(self, base64_str: str): + """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 + Example: + + CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...") + """ + self.chain.append(Image.fromBase64(base64_str)) + return self + def use_t2i(self, use_t2i: bool): """设置是否使用文本转图片服务。 @@ -157,7 +168,7 @@ class ResultContentType(enum.Enum): """普通的消息结果""" STREAMING_RESULT = enum.auto() """调用 LLM 产生的流式结果""" - STREAMING_FINISH= enum.auto() + STREAMING_FINISH = enum.auto() """流式输出完成""" diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index eb5ffb1cd..d98f7c341 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,6 +1,14 @@ +import inspect +import traceback +import typing as T from dataclasses import dataclass from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star import PluginManager +from astrbot.api import logger +from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star import star_map +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult @dataclass @@ -9,3 +17,91 @@ class PipelineContext: astrbot_config: AstrBotConfig # AstrBot 配置对象 plugin_manager: PluginManager # 插件管理器对象 + + async def call_event_hook( + self, + event: AstrMessageEvent, + hook_type: EventType, + *args, + ): + platform_id = event.get_platform_id() + handlers = star_handlers_registry.get_handlers_by_event_type( + hook_type, platform_id=platform_id + ) + for handler in handlers: + try: + logger.debug( + f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) + await handler.handler(event, *args) + except BaseException: + logger.error(traceback.format_exc()) + + if event.is_stopped(): + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) + return + + async def call_handler( + self, + event: AstrMessageEvent, + handler: T.Awaitable, + *args, + **kwargs, + ) -> T.AsyncGenerator[None, None]: + """执行事件处理函数并处理其返回结果 + + 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: + 1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层 + 2. 协程: 执行一次并处理返回值 + + Args: + ctx (PipelineContext): 消息管道上下文对象 + event (AstrMessageEvent): 事件对象 + handler (Awaitable): 事件处理函数 + + Returns: + AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 + """ + ready_to_call = None # 一个协程或者异步生成器 + + trace_ = None + + try: + ready_to_call = handler(event, *args, **kwargs) + except TypeError as _: + # 向下兼容 + trace_ = traceback.format_exc() + # 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份 + ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs) + + if inspect.isasyncgen(ready_to_call): + _has_yielded = False + try: + async for ret in ready_to_call: + # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) + _has_yielded = True + if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 + event.set_result(ret) + yield + else: + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret + if not _has_yielded: + # 如果这个异步生成器没有执行到 yield 分支 + yield + except Exception as e: + logger.error(f"Previous Error: {trace_}") + raise e + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个协程, 直接执行 + ret = await ready_to_call + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) + yield + else: + yield ret diff --git a/astrbot/core/pipeline/process_stage/agent_runner/base.py b/astrbot/core/pipeline/process_stage/agent_runner/base.py new file mode 100644 index 000000000..431a95ca6 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/agent_runner/base.py @@ -0,0 +1,57 @@ +import abc +import typing as T +from dataclasses import dataclass +from astrbot.core.provider.entities import LLMResponse +from ....message.message_event_result import MessageChain +from enum import Enum, auto + + +class AgentState(Enum): + """Agent 状态枚举""" + IDLE = auto() # 初始状态 + RUNNING = auto() # 运行中 + DONE = auto() # 完成 + ERROR = auto() # 错误状态 + + +class AgentResponseData(T.TypedDict): + chain: MessageChain + + +@dataclass +class AgentResponse: + type: str + data: AgentResponseData + + +class BaseAgentRunner: + @abc.abstractmethod + async def reset(self) -> None: + """ + Reset the agent to its initial state. + This method should be called before starting a new run. + """ + ... + + @abc.abstractmethod + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: + """ + Process a single step of the agent. + """ + ... + + @abc.abstractmethod + def done(self) -> bool: + """ + Check if the agent has completed its task. + Returns True if the agent is done, False otherwise. + """ + ... + + @abc.abstractmethod + def get_final_llm_resp(self) -> LLMResponse | None: + """ + Get the final observation from the agent. + This method should be called after the agent is done. + """ + ... diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py new file mode 100644 index 000000000..3163e02e4 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py @@ -0,0 +1,300 @@ +import sys +import traceback +import typing as T +from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState +from ...context import PipelineContext +from astrbot.core.provider.provider import Provider +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import ( + MessageChain, +) +from astrbot.core.provider.entities import ( + ProviderRequest, + LLMResponse, + ToolCallMessageSegment, + AssistantMessageSegment, + ToolCallsResult, +) +from mcp.types import ( + TextContent, + ImageContent, + EmbeddedResource, + TextResourceContents, + BlobResourceContents, +) +from astrbot.core.star.star_handler import EventType +from astrbot import logger + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +# TODO: +# 1. 处理平台不兼容的处理器 + + +class ToolLoopAgent(BaseAgentRunner): + def __init__( + self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext + ) -> None: + self.provider = provider + self.req = None + self.event = event + self.pipeline_ctx = pipeline_ctx + self._state = AgentState.IDLE + self.final_llm_resp = None + self.streaming = False + + @override + async def reset(self, req: ProviderRequest, streaming: bool) -> None: + self.req = req + self.streaming = streaming + self.final_llm_resp = None + self._state = AgentState.IDLE + + def _transition_state(self, new_state: AgentState) -> None: + """转换 Agent 状态""" + if self._state != new_state: + logger.debug(f"Agent state transition: {self._state} -> {new_state}") + self._state = new_state + + async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: + """Yields chunks *and* a final LLMResponse.""" + if self.streaming: + stream = self.provider.text_chat_stream(**self.req.__dict__) + async for resp in stream: # type: ignore + yield resp + else: + yield await self.provider.text_chat(**self.req.__dict__) + + @override + async def step(self): + """ + Process a single step of the agent. + This method should return the result of the step. + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + llm_resp_result = None + + async for llm_response in self._iter_llm_responses(): + assert isinstance(llm_response, LLMResponse) + if llm_response.is_chunk: + if llm_response.result_chain: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=llm_response.result_chain), + ) + else: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(llm_response.completion_text) + ), + ) + continue + llm_resp_result = llm_response + break # got final response + + if not llm_resp_result: + return + + # 处理 LLM 响应 + llm_resp = llm_resp_result + logger.debug(f"LLMResp: {llm_resp}") + + if llm_resp.role == "err": + # 如果 LLM 响应错误,转换到错误状态 + self.final_llm_resp = llm_resp + self._transition_state(AgentState.ERROR) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message( + f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}" + ) + ), + ) + + if not llm_resp.tools_call_name: + # 如果没有工具调用,转换到完成状态 + self.final_llm_resp = llm_resp + self._transition_state(AgentState.DONE) + + # 执行事件钩子 + await self.pipeline_ctx.call_event_hook( + self.event, EventType.OnLLMResponseEvent, llm_resp + ) + + # 返回 LLM 结果 + if llm_resp.result_chain: + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=llm_resp.result_chain), + ) + elif llm_resp.completion_text: + yield AgentResponse( + type="llm_result", + data=AgentResponseData( + chain=MessageChain().message(llm_resp.completion_text) + ), + ) + + # 如果有工具调用,还需处理工具调用 + if llm_resp.tools_call_name: + tool_call_result_blocks = [] + for tool_call_name in llm_resp.tools_call_name: + yield AgentResponse( + type="tool_call", + data=AgentResponseData( + chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}") + ), + ) + async for result in self._handle_function_tools(self.req, llm_resp): + if isinstance(result, list): + tool_call_result_blocks = result + elif isinstance(result, MessageChain): + yield AgentResponse( + type="tool_call_result", + data=AgentResponseData(chain=result), + ) + # 将结果添加到上下文中 + tool_calls_result = ToolCallsResult( + tool_calls_info=AssistantMessageSegment( + role="assistant", + tool_calls=llm_resp.to_openai_tool_calls(), + content=llm_resp.completion_text, + ), + tool_calls_result=tool_call_result_blocks, + ) + self.req.append_tool_calls_result(tool_calls_result) + + async def _handle_function_tools( + self, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]: + """处理函数工具调用。""" + tool_call_result_blocks: list[ToolCallMessageSegment] = [] + logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") + + # 执行函数调用 + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + if not req.func_tool: + return + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] + res = await client.session.call_tool(func_tool.name, func_tool_args) + if not res: + continue + if isinstance(res.content[0], TextContent): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) + ) + yield MessageChain().message(res.content[0].text) + elif isinstance(res.content[0], ImageContent): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + yield MessageChain().base64_image(res.content[0].data) + elif isinstance(res.content[0], EmbeddedResource): + resource = res.content[0].resource + if isinstance(resource, TextResourceContents): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resource.text, + ) + ) + yield MessageChain().message(resource.text) + elif ( + isinstance(resource, BlobResourceContents) + and resource.mimeType + and resource.mimeType.startswith("image/") + ): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + yield MessageChain().base64_image(res.content[0].data) + else: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回的数据类型不受支持", + ) + ) + yield MessageChain().message("返回的数据类型不受支持。") + else: + logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") + # 尝试调用工具函数 + wrapper = self.pipeline_ctx.call_handler( + self.event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: + # Tool 返回结果 + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resp, + ) + ) + yield MessageChain().message(resp) + else: + # Tool 直接请求发送消息给用户 + # 这里我们将直接结束 Agent Loop。 + self._transition_state(AgentState.DONE) + if res := self.event.get_result(): + if res.chain: + yield MessageChain(chain=res.chain) + + self.event.clear_result() + except Exception as e: + logger.warning(traceback.format_exc()) + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", + ) + ) + + # 处理函数调用响应 + if tool_call_result_blocks: + yield tool_call_result_blocks + + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 549440ee3..2ebe4bd42 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -3,6 +3,7 @@ """ import traceback +import copy import asyncio import json from typing import Union, AsyncGenerator @@ -20,39 +21,27 @@ from astrbot.core.utils.metrics import Metric from astrbot.core.provider.entities import ( ProviderRequest, LLMResponse, - ToolCallMessageSegment, - AssistantMessageSegment, - ToolCallsResult, -) -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star import star_map -from mcp.types import ( - TextContent, - ImageContent, - EmbeddedResource, - TextResourceContents, - BlobResourceContents, ) +from astrbot.core.star.star_handler import EventType from astrbot.core import web_chat_back_queue +from ..agent_runner.tool_loop_agent import ToolLoopAgent class LLMRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx - self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list - self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][ - "wake_prefix" - ] # str - self.max_context_length = ctx.astrbot_config["provider_settings"][ - "max_context_length" - ] # int - self.dequeue_context_length = min( - max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]), + conf = ctx.astrbot_config + settings = conf["provider_settings"] + self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list + self.provider_wake_prefix: str = settings["wake_prefix"] # str + self.max_context_length = settings["max_context_length"] # int + self.dequeue_context_length: int = min( + max(1, settings["dequeue_context_length"]), self.max_context_length - 1, - ) # int - self.streaming_response = ctx.astrbot_config["provider_settings"][ - "streaming_response" - ] # bool + ) + self.streaming_response: bool = settings["streaming_response"] + self.max_step: int = settings.get("max_agent_step", 10) + self.show_tool_use: bool = settings.get("show_tool_use_status", True) for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -83,10 +72,7 @@ class LLMRequestSubStage(Stage): ) if req.conversation: - all_contexts = json.loads(req.conversation.history) - req.contexts = self._process_tool_message_pairs( - all_contexts, remove_tags=True - ) + req.contexts = json.loads(req.conversation.history) else: req = ProviderRequest(prompt="", image_urls=[]) @@ -127,26 +113,7 @@ class LLMRequestSubStage(Stage): return # 执行请求 LLM 前事件钩子。 - # 装饰 system_prompt 等功能 - # 获取当前平台ID - platform_id = event.get_platform_id() - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMRequestEvent, platform_id=platform_id - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, req) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return + await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req) if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) @@ -176,77 +143,62 @@ class LLMRequestSubStage(Stage): if not req.session_id: req.session_id = event.unified_msg_origin - async def requesting(req: ProviderRequest): - try: - need_loop = True - while need_loop: - need_loop = False - logger.debug(f"提供商请求 Payload: {req}") + # fix messages + req.contexts = self.fix_messages(req.contexts) - final_llm_response = None + # Call Agent + tool_loop_agent = ToolLoopAgent( + provider=provider, + event=event, + pipeline_ctx=self.ctx, + ) + await tool_loop_agent.reset(req=req, streaming=self.streaming_response) - if self.streaming_response: - stream = provider.text_chat_stream(**req.__dict__) - async for llm_response in stream: - if llm_response.is_chunk: - if llm_response.result_chain: - yield llm_response.result_chain # MessageChain - else: - yield MessageChain().message( - llm_response.completion_text - ) - else: - final_llm_response = llm_response - else: - final_llm_response = await provider.text_chat( - **req.__dict__ - ) # 请求 LLM + async def requesting(): + step_idx = 0 + while step_idx < self.max_step: + step_idx += 1 + try: + async for resp in tool_loop_agent.step(): + if resp.type == "tool_call_result": + continue # 跳过工具调用结果 + if resp.type == "tool_call": + if self.streaming_response: + # 用来标记流式响应需要分节 + yield MessageChain(chain=[], type="break") + if self.show_tool_use or event.get_platform_name() == "webchat": + resp.data["chain"].type = "tool_call" + await event.send(resp.data["chain"]) + continue - if not final_llm_response: - raise Exception("LLM response is None.") + if not self.streaming_response: + content_typ = ( + ResultContentType.LLM_RESULT + if resp.type == "llm_result" + else ResultContentType.GENERAL_RESULT + ) + event.set_result( + MessageEventResult( + chain=resp.data["chain"].chain, + result_content_type=content_typ, + ) + ) + yield + event.clear_result() + else: + if resp.type == "streaming_delta": + yield resp.data["chain"] # MessageChain + if tool_loop_agent.done(): + break - # 执行 LLM 响应后的事件钩子。 - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMResponseEvent + except Exception as e: + logger.error(traceback.format_exc()) + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + ) ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, final_llm_response) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return - - if self.streaming_response: - # 流式输出的处理 - async for result in self._handle_llm_stream_response( - event, req, final_llm_response - ): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True - else: - yield - else: - # 非流式输出的处理 - async for result in self._handle_llm_response( - event, req, final_llm_response - ): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True - else: - yield - + return asyncio.create_task( Metric.upload( llm_tick=1, @@ -255,44 +207,38 @@ class LLMRequestSubStage(Stage): ) ) - # 保存到历史记录 - await self._save_to_history(event, req, final_llm_response) - - except BaseException as e: - logger.error(traceback.format_exc()) - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" - ) - ) - - if not self.streaming_response: - event.set_extra("tool_call_result", None) - async for _ in requesting(req): - yield - else: + if self.streaming_response: + # 流式响应 event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream(requesting(req)) + .set_async_stream(requesting()) ) - # 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理 yield - - if event.get_extra("tool_call_result"): - event.set_result(event.get_extra("tool_call_result")) - event.set_extra("tool_call_result", None) + if tool_loop_agent.done(): + if final_llm_resp := tool_loop_agent.get_final_llm_resp(): + if final_llm_resp.completion_text: + chain = ( + MessageChain().message(final_llm_resp.completion_text).chain + ) + else: + chain = final_llm_resp.result_chain.chain + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ) + ) + else: + async for _ in requesting(): yield - # 暂时直接发出去 - if img_b64 := event.get_extra("tool_call_img_respond"): - await event.send(MessageChain(chain=[Image.fromBase64(img_b64)])) - event.set_extra("tool_call_img_respond", None) - + # 异步处理 WebChat 特殊情况 if event.get_platform_name() == "webchat": - # 异步处理 WebChat 特殊情况 asyncio.create_task(self._handle_webchat(event, req)) + await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp()) + async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest): """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" conversation = await self.conv_manager.get_conversation( @@ -305,10 +251,6 @@ class LLMRequestSubStage(Stage): return provider = self.ctx.plugin_manager.context.get_using_provider() cleaned_text = "User: " + latest_pair[0].get("content", "").strip() - # if len(latest_pair) > 1: - # cleaned_text += ( - # "\nAssistant: " + latest_pair[1].get("content", "").strip() - # ) logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}") llm_resp = await provider.text_chat( system_prompt="You are expert in summarizing user's query.", @@ -349,322 +291,50 @@ class LLMRequestSubStage(Stage): } ) - async def _handle_llm_response( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse, - ) -> AsyncGenerator[Union[None, ProviderRequest], None]: - """处理非流式 LLM 响应。 - - Returns: - AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM - - Yields: - Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM - """ - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.LLM_RESULT) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # 处理函数工具调用 - async for result in self._handle_function_tools(event, req, llm_response): - yield result - - async def _handle_llm_stream_response( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse, - ) -> AsyncGenerator[Union[None, ProviderRequest], None]: - """处理流式 LLM 响应。 - - 专门用于处理流式输出完成后的响应,与非流式响应处理分离。 - - Returns: - AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM - - Yields: - Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM - """ - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.STREAMING_FINISH) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.STREAMING_FINISH) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # 处理函数工具调用 - async for result in self._handle_function_tools(event, req, llm_response): - yield result - - async def _handle_function_tools( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse, - ) -> AsyncGenerator[Union[None, ProviderRequest], None]: - """处理函数工具调用。 - - Returns: - AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM - """ - # function calling - tool_call_result: list[ToolCallMessageSegment] = [] - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" - ) - for func_tool_name, func_tool_args, func_tool_id in zip( - llm_response.tools_call_name, - llm_response.tools_call_args, - llm_response.tools_call_ids, - ): - try: - func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" - ) - client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] - res = await client.session.call_tool(func_tool.name, func_tool_args) - if res: - # TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback - if isinstance(res.content[0], TextContent): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, - ) - ) - elif isinstance(res.content[0], ImageContent): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回了图片(已直接发送给用户)", - ) - ) - event.set_extra( - "tool_call_img_respond", - res.content[0].data, - ) - elif isinstance(res.content[0], EmbeddedResource): - resource = res.content[0].resource - if isinstance(resource, TextResourceContents): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resource.text, - ) - ) - elif ( - isinstance(resource, BlobResourceContents) - and resource.mimeType - and resource.mimeType.startswith("image/") - ): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回了图片(已直接发送给用户)", - ) - ) - event.set_extra( - "tool_call_img_respond", - res.content[0].data, - ) - else: - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回的数据类型不受支持", - ) - ) - else: - # 获取处理器,过滤掉平台不兼容的处理器 - platform_id = event.get_platform_id() - star_md = star_map.get(func_tool.handler_module_path) - if ( - star_md - and platform_id in star_md.supported_platforms - and not star_md.supported_platforms[platform_id] - ): - logger.debug( - f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行" - ) - # 直接跳过,不添加任何消息到tool_call_result - continue - - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resp, - ) - ) - else: - res = event.get_result() - if res and res.chain: - event.set_extra("tool_call_result", res) - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: {str(e)}", - ) - ) - if tool_call_result: - # 函数调用结果 - req.func_tool = None # 暂时不支持递归工具调用 - assistant_msg_seg = AssistantMessageSegment( - role="assistant", tool_calls=llm_response.to_openai_tool_calls() - ) - # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 - req.tool_calls_result = ToolCallsResult( - tool_calls_info=assistant_msg_seg, - tool_calls_result=tool_call_result, - ) - yield req # 再次执行 LLM 请求 - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) - async def _save_to_history( - self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse | None, ): - if not req or not req.conversation or not llm_response: + if ( + not req + or not req.conversation + or not llm_response + or llm_response.role != "assistant" + ): return - if llm_response.role == "assistant": - # 文本回复 - contexts = req.contexts.copy() - contexts.append(await req.assemble_context()) + # 历史上下文 + messages = copy.deepcopy(req.contexts) + # 这一轮对话请求的用户输入 + messages.append(await req.assemble_context()) + # 这一轮对话的 LLM 响应 + if req.tool_calls_result: + if not isinstance(req.tool_calls_result, list): + messages.extend(req.tool_calls_result.to_openai_messages()) + elif isinstance(req.tool_calls_result, list): + for tcr in req.tool_calls_result: + messages.extend(tcr.to_openai_messages()) + messages.append({"role": "assistant", "content": llm_response.completion_text}) + messages = list(filter(lambda item: "_no_save" not in item, messages)) + await self.conv_manager.update_conversation( + event.unified_msg_origin, req.conversation.cid, history=messages + ) + logger.debug(f"messages persisted: {messages}") - # 记录并标记函数调用结果 - if req.tool_calls_result: - tool_calls_messages = req.tool_calls_result.to_openai_messages() - - # 添加标记 - for message in tool_calls_messages: - message["_tool_call_history"] = True - - processed_tool_messages = self._process_tool_message_pairs( - tool_calls_messages, remove_tags=False - ) - - contexts.extend(processed_tool_messages) - - contexts.append( - {"role": "assistant", "content": llm_response.completion_text} - ) - contexts_to_save = list( - filter(lambda item: "_no_save" not in item, contexts) - ) - await self.conv_manager.update_conversation( - event.unified_msg_origin, req.conversation.cid, history=contexts_to_save - ) - - def _process_tool_message_pairs(self, messages, remove_tags=True): - """处理工具调用消息,确保assistant和tool消息成对出现 - - Args: - messages (list): 消息列表 - remove_tags (bool): 是否移除_tool_call_history标记 - - Returns: - list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现 - """ - result = [] - i = 0 - - while i < len(messages): - current_msg = messages[i] - - # 普通消息直接添加 - if "_tool_call_history" not in current_msg: - result.append(current_msg.copy() if remove_tags else current_msg) - i += 1 - continue - - # 工具调用消息成对处理 - if current_msg.get("role") == "assistant" and "tool_calls" in current_msg: - assistant_msg = current_msg.copy() - - if remove_tags and "_tool_call_history" in assistant_msg: - del assistant_msg["_tool_call_history"] - - related_tools = [] - j = i + 1 - while ( - j < len(messages) - and messages[j].get("role") == "tool" - and "_tool_call_history" in messages[j] - ): - tool_msg = messages[j].copy() - - if remove_tags: - del tool_msg["_tool_call_history"] - - related_tools.append(tool_msg) - j += 1 - - # 成对的时候添加到结果 - if related_tools: - result.append(assistant_msg) - result.extend(related_tools) - - i = j # 跳过已处理 + def fix_messages(self, messages: list[dict]) -> list[dict]: + """验证并且修复上下文""" + fixed_messages = [] + for message in messages: + if message.get("role") == "tool": + # tool block 前面必须要有 user 和 assistant block + if len(fixed_messages) < 2: + # 这种情况可能是上下文被截断导致的 + # 我们直接将之前的上下文都清空 + fixed_messages = [] + else: + fixed_messages.append(message) else: - # 单独的tool消息 - i += 1 - - return result + fixed_messages.append(message) + return fixed_messages diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index c7817e49c..00f58d55b 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -50,7 +50,7 @@ class StarRequestSubStage(Stage): logger.debug( f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}" ) - wrapper = self._call_handler(self.ctx, event, handler.handler, **params) + wrapper = self.ctx.call_handler(event, handler.handler, **params) async for ret in wrapper: yield ret event.clear_result() # 清除上一个 handler 的结果 diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index c7d4ff792..b41794733 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,12 +1,8 @@ from __future__ import annotations import abc -import inspect -import traceback -from astrbot.api import logger -from typing import List, AsyncGenerator, Union, Awaitable +from typing import List, AsyncGenerator, Union from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext -from astrbot.core.message.message_event_result import MessageEventResult, CommandResult registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类 @@ -41,70 +37,3 @@ class Stage(abc.ABC): Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段) """ raise NotImplementedError - - async def _call_handler( - self, - ctx: PipelineContext, - event: AstrMessageEvent, - handler: Awaitable, - *args, - **kwargs, - ) -> AsyncGenerator[None, None]: - """执行事件处理函数并处理其返回结果 - - 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: - 1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层 - 2. 协程: 执行一次并处理返回值 - - Args: - ctx (PipelineContext): 消息管道上下文对象 - event (AstrMessageEvent): 待处理的事件对象 - handler (Awaitable): 事件处理函数 - *args: 传递给handler的位置参数 - **kwargs: 传递给handler的关键字参数 - - Returns: - AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 - """ - ready_to_call = None # 一个协程或者异步生成器(async def) - - trace_ = None - - try: - ready_to_call = handler(event, *args, **kwargs) - except TypeError as _: - # 向下兼容 - trace_ = traceback.format_exc() - # 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份 - ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs) - - if isinstance(ready_to_call, AsyncGenerator): - # 如果是一个异步生成器, 进入洋葱模型 - _has_yielded = False # 是否返回过值 - try: - async for ret in ready_to_call: - # 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码 - # 返回值只能是 MessageEventResult 或者 None(无返回值) - _has_yielded = True - if isinstance(ret, (MessageEventResult, CommandResult)): - # 如果返回值是 MessageEventResult, 设置结果并继续 - event.set_result(ret) - yield # 传递控制权给上一层的process函数 - else: - # 如果返回值是 None, 则不设置结果并继续 - # 继续执行后续阶段 - yield ret # 传递控制权给上一层的process函数 - if not _has_yielded: - # 如果这个异步生成器没有执行到yield分支 - yield - except Exception as e: - logger.error(f"Previous Error: {trace_}") - raise e - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个协程, 直接执行 - ret = await ready_to_call - if isinstance(ret, (MessageEventResult, CommandResult)): - event.set_result(ret) - yield # 传递控制权给上一层的process函数 - else: - yield ret # 传递控制权给上一层的process函数 diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 3636cd611..5b3a1d916 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -158,6 +158,12 @@ class TelegramPlatformEvent(AstrMessageEvent): async for chain in generator: if isinstance(chain, MessageChain): + if chain.type == "break": + # 分割符 + message_id = None # 重置消息 ID + delta = "" # 重置 delta + continue + # 处理消息链中的每个组件 for i in chain.chain: if isinstance(i, Plain): diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 76b5dc85d..111027a5c 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -35,6 +35,7 @@ class WebChatMessageEvent(AstrMessageEvent): "cid": cid, "data": data, "streaming": streaming, + "chain_type": message.type, } ) elif isinstance(comp, Image): @@ -110,6 +111,18 @@ class WebChatMessageEvent(AstrMessageEvent): async def send_streaming(self, generator, use_fallback: bool = False): final_data = "" async for chain in generator: + if chain.type == "break" and final_data: + # 分割符 + await web_chat_back_queue.put( + { + "type": "end", + "data": final_data, + "streaming": True, + "cid": self.session_id.split("!")[-1], + } + ) + final_data = "" + continue final_data += await WebChatMessageEvent._send( chain, session_id=self.session_id, streaming=True ) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index e01e46cf9..abb01960c 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -58,7 +58,7 @@ class AssistantMessageSegment: """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" content: str = None - tool_calls: List[ChatCompletionMessageToolCall | Dict] = None + tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list) role: str = "assistant" def to_dict(self): @@ -67,7 +67,7 @@ class AssistantMessageSegment: } if self.content: ret["content"] = self.content - elif self.tool_calls: + if self.tool_calls: ret["tool_calls"] = self.tool_calls return ret @@ -95,19 +95,19 @@ class ProviderRequest: """提示词""" session_id: str = "" """会话 ID""" - image_urls: List[str] = None + image_urls: list[str] = field(default_factory=list) """图片 URL 列表""" - func_tool: FuncCall = None + func_tool: FuncCall | None = None """可用的函数工具""" - contexts: List = None + contexts: list[dict] = field(default_factory=list) """上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages """ system_prompt: str = "" """系统提示词""" - conversation: Conversation = None + conversation: Conversation | None = None - tool_calls_result: ToolCallsResult = None + tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" def __repr__(self): @@ -116,6 +116,14 @@ class ProviderRequest: def __str__(self): return self.__repr__() + def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): + """添加工具调用结果到请求中""" + if not self.tool_calls_result: + self.tool_calls_result = [] + if isinstance(self.tool_calls_result, ToolCallsResult): + self.tool_calls_result = [self.tool_calls_result] + self.tool_calls_result.append(tool_calls_result) + def _print_friendly_context(self): """打印友好的消息上下文。将 image_url 的值替换为 """ if not self.contexts: diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 5886b8083..2abe59d65 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -190,11 +190,6 @@ class ProviderManager: from .sources.anthropic_source import ( ProviderAnthropic as ProviderAnthropic, ) - case "llm_tuner": - logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import ( - LLMTunerModelLoader as LLMTunerModelLoader, - ) case "dify": from .sources.dify_source import ProviderDify as ProviderDify case "dashscope": @@ -330,8 +325,6 @@ class ProviderManager: inst = provider_metadata.cls_type( provider_config, self.provider_settings, - self.db_helper, - self.provider_settings.get("persistant_history", True), self.selected_default_persona, ) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index c285ebd42..1ecca3537 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,6 +1,5 @@ import abc from typing import List -from astrbot.core.db import BaseDatabase from typing import TypedDict, AsyncGenerator from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.entities import LLMResponse, ToolCallsResult @@ -53,15 +52,13 @@ class Provider(AbstractProvider): self, provider_config: dict, provider_settings: dict, - persistant_history: bool = True, - db_helper: BaseDatabase = None, - default_persona: Personality = None, + default_persona: Personality | None = None, ) -> None: super().__init__(provider_config) self.provider_settings = provider_settings - self.curr_personality: Personality = default_persona + self.curr_personality = default_persona """维护了当前的使用的 persona,即人格。可能为 None""" @abc.abstractmethod @@ -86,11 +83,11 @@ class Provider(AbstractProvider): self, prompt: str, session_id: str = None, - image_urls: List[str] = None, + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts: List = None, + contexts: list = None, system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -114,11 +111,11 @@ class Provider(AbstractProvider): self, prompt: str, session_id: str = None, - image_urls: List[str] = None, + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts: List = None, + contexts: list = None, system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index c3ad45868..a53250fb7 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,3 +1,6 @@ +import json +import anthropic +import base64 from typing import List from mimetypes import guess_type @@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic from anthropic.types import Message from astrbot.core.utils.io import download_image_by_url -from astrbot.core.db import BaseDatabase -from astrbot.api.provider import Provider, Personality +from astrbot.api.provider import Provider from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse, ToolCallsResult -from .openai_source import ProviderOpenAIOfficial +from astrbot.core.provider.entities import LLMResponse +from typing import AsyncGenerator @register_provider_adapter( "anthropic_chat_completion", "Anthropic Claude API 提供商适配器" ) -class ProviderAnthropic(ProviderOpenAIOfficial): +class ProviderAnthropic(Provider): def __init__( self, - provider_config: dict, - provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=True, - default_persona: Personality = None, + provider_config, + provider_settings, + default_persona=None, ) -> None: - # Skip OpenAI's __init__ and call Provider's __init__ directly - Provider.__init__( - self, + super().__init__( provider_config, provider_settings, - persistant_history, - db_helper, default_persona, ) - self.chosen_api_key = None + self.chosen_api_key: str = "" self.api_keys: List = provider_config.get("key", []) - self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None + self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else "" self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): @@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial): self.set_model(provider_config["model_config"]["model"]) + def _prepare_payload(self, messages: list[dict]): + """准备 Anthropic API 的请求 payload + + Args: + messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息 + Returns: + system_prompt: 系统提示内容 + new_messages: 处理后的消息列表,去除系统提示 + """ + system_prompt = "" + new_messages = [] + for message in messages: + if message["role"] == "system": + system_prompt = message["content"] + elif message["role"] == "assistant": + blocks = [] + if isinstance(message["content"], str): + blocks.append({"type": "text", "text": message["content"]}) + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + blocks.append( # noqa: PERF401 + { + "type": "tool_use", + "name": tool_call["function"]["name"], + "input": json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"], + "id": tool_call["id"], + } + ) + new_messages.append( + { + "role": "assistant", + "content": blocks, + } + ) + elif message["role"] == "tool": + new_messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": message["tool_call_id"], + "content": message["content"], + } + ], + } + ) + else: + new_messages.append(message) + + return system_prompt, new_messages + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: - tool_list = tools.get_func_desc_anthropic_style() - if tool_list: + if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list completion = await self.client.messages.create(**payloads, stream=False) @@ -64,70 +112,157 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if len(completion.content) == 0: raise Exception("API 返回的 completion 为空。") - # TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容 - # 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求 - content = completion.content[-1] - llm_response = LLMResponse("assistant") + llm_response = LLMResponse(role="assistant") - if content.type == "text": - # text completion - completion_text = str(content.text).strip() - # llm_response.completion_text = completion_text - llm_response.result_chain = MessageChain().message(completion_text) - - # Anthropic每次只返回一个函数调用 - if completion.stop_reason == "tool_use": - # tools call (function calling) - args_ls = [] - func_name_ls = [] - tool_use_ids = [] - func_name_ls.append(content.name) - args_ls.append(content.input) - tool_use_ids.append(content.id) - llm_response.role = "tool" - llm_response.tools_call_args = args_ls - llm_response.tools_call_name = func_name_ls - llm_response.tools_call_ids = tool_use_ids + for content_block in completion.content: + if content_block.type == "text": + completion_text = str(content_block.text).strip() + llm_response.completion_text = completion_text + if content_block.type == "tool_use": + llm_response.tools_call_args.append(content_block.input) + llm_response.tools_call_name.append(content_block.name) + llm_response.tools_call_ids.append(content_block.id) + # TODO(Soulter): 处理 end_turn 情况 if not llm_response.completion_text and not llm_response.tools_call_args: - logger.error(f"API 返回的 completion 无法解析:{completion}。") - raise Exception(f"API 返回的 completion 无法解析:{completion}。") - - llm_response.raw_completion = completion + raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。") return llm_response + async def _query_stream( + self, payloads: dict, tools: FuncCall + ) -> AsyncGenerator[LLMResponse, None]: + if tools: + if tool_list := tools.get_func_desc_anthropic_style(): + payloads["tools"] = tool_list + + # 用于累积工具调用信息 + tool_use_buffer = {} + # 用于累积最终结果 + final_text = "" + final_tool_calls = [] + + async with self.client.messages.stream(**payloads) as stream: + assert isinstance(stream, anthropic.AsyncMessageStream) + async for event in stream: + if event.type == "content_block_start": + if event.content_block.type == "text": + # 文本块开始 + yield LLMResponse( + role="assistant", completion_text="", is_chunk=True + ) + elif event.content_block.type == "tool_use": + # 工具使用块开始,初始化缓冲区 + tool_use_buffer[event.index] = { + "id": event.content_block.id, + "name": event.content_block.name, + "input": {}, + } + + elif event.type == "content_block_delta": + if event.delta.type == "text_delta": + # 文本增量 + final_text += event.delta.text + yield LLMResponse( + role="assistant", + completion_text=event.delta.text, + is_chunk=True, + ) + elif event.delta.type == "input_json_delta": + # 工具调用参数增量 + if event.index in tool_use_buffer: + # 累积 JSON 输入 + if "input_json" not in tool_use_buffer[event.index]: + tool_use_buffer[event.index]["input_json"] = "" + tool_use_buffer[event.index]["input_json"] += ( + event.delta.partial_json + ) + + elif event.type == "content_block_stop": + # 内容块结束 + if event.index in tool_use_buffer: + # 解析完整的工具调用 + tool_info = tool_use_buffer[event.index] + try: + if "input_json" in tool_info: + tool_info["input"] = json.loads(tool_info["input_json"]) + + # 添加到最终结果 + final_tool_calls.append( + { + "id": tool_info["id"], + "name": tool_info["name"], + "input": tool_info["input"], + } + ) + + yield LLMResponse( + role="tool", + completion_text="", + tools_call_args=[tool_info["input"]], + tools_call_name=[tool_info["name"]], + tools_call_ids=[tool_info["id"]], + is_chunk=True, + ) + except json.JSONDecodeError: + # JSON 解析失败,跳过这个工具调用 + logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}") + + # 清理缓冲区 + del tool_use_buffer[event.index] + + # 返回最终的完整结果 + final_response = LLMResponse( + role="assistant", completion_text=final_text, is_chunk=False + ) + + if final_tool_calls: + final_response.tools_call_args = [ + call["input"] for call in final_tool_calls + ] + final_response.tools_call_name = [call["name"] for call in final_tool_calls] + final_response.tools_call_ids = [call["id"] for call in final_tool_calls] + + yield final_response + async def text_chat( self, - prompt: str, - session_id: str = None, - image_urls: List[str] = [], - func_tool: FuncCall = None, + prompt, + session_id=None, + image_urls=None, + func_tool=None, contexts=None, system_prompt=None, - tool_calls_result: ToolCallsResult = None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: if contexts is None: contexts = [] - if not prompt: - prompt = "" - new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) for part in context_query: if "_no_save" in part: del part["_no_save"] + # tool calls result if tool_calls_result: - # 暂时这样写。 - prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}" + if not isinstance(tool_calls_result, list): + context_query.extend(tool_calls_result.to_openai_messages()) + else: + for tcr in tool_calls_result: + context_query.extend(tcr.to_openai_messages()) + + system_prompt, new_messages = self._prepare_payload(context_query) model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": new_messages, **model_config} - payloads = {"messages": context_query, **model_config} # Anthropic has a different way of handling system prompts if system_prompt: payloads["system"] = system_prompt @@ -135,32 +270,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial): llm_response = None try: llm_response = await self._query(payloads, func_tool) - except Exception as e: - if "maximum context length" in str(e): - retry_cnt = 20 - while retry_cnt > 0: - logger.warning( - f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" - ) - try: - await self.pop_record(context_query) - response = await self.client.messages.create( - messages=context_query, **model_config - ) - llm_response = LLMResponse("assistant") - llm_response.result_chain = MessageChain().message(response.content[0].text) - llm_response.raw_completion = response - return llm_response - except Exception as e: - if "maximum context length" in str(e): - retry_cnt -= 1 - else: - raise e - return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。") - else: - logger.error(f"发生了错误。Provider 配置如下: {model_config}") - raise e + logger.error(f"发生了错误。Provider 配置如下: {model_config}") + raise e return llm_response @@ -175,21 +287,34 @@ class ProviderAnthropic(ProviderOpenAIOfficial): tool_calls_result=None, **kwargs, ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response + if contexts is None: + contexts = [] + new_record = await self.assemble_context(prompt, image_urls) + context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + + system_prompt, new_messages = self._prepare_payload(context_query) + + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": new_messages, **model_config} + + # Anthropic has a different way of handling system prompts + if system_prompt: + payloads["system"] = system_prompt + + async for llm_response in self._query_stream(payloads, func_tool): + yield llm_response async def assemble_context(self, text: str, image_urls: List[str] = None): """组装上下文,支持文本和图片""" @@ -232,3 +357,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial): ) return {"role": "user", "content": content} + + async def encode_image_bs64(self, image_url: str) -> str: + """ + 将图片转换为 base64 + """ + if image_url.startswith("base64://"): + return image_url.replace("base64://", "data:image/jpeg;base64,") + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 + return "" + + def get_current_key(self) -> str: + return self.chosen_api_key + + async def get_models(self) -> List[str]: + models_str = [] + models = await self.client.models.list() + models = sorted(models.data, key=lambda x: x.id) + for model in models: + models_str.append(model.id) + return models_str + + def set_key(self, key: str): + self.chosen_api_key = key diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index f719190a1..3498f8346 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -5,7 +5,6 @@ from typing import List from .. import Provider, Personality from ..entities import LLMResponse from ..func_tool_manager import FuncCall -from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter from astrbot.core.message.message_event_result import MessageChain from .openai_source import ProviderOpenAIOfficial @@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial): self, provider_config: dict, provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=False, - default_persona: Personality = None, + default_persona: Personality | None = None, ) -> None: Provider.__init__( self, provider_config, provider_settings, - persistant_history, - db_helper, default_persona, ) self.api_key = provider_config.get("dashscope_api_key", "") diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 348c3e72c..81c910d66 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,10 +1,9 @@ import astrbot.core.message.components as Comp import os from typing import List -from .. import Provider, Personality +from .. import Provider from ..entities import LLMResponse from ..func_tool_manager import FuncCall -from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.io import download_image_by_url, download_file @@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path class ProviderDify(Provider): def __init__( self, - provider_config: dict, - provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=False, - default_persona: Personality = None, + provider_config, + provider_settings, + default_persona = None, ) -> None: super().__init__( provider_config, provider_settings, - persistant_history, - db_helper, default_persona, ) self.api_key = provider_config.get("dify_api_key", "") diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index c16b39415..e1d1f11bd 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -12,8 +12,7 @@ from google.genai.errors import APIError import astrbot.core.message.components as Comp from astrbot import logger -from astrbot.api.provider import Personality, Provider -from astrbot.core.db import BaseDatabase +from astrbot.api.provider import Provider from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.func_tool_manager import FuncCall @@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider): def __init__( self, - provider_config: dict, - provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=True, - default_persona: Personality = None, + provider_config, + provider_settings, + default_persona=None, ) -> None: super().__init__( provider_config, provider_settings, - persistant_history, - db_helper, default_persona, ) self.api_keys: list = provider_config.get("key", []) @@ -264,12 +259,10 @@ class ProviderGoogleGenAI(Provider): contents.append(content_cls(parts=part)) gemini_contents: list[types.Content] = [] - native_tool_enabled = any( - [ - self.provider_config.get("gm_native_coderunner", False), - self.provider_config.get("gm_native_search", False), - ] - ) + native_tool_enabled = any([ + self.provider_config.get("gm_native_coderunner", False), + self.provider_config.get("gm_native_search", False), + ]) for message in payloads["messages"]: role, content = message["role"], message.get("content") @@ -506,12 +499,12 @@ class ProviderGoogleGenAI(Provider): async def text_chat( self, prompt: str, - session_id: str = None, - image_urls: list[str] = None, - func_tool: FuncCall = None, - contexts: list = None, - system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + session_id=None, + image_urls=None, + func_tool=None, + contexts=None, + system_prompt=None, + tool_calls_result=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -527,7 +520,11 @@ class ProviderGoogleGenAI(Provider): # tool calls result if tool_calls_result: - context_query.extend(tool_calls_result.to_openai_messages()) + if not isinstance(tool_calls_result, list): + context_query.extend(tool_calls_result.to_openai_messages()) + else: + for tcr in tool_calls_result: + context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() @@ -631,9 +628,10 @@ class ProviderGoogleGenAI(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append( - {"type": "image_url", "image_url": {"url": image_data}} - ) + user_content["content"].append({ + "type": "image_url", + "image_url": {"url": image_data}, + }) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py deleted file mode 100644 index 8648512d0..000000000 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -from llmtuner.chat import ChatModel -from typing import List -from .. import Provider -from ..entities import LLMResponse -from ..func_tool_manager import FuncCall -from astrbot.core.db import BaseDatabase -from ..register import register_provider_adapter - - -@register_provider_adapter( - "llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型" -) -class LLMTunerModelLoader(Provider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=True, - default_persona=None, - ) -> None: - super().__init__( - provider_config, - provider_settings, - persistant_history, - db_helper, - default_persona, - ) - if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists( - provider_config["adapter_model_path"] - ): - raise FileNotFoundError("模型文件路径不存在。") - self.base_model_path = provider_config["base_model_path"] - self.adapter_model_path = provider_config["adapter_model_path"] - self.model = ChatModel( - { - "model_name_or_path": self.base_model_path, - "adapter_name_or_path": self.adapter_model_path, - "template": provider_config["llmtuner_template"], - "finetuning_type": provider_config["finetuning_type"], - "quantization_bit": provider_config["quantization_bit"], - } - ) - self.set_model( - os.path.basename(self.base_model_path) - + "_" - + os.path.basename(self.adapter_model_path) - ) - - async def assemble_context(self, text: str, image_urls: List[str] = None): - """ - 组装上下文。 - """ - return {"role": "user", "content": text} - - async def text_chat( - self, - prompt: str, - session_id: str = None, - image_urls: List[str] = None, - func_tool: FuncCall = None, - contexts: List = None, - system_prompt: str = None, - **kwargs, - ) -> LLMResponse: - if contexts is None: - contexts = [] - system_prompt = "" - new_record = {"role": "user", "content": prompt} - query_context = [*contexts, new_record] - - # 提取出系统提示 - system_idxs = [] - for idx, context in enumerate(query_context): - if context["role"] == "system": - system_idxs.append(idx) - - if "_no_save" in context: - del context["_no_save"] - - for idx in reversed(system_idxs): - system_prompt += " " + query_context.pop(idx)["content"] - - conf = { - "messages": query_context, - "system": system_prompt, - } - if func_tool: - tool_list = func_tool.get_func_desc_openai_style() - if tool_list: - conf["tools"] = tool_list - - responses = await self.model.achat(**conf) - - llm_response = LLMResponse("assistant", responses[-1].response_text) - - return llm_response - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def get_current_key(self): - return "none" - - async def set_key(self, key): - pass - - async def get_models(self): - return [self.get_model()] diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 15104db3c..ef6131d8c 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat.chat_completion import ChatCompletion -# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai._exceptions import NotFoundError, UnprocessableEntityError from openai.lib.streaming.chat._completions import ChatCompletionStreamState from astrbot.core.utils.io import download_image_by_url from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.db import BaseDatabase -from astrbot.api.provider import Provider, Personality +from astrbot.api.provider import Provider from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List, AsyncGenerator @@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult class ProviderOpenAIOfficial(Provider): def __init__( self, - provider_config: dict, - provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=True, - default_persona: Personality = None, + provider_config, + provider_settings, + default_persona = None, ) -> None: super().__init__( provider_config, provider_settings, - persistant_history, - db_helper, default_persona, ) self.chosen_api_key = None @@ -224,12 +218,10 @@ class ProviderOpenAIOfficial(Provider): async def _prepare_chat_payload( self, prompt: str, - session_id: str = None, - image_urls: list[str] = None, - func_tool: FuncCall = None, - contexts: list = None, - system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + image_urls: list[str] | None = None, + contexts: list | None = None, + system_prompt: str | None = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" @@ -246,14 +238,18 @@ class ProviderOpenAIOfficial(Provider): # tool calls result if tool_calls_result: - context_query.extend(tool_calls_result.to_openai_messages()) + if isinstance(tool_calls_result, ToolCallsResult): + context_query.extend(tool_calls_result.to_openai_messages()) + else: + for tcr in tool_calls_result: + context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} - return payloads, context_query, func_tool + return payloads, context_query async def _handle_api_error( self, @@ -352,11 +348,9 @@ class ProviderOpenAIOfficial(Provider): tool_calls_result=None, **kwargs, ) -> LLMResponse: - payloads, context_query, func_tool = await self._prepare_chat_payload( + payloads, context_query = await self._prepare_chat_payload( prompt, - session_id, image_urls, - func_tool, contexts, system_prompt, tool_calls_result, @@ -422,11 +416,9 @@ class ProviderOpenAIOfficial(Provider): **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" - payloads, context_query, func_tool = await self._prepare_chat_payload( + payloads, context_query = await self._prepare_chat_payload( prompt, - session_id, image_urls, - func_tool, contexts, system_prompt, tool_calls_result, diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index e7e9d4a14..428dee8f4 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -1,4 +1,3 @@ -from astrbot.core.db import BaseDatabase from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List @@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial): self, provider_config: dict, provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=True, default_persona=None, ) -> None: super().__init__( provider_config, provider_settings, - db_helper, - persistant_history, default_persona, ) diff --git a/dashboard/src/theme/DarkTheme.ts b/dashboard/src/theme/DarkTheme.ts index 9276c8f98..5906eca32 100644 --- a/dashboard/src/theme/DarkTheme.ts +++ b/dashboard/src/theme/DarkTheme.ts @@ -39,7 +39,8 @@ const PurpleThemeDark: ThemeTypes = { background: '#111111', overlay: '#111111aa', codeBg: '#282833', - code: '#ffffffdd' + code: '#ffffffdd', + chatMessageBubble: '#2d2e30', } }; diff --git a/dashboard/src/theme/LightTheme.ts b/dashboard/src/theme/LightTheme.ts index 35aa1339a..a555fddd7 100644 --- a/dashboard/src/theme/LightTheme.ts +++ b/dashboard/src/theme/LightTheme.ts @@ -39,7 +39,8 @@ const PurpleTheme: ThemeTypes = { background: '#f9fafcf4', overlay: '#ffffffaa', codeBg: '#f5f0ff', - code: '#673ab7' + code: '#673ab7', + chatMessageBubble: '#e7ebf4', } }; diff --git a/dashboard/src/types/themeTypes/ThemeType.ts b/dashboard/src/types/themeTypes/ThemeType.ts index 69b00a1ab..f5e2e5491 100644 --- a/dashboard/src/types/themeTypes/ThemeType.ts +++ b/dashboard/src/types/themeTypes/ThemeType.ts @@ -35,5 +35,6 @@ export type ThemeTypes = { secondary200?: string; codeBg?: string; code?: string; + chatMessageBubble?: string; }; }; diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index fadba8aac..c47509ab5 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -175,7 +175,7 @@
-
+
{{ msg.message }} @@ -195,15 +195,12 @@
- - -
- - + +
@@ -1574,28 +1571,25 @@ export default { } .message-bubble { - padding: 12px 16px; - border-radius: 18px; + padding: 8px 16px; + border-radius: 12px; max-width: 80%; - box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1); } .user-bubble { - background-color: var(--v-theme-background); + background-color: var(--v-theme-chatMessageBubble); color: var(--v-theme-primaryText); - border-top-right-radius: 4px; } .bot-bubble { - background-color: var(--v-theme-surface); border: 1px solid var(--v-theme-border); color: var(--v-theme-primaryText); - border-top-left-radius: 4px; } .user-avatar, .bot-avatar { - align-self: flex-end; + align-self: flex-start; + margin-top: 12px; } /* 附件样式 */ diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 220393dea..44e7f8206 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -414,7 +414,6 @@ export default { "anthropic_chat_completion": "chat_completion", "googlegenai_chat_completion": "chat_completion", "zhipu_chat_completion": "chat_completion", - "llm_tuner": "chat_completion", "dify": "chat_completion", "dashscope": "chat_completion", "openai_whisper_api": "speech_to_text",