chore: code quality
This commit is contained in:
@@ -51,6 +51,15 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
self.final_llm_resp = None
|
||||
self.is_done = False
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
@@ -62,29 +71,25 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
|
||||
# 执行 LLM 请求
|
||||
llm_resp_result = None
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
async for llm_response in stream: # type: ignore
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
llm_response.completion_text
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
llm_resp_result = llm_response
|
||||
else:
|
||||
llm_resp_result = await self.provider.text_chat(**self.req.__dict__)
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text)
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
return
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from dataclasses import dataclass
|
||||
from deprecated import deprecated
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
|
||||
@@ -97,9 +97,6 @@ class ProviderAnthropic(Provider):
|
||||
)
|
||||
else:
|
||||
new_messages.append(message)
|
||||
|
||||
logger.debug(f"message: {messages}")
|
||||
logger.debug(f"new message: {new_messages}")
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
@@ -233,7 +230,7 @@ class ProviderAnthropic(Provider):
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=[],
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
|
||||
@@ -249,8 +249,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
logger.debug(f"payloads: {payloads}")
|
||||
|
||||
return payloads, context_query
|
||||
|
||||
async def _handle_api_error(
|
||||
|
||||
Reference in New Issue
Block a user