✨ feat: 完善流式处理
This commit is contained in:
@@ -12,6 +12,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
@@ -156,7 +157,10 @@ class LLMRequestSubStage(Stage):
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
logger.debug(llm_response)
|
||||
yield llm_response.result_chain
|
||||
if llm_response.result_chain:
|
||||
yield llm_response.result_chain # MessageChain
|
||||
else:
|
||||
yield MessageChain().message(llm_response.completion_text)
|
||||
else:
|
||||
final_llm_response = llm_response
|
||||
else:
|
||||
|
||||
@@ -202,7 +202,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.role == "admin"
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator[List[BaseMessageComponent], None]):
|
||||
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram。
|
||||
"""
|
||||
|
||||
@@ -9,7 +9,6 @@ from astrbot.api.message_components import (
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
@@ -110,15 +109,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
payload["reply_to_message_id"] = message_thread_id
|
||||
|
||||
delta = ""
|
||||
current_content = ""
|
||||
message_id = None
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
|
||||
|
||||
async for chain in generator:
|
||||
logger.debug(f"streaming: {chain}")
|
||||
if isinstance(chain, list):
|
||||
if isinstance(chain, MessageChain):
|
||||
# 处理消息链中的每个组件
|
||||
for i in chain:
|
||||
for i in chain.chain:
|
||||
if isinstance(i, Plain):
|
||||
delta += i.text
|
||||
elif isinstance(i, Image):
|
||||
@@ -144,6 +143,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e}")
|
||||
message_id = msg.message_id
|
||||
@@ -163,13 +163,14 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e}")
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
|
||||
if delta:
|
||||
if delta and current_content != delta:
|
||||
await self.client.edit_message_text(
|
||||
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||
)
|
||||
|
||||
@@ -110,7 +110,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
MessageChain(chain=chain), session_id=self.session_id, streaming=True
|
||||
chain, session_id=self.session_id, streaming=True
|
||||
)
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
|
||||
@@ -10,6 +10,7 @@ from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
@@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
if content.type == "text":
|
||||
# text completion
|
||||
completion_text = str(content.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
# llm_response.completion_text = completion_text
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# Anthropic每次只返回一个函数调用
|
||||
if completion.stop_reason == "tool_use":
|
||||
@@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
messages=context_query, **model_config
|
||||
)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.completion_text = response.content[0].text
|
||||
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
||||
llm_response.raw_completion = response
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
@@ -171,7 +173,21 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core import logger, sp
|
||||
from dashscope import Application
|
||||
@@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
)
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
||||
result_chain=MessageChain().message(
|
||||
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "")
|
||||
@@ -149,7 +152,10 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
||||
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=output_text)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = MessageChain().message(output_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
@@ -162,7 +168,21 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def forget(self, session_id):
|
||||
return True
|
||||
|
||||
@@ -200,7 +200,21 @@ class ProviderDify(Provider):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
if isinstance(chunk, str):
|
||||
|
||||
@@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient:
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise Exception("Gemini 返回了非 json 数据: ")
|
||||
|
||||
async def stream_generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
"stream": True,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}"
|
||||
)
|
||||
async with self.client.post(
|
||||
request_url, json=payload, timeout=self.timeout
|
||||
) as resp:
|
||||
async for line in resp.content:
|
||||
if line:
|
||||
yield line
|
||||
|
||||
@register_provider_adapter(
|
||||
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
|
||||
@@ -349,7 +382,21 @@ class ProviderGoogleGenAI(Provider):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
@@ -12,6 +12,7 @@ from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
@@ -149,7 +150,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 处理文本内容
|
||||
if delta.content:
|
||||
completion_text = delta.content
|
||||
llm_response.result_chain = [Comp.Plain(completion_text)]
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(completion_text)])
|
||||
yield llm_response
|
||||
|
||||
final_completion = state.get_final_completion()
|
||||
|
||||
Reference in New Issue
Block a user