chore: code quality

This commit is contained in:
Soulter
2025-06-29 15:51:56 +08:00
parent 20b760529e
commit 4e3d5641c8
4 changed files with 28 additions and 29 deletions
@@ -51,6 +51,15 @@ class ToolLoopAgent(BaseAgentRunner):
self.final_llm_resp = None
self.is_done = False
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**self.req.__dict__)
@override
async def step(self):
"""
@@ -62,29 +71,25 @@ class ToolLoopAgent(BaseAgentRunner):
# 执行 LLM 请求
llm_resp_result = None
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
async for llm_response in stream: # type: ignore
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(
llm_response.completion_text
)
),
)
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=llm_response.result_chain),
)
else:
llm_resp_result = llm_response
else:
llm_resp_result = await self.provider.text_chat(**self.req.__dict__)
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(llm_response.completion_text)
),
)
continue
llm_resp_result = llm_response
break # got final response
if not llm_resp_result:
return
-1
View File
@@ -4,7 +4,6 @@ from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from dataclasses import dataclass
from deprecated import deprecated
class Personality(TypedDict):
@@ -97,9 +97,6 @@ class ProviderAnthropic(Provider):
)
else:
new_messages.append(message)
logger.debug(f"message: {messages}")
logger.debug(f"new message: {new_messages}")
return system_prompt, new_messages
@@ -233,7 +230,7 @@ class ProviderAnthropic(Provider):
self,
prompt,
session_id=None,
image_urls=[],
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
@@ -249,8 +249,6 @@ class ProviderOpenAIOfficial(Provider):
payloads = {"messages": context_query, **model_config}
logger.debug(f"payloads: {payloads}")
return payloads, context_query
async def _handle_api_error(