perf: 为发送流式消息的Fallback可选

This commit is contained in:
Raven95676
2025-04-15 21:21:02 +08:00
parent 98830d147f
commit 3753fce912
11 changed files with 64 additions and 24 deletions
+6
View File
@@ -52,6 +52,7 @@ DEFAULT_CONFIG = {
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
"streaming_segmented": False,
},
"provider_stt_settings": {
"enable": False,
@@ -1008,6 +1009,11 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
},
"streaming_segmented": {
"description": "不支持流式回复的平台分段输出",
"type": "bool",
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
},
},
},
"persona": {
+5 -2
View File
@@ -146,9 +146,12 @@ class RespondStage(Stage):
if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果直接交付平台适配器处理
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream)
await event.send_streaming(result.async_stream, use_fallback)
await event._post_send()
return
elif len(result.chain) > 0:
@@ -159,7 +162,7 @@ class RespondStage(Stage):
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
await event._pre_send()
# 检查消息链是否为空
+4 -1
View File
@@ -220,9 +220,12 @@ class AstrMessageEvent(abc.ABC):
await asyncio.sleep(1.5) # 限速
return buffer
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊。
Fallback仅支持 aiocqhttp, gewechat。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
@@ -83,7 +83,22 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator: AsyncGenerator):
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
@@ -100,7 +115,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
@@ -61,7 +61,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator):
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
@@ -72,4 +72,4 @@ class DingtalkMessageEvent(AstrMessageEvent):
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)
@@ -220,7 +220,22 @@ class GewechatPlatformEvent(AstrMessageEvent):
members=members,
)
async def send_streaming(self, generator: AsyncGenerator):
async def send_streaming(
self, generator: AsyncGenerator, use_fallback: bool = False
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
@@ -237,4 +252,4 @@ class GewechatPlatformEvent(AstrMessageEvent):
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)
@@ -104,7 +104,7 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
@@ -115,4 +115,4 @@ class LarkMessageEvent(AstrMessageEvent):
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)
@@ -33,7 +33,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
else:
self.send_buffer.chain.extend(message.chain)
async def send_streaming(self, generator):
async def send_streaming(self, generator, use_fallback: bool = False):
"""流式输出仅支持消息列表私聊"""
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
@@ -66,7 +66,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
self.send_buffer = None
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)
async def _post_send(self, stream: dict = None):
if not self.send_buffer:
@@ -97,7 +97,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"msg_id": self.message_obj.message_id,
}
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
payload["msg_seq"] = random.randint(1, 10000)
match type(source):
@@ -91,7 +91,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message)
async def send_streaming(self, generator):
async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -183,16 +183,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
text=markdown_text,
chat_id=payload["chat_id"],
message_id=message_id,
parse_mode="MarkdownV2"
parse_mode="MarkdownV2",
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id
text=delta, chat_id=payload["chat_id"], message_id=message_id
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)
@@ -106,7 +106,7 @@ class WebChatMessageEvent(AstrMessageEvent):
)
await super().send(message)
async def send_streaming(self, generator):
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
async for chain in generator:
final_data += await WebChatMessageEvent._send(
@@ -121,4 +121,4 @@ class WebChatMessageEvent(AstrMessageEvent):
"cid": self.session_id.split("!")[-1],
}
)
await super().send_streaming(generator)
await super().send_streaming(generator, use_fallback)
@@ -85,7 +85,7 @@ class WecomPlatformEvent(AstrMessageEvent):
await super().send(message)
async def send_streaming(self, generator):
async def send_streaming(self, generator, use_fallback: bool = False):
buffer = None
async for chain in generator:
if not buffer:
@@ -96,4 +96,4 @@ class WecomPlatformEvent(AstrMessageEvent):
return
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator)
return await super().send_streaming(generator, use_fallback)