perf: 进一步防止llm递归调用
This commit is contained in:
@@ -18,7 +18,7 @@ class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
@@ -58,6 +58,8 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
try:
|
||||
logger.debug(f"提供商请求 Payload: {req.__dict__}")
|
||||
if _nested:
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
@@ -91,7 +93,7 @@ class LLMRequestSubStage(Stage):
|
||||
for tool_name, tool_result in function_calling_result.items():
|
||||
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
||||
req.prompt += extra_prompt
|
||||
async for _ in self.process(event):
|
||||
async for _ in self.process(event, _nested=True):
|
||||
yield
|
||||
|
||||
except BaseException as e:
|
||||
|
||||
Reference in New Issue
Block a user