feat: 支持展示工具使用过程
This commit is contained in:
@@ -61,6 +61,7 @@ DEFAULT_CONFIG = {
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": False,
|
||||
},
|
||||
@@ -445,7 +446,7 @@ CONFIG_METADATA_2 = {
|
||||
"ignore_bot_self_message": {
|
||||
"description": "是否忽略机器人自身的消息",
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
"hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"ignore_at_all": {
|
||||
"description": "是否忽略 @ 全体成员",
|
||||
@@ -1692,10 +1693,15 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||
},
|
||||
"show_tool_use_status": {
|
||||
"description": "函数调用状态输出",
|
||||
"type": "bool",
|
||||
"hint": "在触发函数调用时输出其函数名和内容。",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"description": "不支持流式回复的平台分段输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -24,6 +24,8 @@ class MessageChain:
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
type: Optional[str] = None
|
||||
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||
|
||||
def message(self, message: str):
|
||||
"""添加一条文本消息到消息链 `chain` 中。
|
||||
|
||||
@@ -148,6 +148,13 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
@@ -183,6 +190,8 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
@@ -200,6 +209,7 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -219,6 +229,7 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
@@ -240,6 +251,7 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
else:
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
# 尝试调用工具函数
|
||||
@@ -256,6 +268,7 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resp)
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
|
||||
@@ -31,15 +31,17 @@ class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||
self.provider_wake_prefix: str = conf["provider_settings"]["wake_prefix"] # str
|
||||
self.max_context_length = conf["provider_settings"]["max_context_length"] # int
|
||||
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, conf["provider_settings"]["dequeue_context_length"]),
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = conf["provider_settings"]["streaming_response"]
|
||||
self.max_step: int = conf["provider_settings"].get("max_agent_step", 10)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.max_step: int = settings.get("max_agent_step", 10)
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -158,10 +160,17 @@ class LLMRequestSubStage(Stage):
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in tool_loop_agent.step():
|
||||
if resp.type == "tool_call_result":
|
||||
continue # 跳过工具调用结果
|
||||
if resp.type == "tool_call":
|
||||
if self.show_tool_use or event.get_platform_name() == "webchat":
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not self.streaming_response:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_resp"
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
event.set_result(
|
||||
@@ -173,9 +182,14 @@ class LLMRequestSubStage(Stage):
|
||||
yield
|
||||
event.clear_result()
|
||||
else:
|
||||
yield resp.data["chain"].chain
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if tool_loop_agent.done():
|
||||
break
|
||||
if self.streaming_response:
|
||||
# 用来标记流式响应结束
|
||||
yield MessageChain(chain=[], type="break")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
@@ -268,11 +282,13 @@ class LLMRequestSubStage(Stage):
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
web_chat_back_queue.put_nowait({
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
})
|
||||
web_chat_back_queue.put_nowait(
|
||||
{
|
||||
"type": "update_title",
|
||||
"cid": cid,
|
||||
"data": title,
|
||||
}
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
|
||||
@@ -158,6 +158,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
message_id = None # 重置消息 ID
|
||||
delta = "" # 重置 delta
|
||||
continue
|
||||
|
||||
# 处理消息链中的每个组件
|
||||
for i in chain.chain:
|
||||
if isinstance(i, Plain):
|
||||
|
||||
@@ -96,6 +96,14 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -110,6 +118,18 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
chain, session_id=self.session_id, streaming=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user