perf: 为发送流式消息的Fallback可选
This commit is contained in:
@@ -52,6 +52,7 @@ DEFAULT_CONFIG = {
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"streaming_segmented": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -1008,6 +1009,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"description": "不支持流式回复的平台分段输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
|
||||
@@ -146,9 +146,12 @@ class RespondStage(Stage):
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果直接交付平台适配器处理
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
await event._pre_send()
|
||||
await event.send_streaming(result.async_stream)
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
await event._post_send()
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
@@ -159,7 +162,7 @@ class RespondStage(Stage):
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
|
||||
|
||||
await event._pre_send()
|
||||
|
||||
# 检查消息链是否为空
|
||||
|
||||
@@ -220,9 +220,12 @@ class AstrMessageEvent(abc.ABC):
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
return buffer
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp, gewechat。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
|
||||
@@ -83,7 +83,22 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
@@ -100,7 +115,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if isinstance(group_id, str) and group_id.isdigit():
|
||||
|
||||
@@ -61,7 +61,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -72,4 +72,4 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -220,7 +220,22 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator):
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
@@ -237,4 +252,4 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -104,7 +104,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -115,4 +115,4 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -33,7 +33,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
"""流式输出仅支持消息列表私聊"""
|
||||
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
@@ -66,7 +66,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||
self.send_buffer = None
|
||||
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _post_send(self, stream: dict = None):
|
||||
if not self.send_buffer:
|
||||
@@ -97,7 +97,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
|
||||
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
|
||||
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
|
||||
payload["msg_seq"] = random.randint(1, 10000)
|
||||
|
||||
match type(source):
|
||||
|
||||
@@ -91,7 +91,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
message_thread_id = None
|
||||
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
@@ -183,16 +183,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
text=markdown_text,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
parse_mode="MarkdownV2"
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id
|
||||
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -106,7 +106,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
@@ -121,4 +121,4 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await super().send_streaming(generator)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -85,7 +85,7 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
@@ -96,4 +96,4 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
Reference in New Issue
Block a user