feat: 支持流式输出

This commit is contained in:
Soulter
2025-04-06 00:56:33 +08:00
parent 849913276d
commit 109650faf3
18 changed files with 762 additions and 247 deletions
+6
View File
@@ -50,6 +50,7 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"prompt_prefix": "",
"max_context_length": -1,
"streaming_response": False,
},
"provider_stt_settings": {
"enable": False,
@@ -992,6 +993,11 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
},
"streaming_response": {
"description": "启用流式回复",
"type": "bool",
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API 以及 Telegram 平台,并且暂不支持工具调用(后续将更新)",
},
},
},
"persona": {
+11 -1
View File
@@ -1,6 +1,6 @@
import enum
from typing import List, Optional, Union
from typing import List, Optional, Union, AsyncGenerator
from dataclasses import dataclass, field
from astrbot.core.message.components import (
BaseMessageComponent,
@@ -131,6 +131,8 @@ class ResultContentType(enum.Enum):
"""调用 LLM 产生的结果"""
GENERAL_RESULT = enum.auto()
"""普通的消息结果"""
STREAMING_RESULT = enum.auto()
"""调用 LLM 产生的流式结果"""
@dataclass
@@ -152,6 +154,9 @@ class MessageEventResult(MessageChain):
default_factory=lambda: ResultContentType.GENERAL_RESULT
)
async_stream: Optional[AsyncGenerator] = None
"""异步流"""
def stop_event(self) -> "MessageEventResult":
"""终止事件传播。"""
self.result_type = EventResultType.STOP
@@ -168,6 +173,11 @@ class MessageEventResult(MessageChain):
"""
return self.result_type == EventResultType.STOP
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
"""设置异步流。"""
self.async_stream = stream
return self
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
"""设置事件处理的结果类型。
@@ -37,6 +37,9 @@ class LLMRequestSubStage(Stage):
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -137,59 +140,90 @@ class LLMRequestSubStage(Stage):
if not req.session_id:
req.session_id = event.unified_msg_origin
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
async def requesting(req: ProviderRequest):
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
final_llm_response = None
if self.streaming_response:
stream = provider.text_chat_stream(
**req.__dict__
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
async for result in self._handle_llm_response(event, req, llm_response):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
async for llm_response in stream:
if llm_response.is_chunk:
logger.debug(llm_response)
yield llm_response.result_chain
else:
final_llm_response = llm_response
else:
yield
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
if not final_llm_response:
raise Exception("LLM response is None.")
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
)
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
# 保存到历史记录
await self._save_to_history(event, req, final_llm_response)
except BaseException as e:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
)
if not self.streaming_response:
async for _ in requesting(req):
yield
else:
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
)
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(requesting(req))
)
return
async def _handle_llm_response(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
+9 -2
View File
@@ -6,7 +6,7 @@ from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -79,7 +79,14 @@ class RespondStage(Stage):
if result is None:
return
if len(result.chain) > 0:
if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果直接交付平台适配器处理
logger.info(f"应用流式输出({event.get_platform_name()})")
await event._pre_send()
await event.send_streaming(result.async_stream)
await event._post_send()
return
elif len(result.chain) > 0:
await event._pre_send()
if self.enable_seg and (
@@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
@@ -71,6 +72,9 @@ class ResultDecorateStage(Stage):
result = event.get_result()
if result is None or not result.chain:
return
if result.result_content_type == ResultContentType.STREAMING_RESULT:
# 流式结果暂时不进行处理
return
# 回复时检查内容安全
if (
+10 -1
View File
@@ -1,7 +1,7 @@
import abc
import asyncio
from dataclasses import dataclass
from typing import List, Union, Optional
from typing import List, Union, Optional, AsyncGenerator
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
@@ -202,6 +202,15 @@ class AstrMessageEvent(abc.ABC):
"""
return self.role == "admin"
async def send_streaming(self, generator: AsyncGenerator[List[BaseMessageComponent], None]):
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegram。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
@@ -1,7 +1,16 @@
import asyncio
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
from astrbot.api.message_components import (
Plain,
Image,
Reply,
At,
File,
Record,
BaseMessageComponent,
)
from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
@@ -82,3 +91,87 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
await self.send_with_client(self.client, message, self.get_sender_id())
await super().send(message)
async def send_streaming(self, generator):
message_thread_id = None
if self.get_message_type() == MessageType.GROUP_MESSAGE:
user_name = self.message_obj.group_id
else:
user_name = self.get_sender_id()
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
payload = {
"chat_id": user_name,
}
if message_thread_id:
payload["reply_to_message_id"] = message_thread_id
delta = ""
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
async for chain in generator:
logger.debug(f"streaming: {chain}")
if isinstance(chain, list):
# 处理消息链中的每个组件
for i in chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
path = "data/temp/" + i.name
await download_file(i.file, path)
i.file = path
await self.client.send_document(
document=i.file, filename=i.name, **payload
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
continue
# Plain
if not message_id:
try:
msg = await self.client.send_message(text=delta, **payload)
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
else:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
# 编辑消息
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e}")
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
if delta:
await self.client.edit_message_text(
text=delta, chat_id=payload["chat_id"], message_id=message_id
)
return await super().send_streaming(generator)
@@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True)
@staticmethod
async def _send(message: MessageChain, session_id: str):
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
if not message:
web_chat_back_queue.put_nowait(None)
await web_chat_back_queue.put(
{"type": "end", "data": "", "streaming": False}
)
return
cid = session_id.split("!")[-1]
data = ""
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait((comp.text, cid))
data = comp.text
await web_chat_back_queue.put(
{
"type": "plain",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
@@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
"type": "image",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
@@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
"type": "record",
"cid": cid,
"data": data,
"streaming": streaming,
}
)
else:
logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
return data
async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await web_chat_back_queue.put(
{
"type": "end",
"data": "",
"streaming": False,
"cid": self.session_id.split("!")[-1],
}
)
await super().send(message)
async def send_streaming(self, generator):
final_data = ""
async for chain in generator:
final_data += await WebChatMessageEvent._send(
MessageChain(chain=chain), session_id=self.session_id, streaming=True
)
await web_chat_back_queue.put(
{
"type": "end",
"data": final_data,
"streaming": True,
"cid": self.session_id.split("!")[-1],
}
)
await super().send_streaming(generator)
+5
View File
@@ -204,6 +204,9 @@ class LLMResponse:
_completion_text: str = ""
is_chunk: bool = False
"""是否是流式输出的单个 Chunk"""
def __init__(
self,
role: str,
@@ -214,6 +217,7 @@ class LLMResponse:
tools_call_ids: List[str] = [],
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
is_chunk: bool = False,
):
"""初始化 LLMResponse
@@ -233,6 +237,7 @@ class LLMResponse:
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
self.is_chunk = is_chunk
@property
def completion_text(self):
+30 -2
View File
@@ -1,7 +1,7 @@
import abc
from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from dataclasses import dataclass
@@ -108,7 +108,35 @@ class Provider(AbstractProvider):
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
raise NotImplementedError()
...
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
...
async def pop_record(self, context: List):
"""
@@ -160,6 +160,19 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""组装上下文,支持文本和图片"""
if not image_urls:
@@ -141,12 +141,29 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
ref_title = (
ref.get("title", "")
if ref.get("title")
else ref.get("doc_name", "")
)
ref_str += f"{ref['index_id']}. {ref_title}\n"
output_text += f"\n\n回答来源:\n{ref_str}"
return LLMResponse(role="assistant", completion_text=output_text)
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
async def forget(self, session_id):
return True
@@ -189,6 +189,19 @@ class ProviderDify(Provider):
return LLMResponse(role="assistant", result_chain=chain)
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
# Chat
@@ -338,6 +338,19 @@ class ProviderGoogleGenAI(Provider):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
def get_current_key(self) -> str:
return self.client.api_key
@@ -95,6 +95,19 @@ class LLMTunerModelLoader(Provider):
return llm_response
async def text_chat_stream(
self,
prompt,
session_id=None,
image_urls=...,
func_tool=None,
contexts=...,
system_prompt=None,
tool_calls_result=None,
**kwargs,
):
raise NotImplementedError("This method is not implemented yet.")
async def get_current_key(self):
return "none"
+255 -58
View File
@@ -4,17 +4,20 @@ import os
import inspect
import random
import asyncio
import astrbot.core.message.components as Comp
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai._exceptions import NotFoundError, UnprocessableEntityError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from typing import List, AsyncGenerator
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
@@ -107,12 +110,63 @@ class ProviderOpenAIOfficial(Provider):
logger.debug(f"completion: {completion}")
llm_response = await self.parse_openai_completion(completion, tools)
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
tool_list = tools.get_func_desc_openai_style()
if tool_list:
payloads["tools"] = tool_list
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
to_del = []
for key in payloads.keys():
if key not in self.default_params:
extra_body[key] = payloads[key]
to_del.append(key)
for key in to_del:
del payloads[key]
stream = await self.client.chat.completions.create(
**payloads, stream=True, extra_body=extra_body
)
llm_response = LLMResponse("assistant", is_chunk=True)
state = ChatCompletionStreamState()
async for chunk in stream:
state.handle_chunk(chunk)
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
# 处理文本内容
if delta.content:
completion_text = delta.content
llm_response.result_chain = [Comp.Plain(completion_text)]
yield llm_response
final_completion = state.get_final_completion()
llm_response = await self.parse_openai_completion(final_completion, tools)
yield llm_response
async def parse_openai_completion(
self, completion: ChatCompletion, tools: FuncCall
):
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
llm_response = LLMResponse("assistant")
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
@@ -148,7 +202,7 @@ class ProviderOpenAIOfficial(Provider):
return llm_response
async def text_chat(
async def _prepare_chat_payload(
self,
prompt: str,
session_id: str = None,
@@ -158,7 +212,8 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
new_record = await self.assemble_context(prompt, image_urls)
context_query = [*contexts, new_record]
if system_prompt:
@@ -177,8 +232,117 @@ class ProviderOpenAIOfficial(Provider):
payloads = {"messages": context_query, **model_config}
llm_response = None
return payloads, context_query, func_tool
async def _handle_api_error(
self,
e: Exception,
payloads: dict,
context_query: list,
func_tool: FuncCall,
chosen_key: str,
available_api_keys: List[str],
retry_cnt: int,
max_retries: int,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
payloads["messages"] = context_query
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
return False, chosen_key, available_api_keys, payloads, context_query, None
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
payloads, context_query, func_tool = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
**kwargs,
)
llm_response = None
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
@@ -197,64 +361,97 @@ class ProviderOpenAIOfficial(Provider):
payloads["messages"] = new_contexts
context_query = new_contexts
except Exception as e:
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
continue
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
func_tool = None
else:
logger.error(
f"发生了错误。Provider 配置如下: {self.provider_config}"
)
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error(
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
)
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
return llm_response
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, context_query, func_tool = await self._prepare_chat_payload(
prompt,
session_id,
image_urls,
func_tool,
contexts,
system_prompt,
tool_calls_result,
**kwargs,
)
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
e = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
except Exception as e:
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
async def _remove_image_from_context(self, contexts: List):
"""
从上下文中删除所有带有 image 的记录
+27 -16
View File
@@ -161,42 +161,53 @@ class ChatRoute(Route):
username = g.get("username", "guest")
if username in self.curr_chat_sse:
return "[ERROR]\n"
return Response().error("Already connected").__dict__
self.curr_chat_sse[username] = None
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
async def stream():
try:
yield "[HB]\n"
yield f"data: {heartbeat}\n\n" # 心跳包
while True:
try:
result = await asyncio.wait_for(
web_chat_back_queue.get(), timeout=10
) # 设置超时时间为5秒
except asyncio.TimeoutError:
yield "[HB]\n" # 心跳包
yield f"data: {heartbeat}\n\n" # 心跳包
continue
if not result:
continue
result_text, cid = result
result_text = result["data"]
type = result.get("type")
cid = result.get("cid")
streaming = result.get("streaming", False)
if cid != self.curr_user_cid.get(username):
# 丢弃
continue
yield result_text + "\n"
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.15)
conversation = self.db.get_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
if streaming and type != "end":
continue
await asyncio.sleep(0.5)
if result_text:
conversation = self.db.get_conversation_by_user_id(
username, cid
)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({"type": "bot", "message": result_text})
self.db.update_conversation(
username, cid, history=json.dumps(history)
)
except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。")
self.curr_chat_sse.pop(username)
+104 -114
View File
@@ -1,6 +1,7 @@
<script setup>
import axios from 'axios';
import { marked } from 'marked';
import { ref } from 'vue';
marked.setOptions({
breaks: true
@@ -13,27 +14,30 @@ marked.setOptions({
<div class="chat-layout">
<!-- 左侧对话列表面板 -->
<div class="sidebar-panel">
<v-btn variant="tonal" rounded="xl" class="new-chat-btn" @click="newC"
:disabled="!currCid">
<v-btn variant="tonal" rounded="xl" class="new-chat-btn" @click="newC" :disabled="!currCid">
<v-icon class="mr-2">mdi-plus</v-icon>创建对话
</v-btn>
<v-card class="conversation-list-card" v-if="conversations.length > 0">
<v-list density="compact" nav class="conversation-list" @update:selected="getConversationMessages">
<v-list density="compact" nav class="conversation-list"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl" class="conversation-item">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at)
}}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<div class="status-chips">
<v-chip class="status-chip" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
<v-chip class="status-chip" color="primary"
:append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
LLM
</v-chip>
<v-chip class="status-chip" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
<v-chip class="status-chip" color="success"
:append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
语音转文本
</v-chip>
</div>
@@ -77,14 +81,15 @@ marked.setOptions({
<div v-if="msg.type == 'user'" class="user-message">
<div class="message-bubble user-bubble">
<span>{{ msg.message }}</span>
<!-- 图片附件 -->
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index" class="image-attachment">
<div v-for="(img, index) in msg.image_url" :key="index"
class="image-attachment">
<img :src="img" class="attached-image" />
</div>
</div>
<!-- 音频附件 -->
<div class="audio-attachment" v-if="msg.audio_url && msg.audio_url.length > 0">
<audio controls class="audio-player">
@@ -97,7 +102,7 @@ marked.setOptions({
<v-icon icon="mdi-account" />
</v-avatar>
</div>
<!-- 机器人消息 -->
<div v-else class="bot-message">
<v-avatar class="bot-avatar" color="deep-purple" size="36">
@@ -113,49 +118,30 @@ marked.setOptions({
<!-- 输入区域 -->
<div class="input-area fade-in">
<v-text-field
id="input-field"
variant="outlined"
v-model="prompt"
:label="inputFieldLabel"
placeholder="开始输入..."
:loading="loadingChat"
clear-icon="mdi-close-circle"
clearable
@click:clear="clearMessage"
class="message-input"
@keydown="handleInputKeyDown"
hide-details
>
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
placeholder="开始输入..." :loading="loadingChat" clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" class="message-input" @keydown="handleInputKeyDown"
hide-details>
<template v-slot:loader>
<v-progress-linear :active="loadingChat" height="3" color="deep-purple" indeterminate></v-progress-linear>
<v-progress-linear :active="loadingChat" height="3" color="deep-purple"
indeterminate></v-progress-linear>
</template>
<template v-slot:append>
<v-tooltip text="发送">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
@click="sendMessage"
class="send-btn"
icon="mdi-send"
variant="text"
color="deep-purple"
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl"
/>
<v-btn v-bind="props" @click="sendMessage" class="send-btn" icon="mdi-send"
variant="text" color="deep-purple"
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl" />
</template>
</v-tooltip>
<v-tooltip text="语音输入">
<template v-slot:activator="{ props }">
<v-btn
v-bind="props"
@click="isRecording ? stopRecording() : startRecording()"
class="record-btn"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'"
variant="text"
:color="isRecording ? 'error' : 'deep-purple'"
/>
<v-btn v-bind="props" @click="isRecording ? stopRecording() : startRecording()"
class="record-btn"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
:color="isRecording ? 'error' : 'deep-purple'" />
</template>
</v-tooltip>
</template>
@@ -165,15 +151,17 @@ marked.setOptions({
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl">
<div v-for="(img, index) in stagedImagesUrl" :key="index" class="image-preview">
<img :src="img" class="preview-image" />
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close"
size="small" color="error" variant="text" />
</div>
<div v-if="stagedAudioUrl" class="audio-preview">
<v-chip color="deep-purple-lighten-4" class="audio-chip">
<v-icon start icon="mdi-microphone" size="small"></v-icon>
新录音
</v-chip>
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small"
color="error" variant="text" />
</div>
</div>
</div>
@@ -206,9 +194,9 @@ export default {
status: {},
statusText: '',
eventSource: null,
// Ctrl
ctrlKeyDown: false,
ctrlKeyTimer: null,
@@ -228,18 +216,17 @@ export default {
this.sendMessage();
}
}.bind(this));
// keyup
document.addEventListener('keyup', this.handleInputKeyUp);
},
beforeUnmount() {
console.log("111")
if (this.eventSource) {
this.eventSource.cancel();
console.log('SSE连接已断开');
}
// keyup
document.removeEventListener('keyup', this.handleInputKeyUp);
},
@@ -265,6 +252,9 @@ export default {
this.eventSource = reader
let in_streaming = false
let message_obj = null
while (true) {
const { done, value } = await reader.read();
if (done) {
@@ -273,24 +263,27 @@ export default {
}
const chunk = decoder.decode(value, { stream: true });
console.log("!!!!", chunk);
if (chunk === '[HB]\n') {
// data: {"type": "plain", "data": "helloworld"}
let chunk_json = JSON.parse(chunk.replace('data: ', ''));
if (chunk_json.type === 'heartbeat') {
continue; //
}
if (chunk === '[ERROR]\n') {
if (chunk_json.type === 'error') {
console.error('Error received:', chunk_json.data);
continue;
}
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
if (chunk_json.type === 'image') {
let img = chunk_json.data.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else if (chunk.startsWith('[RECORD]')) {
let audio = chunk.replace('[RECORD]', '');
} else if (chunk_json.type === 'record') {
let audio = chunk_json.data.replace('[RECORD]', '');
let bot_resp = {
type: 'bot',
message: `<audio controls class="audio-player">
@@ -299,12 +292,20 @@ export default {
</audio>`
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
} else if (chunk_json.type === 'plain') {
if (!in_streaming) {
message_obj = {
type: 'bot',
message: ref(chunk_json.data),
}
this.messages.push(message_obj);
in_streaming = true;
} else {
message_obj.message.value += chunk_json.data;
}
this.messages.push(bot_resp);
} else if (chunk_json.type === 'end') {
in_streaming = false;
continue;
}
this.scrollToBottom();
}
@@ -526,42 +527,6 @@ export default {
this.stagedAudioUrl = "";
this.loadingChat = false;
// const reader = response.body.getReader(); // Reader
// const decoder = new TextDecoder();
// const readStream = async () => {
// const { done, value } = await reader.read(); //
// if (done) {
// console.log("Stream finished.");
// return;
// }
// const chunk = decoder.decode(value, { stream: true });
// // bot_resp.message.value += chunk;
// console.log("!!!!", chunk);
// if (chunk.startsWith('[IMAGE]')) {
// let img = chunk.replace('[IMAGE]', '');
// let bot_resp = {
// type: 'bot',
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
// }
// this.messages.push(bot_resp);
// } else {
// let bot_resp = {
// type: 'bot',
// message: chunk
// }
// this.messages.push(bot_resp);
// }
// this.scrollToBottom();
// readStream(); //
// };
// readStream();
})
.catch(err => {
console.error(err);
@@ -578,9 +543,9 @@ export default {
if (e.keyCode === 17) { // Ctrl
//
if (this.ctrlKeyDown) return;
this.ctrlKeyDown = true;
//
this.ctrlKeyTimer = setTimeout(() => {
if (this.ctrlKeyDown && !this.isRecording) {
@@ -589,17 +554,17 @@ export default {
}, this.ctrlKeyLongPressThreshold);
}
},
handleInputKeyUp(e) {
if (e.keyCode === 17) { // Ctrl
this.ctrlKeyDown = false;
//
if (this.ctrlKeyTimer) {
clearTimeout(this.ctrlKeyTimer);
this.ctrlKeyTimer = null;
}
//
if (this.isRecording) {
this.stopRecording();
@@ -613,19 +578,41 @@ export default {
<style>
/* 基础动画 */
@keyframes fadeIn {
from { opacity: 0; transform: translateY(10px); }
to { opacity: 1; transform: translateY(0); }
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.05); }
100% { transform: scale(1); }
0% {
transform: scale(1);
}
50% {
transform: scale(1.05);
}
100% {
transform: scale(1);
}
}
@keyframes slideIn {
from { transform: translateX(20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
from {
transform: translateX(20px);
opacity: 0;
}
to {
transform: translateX(0);
opacity: 1;
}
}
/* 聊天页面布局 */
@@ -828,7 +815,8 @@ export default {
border-top-left-radius: 4px;
}
.user-avatar, .bot-avatar {
.user-avatar,
.bot-avatar {
align-self: flex-end;
}
@@ -881,7 +869,8 @@ export default {
margin: 0 auto;
}
.send-btn, .record-btn {
.send-btn,
.record-btn {
margin-left: 4px;
}
@@ -895,7 +884,8 @@ export default {
flex-wrap: wrap;
}
.image-preview, .audio-preview {
.image-preview,
.audio-preview {
position: relative;
display: inline-flex;
}
@@ -1003,7 +993,7 @@ export default {
margin: 16px 0;
}
.markdown-content th,
.markdown-content th,
.markdown-content td {
border: 1px solid #eee;
padding: 8px 12px;