From c1cf2be533b3d71baac41ca2f20f400d498e568c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 6 Apr 2025 11:56:06 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20=E5=AE=8C=E5=96=84=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/llm_request.py | 6 ++- astrbot/core/platform/astr_message_event.py | 2 +- .../platform/sources/telegram/tg_event.py | 11 +++-- .../platform/sources/webchat/webchat_event.py | 2 +- .../core/provider/sources/anthropic_source.py | 22 +++++++-- .../core/provider/sources/dashscope_source.py | 26 ++++++++-- astrbot/core/provider/sources/dify_source.py | 16 +++++- .../core/provider/sources/gemini_source.py | 49 ++++++++++++++++++- .../core/provider/sources/openai_source.py | 3 +- 9 files changed, 120 insertions(+), 17 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 5f20b13ab..fafb81944 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -12,6 +12,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import ( MessageEventResult, ResultContentType, + MessageChain ) from astrbot.core.message.components import Image from astrbot.core import logger @@ -156,7 +157,10 @@ class LLMRequestSubStage(Stage): async for llm_response in stream: if llm_response.is_chunk: logger.debug(llm_response) - yield llm_response.result_chain + if llm_response.result_chain: + yield llm_response.result_chain # MessageChain + else: + yield MessageChain().message(llm_response.completion_text) else: final_llm_response = llm_response else: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 0f91b7087..414a6721b 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -202,7 +202,7 @@ class AstrMessageEvent(abc.ABC): """ return self.role == "admin" - async def send_streaming(self, generator: AsyncGenerator[List[BaseMessageComponent], None]): + async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]): """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram。 """ diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 6374f8623..4759e8437 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -9,7 +9,6 @@ from astrbot.api.message_components import ( At, File, Record, - BaseMessageComponent, ) from telegram.ext import ExtBot from astrbot.core.utils.io import download_file @@ -110,15 +109,15 @@ class TelegramPlatformEvent(AstrMessageEvent): payload["reply_to_message_id"] = message_thread_id delta = "" + current_content = "" message_id = None last_edit_time = 0 # 上次编辑消息的时间 throttle_interval = 0.6 # 编辑消息的间隔时间 (秒) async for chain in generator: - logger.debug(f"streaming: {chain}") - if isinstance(chain, list): + if isinstance(chain, MessageChain): # 处理消息链中的每个组件 - for i in chain: + for i in chain.chain: if isinstance(i, Plain): delta += i.text elif isinstance(i, Image): @@ -144,6 +143,7 @@ class TelegramPlatformEvent(AstrMessageEvent): if not message_id: try: msg = await self.client.send_message(text=delta, **payload) + current_content = delta except Exception as e: logger.warning(f"发送消息失败(streaming): {e}") message_id = msg.message_id @@ -163,13 +163,14 @@ class TelegramPlatformEvent(AstrMessageEvent): chat_id=payload["chat_id"], message_id=message_id, ) + current_content = delta except Exception as e: logger.warning(f"编辑消息失败(streaming): {e}") last_edit_time = ( asyncio.get_event_loop().time() ) # 更新上次编辑的时间 - if delta: + if delta and current_content != delta: await self.client.edit_message_text( text=delta, chat_id=payload["chat_id"], message_id=message_id ) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 9aac55c23..ef5532920 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -110,7 +110,7 @@ class WebChatMessageEvent(AstrMessageEvent): final_data = "" async for chain in generator: final_data += await WebChatMessageEvent._send( - MessageChain(chain=chain), session_id=self.session_id, streaming=True + chain, session_id=self.session_id, streaming=True ) await web_chat_back_queue.put( diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 7e26018b1..2a73c3079 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,6 +10,7 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from .openai_source import ProviderOpenAIOfficial @@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if content.type == "text": # text completion completion_text = str(content.text).strip() - llm_response.completion_text = completion_text + # llm_response.completion_text = completion_text + llm_response.result_chain = MessageChain().message(completion_text) # Anthropic每次只返回一个函数调用 if completion.stop_reason == "tool_use": @@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): messages=context_query, **model_config ) llm_response = LLMResponse("assistant") - llm_response.completion_text = response.content[0].text + llm_response.result_chain = MessageChain().message(response.content[0].text) llm_response.raw_completion = response return llm_response except Exception as e: @@ -171,7 +173,21 @@ class ProviderAnthropic(ProviderOpenAIOfficial): tool_calls_result=None, **kwargs, ): - raise NotImplementedError("This method is not implemented yet.") + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response async def assemble_context(self, text: str, image_urls: List[str] = None): """组装上下文,支持文本和图片""" diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index cf1559f5e..59f2b5f98 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -7,6 +7,7 @@ from ..entites import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter +from astrbot.core.message.message_event_result import MessageChain from .openai_source import ProviderOpenAIOfficial from astrbot.core import logger, sp from dashscope import Application @@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) return LLMResponse( role="err", - completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + result_chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}" + ), ) output_text = response.output.get("text", "") @@ -149,7 +152,10 @@ class ProviderDashscope(ProviderOpenAIOfficial): ref_str += f"{ref['index_id']}. {ref_title}\n" output_text += f"\n\n回答来源:\n{ref_str}" - return LLMResponse(role="assistant", completion_text=output_text) + llm_response = LLMResponse("assistant") + llm_response.result_chain = MessageChain().message(output_text) + + return llm_response async def text_chat_stream( self, @@ -162,7 +168,21 @@ class ProviderDashscope(ProviderOpenAIOfficial): tool_calls_result=None, **kwargs, ): - raise NotImplementedError("This method is not implemented yet.") + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response async def forget(self, session_id): return True diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index f0c7225e3..65fe76fce 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -200,7 +200,21 @@ class ProviderDify(Provider): tool_calls_result=None, **kwargs, ): - raise NotImplementedError("This method is not implemented yet.") + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response async def parse_dify_result(self, chunk: dict | str) -> MessageChain: if isinstance(chunk, str): diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 19da42f1a..4ce775fa0 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient: logger.error(f"Gemini 返回了非 json 数据: {text}") raise Exception("Gemini 返回了非 json 数据: ") + async def stream_generate_content( + self, + contents: List[dict], + model: str = "gemini-1.5-flash", + system_instruction: str = "", + tools: dict = None, + modalities: List[str] = ["Text"], + safety_settings: List[dict] = [], + ): + payload = {} + if system_instruction: + payload["system_instruction"] = {"parts": {"text": system_instruction}} + if tools: + payload["tools"] = [tools] + payload["contents"] = contents + payload["generationConfig"] = { + "responseModalities": modalities, + "stream": True, + } + payload["safetySettings"] = [ + {"category": s["category"], "threshold": s["threshold"]} + for s in safety_settings + ] + logger.debug(f"payload: {payload}") + request_url = ( + f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}" + ) + async with self.client.post( + request_url, json=payload, timeout=self.timeout + ) as resp: + async for line in resp.content: + if line: + yield line @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" @@ -349,7 +382,21 @@ class ProviderGoogleGenAI(Provider): tool_calls_result=None, **kwargs, ): - raise NotImplementedError("This method is not implemented yet.") + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response def get_current_key(self) -> str: return self.client.api_key diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index c29227926..0d080167e 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -12,6 +12,7 @@ from openai.types.chat.chat_completion import ChatCompletion from openai._exceptions import NotFoundError, UnprocessableEntityError from openai.lib.streaming.chat._completions import ChatCompletionStreamState from astrbot.core.utils.io import download_image_by_url +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider, Personality @@ -149,7 +150,7 @@ class ProviderOpenAIOfficial(Provider): # 处理文本内容 if delta.content: completion_text = delta.content - llm_response.result_chain = [Comp.Plain(completion_text)] + llm_response.result_chain = MessageChain(chain=[Comp.Plain(completion_text)]) yield llm_response final_completion = state.get_final_completion()