feat: 完善流式处理

This commit is contained in:
Soulter
2025-04-06 11:56:06 +08:00
parent 109650faf3
commit c1cf2be533
9 changed files with 120 additions and 17 deletions
@@ -12,6 +12,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
MessageChain
)
from astrbot.core.message.components import Image
from astrbot.core import logger
@@ -156,7 +157,10 @@ class LLMRequestSubStage(Stage):
async for llm_response in stream:
if llm_response.is_chunk:
logger.debug(llm_response)
yield llm_response.result_chain
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
else:
yield MessageChain().message(llm_response.completion_text)
else:
final_llm_response = llm_response
else:
+1 -1
View File
@@ -202,7 +202,7 @@ class AstrMessageEvent(abc.ABC):
"""
return self.role == "admin"
async def send_streaming(self, generator: AsyncGenerator[List[BaseMessageComponent], None]):
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegram。
"""
@@ -9,7 +9,6 @@ from astrbot.api.message_components import (
At,
File,
Record,
BaseMessageComponent,
)
from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
@@ -110,15 +109,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
payload["reply_to_message_id"] = message_thread_id
delta = ""
current_content = ""
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
async for chain in generator:
logger.debug(f"streaming: {chain}")
if isinstance(chain, list):
if isinstance(chain, MessageChain):
# 处理消息链中的每个组件
for i in chain:
for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
@@ -144,6 +143,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
if not message_id:
try:
msg = await self.client.send_message(text=delta, **payload)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e}")
message_id = msg.message_id
@@ -163,13 +163,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
chat_id=payload["chat_id"],
message_id=message_id,
)
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e}")
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
if delta:
if delta and current_content != delta:
await self.client.edit_message_text(
text=delta, chat_id=payload["chat_id"], message_id=message_id
)
@@ -110,7 +110,7 @@ class WebChatMessageEvent(AstrMessageEvent):
final_data = ""
async for chain in generator:
final_data += await WebChatMessageEvent._send(
MessageChain(chain=chain), session_id=self.session_id, streaming=True
chain, session_id=self.session_id, streaming=True
)
await web_chat_back_queue.put(
@@ -10,6 +10,7 @@ from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from .openai_source import ProviderOpenAIOfficial
@@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if content.type == "text":
# text completion
completion_text = str(content.text).strip()
llm_response.completion_text = completion_text
# llm_response.completion_text = completion_text
llm_response.result_chain = MessageChain().message(completion_text)
# Anthropic每次只返回一个函数调用
if completion.stop_reason == "tool_use":
@@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
messages=context_query, **model_config
)
llm_response = LLMResponse("assistant")
llm_response.completion_text = response.content[0].text
llm_response.result_chain = MessageChain().message(response.content[0].text)
llm_response.raw_completion = response
return llm_response
except Exception as e:
@@ -171,7 +173,21 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""组装上下文,支持文本和图片"""
@@ -7,6 +7,7 @@ from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
@@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
return LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
result_chain=MessageChain().message(
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
),
)
output_text = response.output.get("text", "")
@@ -149,7 +152,10 @@ class ProviderDashscope(ProviderOpenAIOfficial):
ref_str += f"{ref['index_id']}. {ref_title}\n"
output_text += f"\n\n回答来源:\n{ref_str}"
return LLMResponse(role="assistant", completion_text=output_text)
llm_response = LLMResponse("assistant")
llm_response.result_chain = MessageChain().message(output_text)
return llm_response
async def text_chat_stream(
self,
@@ -162,7 +168,21 @@ class ProviderDashscope(ProviderOpenAIOfficial):
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def forget(self, session_id):
return True
+15 -1
View File
@@ -200,7 +200,21 @@ class ProviderDify(Provider):
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
+48 -1
View File
@@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient:
logger.error(f"Gemini 返回了非 json 数据: {text}")
raise Exception("Gemini 返回了非 json 数据: ")
async def stream_generate_content(
self,
contents: List[dict],
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
payload["system_instruction"] = {"parts": {"text": system_instruction}}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
"stream": True,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
)
async with self.client.post(
request_url, json=payload, timeout=self.timeout
) as resp:
async for line in resp.content:
if line:
yield line
@register_provider_adapter(
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
@@ -349,7 +382,21 @@ class ProviderGoogleGenAI(Provider):
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
# raise NotImplementedError("This method is not implemented yet.")
# 调用 text_chat 模拟流式
llm_response = await self.text_chat(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
)
llm_response.is_chunk = True
yield llm_response
llm_response.is_chunk = False
yield llm_response
def get_current_key(self) -> str:
return self.client.api_key
@@ -12,6 +12,7 @@ from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
@@ -149,7 +150,7 @@ class ProviderOpenAIOfficial(Provider):
# 处理文本内容
if delta.content:
completion_text = delta.content
llm_response.result_chain = [Comp.Plain(completion_text)]
llm_response.result_chain = MessageChain(chain=[Comp.Plain(completion_text)])
yield llm_response
final_completion = state.get_final_completion()