diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index a08a94543..17cfbd39e 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -774,17 +774,6 @@ CONFIG_METADATA_2 = { "model": "deepseek/deepseek-r1", }, }, - "LLMTuner": { - "id": "llmtuner_default", - "type": "llm_tuner", - "provider_type": "chat_completion", - "enable": True, - "base_model_path": "", - "adapter_model_path": "", - "llmtuner_template": "", - "finetuning_type": "lora", - "quantization_bit": 4, - }, "Dify": { "id": "dify_app_default", "type": "dify", diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 28c92fa89..c5123bca3 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -98,6 +98,15 @@ class MessageChain: self.chain.append(Image.fromFileSystem(path)) return self + def base64_image(self, base64_str: str): + """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 + Example: + + CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...") + """ + self.chain.append(Image.fromBase64(base64_str)) + return self + def use_t2i(self, use_t2i: bool): """设置是否使用文本转图片服务。 @@ -157,7 +166,7 @@ class ResultContentType(enum.Enum): """普通的消息结果""" STREAMING_RESULT = enum.auto() """调用 LLM 产生的流式结果""" - STREAMING_FINISH= enum.auto() + STREAMING_FINISH = enum.auto() """流式输出完成""" diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index eb5ffb1cd..64a8762ed 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,6 +1,14 @@ +import inspect +import traceback +import typing as T from dataclasses import dataclass from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star import PluginManager +from astrbot.api import logger +from astrbot.core.star.star_handler import star_handlers_registry, EventType +from astrbot.core.star.star import star_map +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult @dataclass @@ -9,3 +17,91 @@ class PipelineContext: astrbot_config: AstrBotConfig # AstrBot 配置对象 plugin_manager: PluginManager # 插件管理器对象 + + async def call_event_hook( + self, + event: AstrMessageEvent, + hook_type: EventType, + *args, + ): + platform_id = event.get_platform_id() + handlers = star_handlers_registry.get_handlers_by_event_type( + hook_type, platform_id=platform_id + ) + for handler in handlers: + try: + logger.debug( + f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) + await handler.handler(event, *args) + except BaseException: + logger.error(traceback.format_exc()) + + if event.is_stopped(): + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) + return + + async def call_handler( + self, + event: AstrMessageEvent, + handler: T.Awaitable, + *args, + **kwargs, + ) -> T.AsyncGenerator[None, None]: + """执行事件处理函数并处理其返回结果 + + 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: + 1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层 + 2. 协程: 执行一次并处理返回值 + + Args: + ctx (PipelineContext): 消息管道上下文对象 + event (AstrMessageEvent): 事件对象 + handler (Awaitable): 事件处理函数 + + Returns: + AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 + """ + ready_to_call = None # 一个协程或者异步生成器 + + trace_ = None + + try: + ready_to_call = handler(event, *args, **kwargs) + except TypeError as _: + # 向下兼容 + trace_ = traceback.format_exc() + # 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份 + ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs) + + if isinstance(ready_to_call, T.AsyncGenerator): + _has_yielded = False + try: + async for ret in ready_to_call: + # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) + _has_yielded = True + if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 + event.set_result(ret) + yield + else: + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret + if not _has_yielded: + # 如果这个异步生成器没有执行到 yield 分支 + yield + except Exception as e: + logger.error(f"Previous Error: {trace_}") + raise e + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个协程, 直接执行 + ret = await ready_to_call + if isinstance(ret, (MessageEventResult, CommandResult)): + event.set_result(ret) + yield + else: + yield ret diff --git a/astrbot/core/pipeline/process_stage/agent_runner/base.py b/astrbot/core/pipeline/process_stage/agent_runner/base.py new file mode 100644 index 000000000..fd3c7d4ae --- /dev/null +++ b/astrbot/core/pipeline/process_stage/agent_runner/base.py @@ -0,0 +1,43 @@ +import abc +import typing as T +from dataclasses import dataclass +from astrbot.core.provider.entities import LLMResponse + + +@dataclass +class AgentResponse: + type: str + data: dict + + +class BaseAgentRunner: + @abc.abstractmethod + async def reset(self) -> None: + """ + Reset the agent to its initial state. + This method should be called before starting a new run. + """ + ... + + @abc.abstractmethod + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: + """ + Process a single step of the agent. + """ + ... + + @abc.abstractmethod + def done(self) -> bool: + """ + Check if the agent has completed its task. + Returns True if the agent is done, False otherwise. + """ + ... + + @abc.abstractmethod + def get_final_llm_resp(self) -> LLMResponse | None: + """ + Get the final observation from the agent. + This method should be called after the agent is done. + """ + ... diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py new file mode 100644 index 000000000..d26701d6b --- /dev/null +++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py @@ -0,0 +1,276 @@ +import sys +import traceback +import typing as T +from .base import BaseAgentRunner, AgentResponse +from ...context import PipelineContext +from astrbot.core.provider.provider import Provider +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import ( + MessageChain, +) +from astrbot.core.provider.entities import ( + ProviderRequest, + LLMResponse, + ToolCallMessageSegment, + AssistantMessageSegment, + ToolCallsResult, +) +from mcp.types import ( + TextContent, + ImageContent, + EmbeddedResource, + TextResourceContents, + BlobResourceContents, +) +from astrbot.core.star.star_handler import EventType +from astrbot import logger + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +# TODO: +# 1. 处理平台不兼容的处理器 + + +class ToolLoopAgent(BaseAgentRunner): + def __init__( + self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext + ) -> None: + self.provider = provider + self.req = None + self.event = event + self.pipeline_ctx = pipeline_ctx + + @override + async def reset(self, req: ProviderRequest, streaming: bool) -> None: + self.req = req + self.streaming = streaming + self.final_llm_resp = None + self.is_done = False + + @override + async def step(self): + """ + Process a single step of the agent. + This method should return the result of the step. + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + # 执行 LLM 请求 + llm_resp_result = None + if self.streaming: + stream = self.provider.text_chat_stream(**self.req.__dict__) + async for llm_response in stream: # type: ignore + assert isinstance(llm_response, LLMResponse) + if llm_response.is_chunk: + if llm_response.result_chain: + yield AgentResponse( + type="streaming_delta", + data={ + "chain": llm_response.result_chain.chain, + }, + ) + else: + yield AgentResponse( + type="streaming_delta", + data={ + "chain": MessageChain().message( + llm_response.completion_text + ), + }, + ) + else: + llm_resp_result = llm_response + else: + llm_resp_result = await self.provider.text_chat(**self.req.__dict__) + + if not llm_resp_result: + return + + # 执行事件钩子 + await self.pipeline_ctx.call_event_hook( + self.event, EventType.OnLLMResponseEvent, self.final_llm_resp + ) + + # 处理 LLM 响应 + llm_resp = llm_resp_result + logger.info(f"LLMResp: {llm_resp}") + if llm_resp.role == "err": + # 如果 LLM 响应错误,直接返回错误信息 + self.final_llm_resp = llm_resp + self.is_done = True + yield AgentResponse( + type="err", + data={ + "chain": MessageChain().message( + f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}" + ), + }, + ) + + if not llm_resp.tools_call_name: + # 如果没有工具调用,结束 Agent Loop + self.final_llm_resp = llm_resp + self.is_done = True + + # 返回 LLM 结果 + if llm_resp.result_chain: + yield AgentResponse( + type="llm_result", + data={ + "chain": llm_resp.result_chain.chain, + }, + ) + elif llm_resp.completion_text: + yield AgentResponse( + type="llm_result", + data={ + "chain": MessageChain().message(llm_resp.completion_text), + }, + ) + + # 如果有工具调用,还需处理工具调用 + if llm_resp.tools_call_name: + tool_call_result_blocks = [] + async for result in self._handle_function_tools(self.req, llm_resp): + if isinstance(result, list): + tool_call_result_blocks = result + elif isinstance(result, MessageChain): + yield AgentResponse( + type="tool_call_result", + data={ + "chain": result.chain, + }, + ) + # 将结果添加到上下文中 + tool_calls_result = ToolCallsResult( + tool_calls_info=AssistantMessageSegment( + role="assistant", tool_calls=llm_resp.to_openai_tool_calls() + ), + tool_calls_result=tool_call_result_blocks, + ) + self.req.append_tool_calls_result(tool_calls_result) + + logger.info("done: %s", self.is_done) + + async def _handle_function_tools( + self, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]: + """处理函数工具调用。""" + tool_call_result_blocks: list[ToolCallMessageSegment] = [] + logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") + + # 执行函数调用 + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] + res = await client.session.call_tool(func_tool.name, func_tool_args) + if not res: + continue + if isinstance(res.content[0], TextContent): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) + ) + elif isinstance(res.content[0], ImageContent): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + yield MessageChain().base64_image(res.content[0].data) + elif isinstance(res.content[0], EmbeddedResource): + resource = res.content[0].resource + if isinstance(resource, TextResourceContents): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resource.text, + ) + ) + elif ( + isinstance(resource, BlobResourceContents) + and resource.mimeType + and resource.mimeType.startswith("image/") + ): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + yield MessageChain().base64_image(res.content[0].data) + else: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回的数据类型不受支持", + ) + ) + else: + logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") + # 尝试调用工具函数 + wrapper = self.pipeline_ctx.call_handler( + self.event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: + # Tool 返回结果 + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resp, + ) + ) + else: + # Tool 直接请求发送消息给用户 + # 这里我们将直接结束 Agent Loop。 + self.is_done = True + if res := self.event.get_result(): + if res.chain: + yield MessageChain(chain=res.chain) + + self.event.clear_result() + except BaseException as e: + logger.warning(traceback.format_exc()) + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", + ) + ) + + # 处理函数调用响应 + if tool_call_result_blocks: + yield tool_call_result_blocks + + def done(self) -> bool: + return self.is_done + + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 549440ee3..7d5fa8135 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -3,6 +3,7 @@ """ import traceback +import copy import asyncio import json from typing import Union, AsyncGenerator @@ -20,39 +21,25 @@ from astrbot.core.utils.metrics import Metric from astrbot.core.provider.entities import ( ProviderRequest, LLMResponse, - ToolCallMessageSegment, - AssistantMessageSegment, - ToolCallsResult, -) -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star import star_map -from mcp.types import ( - TextContent, - ImageContent, - EmbeddedResource, - TextResourceContents, - BlobResourceContents, ) +from astrbot.core.star.star_handler import EventType from astrbot.core import web_chat_back_queue +from ..agent_runner.tool_loop_agent import ToolLoopAgent class LLMRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx - self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list - self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][ - "wake_prefix" - ] # str - self.max_context_length = ctx.astrbot_config["provider_settings"][ - "max_context_length" - ] # int - self.dequeue_context_length = min( - max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]), + conf = ctx.astrbot_config + self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list + self.provider_wake_prefix: str = conf["provider_settings"]["wake_prefix"] # str + self.max_context_length = conf["provider_settings"]["max_context_length"] # int + self.dequeue_context_length: int = min( + max(1, conf["provider_settings"]["dequeue_context_length"]), self.max_context_length - 1, - ) # int - self.streaming_response = ctx.astrbot_config["provider_settings"][ - "streaming_response" - ] # bool + ) + self.streaming_response: bool = conf["provider_settings"]["streaming_response"] + self.max_step: int = conf["provider_settings"].get("max_agent_step", 10) for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -83,10 +70,7 @@ class LLMRequestSubStage(Stage): ) if req.conversation: - all_contexts = json.loads(req.conversation.history) - req.contexts = self._process_tool_message_pairs( - all_contexts, remove_tags=True - ) + req.contexts = json.loads(req.conversation.history) else: req = ProviderRequest(prompt="", image_urls=[]) @@ -127,26 +111,7 @@ class LLMRequestSubStage(Stage): return # 执行请求 LLM 前事件钩子。 - # 装饰 system_prompt 等功能 - # 获取当前平台ID - platform_id = event.get_platform_id() - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMRequestEvent, platform_id=platform_id - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, req) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return + await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req) if isinstance(req.contexts, str): req.contexts = json.loads(req.contexts) @@ -176,77 +141,46 @@ class LLMRequestSubStage(Stage): if not req.session_id: req.session_id = event.unified_msg_origin - async def requesting(req: ProviderRequest): - try: - need_loop = True - while need_loop: - need_loop = False - logger.debug(f"提供商请求 Payload: {req}") + # Call Agent + tool_loop_agent = ToolLoopAgent( + provider=provider, + event=event, + pipeline_ctx=self.ctx, + ) + await tool_loop_agent.reset(req=req, streaming=self.streaming_response) - final_llm_response = None - - if self.streaming_response: - stream = provider.text_chat_stream(**req.__dict__) - async for llm_response in stream: - if llm_response.is_chunk: - if llm_response.result_chain: - yield llm_response.result_chain # MessageChain - else: - yield MessageChain().message( - llm_response.completion_text - ) - else: - final_llm_response = llm_response - else: - final_llm_response = await provider.text_chat( - **req.__dict__ - ) # 请求 LLM - - if not final_llm_response: - raise Exception("LLM response is None.") - - # 执行 LLM 响应后的事件钩子。 - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMResponseEvent + async def requesting(): + step_idx = 0 + while step_idx < self.max_step: + step_idx += 1 + try: + async for resp in tool_loop_agent.step(): + if not self.streaming_response: + content_typ = ( + ResultContentType.LLM_RESULT + if resp.type == "llm_resp" + else ResultContentType.GENERAL_RESULT + ) + event.set_result( + MessageEventResult( + chain=resp.data["chain"], + result_content_type=content_typ, + ) + ) + yield + event.clear_result() + else: + yield resp.data["chain"] + if tool_loop_agent.done(): + break + except Exception as e: + logger.error(traceback.format_exc()) + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + ) ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, final_llm_response) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return - - if self.streaming_response: - # 流式输出的处理 - async for result in self._handle_llm_stream_response( - event, req, final_llm_response - ): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True - else: - yield - else: - # 非流式输出的处理 - async for result in self._handle_llm_response( - event, req, final_llm_response - ): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True - else: - yield - + return asyncio.create_task( Metric.upload( llm_tick=1, @@ -255,44 +189,38 @@ class LLMRequestSubStage(Stage): ) ) - # 保存到历史记录 - await self._save_to_history(event, req, final_llm_response) - - except BaseException as e: - logger.error(traceback.format_exc()) - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" - ) - ) - - if not self.streaming_response: - event.set_extra("tool_call_result", None) - async for _ in requesting(req): - yield - else: + if self.streaming_response: + # 流式响应 event.set_result( MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream(requesting(req)) + .set_async_stream(requesting()) ) - # 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理 yield - - if event.get_extra("tool_call_result"): - event.set_result(event.get_extra("tool_call_result")) - event.set_extra("tool_call_result", None) + if tool_loop_agent.done(): + if final_llm_resp := tool_loop_agent.get_final_llm_resp(): + if final_llm_resp.completion_text: + chain = ( + MessageChain().message(final_llm_resp.completion_text).chain + ) + else: + chain = final_llm_resp.result_chain.chain + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ) + ) + else: + async for _ in requesting(): yield - # 暂时直接发出去 - if img_b64 := event.get_extra("tool_call_img_respond"): - await event.send(MessageChain(chain=[Image.fromBase64(img_b64)])) - event.set_extra("tool_call_img_respond", None) - + # 异步处理 WebChat 特殊情况 if event.get_platform_name() == "webchat": - # 异步处理 WebChat 特殊情况 asyncio.create_task(self._handle_webchat(event, req)) + await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp()) + async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest): """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" conversation = await self.conv_manager.get_conversation( @@ -305,10 +233,6 @@ class LLMRequestSubStage(Stage): return provider = self.ctx.plugin_manager.context.get_using_provider() cleaned_text = "User: " + latest_pair[0].get("content", "").strip() - # if len(latest_pair) > 1: - # cleaned_text += ( - # "\nAssistant: " + latest_pair[1].get("content", "").strip() - # ) logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}") llm_resp = await provider.text_chat( system_prompt="You are expert in summarizing user's query.", @@ -349,322 +273,34 @@ class LLMRequestSubStage(Stage): } ) - async def _handle_llm_response( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse, - ) -> AsyncGenerator[Union[None, ProviderRequest], None]: - """处理非流式 LLM 响应。 - - Returns: - AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM - - Yields: - Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM - """ - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.LLM_RESULT) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.LLM_RESULT) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # 处理函数工具调用 - async for result in self._handle_function_tools(event, req, llm_response): - yield result - - async def _handle_llm_stream_response( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse, - ) -> AsyncGenerator[Union[None, ProviderRequest], None]: - """处理流式 LLM 响应。 - - 专门用于处理流式输出完成后的响应,与非流式响应处理分离。 - - Returns: - AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM - - Yields: - Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM - """ - if llm_response.role == "assistant": - # text completion - if llm_response.result_chain: - event.set_result( - MessageEventResult( - chain=llm_response.result_chain.chain - ).set_result_content_type(ResultContentType.STREAMING_FINISH) - ) - else: - event.set_result( - MessageEventResult() - .message(llm_response.completion_text) - .set_result_content_type(ResultContentType.STREAMING_FINISH) - ) - elif llm_response.role == "err": - event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" - ) - ) - elif llm_response.role == "tool": - # 处理函数工具调用 - async for result in self._handle_function_tools(event, req, llm_response): - yield result - - async def _handle_function_tools( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse, - ) -> AsyncGenerator[Union[None, ProviderRequest], None]: - """处理函数工具调用。 - - Returns: - AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM - """ - # function calling - tool_call_result: list[ToolCallMessageSegment] = [] - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" - ) - for func_tool_name, func_tool_args, func_tool_id in zip( - llm_response.tools_call_name, - llm_response.tools_call_args, - llm_response.tools_call_ids, - ): - try: - func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" - ) - client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] - res = await client.session.call_tool(func_tool.name, func_tool_args) - if res: - # TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback - if isinstance(res.content[0], TextContent): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, - ) - ) - elif isinstance(res.content[0], ImageContent): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回了图片(已直接发送给用户)", - ) - ) - event.set_extra( - "tool_call_img_respond", - res.content[0].data, - ) - elif isinstance(res.content[0], EmbeddedResource): - resource = res.content[0].resource - if isinstance(resource, TextResourceContents): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resource.text, - ) - ) - elif ( - isinstance(resource, BlobResourceContents) - and resource.mimeType - and resource.mimeType.startswith("image/") - ): - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回了图片(已直接发送给用户)", - ) - ) - event.set_extra( - "tool_call_img_respond", - res.content[0].data, - ) - else: - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回的数据类型不受支持", - ) - ) - else: - # 获取处理器,过滤掉平台不兼容的处理器 - platform_id = event.get_platform_id() - star_md = star_map.get(func_tool.handler_module_path) - if ( - star_md - and platform_id in star_md.supported_platforms - and not star_md.supported_platforms[platform_id] - ): - logger.debug( - f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行" - ) - # 直接跳过,不添加任何消息到tool_call_result - continue - - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resp, - ) - ) - else: - res = event.get_result() - if res and res.chain: - event.set_extra("tool_call_result", res) - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: {str(e)}", - ) - ) - if tool_call_result: - # 函数调用结果 - req.func_tool = None # 暂时不支持递归工具调用 - assistant_msg_seg = AssistantMessageSegment( - role="assistant", tool_calls=llm_response.to_openai_tool_calls() - ) - # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 - req.tool_calls_result = ToolCallsResult( - tool_calls_info=assistant_msg_seg, - tool_calls_result=tool_call_result, - ) - yield req # 再次执行 LLM 请求 - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) - async def _save_to_history( - self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse | None, ): - if not req or not req.conversation or not llm_response: + if ( + not req + or not req.conversation + or not llm_response + or llm_response.role != "assistant" + ): return - if llm_response.role == "assistant": - # 文本回复 - contexts = req.contexts.copy() - contexts.append(await req.assemble_context()) - - # 记录并标记函数调用结果 - if req.tool_calls_result: - tool_calls_messages = req.tool_calls_result.to_openai_messages() - - # 添加标记 - for message in tool_calls_messages: - message["_tool_call_history"] = True - - processed_tool_messages = self._process_tool_message_pairs( - tool_calls_messages, remove_tags=False - ) - - contexts.extend(processed_tool_messages) - - contexts.append( - {"role": "assistant", "content": llm_response.completion_text} - ) - contexts_to_save = list( - filter(lambda item: "_no_save" not in item, contexts) - ) - await self.conv_manager.update_conversation( - event.unified_msg_origin, req.conversation.cid, history=contexts_to_save - ) - - def _process_tool_message_pairs(self, messages, remove_tags=True): - """处理工具调用消息,确保assistant和tool消息成对出现 - - Args: - messages (list): 消息列表 - remove_tags (bool): 是否移除_tool_call_history标记 - - Returns: - list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现 - """ - result = [] - i = 0 - - while i < len(messages): - current_msg = messages[i] - - # 普通消息直接添加 - if "_tool_call_history" not in current_msg: - result.append(current_msg.copy() if remove_tags else current_msg) - i += 1 - continue - - # 工具调用消息成对处理 - if current_msg.get("role") == "assistant" and "tool_calls" in current_msg: - assistant_msg = current_msg.copy() - - if remove_tags and "_tool_call_history" in assistant_msg: - del assistant_msg["_tool_call_history"] - - related_tools = [] - j = i + 1 - while ( - j < len(messages) - and messages[j].get("role") == "tool" - and "_tool_call_history" in messages[j] - ): - tool_msg = messages[j].copy() - - if remove_tags: - del tool_msg["_tool_call_history"] - - related_tools.append(tool_msg) - j += 1 - - # 成对的时候添加到结果 - if related_tools: - result.append(assistant_msg) - result.extend(related_tools) - - i = j # 跳过已处理 - else: - # 单独的tool消息 - i += 1 - - return result + # 历史上下文 + messages = copy.deepcopy(req.contexts) + # 这一轮对话请求的用户输入 + messages.append(await req.assemble_context()) + # 这一轮对话的 LLM 响应 + if req.tool_calls_result: + if not isinstance(req.tool_calls_result, list): + messages.extend(req.tool_calls_result.to_openai_messages()) + elif isinstance(req.tool_calls_result, list): + for tcr in req.tool_calls_result: + messages.extend(tcr.to_openai_messages()) + messages.append({"role": "assistant", "content": llm_response.completion_text}) + messages = list(filter(lambda item: "_no_save" not in item, messages)) + await self.conv_manager.update_conversation( + event.unified_msg_origin, req.conversation.cid, history=messages + ) + logger.debug(f"messages persisted: {messages}") diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index c7817e49c..00f58d55b 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -50,7 +50,7 @@ class StarRequestSubStage(Stage): logger.debug( f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}" ) - wrapper = self._call_handler(self.ctx, event, handler.handler, **params) + wrapper = self.ctx.call_handler(event, handler.handler, **params) async for ret in wrapper: yield ret event.clear_result() # 清除上一个 handler 的结果 diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index c7d4ff792..b41794733 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,12 +1,8 @@ from __future__ import annotations import abc -import inspect -import traceback -from astrbot.api import logger -from typing import List, AsyncGenerator, Union, Awaitable +from typing import List, AsyncGenerator, Union from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext -from astrbot.core.message.message_event_result import MessageEventResult, CommandResult registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类 @@ -41,70 +37,3 @@ class Stage(abc.ABC): Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段) """ raise NotImplementedError - - async def _call_handler( - self, - ctx: PipelineContext, - event: AstrMessageEvent, - handler: Awaitable, - *args, - **kwargs, - ) -> AsyncGenerator[None, None]: - """执行事件处理函数并处理其返回结果 - - 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: - 1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层 - 2. 协程: 执行一次并处理返回值 - - Args: - ctx (PipelineContext): 消息管道上下文对象 - event (AstrMessageEvent): 待处理的事件对象 - handler (Awaitable): 事件处理函数 - *args: 传递给handler的位置参数 - **kwargs: 传递给handler的关键字参数 - - Returns: - AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 - """ - ready_to_call = None # 一个协程或者异步生成器(async def) - - trace_ = None - - try: - ready_to_call = handler(event, *args, **kwargs) - except TypeError as _: - # 向下兼容 - trace_ = traceback.format_exc() - # 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份 - ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs) - - if isinstance(ready_to_call, AsyncGenerator): - # 如果是一个异步生成器, 进入洋葱模型 - _has_yielded = False # 是否返回过值 - try: - async for ret in ready_to_call: - # 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码 - # 返回值只能是 MessageEventResult 或者 None(无返回值) - _has_yielded = True - if isinstance(ret, (MessageEventResult, CommandResult)): - # 如果返回值是 MessageEventResult, 设置结果并继续 - event.set_result(ret) - yield # 传递控制权给上一层的process函数 - else: - # 如果返回值是 None, 则不设置结果并继续 - # 继续执行后续阶段 - yield ret # 传递控制权给上一层的process函数 - if not _has_yielded: - # 如果这个异步生成器没有执行到yield分支 - yield - except Exception as e: - logger.error(f"Previous Error: {trace_}") - raise e - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个协程, 直接执行 - ret = await ready_to_call - if isinstance(ret, (MessageEventResult, CommandResult)): - event.set_result(ret) - yield # 传递控制权给上一层的process函数 - else: - yield ret # 传递控制权给上一层的process函数 diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index e01e46cf9..7dd96e8d0 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -95,19 +95,19 @@ class ProviderRequest: """提示词""" session_id: str = "" """会话 ID""" - image_urls: List[str] = None + image_urls: list[str] = field(default_factory=list) """图片 URL 列表""" - func_tool: FuncCall = None + func_tool: FuncCall | None = None """可用的函数工具""" - contexts: List = None + contexts: list[dict] = field(default_factory=list) """上下文。格式与 openai 的上下文格式一致: 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages """ system_prompt: str = "" """系统提示词""" - conversation: Conversation = None + conversation: Conversation | None = None - tool_calls_result: ToolCallsResult = None + tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" def __repr__(self): @@ -116,6 +116,14 @@ class ProviderRequest: def __str__(self): return self.__repr__() + def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): + """添加工具调用结果到请求中""" + if not self.tool_calls_result: + self.tool_calls_result = [] + if isinstance(self.tool_calls_result, ToolCallsResult): + self.tool_calls_result = [self.tool_calls_result] + self.tool_calls_result.append(tool_calls_result) + def _print_friendly_context(self): """打印友好的消息上下文。将 image_url 的值替换为 """ if not self.contexts: diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 5886b8083..0dd108474 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -190,11 +190,6 @@ class ProviderManager: from .sources.anthropic_source import ( ProviderAnthropic as ProviderAnthropic, ) - case "llm_tuner": - logger.info("加载 LLM Tuner 工具 ...") - from .sources.llmtuner_source import ( - LLMTunerModelLoader as LLMTunerModelLoader, - ) case "dify": from .sources.dify_source import ProviderDify as ProviderDify case "dashscope": diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index c285ebd42..2cc956ae1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -86,11 +86,11 @@ class Provider(AbstractProvider): self, prompt: str, session_id: str = None, - image_urls: List[str] = None, + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts: List = None, + contexts: list = None, system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -114,11 +114,11 @@ class Provider(AbstractProvider): self, prompt: str, session_id: str = None, - image_urls: List[str] = None, + image_urls: list[str] = None, func_tool: FuncCall = None, - contexts: List = None, + contexts: list = None, system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index c3ad45868..f40daa5ec 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,3 +1,6 @@ +import json +import anthropic +import base64 from typing import List from mimetypes import guess_type @@ -10,15 +13,14 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse, ToolCallsResult -from .openai_source import ProviderOpenAIOfficial +from astrbot.core.provider.entities import LLMResponse +from typing import AsyncGenerator @register_provider_adapter( "anthropic_chat_completion", "Anthropic Claude API 提供商适配器" ) -class ProviderAnthropic(ProviderOpenAIOfficial): +class ProviderAnthropic(Provider): def __init__( self, provider_config: dict, @@ -27,9 +29,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): persistant_history=True, default_persona: Personality = None, ) -> None: - # Skip OpenAI's __init__ and call Provider's __init__ directly - Provider.__init__( - self, + super().__init__( provider_config, provider_settings, persistant_history, @@ -51,10 +51,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial): self.set_model(provider_config["model_config"]["model"]) + def _prepare_payload(self, messages: list[dict]): + """准备 Anthropic API 的请求 payload + + Args: + messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息 + Returns: + system_prompt: 系统提示内容 + new_messages: 处理后的消息列表,去除系统提示 + """ + system_prompt = "" + new_messages = [] + for message in messages: + if message["role"] == "system": + system_prompt = message["content"] + elif message["role"] == "assistant": + blocks = [] + if isinstance(message["content"], str): + blocks.append({"type": "text", "text": message["content"]}) + if "tool_calls" in message: + for tool_call in message["tool_calls"]: + blocks.append( # noqa: PERF401 + { + "type": "tool_use", + "name": tool_call["function"]["name"], + "input": json.loads(tool_call["function"]["arguments"]) + if isinstance(tool_call["function"]["arguments"], str) + else tool_call["function"]["arguments"], + "id": tool_call["id"], + } + ) + new_messages.append( + { + "role": "assistant", + "content": blocks, + } + ) + elif message["role"] == "tool": + new_messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": message["tool_call_id"], + "content": message["content"], + } + ], + } + ) + else: + new_messages.append(message) + + return system_prompt, new_messages + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: - tool_list = tools.get_func_desc_anthropic_style() - if tool_list: + if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list completion = await self.client.messages.create(**payloads, stream=False) @@ -64,70 +117,157 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if len(completion.content) == 0: raise Exception("API 返回的 completion 为空。") - # TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容 - # 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求 - content = completion.content[-1] - llm_response = LLMResponse("assistant") + llm_response = LLMResponse(role="assistant") - if content.type == "text": - # text completion - completion_text = str(content.text).strip() - # llm_response.completion_text = completion_text - llm_response.result_chain = MessageChain().message(completion_text) - - # Anthropic每次只返回一个函数调用 - if completion.stop_reason == "tool_use": - # tools call (function calling) - args_ls = [] - func_name_ls = [] - tool_use_ids = [] - func_name_ls.append(content.name) - args_ls.append(content.input) - tool_use_ids.append(content.id) - llm_response.role = "tool" - llm_response.tools_call_args = args_ls - llm_response.tools_call_name = func_name_ls - llm_response.tools_call_ids = tool_use_ids + for content_block in completion.content: + if content_block.type == "text": + completion_text = str(content_block.text).strip() + llm_response.completion_text = completion_text + if content_block.type == "tool_use": + llm_response.tools_call_args.append(content_block.input) + llm_response.tools_call_name.append(content_block.name) + llm_response.tools_call_ids.append(content_block.id) + # TODO(Soulter): 处理 end_turn 情况 if not llm_response.completion_text and not llm_response.tools_call_args: - logger.error(f"API 返回的 completion 无法解析:{completion}。") - raise Exception(f"API 返回的 completion 无法解析:{completion}。") - - llm_response.raw_completion = completion + raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。") return llm_response + async def _query_stream( + self, payloads: dict, tools: FuncCall + ) -> AsyncGenerator[LLMResponse, None]: + if tools: + if tool_list := tools.get_func_desc_anthropic_style(): + payloads["tools"] = tool_list + + # 用于累积工具调用信息 + tool_use_buffer = {} + # 用于累积最终结果 + final_text = "" + final_tool_calls = [] + + async with self.client.messages.stream(**payloads) as stream: + assert isinstance(stream, anthropic.AsyncMessageStream) + async for event in stream: + if event.type == "content_block_start": + if event.content_block.type == "text": + # 文本块开始 + yield LLMResponse( + role="assistant", completion_text="", is_chunk=True + ) + elif event.content_block.type == "tool_use": + # 工具使用块开始,初始化缓冲区 + tool_use_buffer[event.index] = { + "id": event.content_block.id, + "name": event.content_block.name, + "input": {}, + } + + elif event.type == "content_block_delta": + if event.delta.type == "text_delta": + # 文本增量 + final_text += event.delta.text + yield LLMResponse( + role="assistant", + completion_text=event.delta.text, + is_chunk=True, + ) + elif event.delta.type == "input_json_delta": + # 工具调用参数增量 + if event.index in tool_use_buffer: + # 累积 JSON 输入 + if "input_json" not in tool_use_buffer[event.index]: + tool_use_buffer[event.index]["input_json"] = "" + tool_use_buffer[event.index]["input_json"] += ( + event.delta.partial_json + ) + + elif event.type == "content_block_stop": + # 内容块结束 + if event.index in tool_use_buffer: + # 解析完整的工具调用 + tool_info = tool_use_buffer[event.index] + try: + if "input_json" in tool_info: + tool_info["input"] = json.loads(tool_info["input_json"]) + + # 添加到最终结果 + final_tool_calls.append( + { + "id": tool_info["id"], + "name": tool_info["name"], + "input": tool_info["input"], + } + ) + + yield LLMResponse( + role="tool", + completion_text="", + tools_call_args=[tool_info["input"]], + tools_call_name=[tool_info["name"]], + tools_call_ids=[tool_info["id"]], + is_chunk=True, + ) + except json.JSONDecodeError: + # JSON 解析失败,跳过这个工具调用 + logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}") + + # 清理缓冲区 + del tool_use_buffer[event.index] + + # 返回最终的完整结果 + final_response = LLMResponse( + role="assistant", completion_text=final_text, is_chunk=False + ) + + if final_tool_calls: + final_response.tools_call_args = [ + call["input"] for call in final_tool_calls + ] + final_response.tools_call_name = [call["name"] for call in final_tool_calls] + final_response.tools_call_ids = [call["id"] for call in final_tool_calls] + + yield final_response + async def text_chat( self, - prompt: str, - session_id: str = None, - image_urls: List[str] = [], - func_tool: FuncCall = None, + prompt, + session_id = None, + image_urls = [], + func_tool = None, contexts=None, system_prompt=None, - tool_calls_result: ToolCallsResult = None, + tool_calls_result = None, **kwargs, ) -> LLMResponse: if contexts is None: contexts = [] - if not prompt: - prompt = "" - new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) for part in context_query: if "_no_save" in part: del part["_no_save"] + # tool calls result if tool_calls_result: - # 暂时这样写。 - prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}" + if not isinstance(tool_calls_result, list): + context_query.extend(tool_calls_result.to_openai_messages()) + else: + for tcr in tool_calls_result: + context_query.extend(tcr.to_openai_messages()) + + system_prompt, new_messages = self._prepare_payload(context_query) model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": new_messages, **model_config} - payloads = {"messages": context_query, **model_config} # Anthropic has a different way of handling system prompts if system_prompt: payloads["system"] = system_prompt @@ -135,32 +275,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial): llm_response = None try: llm_response = await self._query(payloads, func_tool) - except Exception as e: - if "maximum context length" in str(e): - retry_cnt = 20 - while retry_cnt > 0: - logger.warning( - f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" - ) - try: - await self.pop_record(context_query) - response = await self.client.messages.create( - messages=context_query, **model_config - ) - llm_response = LLMResponse("assistant") - llm_response.result_chain = MessageChain().message(response.content[0].text) - llm_response.raw_completion = response - return llm_response - except Exception as e: - if "maximum context length" in str(e): - retry_cnt -= 1 - else: - raise e - return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。") - else: - logger.error(f"发生了错误。Provider 配置如下: {model_config}") - raise e + logger.error(f"发生了错误。Provider 配置如下: {model_config}") + raise e return llm_response @@ -175,21 +292,34 @@ class ProviderAnthropic(ProviderOpenAIOfficial): tool_calls_result=None, **kwargs, ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response + if contexts is None: + contexts = [] + new_record = await self.assemble_context(prompt, image_urls) + context_query = [*contexts, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + + # tool calls result + if tool_calls_result: + context_query.extend(tool_calls_result.to_openai_messages()) + + system_prompt, new_messages = self._prepare_payload(context_query) + + model_config = self.provider_config.get("model_config", {}) + model_config["model"] = self.get_model() + + payloads = {"messages": new_messages, **model_config} + + # Anthropic has a different way of handling system prompts + if system_prompt: + payloads["system"] = system_prompt + + async for llm_response in self._query_stream(payloads, func_tool): + yield llm_response async def assemble_context(self, text: str, image_urls: List[str] = None): """组装上下文,支持文本和图片""" @@ -232,3 +362,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial): ) return {"role": "user", "content": content} + + async def encode_image_bs64(self, image_url: str) -> str: + """ + 将图片转换为 base64 + """ + if image_url.startswith("base64://"): + return image_url.replace("base64://", "data:image/jpeg;base64,") + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 + return "" diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index c16b39415..b08705907 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -506,12 +506,12 @@ class ProviderGoogleGenAI(Provider): async def text_chat( self, prompt: str, - session_id: str = None, - image_urls: list[str] = None, - func_tool: FuncCall = None, - contexts: list = None, - system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + session_id = None, + image_urls = None, + func_tool = None, + contexts = None, + system_prompt = None, + tool_calls_result = None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -527,7 +527,11 @@ class ProviderGoogleGenAI(Provider): # tool calls result if tool_calls_result: - context_query.extend(tool_calls_result.to_openai_messages()) + if not isinstance(tool_calls_result, list): + context_query.extend(tool_calls_result.to_openai_messages()) + else: + for tcr in tool_calls_result: + context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py deleted file mode 100644 index 8648512d0..000000000 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -from llmtuner.chat import ChatModel -from typing import List -from .. import Provider -from ..entities import LLMResponse -from ..func_tool_manager import FuncCall -from astrbot.core.db import BaseDatabase -from ..register import register_provider_adapter - - -@register_provider_adapter( - "llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型" -) -class LLMTunerModelLoader(Provider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - db_helper: BaseDatabase, - persistant_history=True, - default_persona=None, - ) -> None: - super().__init__( - provider_config, - provider_settings, - persistant_history, - db_helper, - default_persona, - ) - if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists( - provider_config["adapter_model_path"] - ): - raise FileNotFoundError("模型文件路径不存在。") - self.base_model_path = provider_config["base_model_path"] - self.adapter_model_path = provider_config["adapter_model_path"] - self.model = ChatModel( - { - "model_name_or_path": self.base_model_path, - "adapter_name_or_path": self.adapter_model_path, - "template": provider_config["llmtuner_template"], - "finetuning_type": provider_config["finetuning_type"], - "quantization_bit": provider_config["quantization_bit"], - } - ) - self.set_model( - os.path.basename(self.base_model_path) - + "_" - + os.path.basename(self.adapter_model_path) - ) - - async def assemble_context(self, text: str, image_urls: List[str] = None): - """ - 组装上下文。 - """ - return {"role": "user", "content": text} - - async def text_chat( - self, - prompt: str, - session_id: str = None, - image_urls: List[str] = None, - func_tool: FuncCall = None, - contexts: List = None, - system_prompt: str = None, - **kwargs, - ) -> LLMResponse: - if contexts is None: - contexts = [] - system_prompt = "" - new_record = {"role": "user", "content": prompt} - query_context = [*contexts, new_record] - - # 提取出系统提示 - system_idxs = [] - for idx, context in enumerate(query_context): - if context["role"] == "system": - system_idxs.append(idx) - - if "_no_save" in context: - del context["_no_save"] - - for idx in reversed(system_idxs): - system_prompt += " " + query_context.pop(idx)["content"] - - conf = { - "messages": query_context, - "system": system_prompt, - } - if func_tool: - tool_list = func_tool.get_func_desc_openai_style() - if tool_list: - conf["tools"] = tool_list - - responses = await self.model.achat(**conf) - - llm_response = LLMResponse("assistant", responses[-1].response_text) - - return llm_response - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def get_current_key(self): - return "none" - - async def set_key(self, key): - pass - - async def get_models(self): - return [self.get_model()] diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 15104db3c..10db3c31a 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -9,7 +9,6 @@ import astrbot.core.message.components as Comp from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat.chat_completion import ChatCompletion -# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai._exceptions import NotFoundError, UnprocessableEntityError from openai.lib.streaming.chat._completions import ChatCompletionStreamState from astrbot.core.utils.io import download_image_by_url @@ -224,12 +223,10 @@ class ProviderOpenAIOfficial(Provider): async def _prepare_chat_payload( self, prompt: str, - session_id: str = None, - image_urls: list[str] = None, - func_tool: FuncCall = None, - contexts: list = None, - system_prompt: str = None, - tool_calls_result: ToolCallsResult = None, + image_urls: list[str] | None = None, + contexts: list | None = None, + system_prompt: str | None = None, + tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" @@ -246,14 +243,20 @@ class ProviderOpenAIOfficial(Provider): # tool calls result if tool_calls_result: - context_query.extend(tool_calls_result.to_openai_messages()) + if isinstance(tool_calls_result, ToolCallsResult): + context_query.extend(tool_calls_result.to_openai_messages()) + else: + for tcr in tool_calls_result: + context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) model_config["model"] = self.get_model() payloads = {"messages": context_query, **model_config} - return payloads, context_query, func_tool + logger.debug(f"payloads: {payloads}") + + return payloads, context_query async def _handle_api_error( self, @@ -352,11 +355,9 @@ class ProviderOpenAIOfficial(Provider): tool_calls_result=None, **kwargs, ) -> LLMResponse: - payloads, context_query, func_tool = await self._prepare_chat_payload( + payloads, context_query = await self._prepare_chat_payload( prompt, - session_id, image_urls, - func_tool, contexts, system_prompt, tool_calls_result, @@ -422,11 +423,9 @@ class ProviderOpenAIOfficial(Provider): **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" - payloads, context_query, func_tool = await self._prepare_chat_payload( + payloads, context_query = await self._prepare_chat_payload( prompt, - session_id, image_urls, - func_tool, contexts, system_prompt, tool_calls_result, diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 220393dea..44e7f8206 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -414,7 +414,6 @@ export default { "anthropic_chat_completion": "chat_completion", "googlegenai_chat_completion": "chat_completion", "zhipu_chat_completion": "chat_completion", - "llm_tuner": "chat_completion", "dify": "chat_completion", "dashscope": "chat_completion", "openai_whisper_api": "speech_to_text", diff --git a/dump_plugins.py b/dump_plugins.py new file mode 100644 index 000000000..efe34b230 --- /dev/null +++ b/dump_plugins.py @@ -0,0 +1,82 @@ +import urllib.request +import json + +# --- 配置 --- +REPO_OWNER = "AstrBotDevs" +REPO_NAME = "AstrBot" +START_ISSUE_NUMBER = 970 +LABEL_TO_FIND = "plugin-publish" +OUTPUT_FILENAME = "plugin_publish_issues.txt" +ISSUE_STATE = "closed" # 只筛选状态为 closed 的 Issue + + +def fetch_and_format_issues( + repo_owner, repo_name, start_issue_number, label_name, output_file, issue_state +): + api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/issues?state={issue_state}&labels={label_name}&sort=created&direction=asc&per_page=100" + headers = {"Accept": "application/vnd.github+json"} + found_issues = [] + page = 1 + + while True: + request_url = f"{api_url}&page={page}" + request = urllib.request.Request(request_url, headers=headers) + try: + with urllib.request.urlopen(request) as response: + data = json.loads(response.read().decode("utf-8")) + if not data: + break # 没有更多 Issues 了 + + for issue in data: + issue_number = issue.get("number") + if issue_number is not None and issue_number >= start_issue_number: + title = issue.get("title", "No Title") + author = issue.get("user", {}).get("login", "Unknown") + found_issues.append(f"{title} by @{author} in #{issue_number}") + + # 检查是否有下一页 + if "Link" in response.headers: + links = response.headers["Link"].split(",") + next_page_exists = False + for link in links: + if 'rel="next"' in link: + next_page_exists = True + break + if not next_page_exists: + break + page += 1 + else: + break # 没有 Link header,假设没有更多页了 + + except urllib.error.HTTPError as e: + print(f"HTTP Error: {e.code} - {e.reason}") + return + except urllib.error.URLError as e: + print(f"URL Error: {e.reason}") + return + except json.JSONDecodeError: + print("Error decoding JSON response.") + return + + if found_issues: + with open(output_file, "w", encoding="utf-8") as f: + for line in found_issues: + f.write(line + "\n") + print( + f"已找到 {len(found_issues)} 个状态为 '{issue_state}',带有 '{label_name}' 标签且 Issue Number 大于等于 {start_issue_number} 的 Issues,并已保存到 '{output_file}'。" + ) + else: + print( + f"未找到任何状态为 '{issue_state}',带有 '{label_name}' 标签且 Issue Number 大于等于 {start_issue_number} 的 Issues。" + ) + + +if __name__ == "__main__": + fetch_and_format_issues( + REPO_OWNER, + REPO_NAME, + START_ISSUE_NUMBER, + LABEL_TO_FIND, + OUTPUT_FILENAME, + ISSUE_STATE, + )