✨ feat: 支持流式输出
This commit is contained in:
@@ -50,6 +50,7 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"prompt_prefix": "",
|
||||
"max_context_length": -1,
|
||||
"streaming_response": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -992,6 +993,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
|
||||
},
|
||||
"streaming_response": {
|
||||
"description": "启用流式回复",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API 以及 Telegram 平台,并且暂不支持工具调用(后续将更新)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from astrbot.core.message.components import (
|
||||
BaseMessageComponent,
|
||||
@@ -131,6 +131,8 @@ class ResultContentType(enum.Enum):
|
||||
"""调用 LLM 产生的结果"""
|
||||
GENERAL_RESULT = enum.auto()
|
||||
"""普通的消息结果"""
|
||||
STREAMING_RESULT = enum.auto()
|
||||
"""调用 LLM 产生的流式结果"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -152,6 +154,9 @@ class MessageEventResult(MessageChain):
|
||||
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
|
||||
async_stream: Optional[AsyncGenerator] = None
|
||||
"""异步流"""
|
||||
|
||||
def stop_event(self) -> "MessageEventResult":
|
||||
"""终止事件传播。"""
|
||||
self.result_type = EventResultType.STOP
|
||||
@@ -168,6 +173,11 @@ class MessageEventResult(MessageChain):
|
||||
"""
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||
"""设置异步流。"""
|
||||
self.async_stream = stream
|
||||
return self
|
||||
|
||||
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
||||
"""设置事件处理的结果类型。
|
||||
|
||||
|
||||
@@ -37,6 +37,9 @@ class LLMRequestSubStage(Stage):
|
||||
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||
"max_context_length"
|
||||
] # int
|
||||
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||
"streaming_response"
|
||||
] # bool
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -137,59 +140,90 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
need_loop = False
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
async def requesting(req: ProviderRequest):
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
need_loop = False
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
|
||||
# 执行 LLM 响应后的事件钩子。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMResponseEvent
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
final_llm_response = None
|
||||
|
||||
if self.streaming_response:
|
||||
stream = provider.text_chat_stream(
|
||||
**req.__dict__
|
||||
)
|
||||
await handler.handler(event, llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
async for result in self._handle_llm_response(event, req, llm_response):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
logger.debug(llm_response)
|
||||
yield llm_response.result_chain
|
||||
else:
|
||||
final_llm_response = llm_response
|
||||
else:
|
||||
yield
|
||||
final_llm_response = await provider.text_chat(
|
||||
**req.__dict__
|
||||
) # 请求 LLM
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=provider.get_model(),
|
||||
provider_type=provider.meta().type,
|
||||
if not final_llm_response:
|
||||
raise Exception("LLM response is None.")
|
||||
|
||||
# 执行 LLM 响应后的事件钩子。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMResponseEvent
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, final_llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
async for result in self._handle_llm_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=provider.get_model(),
|
||||
provider_type=provider.meta().type,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, llm_response)
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, final_llm_response)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
if not self.streaming_response:
|
||||
async for _ in requesting(req):
|
||||
yield
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting(req))
|
||||
)
|
||||
return
|
||||
|
||||
async def _handle_llm_response(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
@@ -79,7 +79,14 @@ class RespondStage(Stage):
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果直接交付平台适配器处理
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
await event._pre_send()
|
||||
await event.send_streaming(result.async_stream)
|
||||
await event._post_send()
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
await event._pre_send()
|
||||
|
||||
if self.enable_seg and (
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
@@ -71,6 +72,9 @@ class ResultDecorateStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None or not result.chain:
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果暂时不进行处理
|
||||
return
|
||||
|
||||
# 回复时检查内容安全
|
||||
if (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union, Optional
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
@@ -202,6 +202,15 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.role == "admin"
|
||||
|
||||
async def send_streaming(self, generator: AsyncGenerator[List[BaseMessageComponent], None]):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
"""调度器会在执行 send() 前调用该方法"""
|
||||
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
Reply,
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
@@ -82,3 +91,87 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
message_thread_id = None
|
||||
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
user_name = self.message_obj.group_id
|
||||
else:
|
||||
user_name = self.get_sender_id()
|
||||
|
||||
if "#" in user_name:
|
||||
# it's a supergroup chat with message_thread_id
|
||||
user_name, message_thread_id = user_name.split("#")
|
||||
payload = {
|
||||
"chat_id": user_name,
|
||||
}
|
||||
if message_thread_id:
|
||||
payload["reply_to_message_id"] = message_thread_id
|
||||
|
||||
delta = ""
|
||||
message_id = None
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
|
||||
|
||||
async for chain in generator:
|
||||
logger.debug(f"streaming: {chain}")
|
||||
if isinstance(chain, list):
|
||||
# 处理消息链中的每个组件
|
||||
for i in chain:
|
||||
if isinstance(i, Plain):
|
||||
delta += i.text
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await self.client.send_photo(photo=image_path, **payload)
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
await self.client.send_document(
|
||||
document=i.file, filename=i.name, **payload
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await self.client.send_voice(voice=path, **payload)
|
||||
continue
|
||||
|
||||
# Plain
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
else:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
# 编辑消息
|
||||
try:
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e}")
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
|
||||
if delta:
|
||||
await self.client.edit_message_text(
|
||||
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||
)
|
||||
|
||||
return await super().send_streaming(generator)
|
||||
|
||||
@@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
os.makedirs(imgs_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _send(message: MessageChain, session_id: str):
|
||||
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||
if not message:
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await web_chat_back_queue.put(
|
||||
{"type": "end", "data": "", "streaming": False}
|
||||
)
|
||||
return
|
||||
|
||||
cid = session_id.split("!")[-1]
|
||||
|
||||
data = ""
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
||||
data = comp.text
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
@@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
data = f"[IMAGE]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
# save record to local
|
||||
filename = str(uuid.uuid4()) + ".wav"
|
||||
@@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
|
||||
data = f"[RECORD]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "record",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"webchat 忽略: {comp.type}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
MessageChain(chain=chain), session_id=self.session_id, streaming=True
|
||||
)
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await super().send_streaming(generator)
|
||||
|
||||
@@ -204,6 +204,9 @@ class LLMResponse:
|
||||
|
||||
_completion_text: str = ""
|
||||
|
||||
is_chunk: bool = False
|
||||
"""是否是流式输出的单个 Chunk"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
@@ -214,6 +217,7 @@ class LLMResponse:
|
||||
tools_call_ids: List[str] = [],
|
||||
raw_completion: ChatCompletion = None,
|
||||
_new_record: Dict[str, any] = None,
|
||||
is_chunk: bool = False,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
@@ -233,6 +237,7 @@ class LLMResponse:
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.raw_completion = raw_completion
|
||||
self._new_record = _new_record
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
@property
|
||||
def completion_text(self):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
|
||||
from dataclasses import dataclass
|
||||
@@ -108,7 +108,35 @@ class Provider(AbstractProvider):
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
...
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
session_id: 会话 ID(此属性已经被废弃)
|
||||
image_urls: 图片 URL 列表
|
||||
tools: Function-calling 工具
|
||||
contexts: 上下文
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
"""
|
||||
...
|
||||
|
||||
async def pop_record(self, context: List):
|
||||
"""
|
||||
|
||||
@@ -160,6 +160,19 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
|
||||
@@ -141,12 +141,29 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_str = ""
|
||||
for ref in response.output.get("doc_references", []):
|
||||
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
else ref.get("doc_name", "")
|
||||
)
|
||||
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
||||
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=output_text)
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
|
||||
async def forget(self, session_id):
|
||||
return True
|
||||
|
||||
|
||||
@@ -189,6 +189,19 @@ class ProviderDify(Provider):
|
||||
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
|
||||
@@ -338,6 +338,19 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
|
||||
@@ -95,6 +95,19 @@ class LLMTunerModelLoader(Provider):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("This method is not implemented yet.")
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
|
||||
|
||||
@@ -4,17 +4,20 @@ import os
|
||||
import inspect
|
||||
import random
|
||||
import asyncio
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from typing import List, AsyncGenerator
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
@@ -107,12 +110,63 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
llm_response = await self.parse_openai_completion(completion, tools)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式查询API,逐步返回结果"""
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
# 不在默认参数中的参数放在 extra_body 中
|
||||
extra_body = {}
|
||||
to_del = []
|
||||
for key in payloads.keys():
|
||||
if key not in self.default_params:
|
||||
extra_body[key] = payloads[key]
|
||||
to_del.append(key)
|
||||
for key in to_del:
|
||||
del payloads[key]
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
**payloads, stream=True, extra_body=extra_body
|
||||
)
|
||||
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
state = ChatCompletionStreamState()
|
||||
|
||||
async for chunk in stream:
|
||||
state.handle_chunk(chunk)
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
# 处理文本内容
|
||||
if delta.content:
|
||||
completion_text = delta.content
|
||||
llm_response.result_chain = [Comp.Plain(completion_text)]
|
||||
yield llm_response
|
||||
|
||||
final_completion = state.get_final_completion()
|
||||
llm_response = await self.parse_openai_completion(final_completion, tools)
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: FuncCall
|
||||
):
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
@@ -148,7 +202,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
async def _prepare_chat_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
@@ -158,7 +212,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
@@ -177,8 +232,117 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
llm_response = None
|
||||
return payloads, context_query, func_tool
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
e: Exception,
|
||||
payloads: dict,
|
||||
context_query: list,
|
||||
func_tool: FuncCall,
|
||||
chosen_key: str,
|
||||
available_api_keys: List[str],
|
||||
retry_cnt: int,
|
||||
max_retries: int,
|
||||
) -> tuple:
|
||||
"""处理API错误并尝试恢复"""
|
||||
if "429" in str(e):
|
||||
logger.warning(
|
||||
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
||||
)
|
||||
# 最后一次不等待
|
||||
if retry_cnt < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
available_api_keys.remove(chosen_key)
|
||||
if len(available_api_keys) > 0:
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
return (
|
||||
False,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
elif "maximum context length" in str(e):
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
await self.pop_record(context_query)
|
||||
payloads["messages"] = context_query
|
||||
return (
|
||||
False,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
)
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
return (
|
||||
False,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
)
|
||||
elif (
|
||||
"Function calling is not enabled" in str(e)
|
||||
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
||||
or ("function" in str(e).lower() and "support" in str(e).lower())
|
||||
):
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
return False, chosen_key, available_api_keys, payloads, context_query, None
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
llm_response = None
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
@@ -197,64 +361,97 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
if "429" in str(e):
|
||||
logger.warning(
|
||||
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
||||
)
|
||||
# 最后一次不等待
|
||||
if retry_cnt < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
available_api_keys.remove(chosen_key)
|
||||
if len(available_api_keys) > 0:
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
elif "maximum context length" in str(e):
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
await self.pop_record(context_query)
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
elif (
|
||||
"Function calling is not enabled" in str(e)
|
||||
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
||||
or ("function" in str(e).lower() and "support" in str(e).lower())
|
||||
):
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
func_tool = None
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error(
|
||||
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
|
||||
)
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||
)
|
||||
|
||||
raise e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
) = await self._handle_api_error(
|
||||
e,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
retry_cnt,
|
||||
max_retries,
|
||||
)
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
) = await self._handle_api_error(
|
||||
e,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
retry_cnt,
|
||||
max_retries,
|
||||
)
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
"""
|
||||
从上下文中删除所有带有 image 的记录
|
||||
|
||||
@@ -161,42 +161,53 @@ class ChatRoute(Route):
|
||||
username = g.get("username", "guest")
|
||||
|
||||
if username in self.curr_chat_sse:
|
||||
return "[ERROR]\n"
|
||||
return Response().error("Already connected").__dict__
|
||||
|
||||
self.curr_chat_sse[username] = None
|
||||
|
||||
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
|
||||
|
||||
async def stream():
|
||||
try:
|
||||
yield "[HB]\n"
|
||||
yield f"data: {heartbeat}\n\n" # 心跳包
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
web_chat_back_queue.get(), timeout=10
|
||||
) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield "[HB]\n" # 心跳包
|
||||
yield f"data: {heartbeat}\n\n" # 心跳包
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
result_text, cid = result
|
||||
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
cid = result.get("cid")
|
||||
streaming = result.get("streaming", False)
|
||||
if cid != self.curr_user_cid.get(username):
|
||||
# 丢弃
|
||||
continue
|
||||
yield result_text + "\n"
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
conversation = self.db.get_conversation_by_user_id(username, cid)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
history.append({"type": "bot", "message": result_text})
|
||||
self.db.update_conversation(
|
||||
username, cid, history=json.dumps(history)
|
||||
)
|
||||
if streaming and type != "end":
|
||||
continue
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
if result_text:
|
||||
conversation = self.db.get_conversation_by_user_id(
|
||||
username, cid
|
||||
)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
history.append({"type": "bot", "message": result_text})
|
||||
self.db.update_conversation(
|
||||
username, cid, history=json.dumps(history)
|
||||
)
|
||||
except BaseException as _:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||
self.curr_chat_sse.pop(username)
|
||||
|
||||
+104
-114
@@ -1,6 +1,7 @@
|
||||
<script setup>
|
||||
import axios from 'axios';
|
||||
import { marked } from 'marked';
|
||||
import { ref } from 'vue';
|
||||
|
||||
marked.setOptions({
|
||||
breaks: true
|
||||
@@ -13,27 +14,30 @@ marked.setOptions({
|
||||
<div class="chat-layout">
|
||||
<!-- 左侧对话列表面板 -->
|
||||
<div class="sidebar-panel">
|
||||
<v-btn variant="tonal" rounded="xl" class="new-chat-btn" @click="newC"
|
||||
:disabled="!currCid">
|
||||
<v-btn variant="tonal" rounded="xl" class="new-chat-btn" @click="newC" :disabled="!currCid">
|
||||
<v-icon class="mr-2">mdi-plus</v-icon>创建对话
|
||||
</v-btn>
|
||||
|
||||
<v-card class="conversation-list-card" v-if="conversations.length > 0">
|
||||
<v-list density="compact" nav class="conversation-list" @update:selected="getConversationMessages">
|
||||
<v-list density="compact" nav class="conversation-list"
|
||||
@update:selected="getConversationMessages">
|
||||
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
|
||||
color="primary" rounded="xl" class="conversation-item">
|
||||
<v-list-item-title>新对话</v-list-item-title>
|
||||
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
|
||||
<v-list-item-subtitle class="timestamp">{{ formatDate(item.updated_at)
|
||||
}}</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-card>
|
||||
|
||||
<div class="status-chips">
|
||||
<v-chip class="status-chip" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
|
||||
<v-chip class="status-chip" color="primary"
|
||||
:append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
|
||||
LLM
|
||||
</v-chip>
|
||||
|
||||
<v-chip class="status-chip" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
|
||||
<v-chip class="status-chip" color="success"
|
||||
:append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
|
||||
语音转文本
|
||||
</v-chip>
|
||||
</div>
|
||||
@@ -77,14 +81,15 @@ marked.setOptions({
|
||||
<div v-if="msg.type == 'user'" class="user-message">
|
||||
<div class="message-bubble user-bubble">
|
||||
<span>{{ msg.message }}</span>
|
||||
|
||||
|
||||
<!-- 图片附件 -->
|
||||
<div class="image-attachments" v-if="msg.image_url && msg.image_url.length > 0">
|
||||
<div v-for="(img, index) in msg.image_url" :key="index" class="image-attachment">
|
||||
<div v-for="(img, index) in msg.image_url" :key="index"
|
||||
class="image-attachment">
|
||||
<img :src="img" class="attached-image" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- 音频附件 -->
|
||||
<div class="audio-attachment" v-if="msg.audio_url && msg.audio_url.length > 0">
|
||||
<audio controls class="audio-player">
|
||||
@@ -97,7 +102,7 @@ marked.setOptions({
|
||||
<v-icon icon="mdi-account" />
|
||||
</v-avatar>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- 机器人消息 -->
|
||||
<div v-else class="bot-message">
|
||||
<v-avatar class="bot-avatar" color="deep-purple" size="36">
|
||||
@@ -113,49 +118,30 @@ marked.setOptions({
|
||||
|
||||
<!-- 输入区域 -->
|
||||
<div class="input-area fade-in">
|
||||
<v-text-field
|
||||
id="input-field"
|
||||
variant="outlined"
|
||||
v-model="prompt"
|
||||
:label="inputFieldLabel"
|
||||
placeholder="开始输入..."
|
||||
:loading="loadingChat"
|
||||
clear-icon="mdi-close-circle"
|
||||
clearable
|
||||
@click:clear="clearMessage"
|
||||
class="message-input"
|
||||
@keydown="handleInputKeyDown"
|
||||
hide-details
|
||||
>
|
||||
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
|
||||
placeholder="开始输入..." :loading="loadingChat" clear-icon="mdi-close-circle" clearable
|
||||
@click:clear="clearMessage" class="message-input" @keydown="handleInputKeyDown"
|
||||
hide-details>
|
||||
<template v-slot:loader>
|
||||
<v-progress-linear :active="loadingChat" height="3" color="deep-purple" indeterminate></v-progress-linear>
|
||||
<v-progress-linear :active="loadingChat" height="3" color="deep-purple"
|
||||
indeterminate></v-progress-linear>
|
||||
</template>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-tooltip text="发送">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn
|
||||
v-bind="props"
|
||||
@click="sendMessage"
|
||||
class="send-btn"
|
||||
icon="mdi-send"
|
||||
variant="text"
|
||||
color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl"
|
||||
/>
|
||||
<v-btn v-bind="props" @click="sendMessage" class="send-btn" icon="mdi-send"
|
||||
variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesUrl.length === 0 && !stagedAudioUrl" />
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-tooltip text="语音输入">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn
|
||||
v-bind="props"
|
||||
@click="isRecording ? stopRecording() : startRecording()"
|
||||
class="record-btn"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'"
|
||||
variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'"
|
||||
/>
|
||||
<v-btn v-bind="props" @click="isRecording ? stopRecording() : startRecording()"
|
||||
class="record-btn"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'" />
|
||||
</template>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
@@ -165,15 +151,17 @@ marked.setOptions({
|
||||
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl">
|
||||
<div v-for="(img, index) in stagedImagesUrl" :key="index" class="image-preview">
|
||||
<img :src="img" class="preview-image" />
|
||||
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
|
||||
<v-btn @click="removeImage(index)" class="remove-attachment-btn" icon="mdi-close"
|
||||
size="small" color="error" variant="text" />
|
||||
</div>
|
||||
|
||||
|
||||
<div v-if="stagedAudioUrl" class="audio-preview">
|
||||
<v-chip color="deep-purple-lighten-4" class="audio-chip">
|
||||
<v-icon start icon="mdi-microphone" size="small"></v-icon>
|
||||
新录音
|
||||
</v-chip>
|
||||
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small" color="error" variant="text" />
|
||||
<v-btn @click="removeAudio" class="remove-attachment-btn" icon="mdi-close" size="small"
|
||||
color="error" variant="text" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -206,9 +194,9 @@ export default {
|
||||
|
||||
status: {},
|
||||
statusText: '',
|
||||
|
||||
|
||||
eventSource: null,
|
||||
|
||||
|
||||
// Ctrl键长按相关变量
|
||||
ctrlKeyDown: false,
|
||||
ctrlKeyTimer: null,
|
||||
@@ -228,18 +216,17 @@ export default {
|
||||
this.sendMessage();
|
||||
}
|
||||
}.bind(this));
|
||||
|
||||
|
||||
// 添加keyup事件监听
|
||||
document.addEventListener('keyup', this.handleInputKeyUp);
|
||||
},
|
||||
|
||||
beforeUnmount() {
|
||||
console.log("111")
|
||||
if (this.eventSource) {
|
||||
this.eventSource.cancel();
|
||||
console.log('SSE连接已断开');
|
||||
}
|
||||
|
||||
|
||||
// 移除keyup事件监听
|
||||
document.removeEventListener('keyup', this.handleInputKeyUp);
|
||||
},
|
||||
@@ -265,6 +252,9 @@ export default {
|
||||
|
||||
this.eventSource = reader
|
||||
|
||||
let in_streaming = false
|
||||
let message_obj = null
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
@@ -273,24 +263,27 @@ export default {
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
console.log("!!!!", chunk);
|
||||
|
||||
if (chunk === '[HB]\n') {
|
||||
// data: {"type": "plain", "data": "helloworld"}
|
||||
let chunk_json = JSON.parse(chunk.replace('data: ', ''));
|
||||
|
||||
if (chunk_json.type === 'heartbeat') {
|
||||
continue; // 心跳包
|
||||
}
|
||||
if (chunk === '[ERROR]\n') {
|
||||
if (chunk_json.type === 'error') {
|
||||
console.error('Error received:', chunk_json.data);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
if (chunk_json.type === 'image') {
|
||||
let img = chunk_json.data.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else if (chunk.startsWith('[RECORD]')) {
|
||||
let audio = chunk.replace('[RECORD]', '');
|
||||
} else if (chunk_json.type === 'record') {
|
||||
let audio = chunk_json.data.replace('[RECORD]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<audio controls class="audio-player">
|
||||
@@ -299,12 +292,20 @@ export default {
|
||||
</audio>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
} else if (chunk_json.type === 'plain') {
|
||||
if (!in_streaming) {
|
||||
message_obj = {
|
||||
type: 'bot',
|
||||
message: ref(chunk_json.data),
|
||||
}
|
||||
this.messages.push(message_obj);
|
||||
in_streaming = true;
|
||||
} else {
|
||||
message_obj.message.value += chunk_json.data;
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else if (chunk_json.type === 'end') {
|
||||
in_streaming = false;
|
||||
continue;
|
||||
}
|
||||
this.scrollToBottom();
|
||||
}
|
||||
@@ -526,42 +527,6 @@ export default {
|
||||
this.stagedAudioUrl = "";
|
||||
|
||||
this.loadingChat = false;
|
||||
|
||||
// const reader = response.body.getReader(); // 获取流的 Reader
|
||||
// const decoder = new TextDecoder();
|
||||
|
||||
// const readStream = async () => {
|
||||
// const { done, value } = await reader.read(); // 读取流中的数据
|
||||
// if (done) {
|
||||
// console.log("Stream finished.");
|
||||
// return;
|
||||
// }
|
||||
|
||||
// const chunk = decoder.decode(value, { stream: true });
|
||||
// // bot_resp.message.value += chunk;
|
||||
|
||||
// console.log("!!!!", chunk);
|
||||
// if (chunk.startsWith('[IMAGE]')) {
|
||||
// let img = chunk.replace('[IMAGE]', '');
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
// }
|
||||
// this.messages.push(bot_resp);
|
||||
// } else {
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: chunk
|
||||
// }
|
||||
|
||||
// this.messages.push(bot_resp);
|
||||
// }
|
||||
|
||||
// this.scrollToBottom();
|
||||
// readStream(); // 递归读取流
|
||||
// };
|
||||
|
||||
// readStream();
|
||||
})
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
@@ -578,9 +543,9 @@ export default {
|
||||
if (e.keyCode === 17) { // Ctrl键
|
||||
// 防止重复触发
|
||||
if (this.ctrlKeyDown) return;
|
||||
|
||||
|
||||
this.ctrlKeyDown = true;
|
||||
|
||||
|
||||
// 设置定时器识别长按
|
||||
this.ctrlKeyTimer = setTimeout(() => {
|
||||
if (this.ctrlKeyDown && !this.isRecording) {
|
||||
@@ -589,17 +554,17 @@ export default {
|
||||
}, this.ctrlKeyLongPressThreshold);
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
handleInputKeyUp(e) {
|
||||
if (e.keyCode === 17) { // Ctrl键
|
||||
this.ctrlKeyDown = false;
|
||||
|
||||
|
||||
// 清除定时器
|
||||
if (this.ctrlKeyTimer) {
|
||||
clearTimeout(this.ctrlKeyTimer);
|
||||
this.ctrlKeyTimer = null;
|
||||
}
|
||||
|
||||
|
||||
// 如果正在录音,停止录音
|
||||
if (this.isRecording) {
|
||||
this.stopRecording();
|
||||
@@ -613,19 +578,41 @@ export default {
|
||||
<style>
|
||||
/* 基础动画 */
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(10px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% { transform: scale(1); }
|
||||
50% { transform: scale(1.05); }
|
||||
100% { transform: scale(1); }
|
||||
0% {
|
||||
transform: scale(1);
|
||||
}
|
||||
|
||||
50% {
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
100% {
|
||||
transform: scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
from { transform: translateX(20px); opacity: 0; }
|
||||
to { transform: translateX(0); opacity: 1; }
|
||||
from {
|
||||
transform: translateX(20px);
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
transform: translateX(0);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
/* 聊天页面布局 */
|
||||
@@ -828,7 +815,8 @@ export default {
|
||||
border-top-left-radius: 4px;
|
||||
}
|
||||
|
||||
.user-avatar, .bot-avatar {
|
||||
.user-avatar,
|
||||
.bot-avatar {
|
||||
align-self: flex-end;
|
||||
}
|
||||
|
||||
@@ -881,7 +869,8 @@ export default {
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.send-btn, .record-btn {
|
||||
.send-btn,
|
||||
.record-btn {
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
@@ -895,7 +884,8 @@ export default {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.image-preview, .audio-preview {
|
||||
.image-preview,
|
||||
.audio-preview {
|
||||
position: relative;
|
||||
display: inline-flex;
|
||||
}
|
||||
@@ -1003,7 +993,7 @@ export default {
|
||||
margin: 16px 0;
|
||||
}
|
||||
|
||||
.markdown-content th,
|
||||
.markdown-content th,
|
||||
.markdown-content td {
|
||||
border: 1px solid #eee;
|
||||
padding: 8px 12px;
|
||||
|
||||
Reference in New Issue
Block a user