Merge branch 'master' into Astrbot_session_manage
This commit is contained in:
@@ -16,7 +16,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||

|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
@@ -111,10 +111,14 @@ uvx astrbot init
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
#### 在 Replit 上部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### 在 雨云 上部署
|
||||
|
||||
[](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
| 平台 | 支持性 |
|
||||
|
||||
@@ -6,14 +6,14 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.17"
|
||||
VERSION = "3.5.18"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"config_version": 2,
|
||||
"platform_settings": {
|
||||
"plugin_enable": [],
|
||||
"plugin_enable": {},
|
||||
"unique_session": False,
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
@@ -54,6 +54,7 @@ DEFAULT_CONFIG = {
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
"display_reasoning_text": False,
|
||||
"identifier": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
@@ -61,6 +62,7 @@ DEFAULT_CONFIG = {
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": False,
|
||||
},
|
||||
@@ -441,7 +443,7 @@ CONFIG_METADATA_2 = {
|
||||
"ignore_bot_self_message": {
|
||||
"description": "是否忽略机器人自身的消息",
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
"hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"ignore_at_all": {
|
||||
"description": "是否忽略 @ 全体成员",
|
||||
@@ -770,17 +772,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
},
|
||||
},
|
||||
"LLMTuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"base_model_path": "",
|
||||
"adapter_model_path": "",
|
||||
"llmtuner_template": "",
|
||||
"finetuning_type": "lora",
|
||||
"quantization_bit": 4,
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
@@ -1662,6 +1653,11 @@ CONFIG_METADATA_2 = {
|
||||
"obvious_hint": True,
|
||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||
},
|
||||
"display_reasoning_text": {
|
||||
"description": "显示思考内容",
|
||||
"type": "bool",
|
||||
"hint": "开启后,将在回复中显示模型的思考过程。",
|
||||
},
|
||||
"identifier": {
|
||||
"description": "启动识别群员",
|
||||
"type": "bool",
|
||||
@@ -1699,10 +1695,15 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||
},
|
||||
"show_tool_use_status": {
|
||||
"description": "函数调用状态输出",
|
||||
"type": "bool",
|
||||
"hint": "在触发函数调用时输出其函数名和内容。",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"description": "不支持流式回复的平台分段输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -24,6 +24,8 @@ class MessageChain:
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
type: Optional[str] = None
|
||||
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||
|
||||
def message(self, message: str):
|
||||
"""添加一条文本消息到消息链 `chain` 中。
|
||||
@@ -98,6 +100,15 @@ class MessageChain:
|
||||
self.chain.append(Image.fromFileSystem(path))
|
||||
return self
|
||||
|
||||
def base64_image(self, base64_str: str):
|
||||
"""添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。
|
||||
Example:
|
||||
|
||||
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
|
||||
"""
|
||||
self.chain.append(Image.fromBase64(base64_str))
|
||||
return self
|
||||
|
||||
def use_t2i(self, use_t2i: bool):
|
||||
"""设置是否使用文本转图片服务。
|
||||
|
||||
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
|
||||
"""普通的消息结果"""
|
||||
STREAMING_RESULT = enum.auto()
|
||||
"""调用 LLM 产生的流式结果"""
|
||||
STREAMING_FINISH= enum.auto()
|
||||
STREAMING_FINISH = enum.auto()
|
||||
"""流式输出完成"""
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -9,3 +17,91 @@ class PipelineContext:
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
|
||||
async def call_event_hook(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
):
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
async def call_handler(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import abc
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ....message.message_event_result import MessageChain
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Agent 状态枚举"""
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
class BaseAgentRunner:
|
||||
@abc.abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""
|
||||
Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
Check if the agent has completed its task.
|
||||
Returns True if the agent is done, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
"""
|
||||
Get the final observation from the agent.
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,300 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
|
||||
from ...context import PipelineContext
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
AssistantMessageSegment,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# TODO:
|
||||
# 1. 处理平台不兼容的处理器
|
||||
|
||||
|
||||
class ToolLoopAgent(BaseAgentRunner):
|
||||
def __init__(
|
||||
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.req = None
|
||||
self.event = event
|
||||
self.pipeline_ctx = pipeline_ctx
|
||||
self._state = AgentState.IDLE
|
||||
self.final_llm_resp = None
|
||||
self.streaming = False
|
||||
|
||||
@override
|
||||
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
|
||||
self.req = req
|
||||
self.streaming = streaming
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
This method should return the result of the step.
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text)
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
return
|
||||
|
||||
# 处理 LLM 响应
|
||||
llm_resp = llm_resp_result
|
||||
logger.debug(f"LLMResp: {llm_resp}")
|
||||
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if not llm_resp.tools_call_name:
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
# 执行事件钩子
|
||||
await self.pipeline_ctx.call_event_hook(
|
||||
self.event, EventType.OnLLMResponseEvent, llm_resp
|
||||
)
|
||||
|
||||
# 返回 LLM 结果
|
||||
if llm_resp.result_chain:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=llm_resp.result_chain),
|
||||
)
|
||||
elif llm_resp.completion_text:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_resp.completion_text)
|
||||
),
|
||||
)
|
||||
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
role="assistant",
|
||||
tool_calls=llm_resp.to_openai_tool_calls(),
|
||||
content=llm_resp.completion_text,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
|
||||
"""处理函数工具调用。"""
|
||||
tool_call_result_blocks: list[ToolCallMessageSegment] = []
|
||||
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
|
||||
|
||||
# 执行函数调用
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if not res:
|
||||
continue
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain().base64_image(res.content[0].data)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain().base64_image(res.content[0].data)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
else:
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
# 尝试调用工具函数
|
||||
wrapper = self.pipeline_ctx.call_handler(
|
||||
self.event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
# Tool 返回结果
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resp)
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(chain=res.chain)
|
||||
|
||||
self.event.clear_result()
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import copy
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
@@ -20,40 +21,27 @@ from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
AssistantMessageSegment,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core import web_chat_back_queue
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list
|
||||
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
||||
"wake_prefix"
|
||||
] # str
|
||||
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||
"max_context_length"
|
||||
] # int
|
||||
self.dequeue_context_length = min(
|
||||
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
) # int
|
||||
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||
"streaming_response"
|
||||
] # bool
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.max_step: int = settings.get("max_agent_step", 10)
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -90,10 +78,7 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
req.contexts = self._process_tool_message_pairs(
|
||||
all_contexts, remove_tags=True
|
||||
)
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
@@ -135,26 +120,7 @@ class LLMRequestSubStage(Stage):
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req)
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
@@ -184,83 +150,63 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
async def requesting(req: ProviderRequest):
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
need_loop = False
|
||||
|
||||
# 在每次实际请求 LLM 前检查会话级别的启停状态,这可以防止插件或函数工具调用时绕过会话级别的限制
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。")
|
||||
return
|
||||
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
|
||||
final_llm_response = None
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
if self.streaming_response:
|
||||
stream = provider.text_chat_stream(**req.__dict__)
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield llm_response.result_chain # MessageChain
|
||||
else:
|
||||
yield MessageChain().message(
|
||||
llm_response.completion_text
|
||||
)
|
||||
else:
|
||||
final_llm_response = llm_response
|
||||
else:
|
||||
final_llm_response = await provider.text_chat(
|
||||
**req.__dict__
|
||||
) # 请求 LLM
|
||||
# Call Agent
|
||||
tool_loop_agent = ToolLoopAgent(
|
||||
provider=provider,
|
||||
event=event,
|
||||
pipeline_ctx=self.ctx,
|
||||
)
|
||||
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||
|
||||
if not final_llm_response:
|
||||
raise Exception("LLM response is None.")
|
||||
async def requesting():
|
||||
step_idx = 0
|
||||
while step_idx < self.max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in tool_loop_agent.step():
|
||||
if resp.type == "tool_call_result":
|
||||
continue # 跳过工具调用结果
|
||||
if resp.type == "tool_call":
|
||||
if self.streaming_response:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if self.show_tool_use or event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
# 执行 LLM 响应后的事件钩子。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMResponseEvent
|
||||
if not self.streaming_response:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
event.clear_result()
|
||||
else:
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if tool_loop_agent.done():
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, final_llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式输出的处理
|
||||
async for result in self._handle_llm_stream_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
else:
|
||||
# 非流式输出的处理
|
||||
async for result in self._handle_llm_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
@@ -269,44 +215,38 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, final_llm_response)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
if not self.streaming_response:
|
||||
event.set_extra("tool_call_result", None)
|
||||
async for _ in requesting(req):
|
||||
yield
|
||||
else:
|
||||
if self.streaming_response:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting(req))
|
||||
.set_async_stream(requesting())
|
||||
)
|
||||
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||
yield
|
||||
|
||||
if event.get_extra("tool_call_result"):
|
||||
event.set_result(event.get_extra("tool_call_result"))
|
||||
event.set_extra("tool_call_result", None)
|
||||
if tool_loop_agent.done():
|
||||
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
else:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
)
|
||||
)
|
||||
else:
|
||||
async for _ in requesting():
|
||||
yield
|
||||
|
||||
# 暂时直接发出去
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
# 异步处理 WebChat 特殊情况
|
||||
asyncio.create_task(self._handle_webchat(event, req))
|
||||
|
||||
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
|
||||
|
||||
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
# 检查会话级别的LLM启停状态,防止标题生成功能绕过会话级别限制
|
||||
@@ -324,10 +264,6 @@ class LLMRequestSubStage(Stage):
|
||||
return
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
# if len(latest_pair) > 1:
|
||||
# cleaned_text += (
|
||||
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
|
||||
# )
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await provider.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
@@ -368,322 +304,50 @@ class LLMRequestSubStage(Stage):
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_llm_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理非流式 LLM 响应。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
|
||||
Yields:
|
||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||
"""
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
if llm_response.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=llm_response.result_chain.chain
|
||||
).set_result_content_type(ResultContentType.LLM_RESULT)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT)
|
||||
)
|
||||
elif llm_response.role == "err":
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||
)
|
||||
)
|
||||
elif llm_response.role == "tool":
|
||||
# 处理函数工具调用
|
||||
async for result in self._handle_function_tools(event, req, llm_response):
|
||||
yield result
|
||||
|
||||
async def _handle_llm_stream_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理流式 LLM 响应。
|
||||
|
||||
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
|
||||
Yields:
|
||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||
"""
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
if llm_response.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=llm_response.result_chain.chain
|
||||
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||
)
|
||||
elif llm_response.role == "err":
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||
)
|
||||
)
|
||||
elif llm_response.role == "tool":
|
||||
# 处理函数工具调用
|
||||
async for result in self._handle_function_tools(event, req, llm_response):
|
||||
yield result
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理函数工具调用。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
"""
|
||||
# function calling
|
||||
tool_call_result: list[ToolCallMessageSegment] = []
|
||||
logger.info(
|
||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||
)
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if res:
|
||||
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
else:
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取处理器,过滤掉平台不兼容的处理器
|
||||
platform_id = event.get_platform_id()
|
||||
star_md = star_map.get(func_tool.handler_module_path)
|
||||
if (
|
||||
star_md
|
||||
and platform_id in star_md.supported_platforms
|
||||
and not star_md.supported_platforms[platform_id]
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||
)
|
||||
# 直接跳过,不添加任何消息到tool_call_result
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(
|
||||
self.ctx, event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
res = event.get_result()
|
||||
if res and res.chain:
|
||||
event.set_extra("tool_call_result", res)
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
)
|
||||
if tool_call_result:
|
||||
# 函数调用结果
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
assistant_msg_seg = AssistantMessageSegment(
|
||||
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||
)
|
||||
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||
req.tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=assistant_msg_seg,
|
||||
tool_calls_result=tool_call_result,
|
||||
)
|
||||
yield req # 再次执行 LLM 请求
|
||||
else:
|
||||
if llm_response.completion_text:
|
||||
event.set_result(
|
||||
MessageEventResult().message(llm_response.completion_text)
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
):
|
||||
if not req or not req.conversation or not llm_response:
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
return
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts.copy()
|
||||
contexts.append(await req.assemble_context())
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
messages.append(await req.assemble_context())
|
||||
# 这一轮对话的 LLM 响应
|
||||
if req.tool_calls_result:
|
||||
if not isinstance(req.tool_calls_result, list):
|
||||
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=messages
|
||||
)
|
||||
logger.debug(f"messages persisted: {messages}")
|
||||
|
||||
# 记录并标记函数调用结果
|
||||
if req.tool_calls_result:
|
||||
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||
|
||||
# 添加标记
|
||||
for message in tool_calls_messages:
|
||||
message["_tool_call_history"] = True
|
||||
|
||||
processed_tool_messages = self._process_tool_message_pairs(
|
||||
tool_calls_messages, remove_tags=False
|
||||
)
|
||||
|
||||
contexts.extend(processed_tool_messages)
|
||||
|
||||
contexts.append(
|
||||
{"role": "assistant", "content": llm_response.completion_text}
|
||||
)
|
||||
contexts_to_save = list(
|
||||
filter(lambda item: "_no_save" not in item, contexts)
|
||||
)
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||
)
|
||||
|
||||
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||
|
||||
Args:
|
||||
messages (list): 消息列表
|
||||
remove_tags (bool): 是否移除_tool_call_history标记
|
||||
|
||||
Returns:
|
||||
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
current_msg = messages[i]
|
||||
|
||||
# 普通消息直接添加
|
||||
if "_tool_call_history" not in current_msg:
|
||||
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 工具调用消息成对处理
|
||||
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||
assistant_msg = current_msg.copy()
|
||||
|
||||
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||
del assistant_msg["_tool_call_history"]
|
||||
|
||||
related_tools = []
|
||||
j = i + 1
|
||||
while (
|
||||
j < len(messages)
|
||||
and messages[j].get("role") == "tool"
|
||||
and "_tool_call_history" in messages[j]
|
||||
):
|
||||
tool_msg = messages[j].copy()
|
||||
|
||||
if remove_tags:
|
||||
del tool_msg["_tool_call_history"]
|
||||
|
||||
related_tools.append(tool_msg)
|
||||
j += 1
|
||||
|
||||
# 成对的时候添加到结果
|
||||
if related_tools:
|
||||
result.append(assistant_msg)
|
||||
result.extend(related_tools)
|
||||
|
||||
i = j # 跳过已处理
|
||||
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
# 单独的tool消息
|
||||
i += 1
|
||||
|
||||
return result
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
@@ -50,7 +50,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
)
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
wrapper = self.ctx.call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import inspect
|
||||
import traceback
|
||||
from astrbot.api import logger
|
||||
from typing import List, AsyncGenerator, Union, Awaitable
|
||||
from typing import List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
|
||||
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
|
||||
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _call_handler(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 待处理的事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
*args: 传递给handler的位置参数
|
||||
**kwargs: 传递给handler的关键字参数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
# 如果是一个异步生成器, 进入洋葱模型
|
||||
_has_yielded = False # 是否返回过值
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到yield分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
|
||||
@@ -158,6 +158,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
message_id = None # 重置消息 ID
|
||||
delta = "" # 重置 delta
|
||||
continue
|
||||
|
||||
# 处理消息链中的每个组件
|
||||
for i in chain.chain:
|
||||
if isinstance(i, Plain):
|
||||
|
||||
@@ -35,6 +35,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
@@ -110,6 +111,18 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
chain, session_id=self.session_id, streaming=True
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class AssistantMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
content: str = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
elif self.tool_calls:
|
||||
if self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
return ret
|
||||
|
||||
@@ -95,19 +95,19 @@ class ProviderRequest:
|
||||
"""提示词"""
|
||||
session_id: str = ""
|
||||
"""会话 ID"""
|
||||
image_urls: List[str] = None
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall = None
|
||||
func_tool: FuncCall | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: List = None
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
"""
|
||||
system_prompt: str = ""
|
||||
"""系统提示词"""
|
||||
conversation: Conversation = None
|
||||
conversation: Conversation | None = None
|
||||
|
||||
tool_calls_result: ToolCallsResult = None
|
||||
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
|
||||
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||
|
||||
def __repr__(self):
|
||||
@@ -116,6 +116,14 @@ class ProviderRequest:
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
|
||||
"""添加工具调用结果到请求中"""
|
||||
if not self.tool_calls_result:
|
||||
self.tool_calls_result = []
|
||||
if isinstance(self.tool_calls_result, ToolCallsResult):
|
||||
self.tool_calls_result = [self.tool_calls_result]
|
||||
self.tool_calls_result.append(tool_calls_result)
|
||||
|
||||
def _print_friendly_context(self):
|
||||
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||
if not self.contexts:
|
||||
|
||||
@@ -190,11 +190,6 @@ class ProviderManager:
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import (
|
||||
LLMTunerModelLoader as LLMTunerModelLoader,
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "dashscope":
|
||||
@@ -330,8 +325,6 @@ class ProviderManager:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.db_helper,
|
||||
self.provider_settings.get("persistant_history", True),
|
||||
self.selected_default_persona,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
@@ -53,15 +52,13 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
persistant_history: bool = True,
|
||||
db_helper: BaseDatabase = None,
|
||||
default_persona: Personality = None,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
self.curr_personality = default_persona
|
||||
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -86,11 +83,11 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -114,11 +111,11 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import json
|
||||
import anthropic
|
||||
import base64
|
||||
from typing import List
|
||||
from mimetypes import guess_type
|
||||
|
||||
@@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
|
||||
)
|
||||
class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
class ProviderAnthropic(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
# Skip OpenAI's __init__ and call Provider's __init__ directly
|
||||
Provider.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
self.chosen_api_key = None
|
||||
self.chosen_api_key: str = ""
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
@@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
|
||||
Args:
|
||||
messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息
|
||||
Returns:
|
||||
system_prompt: 系统提示内容
|
||||
new_messages: 处理后的消息列表,去除系统提示
|
||||
"""
|
||||
system_prompt = ""
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
elif message["role"] == "assistant":
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], str)
|
||||
else tool_call["function"]["arguments"],
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
)
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": blocks,
|
||||
}
|
||||
)
|
||||
elif message["role"] == "tool":
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_messages.append(message)
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_anthropic_style()
|
||||
if tool_list:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
completion = await self.client.messages.create(**payloads, stream=False)
|
||||
@@ -64,70 +112,157 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
|
||||
if len(completion.content) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
|
||||
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
|
||||
content = completion.content[-1]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response = LLMResponse(role="assistant")
|
||||
|
||||
if content.type == "text":
|
||||
# text completion
|
||||
completion_text = str(content.text).strip()
|
||||
# llm_response.completion_text = completion_text
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# Anthropic每次只返回一个函数调用
|
||||
if completion.stop_reason == "tool_use":
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
tool_use_ids = []
|
||||
func_name_ls.append(content.name)
|
||||
args_ls.append(content.input)
|
||||
tool_use_ids.append(content.id)
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
llm_response.tools_call_ids = tool_use_ids
|
||||
for content_block in completion.content:
|
||||
if content_block.type == "text":
|
||||
completion_text = str(content_block.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
if content_block.type == "tool_use":
|
||||
llm_response.tools_call_args.append(content_block.input)
|
||||
llm_response.tools_call_name.append(content_block.name)
|
||||
llm_response.tools_call_ids.append(content_block.id)
|
||||
# TODO(Soulter): 处理 end_turn 情况
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
# 用于累积工具调用信息
|
||||
tool_use_buffer = {}
|
||||
# 用于累积最终结果
|
||||
final_text = ""
|
||||
final_tool_calls = []
|
||||
|
||||
async with self.client.messages.stream(**payloads) as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
if event.content_block.type == "text":
|
||||
# 文本块开始
|
||||
yield LLMResponse(
|
||||
role="assistant", completion_text="", is_chunk=True
|
||||
)
|
||||
elif event.content_block.type == "tool_use":
|
||||
# 工具使用块开始,初始化缓冲区
|
||||
tool_use_buffer[event.index] = {
|
||||
"id": event.content_block.id,
|
||||
"name": event.content_block.name,
|
||||
"input": {},
|
||||
}
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "text_delta":
|
||||
# 文本增量
|
||||
final_text += event.delta.text
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=event.delta.text,
|
||||
is_chunk=True,
|
||||
)
|
||||
elif event.delta.type == "input_json_delta":
|
||||
# 工具调用参数增量
|
||||
if event.index in tool_use_buffer:
|
||||
# 累积 JSON 输入
|
||||
if "input_json" not in tool_use_buffer[event.index]:
|
||||
tool_use_buffer[event.index]["input_json"] = ""
|
||||
tool_use_buffer[event.index]["input_json"] += (
|
||||
event.delta.partial_json
|
||||
)
|
||||
|
||||
elif event.type == "content_block_stop":
|
||||
# 内容块结束
|
||||
if event.index in tool_use_buffer:
|
||||
# 解析完整的工具调用
|
||||
tool_info = tool_use_buffer[event.index]
|
||||
try:
|
||||
if "input_json" in tool_info:
|
||||
tool_info["input"] = json.loads(tool_info["input_json"])
|
||||
|
||||
# 添加到最终结果
|
||||
final_tool_calls.append(
|
||||
{
|
||||
"id": tool_info["id"],
|
||||
"name": tool_info["name"],
|
||||
"input": tool_info["input"],
|
||||
}
|
||||
)
|
||||
|
||||
yield LLMResponse(
|
||||
role="tool",
|
||||
completion_text="",
|
||||
tools_call_args=[tool_info["input"]],
|
||||
tools_call_name=[tool_info["name"]],
|
||||
tools_call_ids=[tool_info["id"]],
|
||||
is_chunk=True,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# JSON 解析失败,跳过这个工具调用
|
||||
logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}")
|
||||
|
||||
# 清理缓冲区
|
||||
del tool_use_buffer[event.index]
|
||||
|
||||
# 返回最终的完整结果
|
||||
final_response = LLMResponse(
|
||||
role="assistant", completion_text=final_text, is_chunk=False
|
||||
)
|
||||
|
||||
if final_tool_calls:
|
||||
final_response.tools_call_args = [
|
||||
call["input"] for call in final_tool_calls
|
||||
]
|
||||
final_response.tools_call_name = [call["name"] for call in final_tool_calls]
|
||||
final_response.tools_call_ids = [call["id"] for call in final_tool_calls]
|
||||
|
||||
yield final_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
if not prompt:
|
||||
prompt = "<image>"
|
||||
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
# 暂时这样写。
|
||||
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
|
||||
if not isinstance(tool_calls_result, list):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
payloads["system"] = system_prompt
|
||||
@@ -135,32 +270,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
response = await self.client.messages.create(
|
||||
messages=context_query, **model_config
|
||||
)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
||||
llm_response.raw_completion = response
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -175,21 +287,34 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
payloads["system"] = system_prompt
|
||||
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
@@ -232,3 +357,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
)
|
||||
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""
|
||||
将图片转换为 base64
|
||||
"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.chosen_api_key
|
||||
|
||||
async def get_models(self) -> List[str]:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
models = sorted(models.data, key=lambda x: x.id)
|
||||
for model in models:
|
||||
models_str.append(model.id)
|
||||
return models_str
|
||||
|
||||
def set_key(self, key: str):
|
||||
self.chosen_api_key = key
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
@@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality = None,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
@@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
|
||||
@@ -12,8 +12,7 @@ from google.genai.errors import APIError
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Personality, Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = provider_config.get("key", [])
|
||||
@@ -264,12 +259,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
contents.append(content_cls(parts=part))
|
||||
|
||||
gemini_contents: list[types.Content] = []
|
||||
native_tool_enabled = any(
|
||||
[
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
self.provider_config.get("gm_native_search", False),
|
||||
]
|
||||
)
|
||||
native_tool_enabled = any([
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
self.provider_config.get("gm_native_search", False),
|
||||
])
|
||||
for message in payloads["messages"]:
|
||||
role, content = message["role"], message.get("content")
|
||||
|
||||
@@ -506,12 +499,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
@@ -527,7 +520,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
if not isinstance(tool_calls_result, list):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
@@ -631,9 +628,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}}
|
||||
)
|
||||
user_content["content"].append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
|
||||
)
|
||||
class LLMTunerModelLoader(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
||||
provider_config["adapter_model_path"]
|
||||
):
|
||||
raise FileNotFoundError("模型文件路径不存在。")
|
||||
self.base_model_path = provider_config["base_model_path"]
|
||||
self.adapter_model_path = provider_config["adapter_model_path"]
|
||||
self.model = ChatModel(
|
||||
{
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config["llmtuner_template"],
|
||||
"finetuning_type": provider_config["finetuning_type"],
|
||||
"quantization_bit": provider_config["quantization_bit"],
|
||||
}
|
||||
)
|
||||
self.set_model(
|
||||
os.path.basename(self.base_model_path)
|
||||
+ "_"
|
||||
+ os.path.basename(self.adapter_model_path)
|
||||
)
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""
|
||||
组装上下文。
|
||||
"""
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(query_context):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
|
||||
if "_no_save" in context:
|
||||
del context["_no_save"]
|
||||
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + query_context.pop(idx)["content"]
|
||||
|
||||
conf = {
|
||||
"messages": query_context,
|
||||
"system": system_prompt,
|
||||
}
|
||||
if func_tool:
|
||||
tool_list = func_tool.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
conf["tools"] = tool_list
|
||||
|
||||
responses = await self.model.achat(**conf)
|
||||
|
||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
|
||||
async def set_key(self, key):
|
||||
pass
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
@@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List, AsyncGenerator
|
||||
@@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
@@ -224,12 +218,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def _prepare_chat_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -246,14 +238,18 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
if isinstance(tool_calls_result, ToolCallsResult):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
return payloads, context_query, func_tool
|
||||
return payloads, context_query
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
@@ -352,11 +348,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
@@ -422,11 +416,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
@@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
db_helper,
|
||||
persistant_history,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# What's Changed
|
||||
|
||||
> 重构了大模型请求部分,如果发现此部分使用时有问题请提交 issue
|
||||
|
||||
1. 修复: 安装插件按钮被删除、无法自定义安装插件
|
||||
2. 修复: 环境变量中的代理地址无法生效
|
||||
1. 修复: randomize jwt secret
|
||||
2. 修复: 在 Node 消息段发送简单文本信息的问题
|
||||
1. 修复: QQ 官方机器人适配器使用 SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台
|
||||
4. 修复: Discord 适配器无法优雅重载
|
||||
1. 修复: Telegram 适配器无法主动回复
|
||||
1. 修复: 仪表盘的『插件配置』中不显示代码编辑器
|
||||
3. 新增: Gemini TTS API
|
||||
1. 新增: 允许 html_render 方法传入 Playwright.screenshot 配置参数
|
||||
1. 优化: 修复 CommandFilter 支持对布尔类型进行解析
|
||||
4. 新增: WechatPadPro 发送 TTS 时 添加对 MP3 格式音频支持
|
||||
1. 重构: 将大模型请求部分抽象成 AgentRunner,提高可读性和可扩展性,工具调用结果支持持久化保存到数据库,完善 Agent 的多轮工具调用能力。
|
||||
1. 移除: LLMTuner 模型提供商适配器。请使用 Ollama 来加载微调模型
|
||||
@@ -29,7 +29,7 @@ let dashboardCurrentVersion = ref('');
|
||||
let version = ref('');
|
||||
let releases = ref([]);
|
||||
let devCommits = ref([]);
|
||||
|
||||
let updatingDashboardLoading = ref(false);
|
||||
let installLoading = ref(false);
|
||||
|
||||
let tab = ref(0);
|
||||
@@ -217,6 +217,7 @@ function switchVersion(version: string) {
|
||||
}
|
||||
|
||||
function updateDashboard() {
|
||||
updatingDashboardLoading.value = true;
|
||||
updateStatus.value = t('core.header.updateDialog.status.updating');
|
||||
axios.post('/api/update/dashboard')
|
||||
.then((res) => {
|
||||
@@ -230,7 +231,9 @@ function updateDashboard() {
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
}).finally(() => {
|
||||
updatingDashboardLoading.value = false;
|
||||
});
|
||||
}
|
||||
|
||||
function toggleDarkMode() {
|
||||
@@ -416,7 +419,7 @@ commonStore.getStartTime();
|
||||
</div>
|
||||
|
||||
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()"
|
||||
:disabled="!dashboardHasNewVersion">
|
||||
:disabled="!dashboardHasNewVersion" :loading="updatingDashboardLoading">
|
||||
{{ t('core.header.updateDialog.dashboardUpdate.downloadAndUpdate') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
@@ -39,7 +39,8 @@ const PurpleThemeDark: ThemeTypes = {
|
||||
background: '#111111',
|
||||
overlay: '#111111aa',
|
||||
codeBg: '#282833',
|
||||
code: '#ffffffdd'
|
||||
code: '#ffffffdd',
|
||||
chatMessageBubble: '#2d2e30',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ const PurpleTheme: ThemeTypes = {
|
||||
lightsuccess: '#b9f6ca',
|
||||
lighterror: '#f9d8d8',
|
||||
lightwarning: '#fff8e1',
|
||||
primaryText: '#000000dd',
|
||||
primaryText: '#1b1c1d',
|
||||
secondaryText: '#000000aa',
|
||||
darkprimary: '#1565c0',
|
||||
darksecondary: '#4527a0',
|
||||
@@ -39,7 +39,8 @@ const PurpleTheme: ThemeTypes = {
|
||||
background: '#f9fafcf4',
|
||||
overlay: '#ffffffaa',
|
||||
codeBg: '#f5f0ff',
|
||||
code: '#673ab7'
|
||||
code: '#673ab7',
|
||||
chatMessageBubble: '#e7ebf4',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -35,5 +35,6 @@ export type ThemeTypes = {
|
||||
secondary200?: string;
|
||||
codeBg?: string;
|
||||
code?: string;
|
||||
chatMessageBubble?: string;
|
||||
};
|
||||
};
|
||||
|
||||
+122
-228
@@ -3,13 +3,16 @@
|
||||
<v-card-text class="chat-page-container">
|
||||
<div class="chat-layout">
|
||||
<div class="sidebar-panel" :class="{ 'sidebar-collapsed': sidebarCollapsed }"
|
||||
:style="{ 'background-color': isDark ? sidebarCollapsed ? '#1e1e1e' : '#2d2d2d' : sidebarCollapsed ? '#ffffff' : '#f5f5f5' }"
|
||||
@mouseenter="handleSidebarMouseEnter" @mouseleave="handleSidebarMouseLeave">
|
||||
|
||||
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;" v-if="chatboxMode">
|
||||
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;"
|
||||
v-if="chatboxMode">
|
||||
<img width="50" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
|
||||
<span v-if="!sidebarCollapsed" style="font-weight: 1000; font-size: 26px; margin-left: 8px;" class="text-secondary">AstrBot</span>
|
||||
<span v-if="!sidebarCollapsed" style="font-weight: 1000; font-size: 26px; margin-left: 8px;"
|
||||
class="text-secondary">AstrBot</span>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
<div class="sidebar-collapse-btn-container">
|
||||
<v-btn icon class="sidebar-collapse-btn" @click="toggleSidebar" variant="text"
|
||||
@@ -21,7 +24,9 @@
|
||||
|
||||
<div style="padding: 16px; padding-top: 8px;">
|
||||
<v-btn block variant="text" class="new-chat-btn" @click="newC" :disabled="!currCid"
|
||||
v-if="!sidebarCollapsed" prepend-icon="mdi-plus" style="box-shadow: 0 1px 2px rgba(0,0,0,0.1); background-color: transparent !important; border-radius: 4px;">{{ tm('actions.newChat') }}</v-btn>
|
||||
v-if="!sidebarCollapsed" prepend-icon="mdi-plus"
|
||||
style="background-color: transparent !important; border-radius: 4px;">{{
|
||||
tm('actions.newChat') }}</v-btn>
|
||||
<v-btn icon="mdi-plus" rounded="lg" @click="newC" :disabled="!currCid" v-if="sidebarCollapsed"
|
||||
elevation="0"></v-btn>
|
||||
</div>
|
||||
@@ -29,11 +34,12 @@
|
||||
<v-divider class="mx-2"></v-divider>
|
||||
</div>
|
||||
|
||||
<div style="overflow-y: auto; flex-grow: 1;" class="sidebar-panel" :class="{ 'fade-in': sidebarHoverExpanded }"
|
||||
|
||||
<div style="overflow-y: auto; flex-grow: 1;" :class="{ 'fade-in': sidebarHoverExpanded }"
|
||||
v-if="!sidebarCollapsed">
|
||||
<v-card class="conversation-list-card" v-if="conversations.length > 0" flat>
|
||||
<v-card v-if="conversations.length > 0" flat style="background-color: transparent;">
|
||||
<v-list density="compact" nav class="conversation-list"
|
||||
@update:selected="getConversationMessages">
|
||||
style="background-color: transparent;" @update:selected="getConversationMessages">
|
||||
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
|
||||
rounded="lg" class="conversation-item" active-color="secondary">
|
||||
<v-list-item-title v-if="!sidebarCollapsed" class="conversation-title">{{ item.title
|
||||
@@ -62,42 +68,13 @@
|
||||
<div v-if="!sidebarCollapsed">
|
||||
<v-divider class="mx-2"></v-divider>
|
||||
</div>
|
||||
<div style="padding: 16px;" :class="{ 'fade-in': sidebarHoverExpanded }"
|
||||
v-if="!sidebarCollapsed">
|
||||
<div class="sidebar-section-title">
|
||||
{{ tm('conversation.systemStatus') }}
|
||||
</div>
|
||||
<div class="status-chips">
|
||||
<v-chip class="status-chip" :color="status?.llm_enabled ? 'primary' : 'grey-lighten-2'"
|
||||
variant="outlined" size="small" rounded="sm">
|
||||
<template v-slot:prepend>
|
||||
<v-icon :icon="status?.llm_enabled ? 'mdi-check-circle' : 'mdi-alert-circle'"
|
||||
size="x-small"></v-icon>
|
||||
</template>
|
||||
<span>{{ tm('conversation.llmService') }}</span>
|
||||
</v-chip>
|
||||
|
||||
<v-chip class="status-chip" :color="status?.stt_enabled ? 'success' : 'grey-lighten-2'"
|
||||
variant="outlined" size="small" rounded="sm">
|
||||
<template v-slot:prepend>
|
||||
<v-icon :icon="status?.stt_enabled ? 'mdi-check-circle' : 'mdi-alert-circle'"
|
||||
size="x-small"></v-icon>
|
||||
</template>
|
||||
<span>{{ tm('conversation.speechToText') }}</span>
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<transition
|
||||
name="expand"
|
||||
@before-enter="beforeEnter"
|
||||
@enter="enter"
|
||||
@after-enter="afterEnter"
|
||||
@before-leave="beforeLeave"
|
||||
@leave="leave"
|
||||
>
|
||||
<div style="padding: 16px;" :class="{ 'fade-in': sidebarHoverExpanded }" v-if="!sidebarCollapsed">
|
||||
<transition name="expand" @before-enter="beforeEnter" @enter="enter" @after-enter="afterEnter"
|
||||
@before-leave="beforeLeave" @leave="leave">
|
||||
<div v-if="currCid" class="delete-btn-container">
|
||||
<v-btn variant="outlined" rounded="sm" class="delete-chat-btn"
|
||||
@click="deleteConversation(currCid)" color="error" density="comfortable" size="small">
|
||||
@click="deleteConversation(currCid)" color="error" density="comfortable"
|
||||
size="small">
|
||||
<v-icon start size="small">mdi-delete</v-icon>
|
||||
{{ tm('actions.deleteChat') }}
|
||||
</v-btn>
|
||||
@@ -111,14 +88,18 @@
|
||||
|
||||
<div class="conversation-header fade-in">
|
||||
<div class="conversation-header-content" v-if="currCid && getCurrentConversation">
|
||||
<h2 class="conversation-header-title">{{ getCurrentConversation.title || tm('conversation.newConversation') }}</h2>
|
||||
<div class="conversation-header-time">{{ formatDate(getCurrentConversation.updated_at) }}</div>
|
||||
<h2 class="conversation-header-title">{{ getCurrentConversation.title ||
|
||||
tm('conversation.newConversation')
|
||||
}}</h2>
|
||||
<div class="conversation-header-time">{{ formatDate(getCurrentConversation.updated_at) }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="conversation-header-actions">
|
||||
<!-- router 推送到 /chatbox -->
|
||||
<v-tooltip :text="tm('actions.fullscreen')" v-if="!chatboxMode">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-icon v-bind="props" @click="router.push(currCid ? `/chatbox/${currCid}` : '/chatbox')"
|
||||
<v-icon v-bind="props"
|
||||
@click="router.push(currCid ? `/chatbox/${currCid}` : '/chatbox')"
|
||||
class="fullscreen-icon">mdi-fullscreen</v-icon>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
@@ -131,7 +112,8 @@
|
||||
<!-- 主题切换按钮 -->
|
||||
<v-tooltip :text="isDark ? tm('modes.lightMode') : tm('modes.darkMode')" v-if="chatboxMode">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" icon @click="toggleTheme" class="theme-toggle-icon" variant="text">
|
||||
<v-btn v-bind="props" icon @click="toggleTheme" class="theme-toggle-icon"
|
||||
variant="text">
|
||||
<v-icon>{{ isDark ? 'mdi-weather-night' : 'mdi-white-balance-sunny' }}</v-icon>
|
||||
</v-btn>
|
||||
</template>
|
||||
@@ -176,7 +158,8 @@
|
||||
<div class="message-item fade-in" v-for="(msg, index) in messages" :key="index">
|
||||
<!-- 用户消息 -->
|
||||
<div v-if="msg.type == 'user'" class="user-message">
|
||||
<div class="message-bubble user-bubble">
|
||||
<div class="message-bubble user-bubble"
|
||||
:style="{ backgroundColor: isDark ? '#2d2e30' : '#e7ebf4' }">
|
||||
<span>{{ msg.message }}</span>
|
||||
|
||||
<!-- 图片附件 -->
|
||||
@@ -195,15 +178,12 @@
|
||||
</audio>
|
||||
</div>
|
||||
</div>
|
||||
<v-avatar class="user-avatar" color="deep-purple-lighten-3" size="36">
|
||||
<v-icon icon="mdi-account" />
|
||||
</v-avatar>
|
||||
</div>
|
||||
|
||||
<!-- 机器人消息 -->
|
||||
<div v-else class="bot-message">
|
||||
<v-avatar class="bot-avatar" color="deep-purple" size="36">
|
||||
<span class="text-h6">✨</span>
|
||||
<v-avatar class="bot-avatar" size="36">
|
||||
<span class="text-h2">✨</span>
|
||||
</v-avatar>
|
||||
<div class="message-bubble bot-bubble">
|
||||
<div v-html="marked(msg.message)" class="markdown-content"></div>
|
||||
@@ -215,34 +195,21 @@
|
||||
|
||||
<!-- 输入区域 -->
|
||||
<div class="input-area fade-in">
|
||||
<v-text-field autocomplete="off" id="input-field" variant="outlined" v-model="prompt"
|
||||
:label="inputFieldLabel" :placeholder="tm('input.placeholder')" :loading="loadingChat"
|
||||
clear-icon="mdi-close-circle" clearable @click:clear="clearMessage" class="message-input"
|
||||
@keydown="handleInputKeyDown" hide-details>
|
||||
<template v-slot:loader>
|
||||
<v-progress-linear :active="loadingChat" height="3" color="deep-purple"
|
||||
indeterminate></v-progress-linear>
|
||||
</template>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-tooltip :text="tm('input.send')">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" @click="sendMessage" class="send-btn" icon="mdi-send"
|
||||
variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl" />
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-tooltip :text="tm('input.voice')">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" @click="isRecording ? stopRecording() : startRecording()"
|
||||
class="record-btn"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'" />
|
||||
</template>
|
||||
</v-tooltip>
|
||||
</template>
|
||||
</v-text-field>
|
||||
<div
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px; padding: 4px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown"
|
||||
@click:clear="clearMessage" placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 12px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"
|
||||
:disabled="loadingChat"></textarea>
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px;">
|
||||
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
|
||||
class="send-btn" size="small" />
|
||||
<v-btn @click="isRecording ? stopRecording() : startRecording()"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'" class="record-btn" size="small" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 附件预览区 -->
|
||||
<div class="attachments-preview" v-if="stagedImagesUrl.length > 0 || stagedAudioUrl">
|
||||
@@ -263,16 +230,17 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-card>
|
||||
<!-- 编辑对话标题对话框 -->
|
||||
<v-dialog v-model="editTitleDialog" max-width="400">
|
||||
<v-card>
|
||||
<v-card-title class="dialog-title">{{ tm('actions.editTitle') }}</v-card-title>
|
||||
<v-card-text>
|
||||
<v-text-field v-model="editingTitle" :label="tm('conversation.newConversation')" variant="outlined" hide-details class="mt-2"
|
||||
@keyup.enter="saveTitle" autofocus />
|
||||
<v-text-field v-model="editingTitle" :label="tm('conversation.newConversation')" variant="outlined"
|
||||
hide-details class="mt-2" @keyup.enter="saveTitle" autofocus />
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
@@ -280,9 +248,9 @@
|
||||
<v-btn text @click="saveTitle" color="primary">{{ t('core.common.save') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog> <!-- 连接冲突提示对话框 -->
|
||||
</v-dialog> <!-- 连接冲突提示对话框 -->
|
||||
<v-dialog v-model="connectionConflictDialog" max-width="600" persistent>
|
||||
<v-card class="rounded-lg">
|
||||
<v-card class="rounded-lg">
|
||||
<v-toolbar color="primary" density="comfortable" flat>
|
||||
<v-icon color="white" class="ml-4 mr-2">mdi-information-outline</v-icon>
|
||||
<v-toolbar-title class="text-white">{{ tm('connection.title') }}</v-toolbar-title>
|
||||
@@ -296,13 +264,8 @@
|
||||
<div class="text-body-1 mb-4">
|
||||
{{ tm('connection.message') }}
|
||||
</div>
|
||||
|
||||
<v-alert
|
||||
type="info"
|
||||
variant="tonal"
|
||||
class="mb-4"
|
||||
icon="mdi-lightbulb-outline"
|
||||
>
|
||||
|
||||
<v-alert type="info" variant="tonal" class="mb-4" icon="mdi-lightbulb-outline">
|
||||
<div class="text-body-2 mb-2">
|
||||
<strong>{{ tm('connection.reasons') }}</strong>
|
||||
</div>
|
||||
@@ -313,12 +276,7 @@
|
||||
</ul>
|
||||
</v-alert>
|
||||
|
||||
<v-alert
|
||||
type="warning"
|
||||
variant="tonal"
|
||||
icon="mdi-alert-circle-outline"
|
||||
class="mb-0"
|
||||
>
|
||||
<v-alert type="warning" variant="tonal" icon="mdi-alert-circle-outline" class="mb-0">
|
||||
<div class="text-body-2">
|
||||
{{ tm('connection.notice') }}
|
||||
</div>
|
||||
@@ -327,12 +285,7 @@
|
||||
|
||||
<v-card-actions class="px-6 pb-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="connectionConflictDialog = false"
|
||||
class="px-6"
|
||||
>
|
||||
<v-btn color="primary" variant="elevated" @click="connectionConflictDialog = false" class="px-6">
|
||||
{{ tm('connection.understand') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -340,15 +293,10 @@
|
||||
</v-dialog>
|
||||
|
||||
<!-- 连接状态消息提示 -->
|
||||
<v-snackbar
|
||||
v-model="connectionStatusSnackbar"
|
||||
:color="connectionStatusColor"
|
||||
:timeout="4000"
|
||||
location="top"
|
||||
>
|
||||
<v-snackbar v-model="connectionStatusSnackbar" :color="connectionStatusColor" :timeout="4000" location="top">
|
||||
<v-icon class="mr-2">
|
||||
{{ connectionStatusColor === 'success' ? 'mdi-check-circle' :
|
||||
connectionStatusColor === 'warning' ? 'mdi-alert-circle' : 'mdi-information' }}
|
||||
{{ connectionStatusColor === 'success' ? 'mdi-check-circle' :
|
||||
connectionStatusColor === 'warning' ? 'mdi-alert-circle' : 'mdi-information' }}
|
||||
</v-icon>
|
||||
{{ connectionStatusMessage }}
|
||||
</v-snackbar>
|
||||
@@ -377,10 +325,10 @@ export default {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
}, setup() {
|
||||
}, setup() {
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
|
||||
|
||||
return {
|
||||
t,
|
||||
tm,
|
||||
@@ -413,7 +361,7 @@ export default {
|
||||
eventSourceReader: null,
|
||||
sseReconnecting: false, // 添加重连状态标志
|
||||
|
||||
|
||||
|
||||
// // Ctrl键长按相关变量
|
||||
ctrlKeyDown: false,
|
||||
ctrlKeyTimer: null,
|
||||
@@ -440,7 +388,7 @@ export default {
|
||||
connectionStatusColor: 'info',
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
computed: {
|
||||
isDark() {
|
||||
return useCustomizerStore().uiTheme === 'PurpleThemeDark';
|
||||
@@ -458,13 +406,13 @@ export default {
|
||||
immediate: true,
|
||||
handler(to, from) {
|
||||
console.log('Route changed:', to.path, 'from:', from?.path); // 如果是从不同的路由模式切换(chat <-> chatbox),重新建立SSE连接
|
||||
if (from &&
|
||||
if (from &&
|
||||
((from.path.startsWith('/chat') && to.path.startsWith('/chatbox')) ||
|
||||
(from.path.startsWith('/chatbox') && to.path.startsWith('/chat')))) {
|
||||
(from.path.startsWith('/chatbox') && to.path.startsWith('/chat')))) {
|
||||
console.log('Route mode changed, reconnecting SSE...');
|
||||
this.reconnectSSE();
|
||||
}
|
||||
|
||||
|
||||
// Check if the route matches /chat/<cid> or /chatbox/<cid> pattern
|
||||
if (to.path.startsWith('/chat/') || to.path.startsWith('/chatbox/')) {
|
||||
const pathCid = to.path.split('/')[2];
|
||||
@@ -484,7 +432,7 @@ export default {
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
// Watch for conversations loaded to handle pending cid
|
||||
conversations: {
|
||||
handler(newConversations) {
|
||||
@@ -530,7 +478,7 @@ export default {
|
||||
|
||||
beforeUnmount() {
|
||||
this.disconnectSSE();
|
||||
|
||||
|
||||
// 移除keyup事件监听
|
||||
document.removeEventListener('keyup', this.handleInputKeyUp);
|
||||
|
||||
@@ -541,7 +489,7 @@ export default {
|
||||
|
||||
// Cleanup blob URLs
|
||||
this.cleanupMediaCache();
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
// 显示连接冲突对话框
|
||||
showConnectionConflictDialog() {
|
||||
@@ -661,7 +609,7 @@ export default {
|
||||
}
|
||||
this.eventSourceReader = null;
|
||||
}
|
||||
|
||||
|
||||
if (this.eventSource) {
|
||||
try {
|
||||
this.eventSource.cancel();
|
||||
@@ -679,33 +627,33 @@ export default {
|
||||
console.log('SSE reconnection already in progress');
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
this.sseReconnecting = true;
|
||||
console.log('Reconnecting SSE...');
|
||||
this.disconnectSSE();
|
||||
|
||||
|
||||
// 等待更长时间确保后端连接完全清理
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
|
||||
this.startListeningEvent();
|
||||
},
|
||||
|
||||
async startListeningEvent() {
|
||||
// 确保之前的连接已断开
|
||||
this.disconnectSSE();
|
||||
|
||||
|
||||
// 如果正在重连过程中,等待一下
|
||||
if (this.sseReconnecting) {
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
}
|
||||
|
||||
|
||||
let retryCount = 0;
|
||||
const maxRetries = 3;
|
||||
|
||||
|
||||
while (retryCount < maxRetries) {
|
||||
try {
|
||||
console.log(`尝试建立SSE连接 (${retryCount + 1}/${maxRetries})`);
|
||||
|
||||
|
||||
const response = await fetch('/api/chat/listen', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
@@ -719,13 +667,13 @@ export default {
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
const decoder = new TextDecoder();
|
||||
this.eventSource = reader;
|
||||
this.eventSourceReader = reader;
|
||||
this.sseReconnecting = false;
|
||||
|
||||
let in_streaming = false;
|
||||
let message_obj = null;
|
||||
let message_obj = null;
|
||||
console.log('SSE连接已建立');
|
||||
// 显示连接成功状态
|
||||
if (retryCount > 0) {
|
||||
@@ -859,28 +807,28 @@ export default {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 如果成功连接并正常结束,跳出重试循环
|
||||
break;
|
||||
|
||||
|
||||
} catch (error) {
|
||||
console.error(`SSE连接错误 (尝试 ${retryCount + 1}):`, error);
|
||||
|
||||
retryCount++;
|
||||
|
||||
retryCount++;
|
||||
if (error.message === 'CONNECTION_CONFLICT' && retryCount < maxRetries) {
|
||||
console.log(`连接冲突,等待 ${2000 * retryCount}ms 后重试...`);
|
||||
this.showConnectionStatus(`${this.tm('connection.status.reconnecting')} (${retryCount}/${maxRetries})`, 'warning');
|
||||
await new Promise(resolve => setTimeout(resolve, 2000 * retryCount));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
if (retryCount >= maxRetries) {
|
||||
console.error('SSE连接重试次数已达上限');
|
||||
this.showConnectionStatus(this.tm('connection.status.failed'), 'error');
|
||||
this.sseReconnecting = false;
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// 等待一段时间后重试
|
||||
await new Promise(resolve => setTimeout(resolve, 1000 * retryCount));
|
||||
} finally {
|
||||
@@ -888,7 +836,7 @@ export default {
|
||||
this.eventSourceReader = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
this.sseReconnecting = false;
|
||||
},
|
||||
|
||||
@@ -984,7 +932,7 @@ export default {
|
||||
getConversations() {
|
||||
axios.get('/api/chat/conversations').then(response => {
|
||||
this.conversations = response.data.data;
|
||||
|
||||
|
||||
// If there's a pending conversation ID from the route
|
||||
if (this.pendingCid) {
|
||||
const conversation = this.conversations.find(c => c.cid === this.pendingCid);
|
||||
@@ -1003,7 +951,7 @@ export default {
|
||||
getConversationMessages(cid) {
|
||||
if (!cid[0])
|
||||
return;
|
||||
|
||||
|
||||
// Update the URL to reflect the selected conversation
|
||||
if (this.$route.path !== `/chat/${cid[0]}` && this.$route.path !== `/chatbox/${cid[0]}`) {
|
||||
if (this.$route.path.startsWith('/chatbox')) {
|
||||
@@ -1013,7 +961,7 @@ export default {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(async response => {
|
||||
this.currCid = cid[0];
|
||||
let message = JSON.parse(response.data.data.history);
|
||||
@@ -1101,9 +1049,9 @@ export default {
|
||||
|
||||
// 检查是否可以发送消息
|
||||
canSendMessage() {
|
||||
return (this.prompt && this.prompt.trim()) ||
|
||||
this.stagedImagesName.length > 0 ||
|
||||
this.stagedAudioUrl;
|
||||
return (this.prompt && this.prompt.trim()) ||
|
||||
this.stagedImagesName.length > 0 ||
|
||||
this.stagedAudioUrl;
|
||||
},
|
||||
|
||||
async sendMessage() {
|
||||
@@ -1183,11 +1131,11 @@ export default {
|
||||
const container = this.$refs.messageContainer;
|
||||
container.scrollTop = container.scrollHeight;
|
||||
});
|
||||
},
|
||||
},
|
||||
handleInputKeyDown(e) {
|
||||
if (e.ctrlKey && e.keyCode === 66) { // Ctrl+B组合键
|
||||
e.preventDefault(); // 防止默认行为
|
||||
|
||||
|
||||
// 防止重复触发
|
||||
if (this.ctrlKeyDown) return;
|
||||
|
||||
@@ -1200,7 +1148,7 @@ export default {
|
||||
}
|
||||
}, this.ctrlKeyLongPressThreshold);
|
||||
}
|
||||
},
|
||||
},
|
||||
handleInputKeyUp(e) {
|
||||
if (e.keyCode === 66) { // B键释放
|
||||
this.ctrlKeyDown = false;
|
||||
@@ -1308,26 +1256,27 @@ export default {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
|
||||
/* 聊天页面布局 */
|
||||
.chat-page-card {
|
||||
margin-bottom: 16px;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
border-radius: 16px;
|
||||
height: calc(100vh - 84px);
|
||||
max-height: 100%;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05) !important;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.chat-page-container {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
max-height: calc(100vh - 120px);
|
||||
max-height: 100%;
|
||||
padding: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.chat-layout {
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
display: flex;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar-panel {
|
||||
@@ -1337,29 +1286,11 @@ export default {
|
||||
flex-direction: column;
|
||||
padding: 0;
|
||||
border-right: 1px solid rgba(0, 0, 0, 0.05);
|
||||
background-color: var(--v-theme-containerBg);
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
position: relative;
|
||||
transition: all 0.3s ease;
|
||||
overflow: hidden;
|
||||
/* 防止内容溢出 */
|
||||
}
|
||||
|
||||
.sidebar-panel ::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.sidebar-panel ::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.sidebar-panel ::-webkit-scrollbar-thumb {
|
||||
background-color: rgba(0, 0, 0, 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.sidebar-panel ::-webkit-scrollbar-thumb:hover {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
/* 侧边栏折叠状态 */
|
||||
@@ -1442,7 +1373,7 @@ export default {
|
||||
.status-chips .v-chip {
|
||||
flex: 1 1 0;
|
||||
justify-content: center;
|
||||
opacity: 0.7; /* Make border and text slightly transparent */
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.status-chip {
|
||||
@@ -1470,7 +1401,8 @@ export default {
|
||||
}
|
||||
|
||||
.delete-btn-container {
|
||||
/* margin-top: -8px; */ /* Removed for better layout practices */
|
||||
/* margin-top: -8px; */
|
||||
/* Removed for better layout practices */
|
||||
}
|
||||
|
||||
.expand-enter-active,
|
||||
@@ -1495,20 +1427,24 @@ export default {
|
||||
transition: opacity 0.25s ease;
|
||||
}
|
||||
|
||||
/* 聊天内容区域 */
|
||||
.chat-content-panel {
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
width: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.messages-container {
|
||||
height: calc(100% - 80px);
|
||||
height: 100%;
|
||||
max-height: 100%;
|
||||
overflow-y: auto;
|
||||
padding: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
/* 欢迎页样式 */
|
||||
@@ -1574,28 +1510,26 @@ export default {
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
padding: 12px 16px;
|
||||
border-radius: 18px;
|
||||
padding: 8px 16px;
|
||||
border-radius: 12px;
|
||||
max-width: 80%;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.user-bubble {
|
||||
background-color: var(--v-theme-background);
|
||||
color: var(--v-theme-primaryText);
|
||||
border-top-right-radius: 4px;
|
||||
padding: 12px 16px;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.bot-bubble {
|
||||
background-color: var(--v-theme-surface);
|
||||
border: 1px solid var(--v-theme-border);
|
||||
color: var(--v-theme-primaryText);
|
||||
border-top-left-radius: 4px;
|
||||
}
|
||||
|
||||
.user-avatar,
|
||||
.bot-avatar {
|
||||
align-self: flex-end;
|
||||
align-self: flex-start;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
/* 附件样式 */
|
||||
@@ -1639,17 +1573,8 @@ export default {
|
||||
background-color: var(--v-theme-surface);
|
||||
position: relative;
|
||||
border-top: 1px solid var(--v-theme-border);
|
||||
}
|
||||
|
||||
.message-input {
|
||||
border-radius: 24px;
|
||||
max-width: 900px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.send-btn,
|
||||
.record-btn {
|
||||
margin-left: 4px;
|
||||
flex-shrink: 0;
|
||||
/* 防止输入区域被压缩 */
|
||||
}
|
||||
|
||||
/* 附件预览区 */
|
||||
@@ -1803,38 +1728,7 @@ export default {
|
||||
border-bottom: 1px solid var(--v-theme-border);
|
||||
width: 100%;
|
||||
padding-right: 32px;
|
||||
}
|
||||
|
||||
.conversation-header-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.conversation-header-title {
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
margin: 0;
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.conversation-header-time {
|
||||
font-size: 12px;
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.conversation-header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.fullscreen-icon {
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.fullscreen-icon:hover {
|
||||
opacity: 1;
|
||||
flex-shrink: 0;
|
||||
/* 防止header被压缩 */
|
||||
}
|
||||
</style>
|
||||
@@ -414,7 +414,6 @@ export default {
|
||||
"anthropic_chat_completion": "chat_completion",
|
||||
"googlegenai_chat_completion": "chat_completion",
|
||||
"zhipu_chat_completion": "chat_completion",
|
||||
"llm_tuner": "chat_completion",
|
||||
"dify": "chat_completion",
|
||||
"dashscope": "chat_completion",
|
||||
"openai_whisper_api": "speech_to_text",
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
import re
|
||||
from astrbot.api.event import filter, AstrMessageEvent
|
||||
from astrbot.api.star import Context, Star, register
|
||||
from astrbot.api.provider import LLMResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
|
||||
@register(
|
||||
"thinking_filter",
|
||||
"Soulter",
|
||||
"可选择是否过滤推理模型的思考内容",
|
||||
"1.0.0",
|
||||
"https://astrbot.app",
|
||||
)
|
||||
class R1Filter(Star):
|
||||
def __init__(self, context: Context):
|
||||
super().__init__(context)
|
||||
self.display_reasoning_text = (
|
||||
self.context.get_config()
|
||||
.get("provider_settings", {})
|
||||
.get("display_reasoning_text", False)
|
||||
)
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def resp(self, event: AstrMessageEvent, response: LLMResponse):
|
||||
if self.display_reasoning_text:
|
||||
# 显示推理内容的处理逻辑
|
||||
if (
|
||||
response
|
||||
and response.raw_completion
|
||||
and isinstance(response.raw_completion, ChatCompletion)
|
||||
and len(response.raw_completion.choices) > 0
|
||||
and response.raw_completion.choices[0].message
|
||||
):
|
||||
message = response.raw_completion.choices[0].message
|
||||
reasoning_content = "" # 初始化 reasoning_content
|
||||
|
||||
# 检查 Groq deepseek-r1-distill-llama-70b 模型的 'reasoning' 属性
|
||||
if hasattr(message, "reasoning") and message.reasoning:
|
||||
reasoning_content = message.reasoning
|
||||
# 检查 DeepSeek deepseek-reasoner 模型的 'reasoning_content'
|
||||
elif (
|
||||
hasattr(message, "reasoning_content") and message.reasoning_content
|
||||
):
|
||||
reasoning_content = message.reasoning_content
|
||||
|
||||
if reasoning_content:
|
||||
response.completion_text = (
|
||||
f"🤔思考:{reasoning_content}\n\n{message.content}"
|
||||
)
|
||||
else:
|
||||
response.completion_text = message.content
|
||||
else:
|
||||
# 过滤推理标签的处理逻辑
|
||||
completion_text = response.completion_text
|
||||
|
||||
# 检查并移除 <think> 标签
|
||||
if r"<think>" in completion_text or r"</think>" in completion_text:
|
||||
# 移除配对的标签及其内容
|
||||
completion_text = re.sub(
|
||||
r"<think>.*?</think>", "", completion_text, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
# 移除可能残留的单个标签
|
||||
completion_text = (
|
||||
completion_text.replace(r"<think>", "")
|
||||
.replace(r"</think>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
response.completion_text = completion_text
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "3.5.17"
|
||||
version = "3.5.18"
|
||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user