Compare commits
76 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7cfcba29a6 | |||
| 9bf8aadca9 | |||
| 714d4af63d | |||
| 8203fdb4f0 | |||
| 5e1e2d1a4f | |||
| 2f941de65b | |||
| 777c503002 | |||
| e9b23f68fd | |||
| efa45e6203 | |||
| 638f55f83c | |||
| 8b2fc29d5b | |||
| b516fb0550 | |||
| efef34c01e | |||
| 5f1dfa7599 | |||
| 8e9c7544cf | |||
| 4e3d5641c8 | |||
| 20b760529e | |||
| a55a07c5ff | |||
| 94ee8ea297 | |||
| 010f082fbb | |||
| 073cdf6d51 | |||
| 71b233fe5f | |||
| 770dec9ed6 | |||
| 2ca95a988e | |||
| cf1e7ee08a | |||
| d14513ddfd | |||
| 9a9017bc6c | |||
| 3c9b654713 | |||
| 80d2ad40bc | |||
| f5bff00b1f | |||
| 27c9717445 | |||
| 863a1ba8ef | |||
| cb04dd2b83 | |||
| 8c7cf51958 | |||
| 244fb1fed6 | |||
| 25f7a68a13 | |||
| 62d8cf79ef | |||
| 2f81b2e381 | |||
| 1f5a7e7885 | |||
| 80fca470f2 | |||
| 6e9d9ac856 | |||
| 8d6fada1eb | |||
| 3e715399a1 | |||
| 81cc8831f9 | |||
| f7370044a7 | |||
| 51b015a629 | |||
| 392af7a553 | |||
| d2dd07bad7 | |||
| cebcd6925a | |||
| e7b4357fc7 | |||
| dc279dde4a | |||
| c0810a674f | |||
| 0760cabbbe | |||
| 3b149c520b | |||
| 3d19fc89ff | |||
| cd1b1919f4 | |||
| 0ed646eb27 | |||
| c0c5859c99 | |||
| a47121b849 | |||
| d9dd20e89a | |||
| ed4609ebe5 | |||
| 01ef86d658 | |||
| cd4802da04 | |||
| 2aca65780f | |||
| 2c435f7387 | |||
| cc1afd1a9c | |||
| 6f098cdba6 | |||
| d03e9fb90a | |||
| 9f2966abe9 | |||
| 4e28ea1883 | |||
| 289214e85c | |||
| a20d98bf93 | |||
| 7c3d98acbe | |||
| ccb95f803c | |||
| 1ce95c473d | |||
| eb365e398d |
@@ -0,0 +1,30 @@
|
||||
---
|
||||
name: '🥳 发布插件'
|
||||
title: "[Plugin] 插件名"
|
||||
about: 提交插件到插件市场
|
||||
labels: [ "plugin-publish" ]
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
欢迎发布插件到插件市场!
|
||||
|
||||
## 插件基本信息
|
||||
|
||||
请将插件信息填写到下方的 Json 代码块中。`tags`(插件标签)和 `social_link`(社交链接)选填。
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "插件名",
|
||||
"desc": "插件介绍",
|
||||
"repo": "插件仓库链接",
|
||||
"tags": [],
|
||||
"social_link": ""
|
||||
}
|
||||
```
|
||||
|
||||
## 检查
|
||||
|
||||
- [ ] 我的插件经过完整的测试
|
||||
- [ ] 我的插件不包含恶意代码
|
||||
- [ ] 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
@@ -1,40 +0,0 @@
|
||||
name: '🥳 发布插件'
|
||||
title: "[Plugin] 插件名"
|
||||
description: 提交插件到插件市场
|
||||
labels: [ "plugin-publish" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 插件仓库
|
||||
description: 插件的 GitHub 仓库链接
|
||||
placeholder: >
|
||||
如 https://github.com/Soulter/astrbot-github-cards
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 描述
|
||||
value: |
|
||||
插件名:
|
||||
插件作者:
|
||||
插件简介:
|
||||
支持的消息平台:(必填,如 QQ、微信、飞书)
|
||||
标签:(可选)
|
||||
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
|
||||
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Code of Conduct
|
||||
options:
|
||||
- label: >
|
||||
我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: "❤️"
|
||||
@@ -223,7 +223,7 @@ _✨ WebUI ✨_
|
||||
|
||||
此外,本项目的诞生离不开以下开源项目:
|
||||
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ)
|
||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
||||
|
||||
## ⭐ Star History
|
||||
@@ -237,6 +237,9 @@ _✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||

|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
1. The project is protected under the `AGPL-v3` opensource license.
|
||||
|
||||
@@ -13,7 +13,6 @@ from .utils.astrbot_path import get_astrbot_data_path
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
@@ -31,4 +30,3 @@ pip_installer = PipInstaller(
|
||||
)
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
|
||||
|
||||
@@ -3,15 +3,17 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "3.5.16"
|
||||
VERSION = "3.5.18"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"config_version": 2,
|
||||
"platform_settings": {
|
||||
"plugin_enable": [],
|
||||
"unique_session": False,
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
@@ -59,6 +61,7 @@ DEFAULT_CONFIG = {
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"show_tool_use_status": False,
|
||||
"streaming_segmented": False,
|
||||
"separate_provider": False,
|
||||
},
|
||||
@@ -102,6 +105,7 @@ DEFAULT_CONFIG = {
|
||||
"enable": True,
|
||||
"username": "astrbot",
|
||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||
"jwt_secret": "",
|
||||
"host": "0.0.0.0",
|
||||
"port": 6185,
|
||||
},
|
||||
@@ -367,15 +371,15 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"discord_token":{
|
||||
"discord_token": {
|
||||
"description": "Discord Bot Token",
|
||||
"type": "string",
|
||||
"hint": "在此处填入你的Discord Bot Token"
|
||||
"hint": "在此处填入你的Discord Bot Token",
|
||||
},
|
||||
"discord_proxy":{
|
||||
"discord_proxy": {
|
||||
"description": "Discord 代理地址",
|
||||
"type": "string",
|
||||
"hint": "可选的代理地址:http://ip:port"
|
||||
"hint": "可选的代理地址:http://ip:port",
|
||||
},
|
||||
"discord_command_register": {
|
||||
"description": "是否自动将插件指令注册为 Discord 斜杠指令",
|
||||
@@ -386,10 +390,6 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
|
||||
},
|
||||
"discord_guild_id_for_debug": {
|
||||
"description": "【开发用】指定一个服务器(Guild)ID。在此服务器注册的指令会立刻生效,便于调试。留空则注册为全局指令。",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -442,7 +442,7 @@ CONFIG_METADATA_2 = {
|
||||
"ignore_bot_self_message": {
|
||||
"description": "是否忽略机器人自身的消息",
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
"hint": "某些平台会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"ignore_at_all": {
|
||||
"description": "是否忽略 @ 全体成员",
|
||||
@@ -771,17 +771,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
},
|
||||
},
|
||||
"LLMTuner": {
|
||||
"id": "llmtuner_default",
|
||||
"type": "llm_tuner",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"base_model_path": "",
|
||||
"adapter_model_path": "",
|
||||
"llmtuner_template": "",
|
||||
"finetuning_type": "lora",
|
||||
"quantization_bit": 4,
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
@@ -974,6 +963,18 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
},
|
||||
"Gemini TTS": {
|
||||
"id": "gemini_tts",
|
||||
"type": "gemini_tts",
|
||||
"provider_type": "text_to_speech",
|
||||
"enable": False,
|
||||
"gemini_tts_api_key": "",
|
||||
"gemini_tts_api_base": "",
|
||||
"gemini_tts_timeout": 20,
|
||||
"gemini_tts_model": "gemini-2.5-flash-preview-tts",
|
||||
"gemini_tts_prefix": "",
|
||||
"gemini_tts_voice_name": "Leda",
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
"type": "openai_embedding",
|
||||
@@ -1688,10 +1689,15 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||
},
|
||||
"show_tool_use_status": {
|
||||
"description": "函数调用状态输出",
|
||||
"type": "bool",
|
||||
"hint": "在触发函数调用时输出其函数名和内容。",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"description": "不支持流式回复的平台分段输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -46,9 +46,12 @@ class AstrBotCoreLifecycle:
|
||||
self.astrbot_config = astrbot_config # 初始化配置
|
||||
self.db = db # 初始化数据库
|
||||
|
||||
# 根据环境变量设置代理
|
||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||
# 设置代理
|
||||
if self.astrbot_config.get("http_proxy", ""):
|
||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||
if proxy := os.environ.get("https_proxy"):
|
||||
logger.debug(f"Using proxy: {proxy}")
|
||||
os.environ["no_proxy"] = "localhost"
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -125,6 +125,8 @@ class Plain(BaseMessageComponent):
|
||||
def toDict(self):
|
||||
return {"type": "text", "data": {"text": self.text.strip()}}
|
||||
|
||||
async def to_dict(self):
|
||||
return {"type": "text", "data": {"text": self.text}}
|
||||
|
||||
class Face(BaseMessageComponent):
|
||||
type: ComponentType = "Face"
|
||||
@@ -610,6 +612,10 @@ class Node(BaseMessageComponent):
|
||||
"data": {"file": f"base64://{bs64}"},
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Plain):
|
||||
# For Plain segments, we need to handle the plain differently
|
||||
d = await comp.to_dict()
|
||||
data_content.append(d)
|
||||
elif isinstance(comp, File):
|
||||
# For File segments, we need to handle the file differently
|
||||
d = await comp.to_dict()
|
||||
|
||||
@@ -24,6 +24,8 @@ class MessageChain:
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
type: Optional[str] = None
|
||||
"""消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。"""
|
||||
|
||||
def message(self, message: str):
|
||||
"""添加一条文本消息到消息链 `chain` 中。
|
||||
@@ -98,6 +100,15 @@ class MessageChain:
|
||||
self.chain.append(Image.fromFileSystem(path))
|
||||
return self
|
||||
|
||||
def base64_image(self, base64_str: str):
|
||||
"""添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。
|
||||
Example:
|
||||
|
||||
CommandResult().base64_image("iVBORw0KGgoAAAANSUhEUgAAAAUA...")
|
||||
"""
|
||||
self.chain.append(Image.fromBase64(base64_str))
|
||||
return self
|
||||
|
||||
def use_t2i(self, use_t2i: bool):
|
||||
"""设置是否使用文本转图片服务。
|
||||
|
||||
@@ -157,7 +168,7 @@ class ResultContentType(enum.Enum):
|
||||
"""普通的消息结果"""
|
||||
STREAMING_RESULT = enum.auto()
|
||||
"""调用 LLM 产生的流式结果"""
|
||||
STREAMING_FINISH= enum.auto()
|
||||
STREAMING_FINISH = enum.auto()
|
||||
"""流式输出完成"""
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -9,3 +17,91 @@ class PipelineContext:
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
|
||||
async def call_event_hook(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
hook_type: EventType,
|
||||
*args,
|
||||
):
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
hook_type, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, *args)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
async def call_handler(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次 yield 都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到 yield 分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
else:
|
||||
yield ret
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import abc
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ....message.message_event_result import MessageChain
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
"""Agent 状态枚举"""
|
||||
IDLE = auto() # 初始状态
|
||||
RUNNING = auto() # 运行中
|
||||
DONE = auto() # 完成
|
||||
ERROR = auto() # 错误状态
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
chain: MessageChain
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
class BaseAgentRunner:
|
||||
@abc.abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""
|
||||
Reset the agent to its initial state.
|
||||
This method should be called before starting a new run.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
Check if the agent has completed its task.
|
||||
Returns True if the agent is done, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
"""
|
||||
Get the final observation from the agent.
|
||||
This method should be called after the agent is done.
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,300 @@
|
||||
import sys
|
||||
import traceback
|
||||
import typing as T
|
||||
from .base import BaseAgentRunner, AgentResponse, AgentResponseData, AgentState
|
||||
from ...context import PipelineContext
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
AssistantMessageSegment,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot import logger
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# TODO:
|
||||
# 1. 处理平台不兼容的处理器
|
||||
|
||||
|
||||
class ToolLoopAgent(BaseAgentRunner):
|
||||
def __init__(
|
||||
self, provider: Provider, event: AstrMessageEvent, pipeline_ctx: PipelineContext
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.req = None
|
||||
self.event = event
|
||||
self.pipeline_ctx = pipeline_ctx
|
||||
self._state = AgentState.IDLE
|
||||
self.final_llm_resp = None
|
||||
self.streaming = False
|
||||
|
||||
@override
|
||||
async def reset(self, req: ProviderRequest, streaming: bool) -> None:
|
||||
self.req = req
|
||||
self.streaming = streaming
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
|
||||
def _transition_state(self, new_state: AgentState) -> None:
|
||||
"""转换 Agent 状态"""
|
||||
if self._state != new_state:
|
||||
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
||||
self._state = new_state
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""
|
||||
Process a single step of the agent.
|
||||
This method should return the result of the step.
|
||||
"""
|
||||
if not self.req:
|
||||
raise ValueError("Request is not set. Please call reset() first.")
|
||||
|
||||
# 开始处理,转换到运行状态
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(chain=llm_response.result_chain),
|
||||
)
|
||||
else:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_response.completion_text)
|
||||
),
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
return
|
||||
|
||||
# 处理 LLM 响应
|
||||
llm_resp = llm_resp_result
|
||||
logger.debug(f"LLMResp: {llm_resp}")
|
||||
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(
|
||||
f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if not llm_resp.tools_call_name:
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
|
||||
# 执行事件钩子
|
||||
await self.pipeline_ctx.call_event_hook(
|
||||
self.event, EventType.OnLLMResponseEvent, llm_resp
|
||||
)
|
||||
|
||||
# 返回 LLM 结果
|
||||
if llm_resp.result_chain:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(chain=llm_resp.result_chain),
|
||||
)
|
||||
elif llm_resp.completion_text:
|
||||
yield AgentResponse(
|
||||
type="llm_result",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(llm_resp.completion_text)
|
||||
),
|
||||
)
|
||||
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}")
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
role="assistant",
|
||||
tool_calls=llm_resp.to_openai_tool_calls(),
|
||||
content=llm_resp.completion_text,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
|
||||
"""处理函数工具调用。"""
|
||||
tool_call_result_blocks: list[ToolCallMessageSegment] = []
|
||||
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
|
||||
|
||||
# 执行函数调用
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if not res:
|
||||
continue
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain().base64_image(res.content[0].data)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain().base64_image(res.content[0].data)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
else:
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
# 尝试调用工具函数
|
||||
wrapper = self.pipeline_ctx.call_handler(
|
||||
self.event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
# Tool 返回结果
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resp)
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(chain=res.chain)
|
||||
|
||||
self.event.clear_result()
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
def done(self) -> bool:
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import copy
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
@@ -20,39 +21,27 @@ from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
AssistantMessageSegment,
|
||||
ToolCallsResult,
|
||||
)
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core import web_chat_back_queue
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.bot_wake_prefixs = ctx.astrbot_config["wake_prefix"] # list
|
||||
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
|
||||
"wake_prefix"
|
||||
] # str
|
||||
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||
"max_context_length"
|
||||
] # int
|
||||
self.dequeue_context_length = min(
|
||||
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
|
||||
self.provider_wake_prefix: str = settings["wake_prefix"] # str
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
) # int
|
||||
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||
"streaming_response"
|
||||
] # bool
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.max_step: int = settings.get("max_agent_step", 10)
|
||||
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -83,10 +72,7 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
req.contexts = self._process_tool_message_pairs(
|
||||
all_contexts, remove_tags=True
|
||||
)
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
@@ -127,26 +113,7 @@ class LLMRequestSubStage(Stage):
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
await self.ctx.call_event_hook(event, EventType.OnLLMRequestEvent, req)
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
@@ -176,77 +143,62 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
async def requesting(req: ProviderRequest):
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
need_loop = False
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
final_llm_response = None
|
||||
# Call Agent
|
||||
tool_loop_agent = ToolLoopAgent(
|
||||
provider=provider,
|
||||
event=event,
|
||||
pipeline_ctx=self.ctx,
|
||||
)
|
||||
await tool_loop_agent.reset(req=req, streaming=self.streaming_response)
|
||||
|
||||
if self.streaming_response:
|
||||
stream = provider.text_chat_stream(**req.__dict__)
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield llm_response.result_chain # MessageChain
|
||||
else:
|
||||
yield MessageChain().message(
|
||||
llm_response.completion_text
|
||||
)
|
||||
else:
|
||||
final_llm_response = llm_response
|
||||
else:
|
||||
final_llm_response = await provider.text_chat(
|
||||
**req.__dict__
|
||||
) # 请求 LLM
|
||||
async def requesting():
|
||||
step_idx = 0
|
||||
while step_idx < self.max_step:
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in tool_loop_agent.step():
|
||||
if resp.type == "tool_call_result":
|
||||
continue # 跳过工具调用结果
|
||||
if resp.type == "tool_call":
|
||||
if self.streaming_response:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if self.show_tool_use or event.get_platform_name() == "webchat":
|
||||
resp.data["chain"].type = "tool_call"
|
||||
await event.send(resp.data["chain"])
|
||||
continue
|
||||
|
||||
if not final_llm_response:
|
||||
raise Exception("LLM response is None.")
|
||||
if not self.streaming_response:
|
||||
content_typ = (
|
||||
ResultContentType.LLM_RESULT
|
||||
if resp.type == "llm_result"
|
||||
else ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=resp.data["chain"].chain,
|
||||
result_content_type=content_typ,
|
||||
)
|
||||
)
|
||||
yield
|
||||
event.clear_result()
|
||||
else:
|
||||
if resp.type == "streaming_delta":
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if tool_loop_agent.done():
|
||||
break
|
||||
|
||||
# 执行 LLM 响应后的事件钩子。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMResponseEvent
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, final_llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式输出的处理
|
||||
async for result in self._handle_llm_stream_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
else:
|
||||
# 非流式输出的处理
|
||||
async for result in self._handle_llm_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
@@ -255,44 +207,38 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, final_llm_response)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
if not self.streaming_response:
|
||||
event.set_extra("tool_call_result", None)
|
||||
async for _ in requesting(req):
|
||||
yield
|
||||
else:
|
||||
if self.streaming_response:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting(req))
|
||||
.set_async_stream(requesting())
|
||||
)
|
||||
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||
yield
|
||||
|
||||
if event.get_extra("tool_call_result"):
|
||||
event.set_result(event.get_extra("tool_call_result"))
|
||||
event.set_extra("tool_call_result", None)
|
||||
if tool_loop_agent.done():
|
||||
if final_llm_resp := tool_loop_agent.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
else:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
)
|
||||
)
|
||||
else:
|
||||
async for _ in requesting():
|
||||
yield
|
||||
|
||||
# 暂时直接发出去
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
# 异步处理 WebChat 特殊情况
|
||||
asyncio.create_task(self._handle_webchat(event, req))
|
||||
|
||||
await self._save_to_history(event, req, tool_loop_agent.get_final_llm_resp())
|
||||
|
||||
async def _handle_webchat(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
@@ -305,10 +251,6 @@ class LLMRequestSubStage(Stage):
|
||||
return
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
# if len(latest_pair) > 1:
|
||||
# cleaned_text += (
|
||||
# "\nAssistant: " + latest_pair[1].get("content", "").strip()
|
||||
# )
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await provider.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
@@ -349,322 +291,50 @@ class LLMRequestSubStage(Stage):
|
||||
}
|
||||
)
|
||||
|
||||
async def _handle_llm_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理非流式 LLM 响应。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
|
||||
Yields:
|
||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||
"""
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
if llm_response.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=llm_response.result_chain.chain
|
||||
).set_result_content_type(ResultContentType.LLM_RESULT)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT)
|
||||
)
|
||||
elif llm_response.role == "err":
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||
)
|
||||
)
|
||||
elif llm_response.role == "tool":
|
||||
# 处理函数工具调用
|
||||
async for result in self._handle_function_tools(event, req, llm_response):
|
||||
yield result
|
||||
|
||||
async def _handle_llm_stream_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理流式 LLM 响应。
|
||||
|
||||
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
|
||||
Yields:
|
||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||
"""
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
if llm_response.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=llm_response.result_chain.chain
|
||||
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||
)
|
||||
elif llm_response.role == "err":
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||
)
|
||||
)
|
||||
elif llm_response.role == "tool":
|
||||
# 处理函数工具调用
|
||||
async for result in self._handle_function_tools(event, req, llm_response):
|
||||
yield result
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理函数工具调用。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
"""
|
||||
# function calling
|
||||
tool_call_result: list[ToolCallMessageSegment] = []
|
||||
logger.info(
|
||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||
)
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if res:
|
||||
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
else:
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取处理器,过滤掉平台不兼容的处理器
|
||||
platform_id = event.get_platform_id()
|
||||
star_md = star_map.get(func_tool.handler_module_path)
|
||||
if (
|
||||
star_md
|
||||
and platform_id in star_md.supported_platforms
|
||||
and not star_md.supported_platforms[platform_id]
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||
)
|
||||
# 直接跳过,不添加任何消息到tool_call_result
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(
|
||||
self.ctx, event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
res = event.get_result()
|
||||
if res and res.chain:
|
||||
event.set_extra("tool_call_result", res)
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
)
|
||||
if tool_call_result:
|
||||
# 函数调用结果
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
assistant_msg_seg = AssistantMessageSegment(
|
||||
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||
)
|
||||
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||
req.tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=assistant_msg_seg,
|
||||
tool_calls_result=tool_call_result,
|
||||
)
|
||||
yield req # 再次执行 LLM 请求
|
||||
else:
|
||||
if llm_response.completion_text:
|
||||
event.set_result(
|
||||
MessageEventResult().message(llm_response.completion_text)
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
):
|
||||
if not req or not req.conversation or not llm_response:
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
return
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts.copy()
|
||||
contexts.append(await req.assemble_context())
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
messages.append(await req.assemble_context())
|
||||
# 这一轮对话的 LLM 响应
|
||||
if req.tool_calls_result:
|
||||
if not isinstance(req.tool_calls_result, list):
|
||||
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=messages
|
||||
)
|
||||
logger.debug(f"messages persisted: {messages}")
|
||||
|
||||
# 记录并标记函数调用结果
|
||||
if req.tool_calls_result:
|
||||
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||
|
||||
# 添加标记
|
||||
for message in tool_calls_messages:
|
||||
message["_tool_call_history"] = True
|
||||
|
||||
processed_tool_messages = self._process_tool_message_pairs(
|
||||
tool_calls_messages, remove_tags=False
|
||||
)
|
||||
|
||||
contexts.extend(processed_tool_messages)
|
||||
|
||||
contexts.append(
|
||||
{"role": "assistant", "content": llm_response.completion_text}
|
||||
)
|
||||
contexts_to_save = list(
|
||||
filter(lambda item: "_no_save" not in item, contexts)
|
||||
)
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||
)
|
||||
|
||||
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||
|
||||
Args:
|
||||
messages (list): 消息列表
|
||||
remove_tags (bool): 是否移除_tool_call_history标记
|
||||
|
||||
Returns:
|
||||
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
current_msg = messages[i]
|
||||
|
||||
# 普通消息直接添加
|
||||
if "_tool_call_history" not in current_msg:
|
||||
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 工具调用消息成对处理
|
||||
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||
assistant_msg = current_msg.copy()
|
||||
|
||||
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||
del assistant_msg["_tool_call_history"]
|
||||
|
||||
related_tools = []
|
||||
j = i + 1
|
||||
while (
|
||||
j < len(messages)
|
||||
and messages[j].get("role") == "tool"
|
||||
and "_tool_call_history" in messages[j]
|
||||
):
|
||||
tool_msg = messages[j].copy()
|
||||
|
||||
if remove_tags:
|
||||
del tool_msg["_tool_call_history"]
|
||||
|
||||
related_tools.append(tool_msg)
|
||||
j += 1
|
||||
|
||||
# 成对的时候添加到结果
|
||||
if related_tools:
|
||||
result.append(assistant_msg)
|
||||
result.extend(related_tools)
|
||||
|
||||
i = j # 跳过已处理
|
||||
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
# 单独的tool消息
|
||||
i += 1
|
||||
|
||||
return result
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
@@ -50,7 +50,7 @@ class StarRequestSubStage(Stage):
|
||||
logger.debug(
|
||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
||||
)
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
wrapper = self.ctx.call_handler(event, handler.handler, **params)
|
||||
async for ret in wrapper:
|
||||
yield ret
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
@@ -128,9 +128,7 @@ class RespondStage(Stage):
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
await event._pre_send()
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
await event._post_send()
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
# 检查路径映射
|
||||
@@ -141,8 +139,6 @@ class RespondStage(Stage):
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
|
||||
await event._pre_send()
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
if await self._is_empty_message_chain(result.chain):
|
||||
@@ -158,9 +154,14 @@ class RespondStage(Stage):
|
||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
||||
]
|
||||
|
||||
if self.enable_seg and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
if (
|
||||
self.enable_seg
|
||||
and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
)
|
||||
and event.get_platform_name()
|
||||
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
||||
):
|
||||
decorated_comps = []
|
||||
if self.reply_with_mention:
|
||||
@@ -208,7 +209,6 @@ class RespondStage(Stage):
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||
)
|
||||
|
||||
@@ -141,7 +141,11 @@ class ResultDecorateStage(Stage):
|
||||
break
|
||||
|
||||
# 分段回复
|
||||
if self.enable_segmented_reply:
|
||||
if self.enable_segmented_reply and event.get_platform_name() not in [
|
||||
"qq_official",
|
||||
"weixin_official_account",
|
||||
"dingtalk",
|
||||
]:
|
||||
if (
|
||||
self.only_llm_result and result.is_llm_result()
|
||||
) or not self.only_llm_result:
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import inspect
|
||||
import traceback
|
||||
from astrbot.api import logger
|
||||
from typing import List, AsyncGenerator, Union, Awaitable
|
||||
from typing import List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
|
||||
@@ -41,70 +37,3 @@ class Stage(abc.ABC):
|
||||
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _call_handler(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 待处理的事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
*args: 传递给handler的位置参数
|
||||
**kwargs: 传递给handler的关键字参数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||
|
||||
trace_ = None
|
||||
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
# 如果是一个异步生成器, 进入洋葱模型
|
||||
_has_yielded = False # 是否返回过值
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到yield分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
|
||||
@@ -135,7 +135,6 @@ class WakingCheckStage(Stage):
|
||||
f"插件 {star_map[handler.handler_module_path].name}: {e}"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
event.stop_event()
|
||||
passed = False
|
||||
break
|
||||
@@ -150,7 +149,6 @@ class WakingCheckStage(Stage):
|
||||
f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。"
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||
)
|
||||
|
||||
@@ -235,10 +235,10 @@ class AstrMessageEvent(abc.ABC):
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
"""调度器会在执行 send() 前调用该方法"""
|
||||
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
async def _post_send(self):
|
||||
"""调度器会在执行 send() 后调用该方法"""
|
||||
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
def set_result(self, result: Union[MessageEventResult, str]):
|
||||
"""设置消息事件的结果。
|
||||
|
||||
@@ -168,9 +168,7 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if "sub_type" in event:
|
||||
if event["sub_type"] == "poke" and "target_id" in event:
|
||||
abm.message.append(
|
||||
Poke(qq=str(event["target_id"]), type="poke")
|
||||
) # noqa: F405
|
||||
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
||||
|
||||
return abm
|
||||
|
||||
@@ -273,6 +271,8 @@ class AiocqhttpAdapter(Platform):
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
|
||||
reply_event_data["post_type"] = "message"
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
@@ -307,7 +307,7 @@ class AiocqhttpAdapter(Platform):
|
||||
user_id=int(m["data"]["qq"]),
|
||||
)
|
||||
if at_info:
|
||||
nickname = at_info.get("nick", "")
|
||||
nickname = at_info.get("nick", "") or at_info.get("nickname", "")
|
||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||
|
||||
abm.message.append(
|
||||
@@ -322,7 +322,7 @@ class AiocqhttpAdapter(Platform):
|
||||
first_at_self_processed = True
|
||||
else:
|
||||
# 非第一个@机器人或@其他用户,添加到message_str
|
||||
message_str += f" @{nickname} "
|
||||
message_str += f" @{nickname}({m['data']['qq']}) "
|
||||
else:
|
||||
abm.message.append(At(qq=str(m["data"]["qq"]), name=""))
|
||||
except ActionFailed as e:
|
||||
|
||||
@@ -46,6 +46,8 @@ class DiscordPlatformAdapter(Platform):
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
self.guild_id = self.config.get("discord_guild_id_for_debug", None)
|
||||
self.activity_name = self.config.get("discord_activity_name", None)
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self._polling_task = None
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
@@ -137,7 +139,8 @@ class DiscordPlatformAdapter(Platform):
|
||||
self.client.on_ready_once_callback = callback
|
||||
|
||||
try:
|
||||
await self.client.start_polling()
|
||||
self._polling_task = asyncio.create_task(self.client.start_polling())
|
||||
await self.shutdown_event.wait()
|
||||
except discord.errors.LoginFailure:
|
||||
logger.error("[Discord] 登录失败。请检查你的 Bot Token 是否正确。")
|
||||
except discord.errors.ConnectionClosed:
|
||||
@@ -162,42 +165,47 @@ class DiscordPlatformAdapter(Platform):
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
is_mentioned = data.get("is_mentioned", False)
|
||||
|
||||
content = message.content
|
||||
|
||||
# 如果机器人被@,移除@部分
|
||||
if (
|
||||
is_mentioned
|
||||
and self.client
|
||||
and self.client.user
|
||||
and self.client.user in message.mentions
|
||||
):
|
||||
# 构建机器人的@字符串,格式为 <@USER_ID> 或 <@!USER_ID>
|
||||
# 剥离 User Mention (<@id>, <@!id>)
|
||||
if self.client and self.client.user:
|
||||
mention_str = f"<@{self.client.user.id}>"
|
||||
mention_str_nickname = (
|
||||
f"<@!{self.client.user.id}>" # 有些客户端会使用带!的格式
|
||||
)
|
||||
|
||||
mention_str_nickname = f"<@!{self.client.user.id}>"
|
||||
if content.startswith(mention_str):
|
||||
content = content[len(mention_str) :].lstrip()
|
||||
elif content.startswith(mention_str_nickname):
|
||||
content = content[len(mention_str_nickname) :].lstrip()
|
||||
|
||||
abm = AstrBotMessage()
|
||||
# 剥离 Role Mention(bot 拥有的任一角色被提及,<@&role_id>)
|
||||
if (
|
||||
hasattr(message, "role_mentions")
|
||||
and hasattr(message, "guild")
|
||||
and message.guild
|
||||
):
|
||||
bot_member = (
|
||||
message.guild.get_member(self.client.user.id)
|
||||
if self.client and self.client.user
|
||||
else None
|
||||
)
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
for role in bot_member.roles:
|
||||
role_mention_str = f"<@&{role.id}>"
|
||||
if content.startswith(role_mention_str):
|
||||
content = content[len(role_mention_str) :].lstrip()
|
||||
break # 只剥离第一个匹配的角色 mention
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.type = self._get_message_type(message.channel)
|
||||
abm.group_id = self._get_channel_id(message.channel)
|
||||
|
||||
abm.message_str = content
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(message.author.id), nickname=message.author.display_name
|
||||
)
|
||||
|
||||
message_chain = []
|
||||
if abm.message_str:
|
||||
message_chain.append(Plain(text=abm.message_str))
|
||||
|
||||
if message.attachments:
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith(
|
||||
@@ -210,7 +218,6 @@ class DiscordPlatformAdapter(Platform):
|
||||
message_chain.append(
|
||||
File(name=attachment.filename, url=attachment.url)
|
||||
)
|
||||
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
@@ -237,13 +244,35 @@ class DiscordPlatformAdapter(Platform):
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 检查是否被@
|
||||
is_mention = (
|
||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||
is_mention = False
|
||||
# User Mention
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
and self.client.user in message.raw_message.mentions
|
||||
)
|
||||
):
|
||||
if self.client.user in message.raw_message.mentions:
|
||||
is_mention = True
|
||||
# Role Mention(Bot 拥有的角色被提及)
|
||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||
bot_member = None
|
||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||
try:
|
||||
bot_member = message.raw_message.guild.get_member(
|
||||
self.client.user.id
|
||||
)
|
||||
except Exception:
|
||||
bot_member = None
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
bot_roles = set(bot_member.roles)
|
||||
mentioned_roles = set(message.raw_message.role_mentions)
|
||||
if (
|
||||
bot_roles
|
||||
and mentioned_roles
|
||||
and bot_roles.intersection(mentioned_roles)
|
||||
):
|
||||
is_mention = True
|
||||
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
@@ -255,23 +284,37 @@ class DiscordPlatformAdapter(Platform):
|
||||
@override
|
||||
async def terminate(self):
|
||||
"""终止适配器"""
|
||||
logger.info("[Discord] 正在终止适配器...")
|
||||
|
||||
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
|
||||
self.shutdown_event.set()
|
||||
# 优先 cancel polling_task
|
||||
if self._polling_task:
|
||||
self._polling_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._polling_task, timeout=10)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[Discord] polling_task 已取消。")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] polling_task 取消异常: {e}")
|
||||
logger.info("[Discord] 正在清理已注册的斜杠指令... (step 2)")
|
||||
# 清理指令
|
||||
if self.enable_command_register and self.client:
|
||||
logger.info("[Discord] 正在清理已注册的斜杠指令...")
|
||||
try:
|
||||
# 传入空的列表来清除所有全局指令
|
||||
# 如果指定了 guild_id,则只清除该服务器的指令
|
||||
await self.client.sync_commands(
|
||||
commands=[], guild_ids=[self.guild_id] if self.guild_id else None
|
||||
await asyncio.wait_for(
|
||||
self.client.sync_commands(
|
||||
commands=[],
|
||||
guild_ids=[self.guild_id] if self.guild_id else None,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
logger.info("[Discord] 指令清理完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True)
|
||||
|
||||
logger.info("[Discord] 正在关闭 Discord 客户端... (step 3)")
|
||||
if self.client and hasattr(self.client, "close"):
|
||||
await self.client.close()
|
||||
try:
|
||||
await asyncio.wait_for(self.client.close(), timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] 客户端关闭异常: {e}")
|
||||
logger.info("[Discord] 适配器已终止。")
|
||||
|
||||
def register_handler(self, handler_info):
|
||||
|
||||
@@ -28,10 +28,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
self.send_buffer = None
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not self.send_buffer:
|
||||
self.send_buffer = message
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
self.send_buffer = message
|
||||
await self._post_send()
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
"""流式输出仅支持消息列表私聊"""
|
||||
|
||||
@@ -40,20 +40,21 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
def _split_message(self, text: str) -> list[str]:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
@classmethod
|
||||
def _split_message(cls, text: str) -> list[str]:
|
||||
if len(text) <= cls.MAX_MESSAGE_LENGTH:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= self.MAX_MESSAGE_LENGTH:
|
||||
if len(text) <= cls.MAX_MESSAGE_LENGTH:
|
||||
chunks.append(text)
|
||||
break
|
||||
|
||||
split_point = self.MAX_MESSAGE_LENGTH
|
||||
segment = text[: self.MAX_MESSAGE_LENGTH]
|
||||
split_point = cls.MAX_MESSAGE_LENGTH
|
||||
segment = text[: cls.MAX_MESSAGE_LENGTH]
|
||||
|
||||
for _, pattern in self.SPLIT_PATTERNS.items():
|
||||
for _, pattern in cls.SPLIT_PATTERNS.items():
|
||||
if matches := list(pattern.finditer(segment)):
|
||||
last_match = matches[-1]
|
||||
split_point = last_match.end()
|
||||
@@ -64,9 +65,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
return chunks
|
||||
|
||||
async def send_with_client(
|
||||
self, client: ExtBot, message: MessageChain, user_name: str
|
||||
):
|
||||
@classmethod
|
||||
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
|
||||
image_path = None
|
||||
|
||||
has_reply = False
|
||||
@@ -97,7 +97,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
if at_user_id and not at_flag:
|
||||
i.text = f"@{at_user_id} {i.text}"
|
||||
at_flag = True
|
||||
chunks = self._split_message(i.text)
|
||||
chunks = cls._split_message(i.text)
|
||||
for chunk in chunks:
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
@@ -158,6 +158,12 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
message_id = None # 重置消息 ID
|
||||
delta = "" # 重置 delta
|
||||
continue
|
||||
|
||||
# 处理消息链中的每个组件
|
||||
for i in chain.chain:
|
||||
if isinstance(i, Plain):
|
||||
|
||||
@@ -35,6 +35,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
@@ -110,6 +111,18 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
chain, session_id=self.session_id, streaming=True
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64
|
||||
from astrbot.core.utils.tencent_record_helper import audio_to_tencent_silk_base64
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .wechatpadpro_adapter import WeChatPadProAdapter
|
||||
@@ -113,7 +113,7 @@ class WeChatPadProMessageEvent(AstrMessageEvent):
|
||||
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 默认已经存在 data/temp 中
|
||||
b64, duration = await wav_to_tencent_silk_base64(record_path)
|
||||
b64, duration = await audio_to_tencent_silk_base64(record_path)
|
||||
payload = {
|
||||
"ToUserName": self.session_id,
|
||||
"VoiceData": b64,
|
||||
|
||||
@@ -58,7 +58,7 @@ class AssistantMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
content: str = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list)
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
@@ -67,7 +67,7 @@ class AssistantMessageSegment:
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
elif self.tool_calls:
|
||||
if self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
return ret
|
||||
|
||||
@@ -95,19 +95,19 @@ class ProviderRequest:
|
||||
"""提示词"""
|
||||
session_id: str = ""
|
||||
"""会话 ID"""
|
||||
image_urls: List[str] = None
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall = None
|
||||
func_tool: FuncCall | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: List = None
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
"""
|
||||
system_prompt: str = ""
|
||||
"""系统提示词"""
|
||||
conversation: Conversation = None
|
||||
conversation: Conversation | None = None
|
||||
|
||||
tool_calls_result: ToolCallsResult = None
|
||||
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
|
||||
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||
|
||||
def __repr__(self):
|
||||
@@ -116,6 +116,14 @@ class ProviderRequest:
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
|
||||
"""添加工具调用结果到请求中"""
|
||||
if not self.tool_calls_result:
|
||||
self.tool_calls_result = []
|
||||
if isinstance(self.tool_calls_result, ToolCallsResult):
|
||||
self.tool_calls_result = [self.tool_calls_result]
|
||||
self.tool_calls_result.append(tool_calls_result)
|
||||
|
||||
def _print_friendly_context(self):
|
||||
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||
if not self.contexts:
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .entities import ProviderType
|
||||
import traceback
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from .register import provider_cls_map, llm_tools
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
@@ -38,13 +40,11 @@ class ProviderManager:
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
bd_processed.append({
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
})
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
@@ -190,11 +190,6 @@ class ProviderManager:
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import (
|
||||
LLMTunerModelLoader as LLMTunerModelLoader,
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "dashscope":
|
||||
@@ -253,6 +248,10 @@ class ProviderManager:
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
@@ -326,8 +325,6 @@ class ProviderManager:
|
||||
inst = provider_metadata.cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
self.db_helper,
|
||||
self.provider_settings.get("persistant_history", True),
|
||||
self.selected_default_persona,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
@@ -53,15 +52,13 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
persistant_history: bool = True,
|
||||
db_helper: BaseDatabase = None,
|
||||
default_persona: Personality = None,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
self.curr_personality = default_persona
|
||||
"""维护了当前的使用的 persona,即人格。可能为 None"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -86,11 +83,11 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -114,11 +111,11 @@ class Provider(AbstractProvider):
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import json
|
||||
import anthropic
|
||||
import base64
|
||||
from typing import List
|
||||
from mimetypes import guess_type
|
||||
|
||||
@@ -5,41 +8,33 @@ from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"anthropic_chat_completion", "Anthropic Claude API 提供商适配器"
|
||||
)
|
||||
class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
class ProviderAnthropic(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
# Skip OpenAI's __init__ and call Provider's __init__ directly
|
||||
Provider.__init__(
|
||||
self,
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
self.chosen_api_key = None
|
||||
self.chosen_api_key: str = ""
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
|
||||
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
@@ -51,10 +46,63 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
|
||||
Args:
|
||||
messages: OpenAI 格式的消息列表,包含用户输入和系统提示等信息
|
||||
Returns:
|
||||
system_prompt: 系统提示内容
|
||||
new_messages: 处理后的消息列表,去除系统提示
|
||||
"""
|
||||
system_prompt = ""
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
elif message["role"] == "assistant":
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": json.loads(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], str)
|
||||
else tool_call["function"]["arguments"],
|
||||
"id": tool_call["id"],
|
||||
}
|
||||
)
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": blocks,
|
||||
}
|
||||
)
|
||||
elif message["role"] == "tool":
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_messages.append(message)
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_anthropic_style()
|
||||
if tool_list:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
completion = await self.client.messages.create(**payloads, stream=False)
|
||||
@@ -64,70 +112,157 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
|
||||
if len(completion.content) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
# TODO: 如果进行函数调用,思维链被截断,用户可能需要思维链的内容
|
||||
# 选最后一条消息,如果要进行函数调用,anthropic会先返回文本消息的思维链,然后再返回函数调用请求
|
||||
content = completion.content[-1]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response = LLMResponse(role="assistant")
|
||||
|
||||
if content.type == "text":
|
||||
# text completion
|
||||
completion_text = str(content.text).strip()
|
||||
# llm_response.completion_text = completion_text
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# Anthropic每次只返回一个函数调用
|
||||
if completion.stop_reason == "tool_use":
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
tool_use_ids = []
|
||||
func_name_ls.append(content.name)
|
||||
args_ls.append(content.input)
|
||||
tool_use_ids.append(content.id)
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args = args_ls
|
||||
llm_response.tools_call_name = func_name_ls
|
||||
llm_response.tools_call_ids = tool_use_ids
|
||||
for content_block in completion.content:
|
||||
if content_block.type == "text":
|
||||
completion_text = str(content_block.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
if content_block.type == "tool_use":
|
||||
llm_response.tools_call_args.append(content_block.input)
|
||||
llm_response.tools_call_name.append(content_block.name)
|
||||
llm_response.tools_call_ids.append(content_block.id)
|
||||
# TODO(Soulter): 处理 end_turn 情况
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
# 用于累积工具调用信息
|
||||
tool_use_buffer = {}
|
||||
# 用于累积最终结果
|
||||
final_text = ""
|
||||
final_tool_calls = []
|
||||
|
||||
async with self.client.messages.stream(**payloads) as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
if event.content_block.type == "text":
|
||||
# 文本块开始
|
||||
yield LLMResponse(
|
||||
role="assistant", completion_text="", is_chunk=True
|
||||
)
|
||||
elif event.content_block.type == "tool_use":
|
||||
# 工具使用块开始,初始化缓冲区
|
||||
tool_use_buffer[event.index] = {
|
||||
"id": event.content_block.id,
|
||||
"name": event.content_block.name,
|
||||
"input": {},
|
||||
}
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
if event.delta.type == "text_delta":
|
||||
# 文本增量
|
||||
final_text += event.delta.text
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=event.delta.text,
|
||||
is_chunk=True,
|
||||
)
|
||||
elif event.delta.type == "input_json_delta":
|
||||
# 工具调用参数增量
|
||||
if event.index in tool_use_buffer:
|
||||
# 累积 JSON 输入
|
||||
if "input_json" not in tool_use_buffer[event.index]:
|
||||
tool_use_buffer[event.index]["input_json"] = ""
|
||||
tool_use_buffer[event.index]["input_json"] += (
|
||||
event.delta.partial_json
|
||||
)
|
||||
|
||||
elif event.type == "content_block_stop":
|
||||
# 内容块结束
|
||||
if event.index in tool_use_buffer:
|
||||
# 解析完整的工具调用
|
||||
tool_info = tool_use_buffer[event.index]
|
||||
try:
|
||||
if "input_json" in tool_info:
|
||||
tool_info["input"] = json.loads(tool_info["input_json"])
|
||||
|
||||
# 添加到最终结果
|
||||
final_tool_calls.append(
|
||||
{
|
||||
"id": tool_info["id"],
|
||||
"name": tool_info["name"],
|
||||
"input": tool_info["input"],
|
||||
}
|
||||
)
|
||||
|
||||
yield LLMResponse(
|
||||
role="tool",
|
||||
completion_text="",
|
||||
tools_call_args=[tool_info["input"]],
|
||||
tools_call_name=[tool_info["name"]],
|
||||
tools_call_ids=[tool_info["id"]],
|
||||
is_chunk=True,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# JSON 解析失败,跳过这个工具调用
|
||||
logger.warning(f"工具调用参数 JSON 解析失败: {tool_info}")
|
||||
|
||||
# 清理缓冲区
|
||||
del tool_use_buffer[event.index]
|
||||
|
||||
# 返回最终的完整结果
|
||||
final_response = LLMResponse(
|
||||
role="assistant", completion_text=final_text, is_chunk=False
|
||||
)
|
||||
|
||||
if final_tool_calls:
|
||||
final_response.tools_call_args = [
|
||||
call["input"] for call in final_tool_calls
|
||||
]
|
||||
final_response.tools_call_name = [call["name"] for call in final_tool_calls]
|
||||
final_response.tools_call_ids = [call["id"] for call in final_tool_calls]
|
||||
|
||||
yield final_response
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
if not prompt:
|
||||
prompt = "<image>"
|
||||
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
# 暂时这样写。
|
||||
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
|
||||
if not isinstance(tool_calls_result, list):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
payloads["system"] = system_prompt
|
||||
@@ -135,32 +270,9 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
response = await self.client.messages.create(
|
||||
messages=context_query, **model_config
|
||||
)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
||||
llm_response.raw_completion = response
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
return LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -175,21 +287,34 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
payloads["system"] = system_prompt
|
||||
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
@@ -232,3 +357,28 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
)
|
||||
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""
|
||||
将图片转换为 base64
|
||||
"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.chosen_api_key
|
||||
|
||||
async def get_models(self) -> List[str]:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
models = sorted(models.data, key=lambda x: x.id)
|
||||
for model in models:
|
||||
models_str.append(model.id)
|
||||
return models_str
|
||||
|
||||
def set_key(self, key: str):
|
||||
self.chosen_api_key = key
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
@@ -19,16 +18,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality = None,
|
||||
default_persona: Personality | None = None,
|
||||
) -> None:
|
||||
Provider.__init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dashscope_api_key", "")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
@@ -17,17 +16,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
|
||||
@@ -12,8 +12,7 @@ from google.genai.errors import APIError
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Personality, Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
@@ -52,17 +51,13 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.api_keys: list = provider_config.get("key", [])
|
||||
@@ -264,12 +259,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
contents.append(content_cls(parts=part))
|
||||
|
||||
gemini_contents: list[types.Content] = []
|
||||
native_tool_enabled = any(
|
||||
[
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
self.provider_config.get("gm_native_search", False),
|
||||
]
|
||||
)
|
||||
native_tool_enabled = any([
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
self.provider_config.get("gm_native_search", False),
|
||||
])
|
||||
for message in payloads["messages"]:
|
||||
role, content = message["role"], message.get("content")
|
||||
|
||||
@@ -506,12 +499,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
@@ -527,7 +520,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
if not isinstance(tool_calls_result, list):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
@@ -631,9 +628,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}}
|
||||
)
|
||||
user_content["content"].append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import uuid
|
||||
import wave
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderGeminiTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
api_key: str = provider_config.get("gemini_tts_api_key", "")
|
||||
api_base: str | None = provider_config.get("gemini_tts_api_base")
|
||||
timeout: int = int(provider_config.get("gemini_tts_timeout", 20))
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
|
||||
if api_base:
|
||||
if api_base.endswith("/"):
|
||||
api_base = api_base[:-1]
|
||||
http_options.base_url = api_base
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
self.model: str = provider_config.get(
|
||||
"gemini_tts_model", "gemini-2.5-flash-preview-tts"
|
||||
)
|
||||
self.prefix: str | None = provider_config.get(
|
||||
"gemini_tts_prefix",
|
||||
)
|
||||
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
|
||||
prompt = f"{self.prefix}: {text}" if self.prefix else text
|
||||
response = await self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=["AUDIO"],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name=self.voice_name,
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# 不想看类型检查报错
|
||||
if (
|
||||
not response.candidates
|
||||
or not response.candidates[0].content
|
||||
or not response.candidates[0].content.parts
|
||||
or not response.candidates[0].content.parts[0].inline_data
|
||||
or not response.candidates[0].content.parts[0].inline_data.data
|
||||
):
|
||||
raise Exception("No audio content returned from Gemini TTS API.")
|
||||
|
||||
with wave.open(path, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
|
||||
|
||||
return path
|
||||
@@ -1,134 +0,0 @@
|
||||
import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
|
||||
)
|
||||
class LLMTunerModelLoader(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
|
||||
provider_config["adapter_model_path"]
|
||||
):
|
||||
raise FileNotFoundError("模型文件路径不存在。")
|
||||
self.base_model_path = provider_config["base_model_path"]
|
||||
self.adapter_model_path = provider_config["adapter_model_path"]
|
||||
self.model = ChatModel(
|
||||
{
|
||||
"model_name_or_path": self.base_model_path,
|
||||
"adapter_name_or_path": self.adapter_model_path,
|
||||
"template": provider_config["llmtuner_template"],
|
||||
"finetuning_type": provider_config["finetuning_type"],
|
||||
"quantization_bit": provider_config["quantization_bit"],
|
||||
}
|
||||
)
|
||||
self.set_model(
|
||||
os.path.basename(self.base_model_path)
|
||||
+ "_"
|
||||
+ os.path.basename(self.adapter_model_path)
|
||||
)
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""
|
||||
组装上下文。
|
||||
"""
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
for idx, context in enumerate(query_context):
|
||||
if context["role"] == "system":
|
||||
system_idxs.append(idx)
|
||||
|
||||
if "_no_save" in context:
|
||||
del context["_no_save"]
|
||||
|
||||
for idx in reversed(system_idxs):
|
||||
system_prompt += " " + query_context.pop(idx)["content"]
|
||||
|
||||
conf = {
|
||||
"messages": query_context,
|
||||
"system": system_prompt,
|
||||
}
|
||||
if func_tool:
|
||||
tool_list = func_tool.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
conf["tools"] = tool_list
|
||||
|
||||
responses = await self.model.achat(**conf)
|
||||
|
||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
|
||||
async def set_key(self, key):
|
||||
pass
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
@@ -9,14 +9,12 @@ import astrbot.core.message.components as Comp
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List, AsyncGenerator
|
||||
@@ -30,17 +28,13 @@ from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona: Personality = None,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
persistant_history,
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
@@ -224,12 +218,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def _prepare_chat_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -246,14 +238,18 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
if isinstance(tool_calls_result, ToolCallsResult):
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
else:
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
return payloads, context_query, func_tool
|
||||
return payloads, context_query
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
@@ -352,11 +348,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
@@ -422,11 +416,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
@@ -13,15 +12,11 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=True,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
db_helper,
|
||||
persistant_history,
|
||||
default_persona,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,10 +18,12 @@ class Star(CommandParserMixin):
|
||||
"""将文本转换为图片"""
|
||||
return await html_renderer.render_t2i(text, return_url=return_url)
|
||||
|
||||
async def html_render(self, tmpl: str, data: dict, return_url=True) -> str:
|
||||
async def html_render(
|
||||
self, tmpl: str, data: dict, return_url=True, options: dict = None
|
||||
) -> str:
|
||||
"""渲染 HTML"""
|
||||
return await html_renderer.render_custom_template(
|
||||
tmpl, data, return_url=return_url
|
||||
tmpl, data, return_url=return_url, options=options
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
|
||||
@@ -7,10 +7,13 @@ from astrbot.core.config import AstrBotConfig
|
||||
from .custom_filter import CustomFilter
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
|
||||
class GreedyStr(str):
|
||||
"""标记指令完成其他参数接收后的所有剩余文本。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter):
|
||||
"""标准指令过滤器"""
|
||||
@@ -18,8 +21,8 @@ class CommandFilter(HandlerFilter):
|
||||
def __init__(
|
||||
self,
|
||||
command_name: str,
|
||||
alias: set = None,
|
||||
handler_md: StarHandlerMetadata = None,
|
||||
alias: set | None = None,
|
||||
handler_md: StarHandlerMetadata | None = None,
|
||||
parent_command_names: List[str] = [""],
|
||||
):
|
||||
self.command_name = command_name
|
||||
@@ -110,6 +113,17 @@ class CommandFilter(HandlerFilter):
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, bool):
|
||||
# 处理布尔类型
|
||||
lower_param = str(params[i]).lower()
|
||||
if lower_param in ["true", "yes", "1"]:
|
||||
result[param_name] = True
|
||||
elif lower_param in ["false", "no", "0"]:
|
||||
result[param_name] = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"参数 {param_name} 必须是布尔值(true/false, yes/no, 1/0)。"
|
||||
)
|
||||
elif isinstance(param_type_or_default_val, int):
|
||||
result[param_name] = int(params[i])
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
|
||||
@@ -11,7 +11,7 @@ ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||
|
||||
|
||||
class NetworkRenderStrategy(RenderStrategy):
|
||||
def __init__(self, base_url: str = ASTRBOT_T2I_DEFAULT_ENDPOINT) -> None:
|
||||
def __init__(self, base_url: str | None = None) -> None:
|
||||
super().__init__()
|
||||
if not base_url:
|
||||
base_url = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||
@@ -34,18 +34,22 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
self.BASE_RENDER_URL += "/text2img"
|
||||
|
||||
async def render_custom_template(
|
||||
self, tmpl_str: str, tmpl_data: dict, return_url: bool = True
|
||||
self,
|
||||
tmpl_str: str,
|
||||
tmpl_data: dict,
|
||||
return_url: bool = True,
|
||||
options: dict | None = None,
|
||||
) -> str:
|
||||
"""使用自定义文转图模板"""
|
||||
default_options = {"full_page": True, "type": "jpeg", "quality": 40}
|
||||
if options:
|
||||
default_options |= options
|
||||
|
||||
post_data = {
|
||||
"tmpl": tmpl_str,
|
||||
"json": return_url,
|
||||
"tmpldata": tmpl_data,
|
||||
"options": {
|
||||
"full_page": True,
|
||||
"type": "jpeg",
|
||||
"quality": 40,
|
||||
},
|
||||
"options": default_options,
|
||||
}
|
||||
if return_url:
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
@@ -6,7 +6,7 @@ logger = LogManager.GetLogger(log_name="astrbot")
|
||||
|
||||
|
||||
class HtmlRenderer:
|
||||
def __init__(self, endpoint_url: str = None):
|
||||
def __init__(self, endpoint_url: str | None = None):
|
||||
self.network_strategy = NetworkRenderStrategy(endpoint_url)
|
||||
self.local_strategy = LocalRenderStrategy()
|
||||
|
||||
@@ -16,19 +16,24 @@ class HtmlRenderer:
|
||||
self.network_strategy.set_endpoint(endpoint_url)
|
||||
|
||||
async def render_custom_template(
|
||||
self, tmpl_str: str, tmpl_data: dict, return_url: bool = False
|
||||
self,
|
||||
tmpl_str: str,
|
||||
tmpl_data: dict,
|
||||
return_url: bool = False,
|
||||
options: dict | None = None,
|
||||
):
|
||||
"""使用自定义文转图模板。该方法会通过网络调用 t2i 终结点图文渲染API。
|
||||
@param tmpl_str: HTML Jinja2 模板。
|
||||
@param tmpl_data: jinja2 模板数据。
|
||||
@param options: 渲染选项。
|
||||
|
||||
@return: 图片 URL 或者文件路径,取决于 return_url 参数。
|
||||
|
||||
@example: 参见 https://astrbot.app 插件开发部分。
|
||||
"""
|
||||
local = locals()
|
||||
local.pop("self")
|
||||
return await self.network_strategy.render_custom_template(**local)
|
||||
return await self.network_strategy.render_custom_template(
|
||||
tmpl_str, tmpl_data, return_url, options
|
||||
)
|
||||
|
||||
async def render_t2i(
|
||||
self, text: str, use_network: bool = True, return_url: bool = False
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import base64
|
||||
import wave
|
||||
import os
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
import tempfile
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@@ -57,33 +59,89 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
|
||||
return duration
|
||||
|
||||
|
||||
async def wav_to_tencent_silk_base64(wav_path: str) -> str:
|
||||
async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
|
||||
"""
|
||||
将 WAV 文件转为 Silk,并返回 Base64 字符串。
|
||||
默认采样率为 24000,输出临时文件为 temp/output.silk。
|
||||
将 MP3 或其他音频格式转换为 PCM 16bit WAV,采样率24000Hz,单声道。
|
||||
若转换失败则抛出异常。
|
||||
"""
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=input_path, output=output_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
input_path,
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ar",
|
||||
"24000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-af",
|
||||
"apad=pad_dur=2",
|
||||
"-fflags",
|
||||
"+genpts",
|
||||
"-hide_banner",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await p.communicate()
|
||||
logger.info(f"[FFmpeg] stdout: {stdout.decode().strip()}")
|
||||
logger.debug(f"[FFmpeg] stderr: {stderr.decode().strip()}")
|
||||
logger.info(f"[FFmpeg] return code: {p.returncode}")
|
||||
|
||||
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError("生成的WAV文件不存在或为空")
|
||||
|
||||
|
||||
async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
|
||||
"""
|
||||
将 MP3/WAV 文件转为 Tencent Silk 并返回 base64 编码与时长(秒)。
|
||||
|
||||
参数:
|
||||
- wav_path: 输入 .wav 文件路径(需为 PCM 16bit)
|
||||
- audio_path: 输入音频文件路径(.mp3 或 .wav)
|
||||
|
||||
返回:
|
||||
- Base64 编码的 Silk 字符串
|
||||
- silk_b64: Base64 编码的 Silk 字符串
|
||||
- duration: 音频时长(秒)
|
||||
"""
|
||||
try:
|
||||
import pilk
|
||||
except ImportError as e:
|
||||
raise Exception("pysilk 模块未安装,请安装 pysilk") from e
|
||||
raise Exception("未安装 pysilk,请执行: pip install pysilk") from e
|
||||
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
with wave.open(wav_path, "rb") as wav:
|
||||
rate = wav.getframerate()
|
||||
# 是否需要转换为 WAV
|
||||
ext = os.path.splitext(audio_path)[1].lower()
|
||||
temp_wav = tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", delete=False, dir=temp_dir
|
||||
).name
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
if ext != ".wav":
|
||||
await convert_to_pcm_wav(audio_path, temp_wav)
|
||||
# 删除原文件
|
||||
os.remove(audio_path)
|
||||
wav_path = temp_wav
|
||||
else:
|
||||
wav_path = audio_path
|
||||
|
||||
with wave.open(wav_path, "rb") as wav_file:
|
||||
rate = wav_file.getframerate()
|
||||
|
||||
silk_path = tempfile.NamedTemporaryFile(
|
||||
suffix=".silk", delete=False, dir=temp_dir
|
||||
) as tmp_file:
|
||||
silk_path = tmp_file.name
|
||||
).name
|
||||
|
||||
try:
|
||||
duration = await asyncio.to_thread(
|
||||
@@ -96,5 +154,7 @@ async def wav_to_tencent_silk_base64(wav_path: str) -> str:
|
||||
|
||||
return silk_b64, duration # 已是秒
|
||||
finally:
|
||||
if os.path.exists(wav_path) and wav_path != audio_path:
|
||||
os.remove(wav_path)
|
||||
if os.path.exists(silk_path):
|
||||
os.remove(silk_path)
|
||||
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
import asyncio
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core import WEBUI_SK, DEMO_MODE
|
||||
from astrbot.core import DEMO_MODE
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@@ -80,5 +80,8 @@ class AuthRoute(Route):
|
||||
"username": username,
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=7),
|
||||
}
|
||||
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256")
|
||||
jwt_token = self.config["dashboard"].get("jwt_secret", None)
|
||||
if not jwt_token:
|
||||
raise ValueError("JWT secret is not set in the cmd_config.")
|
||||
token = jwt.encode(payload, jwt_token, algorithm="HS256")
|
||||
return token
|
||||
|
||||
@@ -41,10 +41,15 @@ class StatRoute(Route):
|
||||
await self.core_lifecycle.restart()
|
||||
return Response().ok().__dict__
|
||||
|
||||
def format_sec(self, sec: int):
|
||||
m, s = divmod(sec, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return f"{h}小时{m}分{s}秒"
|
||||
def _get_running_time_components(self, total_seconds: int):
|
||||
"""将总秒数转换为时分秒组件"""
|
||||
minutes, seconds = divmod(total_seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return {
|
||||
"hours": hours,
|
||||
"minutes": minutes,
|
||||
"seconds": seconds
|
||||
}
|
||||
|
||||
def is_default_cred(self):
|
||||
username = self.config["dashboard"]["username"]
|
||||
@@ -107,6 +112,11 @@ class StatRoute(Route):
|
||||
}
|
||||
plugin_info.append(info)
|
||||
|
||||
# 计算运行时长组件
|
||||
running_time = self._get_running_time_components(
|
||||
int(time.time()) - self.core_lifecycle.start_time
|
||||
)
|
||||
|
||||
stat_dict.update(
|
||||
{
|
||||
"platform": self.db_helper.get_grouped_base_stats(
|
||||
@@ -119,9 +129,7 @@ class StatRoute(Route):
|
||||
"plugin_count": len(plugins),
|
||||
"plugins": plugin_info,
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": self.format_sec(
|
||||
int(time.time()) - self.core_lifecycle.start_time
|
||||
),
|
||||
"running": running_time, # 现在返回时间组件而不是格式化的字符串
|
||||
"memory": {
|
||||
"process": psutil.Process().memory_info().rss >> 20,
|
||||
"system": psutil.virtual_memory().total >> 20,
|
||||
|
||||
@@ -10,7 +10,7 @@ from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
from .routes.route import RouteContext, Response
|
||||
from astrbot.core import logger, WEBUI_SK
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
@@ -62,6 +62,8 @@ class AstrBotDashboard:
|
||||
|
||||
self.shutdown_event = shutdown_event
|
||||
|
||||
self._init_jwt_secret()
|
||||
|
||||
async def srv_plug_route(self, subpath, *args, **kwargs):
|
||||
"""
|
||||
插件路由
|
||||
@@ -88,7 +90,7 @@ class AstrBotDashboard:
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
try:
|
||||
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
|
||||
g.username = payload["username"]
|
||||
except jwt.ExpiredSignatureError:
|
||||
r = jsonify(Response().error("Token 过期").__dict__)
|
||||
@@ -140,6 +142,15 @@ class AstrBotDashboard:
|
||||
except Exception as e:
|
||||
return f"获取进程信息失败: {str(e)}"
|
||||
|
||||
def _init_jwt_secret(self):
|
||||
if not self.config.get("dashboard", {}).get("jwt_secret", None):
|
||||
# 如果没有设置 JWT 密钥,则生成一个新的密钥
|
||||
jwt_secret = os.urandom(32).hex()
|
||||
self.config["dashboard"]["jwt_secret"] = jwt_secret
|
||||
self.config.save_config()
|
||||
logger.info("Initialized random JWT secret for dashboard.")
|
||||
self._jwt_secret = self.config["dashboard"]["jwt_secret"]
|
||||
|
||||
def run(self):
|
||||
ip_addr = []
|
||||
if p := os.environ.get("DASHBOARD_PORT"):
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# What's Changed
|
||||
|
||||
> 对 v3.5.16 的修订版本
|
||||
|
||||
1. 新增:支持接入 Slack
|
||||
2. 新增:支持接入 Discord
|
||||
3. 新增:支持接入 KOOK
|
||||
4. 新增:支持接入 VoceChat
|
||||
5. 新增:微信客服支持语音的收发
|
||||
6. 新增:实现 WebUI 的 i18n 模型,WebUI 现已支持 English。
|
||||
7. 新增:支持接入 GPT SoVITS
|
||||
8. 优化:支持通过引用 Bot 消息来唤醒 Bot
|
||||
9. 优化:WebUI 滚动条、侧边栏样式优化
|
||||
10. 优化:WebUI ChatBox 的样式优化,添加切换夜间模式按钮
|
||||
11. 优化:WebUI Chat 页面的 SSE 连接优化及一些其他样式优化
|
||||
12. 优化:钉钉发送图片支持使用 AstrBot 自带的文件服务器
|
||||
13. 优化:新建服务提供商时,如果没有添加 Key,会弹出警告提示框
|
||||
14. 修复:会话隔离模式下,WeChatPadPro 会话 ID 为自身 ID
|
||||
15. 修复:会话隔离模式下,WeChatPadPro 无法回复群聊消息
|
||||
16. 修复:使用 uvx 启动 AstrBot 时,插件依赖无法正常安装
|
||||
@@ -0,0 +1,18 @@
|
||||
# What's Changed
|
||||
|
||||
> 重构了大模型请求部分,如果发现此部分使用时有问题请提交 issue
|
||||
|
||||
1. 修复: 安装插件按钮被删除、无法自定义安装插件
|
||||
2. 修复: 环境变量中的代理地址无法生效
|
||||
1. 修复: randomize jwt secret
|
||||
2. 修复: 在 Node 消息段发送简单文本信息的问题
|
||||
1. 修复: QQ 官方机器人适配器使用 SessionController(会话控制)功能时机器人回复消息无法发送到聊天平台
|
||||
4. 修复: Discord 适配器无法优雅重载
|
||||
1. 修复: Telegram 适配器无法主动回复
|
||||
1. 修复: 仪表盘的『插件配置』中不显示代码编辑器
|
||||
3. 新增: Gemini TTS API
|
||||
1. 新增: 允许 html_render 方法传入 Playwright.screenshot 配置参数
|
||||
1. 优化: 修复 CommandFilter 支持对布尔类型进行解析
|
||||
4. 新增: WechatPadPro 发送 TTS 时 添加对 MP3 格式音频支持
|
||||
1. 重构: 将大模型请求部分抽象成 AgentRunner,提高可读性和可扩展性,工具调用结果支持持久化保存到数据库,完善 Agent 的多轮工具调用能力。
|
||||
1. 移除: LLMTuner 模型提供商适配器。请使用 Ollama 来加载微调模型
|
||||
@@ -1,6 +1,25 @@
|
||||
<script setup>
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
defineProps({
|
||||
metadata: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
iterable: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
metadataKey: {
|
||||
type: String,
|
||||
required: true
|
||||
}
|
||||
})
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
const dialog = ref(false)
|
||||
const currentEditingKey = ref('')
|
||||
@@ -307,35 +326,7 @@ function saveEditedContent() {
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import ListConfigItem from './ListConfigItem.vue';
|
||||
import { useI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
name: 'AstrBotConfig',
|
||||
components: {
|
||||
ListConfigItem
|
||||
},
|
||||
setup() {
|
||||
const { t } = useI18n();
|
||||
return { t };
|
||||
},
|
||||
props: {
|
||||
metadata: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
iterable: {
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
metadataKey: {
|
||||
type: String,
|
||||
required: true
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.config-section {
|
||||
@@ -464,4 +455,4 @@ export default {
|
||||
padding: 4px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</style>
|
||||
|
||||
@@ -21,7 +21,8 @@
|
||||
},
|
||||
"runningTime": {
|
||||
"title": "Uptime",
|
||||
"subtitle": "System uptime duration"
|
||||
"subtitle": "System uptime duration",
|
||||
"format": "{hours}h {minutes}m {seconds}s"
|
||||
},
|
||||
"memoryUsage": {
|
||||
"title": "Memory Usage",
|
||||
|
||||
@@ -31,7 +31,8 @@
|
||||
"saveAndClose": "Save and Close",
|
||||
"cancel": "Cancel",
|
||||
"actions": "Actions",
|
||||
"back": "Back"
|
||||
"back": "Back",
|
||||
"selectFile": "Select File"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "Enabled",
|
||||
@@ -106,6 +107,11 @@
|
||||
"uninstall": {
|
||||
"title": "Confirm Deletion",
|
||||
"message": "Are you sure you want to delete this extension?"
|
||||
},
|
||||
"install": {
|
||||
"title": "Install Extension",
|
||||
"fromFile": "Install from File",
|
||||
"fromUrl": "Install from URL"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -127,7 +133,8 @@
|
||||
"hasUpdate": "New version available:",
|
||||
"confirmDelete": "Are you sure you want to delete this extension?",
|
||||
"fillUrlOrFile": "Please fill in extension URL or upload extension file",
|
||||
"dontFillBoth": "Please don't fill in both extension URL and upload file"
|
||||
"dontFillBoth": "Please don't fill in both extension URL and upload file",
|
||||
"supportedFormats": "Supports .zip extension files"
|
||||
},
|
||||
"upload": {
|
||||
"fromFile": "Install from File",
|
||||
|
||||
@@ -21,7 +21,14 @@
|
||||
"refresh": "Refresh",
|
||||
"cancel": "Cancel",
|
||||
"save": "Save",
|
||||
"addPlatform": "Add Platform Adapter"
|
||||
"addPlatform": "Add Platform Adapter",
|
||||
"connectTitle": "Connect {name}",
|
||||
"viewTutorial": "View Tutorial",
|
||||
"idConflict": {
|
||||
"title": "ID Conflict Warning",
|
||||
"message": "Detected duplicate ID \"{id}\". Please use a new ID.",
|
||||
"confirm": "OK"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
"updateSuccess": "Update successful!",
|
||||
|
||||
@@ -21,7 +21,8 @@
|
||||
},
|
||||
"runningTime": {
|
||||
"title": "运行时间",
|
||||
"subtitle": "系统已运行时长"
|
||||
"subtitle": "系统已运行时长",
|
||||
"format": "{hours}小时{minutes}分{seconds}秒"
|
||||
},
|
||||
"memoryUsage": {
|
||||
"title": "内存占用",
|
||||
|
||||
@@ -31,7 +31,8 @@
|
||||
"saveAndClose": "保存并关闭",
|
||||
"cancel": "取消",
|
||||
"actions": "操作",
|
||||
"back": "返回"
|
||||
"back": "返回",
|
||||
"selectFile": "选择文件"
|
||||
},
|
||||
"status": {
|
||||
"enabled": "启用",
|
||||
@@ -106,6 +107,11 @@
|
||||
"uninstall": {
|
||||
"title": "删除确认",
|
||||
"message": "你确定要删除当前插件吗?"
|
||||
},
|
||||
"install": {
|
||||
"title": "安装插件",
|
||||
"fromFile": "从文件安装",
|
||||
"fromUrl": "从链接安装"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -127,7 +133,8 @@
|
||||
"hasUpdate": "有新版本:",
|
||||
"confirmDelete": "确定要删除插件吗?",
|
||||
"fillUrlOrFile": "请填写插件链接或上传插件文件",
|
||||
"dontFillBoth": "请不要同时填写插件链接和上传插件文件"
|
||||
"dontFillBoth": "请不要同时填写插件链接和上传文件",
|
||||
"supportedFormats": "支持 .zip 格式的插件文件"
|
||||
},
|
||||
"upload": {
|
||||
"fromFile": "从文件安装",
|
||||
|
||||
@@ -21,7 +21,14 @@
|
||||
"refresh": "刷新",
|
||||
"cancel": "取消",
|
||||
"save": "保存",
|
||||
"addPlatform": "添加平台适配器"
|
||||
"addPlatform": "添加平台适配器",
|
||||
"connectTitle": "接入 {name}",
|
||||
"viewTutorial": "查看接入教程",
|
||||
"idConflict": {
|
||||
"title": "ID 冲突警告",
|
||||
"message": "检测到 ID \"{id}\" 重复。请使用一个新的 ID。",
|
||||
"confirm": "好的"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
"updateSuccess": "更新成功!",
|
||||
|
||||
@@ -29,7 +29,7 @@ let dashboardCurrentVersion = ref('');
|
||||
let version = ref('');
|
||||
let releases = ref([]);
|
||||
let devCommits = ref([]);
|
||||
|
||||
let updatingDashboardLoading = ref(false);
|
||||
let installLoading = ref(false);
|
||||
|
||||
let tab = ref(0);
|
||||
@@ -217,6 +217,7 @@ function switchVersion(version: string) {
|
||||
}
|
||||
|
||||
function updateDashboard() {
|
||||
updatingDashboardLoading.value = true;
|
||||
updateStatus.value = t('core.header.updateDialog.status.updating');
|
||||
axios.post('/api/update/dashboard')
|
||||
.then((res) => {
|
||||
@@ -230,7 +231,9 @@ function updateDashboard() {
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
}).finally(() => {
|
||||
updatingDashboardLoading.value = false;
|
||||
});
|
||||
}
|
||||
|
||||
function toggleDarkMode() {
|
||||
@@ -416,7 +419,7 @@ commonStore.getStartTime();
|
||||
</div>
|
||||
|
||||
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()"
|
||||
:disabled="!dashboardHasNewVersion">
|
||||
:disabled="!dashboardHasNewVersion" :loading="updatingDashboardLoading">
|
||||
{{ t('core.header.updateDialog.dashboardUpdate.downloadAndUpdate') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
@@ -39,7 +39,8 @@ const PurpleThemeDark: ThemeTypes = {
|
||||
background: '#111111',
|
||||
overlay: '#111111aa',
|
||||
codeBg: '#282833',
|
||||
code: '#ffffffdd'
|
||||
code: '#ffffffdd',
|
||||
chatMessageBubble: '#2d2e30',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -39,7 +39,8 @@ const PurpleTheme: ThemeTypes = {
|
||||
background: '#f9fafcf4',
|
||||
overlay: '#ffffffaa',
|
||||
codeBg: '#f5f0ff',
|
||||
code: '#673ab7'
|
||||
code: '#673ab7',
|
||||
chatMessageBubble: '#e7ebf4',
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -35,5 +35,6 @@ export type ThemeTypes = {
|
||||
secondary200?: string;
|
||||
codeBg?: string;
|
||||
code?: string;
|
||||
chatMessageBubble?: string;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -175,7 +175,7 @@
|
||||
<div v-else class="message-list">
|
||||
<div class="message-item fade-in" v-for="(msg, index) in messages" :key="index">
|
||||
<!-- 用户消息 -->
|
||||
<div v-if="msg.type == 'user'" class="user-message">
|
||||
<div v-if="msg.type == 'user'" class="user-message" style="background-color: var(--v-theme-chatMessageBubble);">
|
||||
<div class="message-bubble user-bubble">
|
||||
<span>{{ msg.message }}</span>
|
||||
|
||||
@@ -195,15 +195,12 @@
|
||||
</audio>
|
||||
</div>
|
||||
</div>
|
||||
<v-avatar class="user-avatar" color="deep-purple-lighten-3" size="36">
|
||||
<v-icon icon="mdi-account" />
|
||||
</v-avatar>
|
||||
</div>
|
||||
|
||||
<!-- 机器人消息 -->
|
||||
<div v-else class="bot-message">
|
||||
<v-avatar class="bot-avatar" color="deep-purple" size="36">
|
||||
<span class="text-h6">✨</span>
|
||||
<v-avatar class="bot-avatar" size="36">
|
||||
<span class="text-h2">✨</span>
|
||||
</v-avatar>
|
||||
<div class="message-bubble bot-bubble">
|
||||
<div v-html="marked(msg.message)" class="markdown-content"></div>
|
||||
@@ -1574,28 +1571,25 @@ export default {
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
padding: 12px 16px;
|
||||
border-radius: 18px;
|
||||
padding: 8px 16px;
|
||||
border-radius: 12px;
|
||||
max-width: 80%;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.user-bubble {
|
||||
background-color: var(--v-theme-background);
|
||||
background-color: var(--v-theme-chatMessageBubble);
|
||||
color: var(--v-theme-primaryText);
|
||||
border-top-right-radius: 4px;
|
||||
}
|
||||
|
||||
.bot-bubble {
|
||||
background-color: var(--v-theme-surface);
|
||||
border: 1px solid var(--v-theme-border);
|
||||
color: var(--v-theme-primaryText);
|
||||
border-top-left-radius: 4px;
|
||||
}
|
||||
|
||||
.user-avatar,
|
||||
.bot-avatar {
|
||||
align-self: flex-end;
|
||||
align-self: flex-start;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
/* 附件样式 */
|
||||
|
||||
@@ -13,6 +13,7 @@ import { ref, computed, onMounted, reactive } from 'vue';
|
||||
const commonStore = useCommonStore();
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/extension');
|
||||
const fileInput = ref(null);
|
||||
const activeTab = ref('installed');
|
||||
const extension_data = reactive({
|
||||
data: [],
|
||||
@@ -61,6 +62,7 @@ const loading_ = ref(false);
|
||||
const extension_url = ref("");
|
||||
const dialog = ref(false);
|
||||
const upload_file = ref(null);
|
||||
const uploadTab = ref('file');
|
||||
const showPluginFullName = ref(false);
|
||||
const marketSearch = ref("");
|
||||
const filterKeys = ['name', 'desc', 'author'];
|
||||
@@ -629,16 +631,20 @@ onMounted(async () => {
|
||||
</v-btn>
|
||||
</v-btn-group>
|
||||
|
||||
<v-btn @click="toggleShowReserved" prepend-icon="mdi-eye-settings-outline"
|
||||
:color="showReserved ? 'primary' : undefined" :variant="showReserved ? 'flat' : 'outlined'"
|
||||
class="flex-shrink-0">
|
||||
<v-btn class="ml-2" variant="tonal" @click="toggleShowReserved">
|
||||
<v-icon>{{ showReserved ? 'mdi-eye-off' : 'mdi-eye' }}</v-icon>
|
||||
{{ showReserved ? tm('buttons.hideSystemPlugins') : tm('buttons.showSystemPlugins') }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn prepend-icon="mdi-tune-vertical" color="primary" variant="outlined"
|
||||
@click="getPlatformEnableConfig" class="flex-shrink-0">
|
||||
<v-btn class="ml-2" variant="tonal" @click="getPlatformEnableConfig">
|
||||
<v-icon>mdi-cog</v-icon>
|
||||
{{ tm('buttons.platformConfig') }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn class="ml-2" color="primary" variant="tonal" @click="dialog = true">
|
||||
<v-icon>mdi-plus</v-icon>
|
||||
{{ tm('buttons.install') }}
|
||||
</v-btn>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="auto" md="6" class="ml-auto">
|
||||
@@ -808,12 +814,16 @@ onMounted(async () => {
|
||||
|
||||
<!-- <small style="color: var(--v-theme-secondaryText);">每个插件都是作者无偿提供的的劳动成果。如果您喜欢某个插件,请 Star!</small> -->
|
||||
|
||||
<v-btn icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;" @click="dialog = true"
|
||||
color="darkprimary">
|
||||
</v-btn>
|
||||
|
||||
<div v-if="pinnedPlugins.length > 0" class="mt-4">
|
||||
<h2>{{ tm('market.recommended') }}</h2>
|
||||
<v-row style="margin-top: 8px;">
|
||||
<v-col cols="12" md="6" lg="6" v-for="plugin in pinnedPlugins" :key="plugin.name">
|
||||
<ExtensionCard :extension="plugin" class="h-120 rounded-lg" market-mode="true" :highlight="true"
|
||||
@install="extension_url = plugin.repo; newExtension()" @view-readme="open(plugin.repo)">
|
||||
@install="extension_url = plugin.repo; dialog = true; uploadTab = 'url'" @view-readme="open(plugin.repo)">
|
||||
</ExtensionCard>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -866,7 +876,7 @@ onMounted(async () => {
|
||||
</template>
|
||||
<template v-slot:item.actions="{ item }">
|
||||
<v-btn v-if="!item.installed" class="text-none mr-2" size="x-small" variant="flat"
|
||||
@click="extension_url = item.repo; newExtension()">
|
||||
@click="extension_url = item.repo; dialog = true; uploadTab = 'url'">
|
||||
<v-icon>mdi-download</v-icon></v-btn>
|
||||
<v-btn v-else class="text-none mr-2" size="x-small" variant="flat" border
|
||||
disabled><v-icon>mdi-check</v-icon></v-btn>
|
||||
@@ -1067,6 +1077,75 @@ onMounted(async () => {
|
||||
|
||||
<ReadmeDialog v-model:show="readmeDialog.show" :plugin-name="readmeDialog.pluginName"
|
||||
:repo-url="readmeDialog.repoUrl" />
|
||||
|
||||
<!-- 上传插件对话框 -->
|
||||
<v-dialog v-model="dialog" width="500">
|
||||
<v-card>
|
||||
<v-card-title class="text-h5">{{ tm('dialogs.install.title') }}</v-card-title>
|
||||
<v-card-text>
|
||||
<v-tabs v-model="uploadTab">
|
||||
<v-tab value="file">{{ tm('dialogs.install.fromFile') }}</v-tab>
|
||||
<v-tab value="url">{{ tm('dialogs.install.fromUrl') }}</v-tab>
|
||||
</v-tabs>
|
||||
|
||||
<v-window v-model="uploadTab" class="mt-4">
|
||||
<v-window-item value="file">
|
||||
<div class="d-flex flex-column align-center justify-center pa-4">
|
||||
<v-file-input
|
||||
ref="fileInput"
|
||||
v-model="upload_file"
|
||||
:label="tm('upload.selectFile')"
|
||||
accept=".zip"
|
||||
hide-details
|
||||
hide-input
|
||||
class="d-none"
|
||||
></v-file-input>
|
||||
|
||||
<v-btn
|
||||
color="primary"
|
||||
size="large"
|
||||
prepend-icon="mdi-upload"
|
||||
@click="$refs.fileInput.click()"
|
||||
>
|
||||
{{ tm('buttons.selectFile') }}
|
||||
</v-btn>
|
||||
|
||||
<div class="text-body-2 text-medium-emphasis mt-2">
|
||||
{{ tm('messages.supportedFormats') }}
|
||||
</div>
|
||||
|
||||
<div v-if="upload_file" class="mt-4 text-center">
|
||||
<v-chip color="primary" size="large" closable @click:close="upload_file = null">
|
||||
{{ upload_file.name }}
|
||||
<template v-slot:append>
|
||||
<span class="text-caption ml-2">({{ (upload_file.size / 1024).toFixed(1) }}KB)</span>
|
||||
</template>
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
</v-window-item>
|
||||
|
||||
<v-window-item value="url">
|
||||
<div class="pa-4">
|
||||
<v-text-field
|
||||
v-model="extension_url"
|
||||
:label="tm('upload.enterUrl')"
|
||||
variant="outlined"
|
||||
prepend-inner-icon="mdi-link"
|
||||
hide-details
|
||||
placeholder="https://github.com/username/repo"
|
||||
></v-text-field>
|
||||
</div>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="dialog = false">{{ tm('buttons.cancel') }}</v-btn>
|
||||
<v-btn color="primary" variant="text" @click="newExtension">{{ tm('buttons.install') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -93,7 +93,7 @@
|
||||
<v-card variant="outlined" hover class="platform-card" @click="selectPlatformTemplate(name)">
|
||||
<div class="platform-card-content">
|
||||
<div class="platform-card-text">
|
||||
<v-card-title class="platform-card-title">接入 {{ name }}</v-card-title>
|
||||
<v-card-title class="platform-card-title">{{ tm('dialog.connectTitle', { name }) }}</v-card-title>
|
||||
<v-card-text class="text-caption text-medium-emphasis platform-card-description">
|
||||
{{ getPlatformDescription(template, name) }}
|
||||
</v-card-text>
|
||||
@@ -139,7 +139,7 @@
|
||||
<v-col cols="12" class="text-center">
|
||||
<v-btn color="info" variant="outlined" @click="openTutorial">
|
||||
<v-icon start>mdi-book-open-variant</v-icon>
|
||||
查看接入教程
|
||||
{{ tm('dialog.viewTutorial') }}
|
||||
</v-btn>
|
||||
</v-col>
|
||||
</v-row>
|
||||
@@ -172,14 +172,14 @@
|
||||
<v-card>
|
||||
<v-card-title class="text-h6 bg-warning d-flex align-center">
|
||||
<v-icon start class="me-2">mdi-alert-circle-outline</v-icon>
|
||||
ID 冲突警告
|
||||
{{ tm('dialog.idConflict.title') }}
|
||||
</v-card-title>
|
||||
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
|
||||
检测到 ID "{{ conflictId }}" 重复。请使用一个新的 ID。
|
||||
{{ tm('dialog.idConflict.message', { id: conflictId }) }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">好的</v-btn>
|
||||
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -275,25 +275,25 @@ export default {
|
||||
|
||||
getPlatformIcon(name) {
|
||||
if (name.includes('QQ')) {
|
||||
return '/src/assets/images/platform_logos/qq.png'
|
||||
return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href
|
||||
} else if (name.includes('企业微信')) {
|
||||
return '/src/assets/images/platform_logos/wecom.png'
|
||||
return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href
|
||||
} else if (name.includes('微信')) {
|
||||
return '/src/assets/images/platform_logos/wechat.png';
|
||||
return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href
|
||||
} else if (name.includes('Lark')) {
|
||||
return '/src/assets/images/platform_logos/lark.png';
|
||||
return new URL('@/assets/images/platform_logos/lark.png', import.meta.url).href
|
||||
} else if (name.includes('DingTalk')) {
|
||||
return '/src/assets/images/platform_logos/dingtalk.svg';
|
||||
return new URL('@/assets/images/platform_logos/dingtalk.svg', import.meta.url).href
|
||||
} else if (name.includes('Telegram')) {
|
||||
return '/src/assets/images/platform_logos/telegram.svg';
|
||||
return new URL('@/assets/images/platform_logos/telegram.svg', import.meta.url).href
|
||||
} else if (name.includes('Discord')) {
|
||||
return '/src/assets/images/platform_logos/discord.svg';
|
||||
return new URL('@/assets/images/platform_logos/discord.svg', import.meta.url).href
|
||||
} else if (name.includes('Slack')) {
|
||||
return '/src/assets/images/platform_logos/slack.svg';
|
||||
return new URL('@/assets/images/platform_logos/slack.svg', import.meta.url).href
|
||||
} else if (name.includes('kook')) {
|
||||
return '/src/assets/images/platform_logos/kook.png';
|
||||
return new URL('@/assets/images/platform_logos/kook.png', import.meta.url).href
|
||||
} else if (name.includes('vocechat')) {
|
||||
return '/src/assets/images/platform_logos/vocechat.png';
|
||||
return new URL('@/assets/images/platform_logos/vocechat.png', import.meta.url).href
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -414,7 +414,6 @@ export default {
|
||||
"anthropic_chat_completion": "chat_completion",
|
||||
"googlegenai_chat_completion": "chat_completion",
|
||||
"zhipu_chat_completion": "chat_completion",
|
||||
"llm_tuner": "chat_completion",
|
||||
"dify": "chat_completion",
|
||||
"dashscope": "chat_completion",
|
||||
"openai_whisper_api": "speech_to_text",
|
||||
|
||||
@@ -30,7 +30,16 @@ export default {
|
||||
},
|
||||
computed: {
|
||||
formattedTime() {
|
||||
return this.stat?.running || this.t('status.loading');
|
||||
if (!this.stat?.running) {
|
||||
return this.t('status.loading');
|
||||
}
|
||||
|
||||
const { hours, minutes, seconds } = this.stat.running;
|
||||
return this.t('stats.runningTime.format', {
|
||||
hours,
|
||||
minutes,
|
||||
seconds
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "3.5.16"
|
||||
version = "3.5.18"
|
||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
+2
-1
@@ -39,4 +39,5 @@ faiss-cpu
|
||||
aiosqlite
|
||||
nh3
|
||||
py-cord>=2.6.1
|
||||
slack-sdk
|
||||
slack-sdk
|
||||
pydub
|
||||
@@ -204,7 +204,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "astrbot"
|
||||
version = "3.5.16"
|
||||
version = "3.5.17"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiocqhttp" },
|
||||
@@ -636,34 +636,34 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "faiss-cpu"
|
||||
version = "1.11.0"
|
||||
version = "1.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/9a/e33fc563f007924dd4ec3c5101fe5320298d6c13c158a24a9ed849058569/faiss_cpu-1.11.0.tar.gz", hash = "sha256:44877b896a2b30a61e35ea4970d008e8822545cb340eca4eff223ac7f40a1db9", size = 70218 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/1b/6fe5dbe5be0240cfd82b52bd7c186655c578d935c0ce2e713c100e6f8cce/faiss_cpu-1.10.0.tar.gz", hash = "sha256:5bdca555f24bc036f4d67f8a5a4d6cc91b8d2126d4e78de496ca23ccd46e479d", size = 69159 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/e5/7490368ec421e44efd60a21aa88d244653c674d8d6ee6bc455d8ee3d02ed/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1995119152928c68096b0c1e5816e3ee5b1eebcf615b80370874523be009d0f6", size = 3307996 },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/ac/a94fbbbf4f38c2ad11862af92c071ff346630ebf33f3d36fe75c3817c2f0/faiss_cpu-1.11.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:788d7bf24293fdecc1b93f1414ca5cc62ebd5f2fecfcbb1d77f0e0530621c95d", size = 7886309 },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/48/ad79f34f1b9eba58c32399ad4fbedec3f2a717d72fb03648e906aab48a52/faiss_cpu-1.11.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:73408d52429558f67889581c0c6d206eedcf6fabe308908f2bdcd28fd5e8be4a", size = 3778443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/67/3c6b94dd3223a8ecaff1c10c11b4ac6f3f13f1ba8ab6b6109c24b6e9b23d/faiss_cpu-1.11.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1f53513682ca94c76472544fa5f071553e428a1453e0b9755c9673f68de45f12", size = 31295174 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/2c/d843256aabdb7f20f0f87f61efe3fb7c2c8e7487915f560ba523cfcbab57/faiss_cpu-1.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:30489de0356d3afa0b492ca55da164d02453db2f7323c682b69334fde9e8d48e", size = 15003860 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/83/8aefc4d07624a868e046cc23ede8a59bebda57f09f72aee2150ef0855a82/faiss_cpu-1.11.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a90d1c81d0ecf2157e1d2576c482d734d10760652a5b2fcfa269916611e41f1c", size = 3307997 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/64/f97e91d89dc6327e08f619fe387d7d9945bc4be3b0f1ca1e494a41c92ebe/faiss_cpu-1.11.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:2c39a388b059fb82cd97fbaa7310c3580ced63bf285be531453bfffbe89ea3dd", size = 7886308 },
|
||||
{ url = "https://files.pythonhosted.org/packages/44/0a/7c17b6df017b0bc127c6aa4066b028281e67ab83d134c7433c4e75cd6bb6/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a4e3433ffc7f9b8707a7963db04f8676a5756868d325644db2db9d67a618b7a0", size = 3778441 },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/45/7c85551025d9f0237d891b5cffdc5d4a366011d53b4b0a423b972cc52cea/faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:926645f1b6829623bc88e93bc8ca872504d604718ada3262e505177939aaee0a", size = 31295136 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/9a/accade34b8668b21206c0c4cf0b96cd0b750b693ba5b255c1c10cfee460f/faiss_cpu-1.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:931db6ed2197c03a7fdf833b057c13529afa2cec8a827aa081b7f0543e4e671b", size = 15003710 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/d3/7178fa07047fd770964a83543329bb5e3fc1447004cfd85186ccf65ec3ee/faiss_cpu-1.11.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:356437b9a46f98c25831cdae70ca484bd6c05065af6256d87f6505005e9135b9", size = 3313807 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/71/25f5f7b70a9f22a3efe19e7288278da460b043a3b60ad98e4e47401ed5aa/faiss_cpu-1.11.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c4a3d35993e614847f3221c6931529c0bac637a00eff0d55293e1db5cb98c85f", size = 7913537 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/c8/a5cb8466c981ad47750e1d5fda3d4223c82f9da947538749a582b3a2d35c/faiss_cpu-1.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f9af33e0b8324e8199b93eb70ac4a951df02802a9dcff88e9afc183b11666f0", size = 3785180 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/37/eaf15a7d80e1aad74f56cf737b31b4547a1a664ad3c6e4cfaf90e82454a8/faiss_cpu-1.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:48b7e7876829e6bdf7333041800fa3c1753bb0c47e07662e3ef55aca86981430", size = 31287630 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/5c/902a78347e9c47baaf133e47863134e564c39f9afe105795b16ee986b0df/faiss_cpu-1.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:bdc199311266d2be9d299da52361cad981393327b2b8aa55af31a1b75eaaf522", size = 15005398 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/90/d2329ce56423cc61f4c20ae6b4db001c6f88f28bf5a7ef7f8bbc246fd485/faiss_cpu-1.11.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:0c98e5feff83b87348e44eac4d578d6f201780dae6f27f08a11d55536a20b3a8", size = 3313807 },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/14/8af8f996d54e6097a86e6048b1a2c958c52dc985eb4f935027615079939e/faiss_cpu-1.11.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:796e90389427b1c1fb06abdb0427bb343b6350f80112a2e6090ac8f176ff7416", size = 7913539 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/2b/437c2f36c3aa3cffe041479fced1c76420d3e92e1f434f1da3be3e6f32b1/faiss_cpu-1.11.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b6e355dda72b3050991bc32031b558b8f83a2b3537a2b9e905a84f28585b47e", size = 3785181 },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/75/955527414371843f558234df66fa0b62c6e86e71e4022b1be9333ac6004c/faiss_cpu-1.11.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6c482d07194638c169b4422774366e7472877d09181ea86835e782e6304d4185", size = 31287635 },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/51/35b7a3f47f7859363a367c344ae5d415ea9eda65db0a7d497c7ea2c0b576/faiss_cpu-1.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:13eac45299532b10e911bff1abbb19d1bf5211aa9e72afeade653c3f1e50e042", size = 15005455 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/56/87eb506d8634f08fc7c63d1ca5631aeec7d6b9afbfabedf2cb7a2a804b13/faiss_cpu-1.10.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:6693474be296a7142ade1051ea18e7d85cedbfdee4b7eac9c52f83fed0467855", size = 7693034 },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/46/f4d9de34ed1b06300b1a75b824d4857963216f5826de33f291af78088e39/faiss_cpu-1.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:70ebe60a560414dc8dd6cfe8fed105c8f002c0d11f765f5adfe8d63d42c0467f", size = 3234656 },
|
||||
{ url = "https://files.pythonhosted.org/packages/74/3a/e146861019d9290e0198b3470b8d13a658c3b5f228abefc3658ce0afd63d/faiss_cpu-1.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:74c5712d4890f15c661ab7b1b75867812e9596e1469759956fad900999bedbb5", size = 3663789 },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/40/624f0002bb777e37aac1aadfadec1eb4391be6ad05b7fcfbf66049b99a48/faiss_cpu-1.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:473d158fbd638d6ad5fb64469ba79a9f09d3494b5f4e8dfb4f40ce2fc335dca4", size = 30673545 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/39/298ffcbefd899e84a43e63df217a6dc800d52bca37ebe0d1155ff367886a/faiss_cpu-1.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:dcd0cb2ec84698cbe3df9ed247d2392f09bda041ad34b92d38fa916cd019ad4b", size = 13684176 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/93/81800f41cb2c719c199d3eb534fcc154853123261d841e37482e8e468619/faiss_cpu-1.10.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8ff6924b0f00df278afe70940ae86302066466580724c2f3238860039e9946f1", size = 7693037 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/83/fc9028f6d6aec2c2f219f53a5d4a2b279434715643242e59a2e9755b1ce0/faiss_cpu-1.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cb80b530a9ded44a7d4031a7355a237aaa0ff1f150c1176df050e0254ea5f6f6", size = 3234657 },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/45/588a02e60daa73f6052611334fbbdffcedf37122320f1c91cb90f3e69b96/faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:7a9fef4039ed877d40e41d5563417b154c7f8cd57621487dad13c4eb4f32515f", size = 3663710 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/cf/9caa08ca4e21ab935f82be0713e5d60566140414c3fff7932d9427c8fd72/faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:49b6647aa9e159a2c4603cbff2e1b313becd98ad6e851737ab325c74fe8e0278", size = 30673629 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/2d/d2a4171a9cca9a7c04cd9d6f9441a37f1e0558724b90bf7fc7db08553601/faiss_cpu-1.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:6f8c0ef8b615c12c7bf612bd1fc51cffa49c1ddaa6207c6981f01ab6782e6b3b", size = 13683966 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/cc/f6aa1288dbb40b2a4f101d16900885e056541f37d8d08ec70462e92cf277/faiss_cpu-1.10.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:2aca486fe2d680ea64a18d356206c91ff85db99fd34c19a757298c67c23262b1", size = 7720242 },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/56/40901306324a17fbc1eee8a6e86ba67bd99a67e768ce9908f271e648e9e0/faiss_cpu-1.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c1108a4059c66c37c403183e566ca1ed0974a6af7557c92d49207639aab661bc", size = 3239223 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2e/34/5b1463c450c9a6de3109caf8f38fbf0c329ef940ed1973fcf8c8ec7fa27e/faiss_cpu-1.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:449f3eb778d6d937e01a16a3170de4bb8aabfe87c7cb479b458fb790276310c5", size = 3671461 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/d9/0b78c474289f23b31283d8fb64c8e6a522a7fa47b131a3c6c141c8e6639d/faiss_cpu-1.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9899c340f92bd94071d6faf4bef0ccb5362843daea42144d4ba857a2a1f67511", size = 30663859 },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/f0/194727b9e6e282e2877bc001ba886228f6af52e9a6730bbdb223e38591c3/faiss_cpu-1.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:345a52dbfa980d24b93c94410eadf82d1eef359c6a42e5e0768cca96539f1c3c", size = 13687087 },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/25/23239a83142faa319c4f8c025e25fec6cccc7418995eba3515218a57a45b/faiss_cpu-1.10.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:cb8473d69c3964c1bf3f8eb3e04287bb3275f536e6d9635ef32242b5f506b45d", size = 7720240 },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/f1/0e979277831af337739dbacf386d8a359a05eef9642df23d36e6c7d1b1a9/faiss_cpu-1.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82ca5098de694e7b8495c1a8770e2c08df6e834922546dad0ae1284ff519ced6", size = 3239224 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/fa/c2ad85b017a5754f6cdb09c179f8c4f4198d2a264046a8daa7a4d080521f/faiss_cpu-1.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:035e4d797e2db7fc0d0c90531d4a655d089ad5d1382b7a49358c1f2307b3a309", size = 3671236 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/9b/759962f2c34800058f6a76457df3b0ab93b24f383650ea1ef0231acd322c/faiss_cpu-1.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e02af3696a6b9e1f9072e502f48095a305de2163c42ceb1f6f6b1db9e7ffe574", size = 30663948 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/9a/6c496e0189897761978653177386452d62f4060579413d109bff05f458f2/faiss_cpu-1.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:e71f7e24d5b02d3a51df47b77bd10f394a1b48a8331d5c817e71e9e27a8a75ac", size = 13687212 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user