From 36ffcf3cc31a4fb668bc72b3a063e6ffd6c3d2f6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 21 Oct 2025 10:56:44 +0800 Subject: [PATCH] fix: typing error --- .../process_stage/method/llm_request.py | 17 +++++++--- astrbot/core/platform/astr_message_event.py | 8 ++--- .../core/provider/sources/anthropic_source.py | 8 ++--- .../core/provider/sources/dashscope_source.py | 19 ++++++----- astrbot/core/provider/sources/dify_source.py | 14 ++++---- .../core/provider/sources/openai_source.py | 33 ++++++++++--------- .../provider/sources/openai_tts_api_source.py | 2 +- 7 files changed, 53 insertions(+), 48 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index b2245e4da..514a87a96 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -44,7 +44,7 @@ except (ModuleNotFoundError, ImportError): AgentContextWrapper = ContextWrapper[AstrAgentContext] -AgentRunner = ToolLoopAgentRunner[AgentContextWrapper] +AgentRunner = ToolLoopAgentRunner[AstrAgentContext] class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @@ -102,7 +102,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): request = ProviderRequest( prompt=input_, - system_prompt=tool.description, + system_prompt=tool.description or "", image_urls=[], # 暂时不传递原始 agent 的上下文 contexts=[], # 暂时不传递原始 agent 的上下文 func_tool=toolset, @@ -239,7 +239,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): yield res -class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]): +class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): async def on_agent_done(self, run_context, llm_response): # 执行事件钩子 await call_event_hook( @@ -337,7 +337,7 @@ class LLMRequestSubStage(Stage): self.conv_manager = ctx.plugin_manager.context.conversation_manager - def _select_provider(self, event: AstrMessageEvent) -> Provider | None: + def _select_provider(self, event: AstrMessageEvent): """选择使用的 LLM 提供商""" sel_provider = event.get_extra("selected_provider") _ctx = self.ctx.plugin_manager.context @@ -382,6 +382,9 @@ class LLMRequestSubStage(Stage): provider = self._select_provider(event) if provider is None: return + if not isinstance(provider, Provider): + logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。") + return if event.get_extra("provider_request"): req = event.get_extra("provider_request") @@ -520,8 +523,10 @@ class LLMRequestSubStage(Stage): chain = ( MessageChain().message(final_llm_resp.completion_text).chain ) - else: + elif final_llm_resp.result_chain: chain = final_llm_resp.result_chain.chain + else: + chain = MessageChain().chain event.set_result( MessageEventResult( chain=chain, @@ -553,6 +558,8 @@ class LLMRequestSubStage(Stage): self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider ): """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" + if not req.conversation: + return conversation = await self.conv_manager.get_conversation( event.unified_msg_origin, req.conversation.cid ) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 05169c4fe..3a4b8c128 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -4,7 +4,7 @@ import re import hashlib import uuid -from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any +from typing import List, Union, Optional, AsyncGenerator, Any from astrbot import logger from astrbot.core.db.po import Conversation @@ -26,8 +26,6 @@ from .astrbot_message import AstrBotMessage, Group from .platform_metadata import PlatformMetadata from .message_session import MessageSession, MessageSesion # noqa -_VT = TypeVar("_VT") - class AstrMessageEvent(abc.ABC): def __init__( @@ -177,9 +175,7 @@ class AstrMessageEvent(abc.ABC): """ self._extras[key] = value - def get_extra( - self, key: str | None = None, default: _VT = None - ) -> dict[str, Any] | _VT: + def get_extra(self, key: str | None = None, default=None) -> Any: """ 获取额外的信息。 """ diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 57bffdc81..cd4206ce7 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,7 +10,7 @@ from anthropic.types import Message from astrbot.core.utils.io import download_image_by_url from astrbot.api.provider import Provider from astrbot import logger -from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.provider.func_tool_manager import ToolSet from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse from typing import AsyncGenerator @@ -104,7 +104,7 @@ class ProviderAnthropic(Provider): return system_prompt, new_messages - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: if tool_list := tools.get_func_desc_anthropic_style(): payloads["tools"] = tool_list @@ -135,7 +135,7 @@ class ProviderAnthropic(Provider): return llm_response async def _query_stream( - self, payloads: dict, tools: FuncCall + self, payloads: dict, tools: ToolSet | None ) -> AsyncGenerator[LLMResponse, None]: if tools: if tool_list := tools.get_func_desc_anthropic_style(): @@ -326,7 +326,7 @@ class ProviderAnthropic(Provider): async for llm_response in self._query_stream(payloads, func_tool): yield llm_response - async def assemble_context(self, text: str, image_urls: List[str] = None): + async def assemble_context(self, text: str, image_urls: List[str] | None = None): """组装上下文,支持文本和图片""" if not image_urls: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 4e14d20da..0183f7244 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -1,15 +1,14 @@ import re import asyncio import functools -from typing import List from .. import Provider, Personality from ..entities import LLMResponse -from ..func_tool_manager import FuncCall from ..register import register_provider_adapter from astrbot.core.message.message_event_result import MessageChain from .openai_source import ProviderOpenAIOfficial from astrbot.core import logger, sp from dashscope import Application +from dashscope.app.application_response import ApplicationResponse @register_provider_adapter("dashscope", "Dashscope APP 适配器。") @@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial): async def text_chat( self, prompt: str, - session_id: str = None, - image_urls: List[str] = [], - func_tool: FuncCall = None, - contexts: List = None, - system_prompt: str = None, + session_id=None, + image_urls=[], + func_tool=None, + contexts=None, + system_prompt=None, model=None, **kwargs, ) -> LLMResponse: @@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) response = await asyncio.get_event_loop().run_in_executor(None, partial) + assert isinstance(response, ApplicationResponse) + logger.debug(f"dashscope resp: {response}") if response.status_code != 200: @@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial): ), ) - output_text = response.output.get("text", "") + output_text = response.output.get("text", "") or "" # RAG 引用脚标格式化 output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) if self.output_reference and response.output.get("doc_references", None): ref_str = "" - for ref in response.output.get("doc_references", []): + for ref in response.output.get("doc_references", []) or []: ref_title = ( ref.get("title", "") if ref.get("title") diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index e19e912ac..f7c4e63ca 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,9 +1,7 @@ import astrbot.core.message.components as Comp import os -from typing import List from .. import Provider from ..entities import LLMResponse -from ..func_tool_manager import FuncCall from ..register import register_provider_adapter from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.io import download_image_by_url, download_file @@ -55,11 +53,11 @@ class ProviderDify(Provider): async def text_chat( self, prompt: str, - session_id: str = None, - image_urls: List[str] = None, - func_tool: FuncCall = None, - contexts: List = None, - system_prompt: str = None, + session_id=None, + image_urls=None, + func_tool=None, + contexts=None, + system_prompt=None, tool_calls_result=None, model=None, **kwargs, @@ -223,7 +221,7 @@ class ProviderDify(Provider): # Chat return MessageChain(chain=[Comp.Plain(chunk)]) - async def parse_file(item: dict) -> Comp: + async def parse_file(item: dict): match item["type"]: case "image": return Comp.Image(file=item["url"], url=item["url"]) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 81342ad53..09c284acb 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain from astrbot.api.provider import Provider from astrbot import logger -from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.provider.func_tool_manager import ToolSet from typing import List, AsyncGenerator from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse, ToolCallsResult @@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider): self.client = AsyncAzureOpenAI( api_key=self.chosen_api_key, api_version=provider_config.get("api_version", None), - base_url=provider_config.get("api_base", None), + base_url=provider_config.get("api_base", ""), timeout=self.timeout, ) else: @@ -79,7 +79,7 @@ class ProviderOpenAIOfficial(Provider): except NotFoundError as e: raise Exception(f"获取模型列表失败:{e}") - async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse: if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -126,7 +126,7 @@ class ProviderOpenAIOfficial(Provider): return llm_response async def _query_stream( - self, payloads: dict, tools: FuncCall + self, payloads: dict, tools: ToolSet ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" if tools: @@ -183,9 +183,7 @@ class ProviderOpenAIOfficial(Provider): yield llm_response - async def parse_openai_completion( - self, completion: ChatCompletion, tools: FuncCall - ): + async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet): """解析 OpenAI 的 ChatCompletion 响应""" llm_response = LLMResponse("assistant") @@ -208,7 +206,10 @@ class ProviderOpenAIOfficial(Provider): # workaround for #1359 tool_call = json.loads(tool_call) for tool in tools.func_list: - if tool.name == tool_call.function.name: + if ( + tool_call.type == "function" + and tool.name == tool_call.function.name + ): # workaround for #1454 if isinstance(tool_call.function.arguments, str): args = json.loads(tool_call.function.arguments) @@ -277,7 +278,7 @@ class ProviderOpenAIOfficial(Provider): e: Exception, payloads: dict, context_query: list, - func_tool: FuncCall, + func_tool: ToolSet, chosen_key: str, available_api_keys: List[str], retry_cnt: int, @@ -420,7 +421,7 @@ class ProviderOpenAIOfficial(Provider): if success: break - if retry_cnt == max_retries - 1: + if retry_cnt == max_retries - 1 or llm_response is None: logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") if last_exception is None: raise Exception("未知错误") @@ -430,10 +431,10 @@ class ProviderOpenAIOfficial(Provider): async def text_chat_stream( self, prompt: str, - session_id: str = None, - image_urls: List[str] = [], - func_tool: FuncCall = None, - contexts=[], + session_id=None, + image_urls=None, + func_tool=None, + contexts=None, system_prompt=None, tool_calls_result=None, model=None, @@ -526,7 +527,9 @@ class ProviderOpenAIOfficial(Provider): def set_key(self, key): self.client.api_key = key - async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict: + async def assemble_context( + self, text: str, image_urls: List[str] | None = None + ) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: user_content = { diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index c188a9fae..c5fb467b7 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider): timeout=timeout, ) - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model", "")) async def get_audio(self, text: str) -> str: temp_dir = os.path.join(get_astrbot_data_path(), "temp")