流式输出完成后,将完整的LLM响应设置为事件结果
This commit is contained in:
@@ -157,6 +157,8 @@ class ResultContentType(enum.Enum):
|
||||
"""普通的消息结果"""
|
||||
STREAMING_RESULT = enum.auto()
|
||||
"""调用 LLM 产生的流式结果"""
|
||||
STREAMING_FINISH= enum.auto()
|
||||
"""流式输出完成"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
@@ -151,15 +151,15 @@ class LLMRequestSubStage(Stage):
|
||||
final_llm_response = None
|
||||
|
||||
if self.streaming_response:
|
||||
stream = provider.text_chat_stream(
|
||||
**req.__dict__
|
||||
)
|
||||
stream = provider.text_chat_stream(**req.__dict__)
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield llm_response.result_chain # MessageChain
|
||||
yield llm_response.result_chain # MessageChain
|
||||
else:
|
||||
yield MessageChain().message(llm_response.completion_text)
|
||||
yield MessageChain().message(
|
||||
llm_response.completion_text
|
||||
)
|
||||
else:
|
||||
final_llm_response = llm_response
|
||||
else:
|
||||
@@ -210,6 +210,14 @@ class LLMRequestSubStage(Stage):
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, final_llm_response)
|
||||
|
||||
# 流式输出完成后,将完整的LLM响应设置为事件结果
|
||||
if bool(self.streaming_response):
|
||||
event.clear_result()
|
||||
async for _ in self._handle_llm_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
pass
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
@@ -227,9 +235,14 @@ class LLMRequestSubStage(Stage):
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting(req))
|
||||
)
|
||||
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||
yield
|
||||
|
||||
async def _handle_llm_response(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""处理 LLM 响应。
|
||||
|
||||
@@ -239,19 +252,29 @@ class LLMRequestSubStage(Stage):
|
||||
Yields:
|
||||
Iterator[bool]: 将 event 交付给下一个 stage
|
||||
"""
|
||||
is_stream = bool(self.streaming_response)
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
if llm_response.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=llm_response.result_chain.chain
|
||||
).set_result_content_type(ResultContentType.LLM_RESULT)
|
||||
).set_result_content_type(
|
||||
ResultContentType.STREAMING_FINISH
|
||||
if is_stream
|
||||
else ResultContentType.LLM_RESULT
|
||||
)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT)
|
||||
.set_result_content_type(
|
||||
ResultContentType.STREAMING_FINISH
|
||||
if is_stream
|
||||
else ResultContentType.LLM_RESULT
|
||||
)
|
||||
)
|
||||
elif llm_response.role == "err":
|
||||
event.set_result(
|
||||
|
||||
@@ -78,6 +78,8 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
return
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果直接交付平台适配器处理
|
||||
|
||||
@@ -72,15 +72,18 @@ class ResultDecorateStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None or not result.chain:
|
||||
return
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果暂时不进行处理
|
||||
return
|
||||
|
||||
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
|
||||
|
||||
# 回复时检查内容安全
|
||||
if (
|
||||
self.content_safe_check_reply
|
||||
and self.content_safe_check_stage
|
||||
and result.is_llm_result()
|
||||
and not is_stream # 流式输出不检查内容安全
|
||||
):
|
||||
text = ""
|
||||
for comp in result.chain:
|
||||
@@ -100,6 +103,10 @@ class ResultDecorateStage(Stage):
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
if is_stream:
|
||||
logger.warning(
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
|
||||
)
|
||||
await handler.handler(event)
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
logger.debug(
|
||||
@@ -114,6 +121,11 @@ class ResultDecorateStage(Stage):
|
||||
)
|
||||
return
|
||||
|
||||
# 流式输出不执行下面的逻辑
|
||||
if is_stream:
|
||||
logger.info("流式输出已启用,跳过结果装饰阶段")
|
||||
return
|
||||
|
||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
|
||||
Reference in New Issue
Block a user