Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c8c28ebd5 | |||
| 524285f767 | |||
| c2a34475f1 | |||
| a69195a02b | |||
| 19d7438499 | |||
| ccb380ce06 | |||
| a35c439bbd | |||
| 09d1f96603 | |||
| 26aa18d980 | |||
| d10b542797 | |||
| ce4e4fb8dd |
@@ -6,7 +6,7 @@ import os
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.1.7"
|
||||
VERSION = "4.2.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
@@ -869,6 +869,18 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
"provider": "coze",
|
||||
"provider_type": "chat_completion",
|
||||
"type": "coze",
|
||||
"enable": True,
|
||||
"coze_api_key": "",
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
"id": "dashscope",
|
||||
"provider": "dashscope",
|
||||
@@ -1735,6 +1747,26 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
|
||||
"obvious": True,
|
||||
},
|
||||
"coze_api_key": {
|
||||
"description": "Coze API Key",
|
||||
"type": "string",
|
||||
"hint": "Coze API 密钥,用于访问 Coze 服务。",
|
||||
},
|
||||
"bot_id": {
|
||||
"description": "Bot ID",
|
||||
"type": "string",
|
||||
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
|
||||
},
|
||||
"coze_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
|
||||
},
|
||||
"auto_save_history": {
|
||||
"description": "由 Coze 管理对话记录",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
|
||||
@@ -87,17 +87,25 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
f = False
|
||||
if not conversation_id:
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
f = True
|
||||
if conversation_id:
|
||||
await self.db.delete_conversation(cid=conversation_id)
|
||||
if f:
|
||||
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
|
||||
if curr_cid == conversation_id:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
||||
"""删除会话的所有对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
"""
|
||||
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
|
||||
@@ -154,6 +154,11 @@ class BaseDatabase(abc.ABC):
|
||||
"""Delete a conversation by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
"""Delete all conversations for a specific user."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
|
||||
@@ -249,6 +249,14 @@ class SQLiteDatabase(BaseDatabase):
|
||||
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
|
||||
)
|
||||
|
||||
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(ConversationV2.user_id == user_id)
|
||||
)
|
||||
|
||||
async def insert_platform_message_history(
|
||||
self,
|
||||
platform_id,
|
||||
|
||||
@@ -291,13 +291,6 @@ async def run_agent(
|
||||
else:
|
||||
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||
return
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -524,6 +517,14 @@ class LLMRequestSubStage(Stage):
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_webchat(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
@@ -536,7 +537,23 @@ class LLMRequestSubStage(Stage):
|
||||
latest_pair = messages[-2:]
|
||||
if not latest_pair:
|
||||
return
|
||||
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
|
||||
content = latest_pair[0].get("content", "")
|
||||
if isinstance(content, list):
|
||||
# 多模态
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image":
|
||||
text_parts.append("[图片]")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
cleaned_text = "User: " + " ".join(text_parts).strip()
|
||||
elif isinstance(content, str):
|
||||
cleaned_text = "User: " + content.strip()
|
||||
else:
|
||||
return
|
||||
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt="You are expert in summarizing user's query.",
|
||||
|
||||
@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
|
||||
"""检查会话是否整体启用"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
pass
|
||||
self.ctx = ctx
|
||||
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
conv_id = await self.conv_mgr.get_curr_conversation_id(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
if not conv_id:
|
||||
await self.conv_mgr.new_conversation(
|
||||
event.unified_msg_origin, platform_id=event.get_platform_id()
|
||||
)
|
||||
|
||||
event.stop_event()
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra():
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
|
||||
"parsed_params"
|
||||
)
|
||||
is_group_cmd_handler = any(
|
||||
isinstance(f, CommandGroupFilter) for f in handler.event_filters
|
||||
)
|
||||
if not is_group_cmd_handler:
|
||||
activated_handlers.append(handler)
|
||||
if "parsed_params" in event.get_extra(default={}):
|
||||
handlers_parsed_params[handler.handler_full_name] = (
|
||||
event.get_extra("parsed_params")
|
||||
)
|
||||
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
def __init__(
|
||||
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""是否唤醒(是否通过 WakingStage)"""
|
||||
self.is_at_or_wake_command = False
|
||||
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
|
||||
self._extras = {}
|
||||
self._extras: dict[str, Any] = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.id,
|
||||
message_type=message_obj.type,
|
||||
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||
self._result: MessageEventResult = None
|
||||
self._result: MessageEventResult | None = None
|
||||
"""消息事件的结果"""
|
||||
|
||||
self._has_send_oper = False
|
||||
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(self, key=None):
|
||||
def get_extra(
|
||||
self, key: str | None = None, default: _VT = None
|
||||
) -> dict[str, Any] | _VT:
|
||||
"""
|
||||
获取额外的信息。
|
||||
"""
|
||||
if key is None:
|
||||
return self._extras
|
||||
return self._extras.get(key, None)
|
||||
return self._extras.get(key, default)
|
||||
|
||||
def clear_extra(self):
|
||||
"""
|
||||
|
||||
@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"wecom",
|
||||
"wecom 适配器",
|
||||
id=self.config.get("id", "wecom"),
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"weixin_official_account",
|
||||
"微信公众平台 适配器",
|
||||
id=self.config.get("id", "weixin_official_account"),
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -234,6 +234,8 @@ class ProviderManager:
|
||||
)
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify as ProviderDify
|
||||
case "coze":
|
||||
from .sources.coze_source import ProviderCoze as ProviderCoze
|
||||
case "dashscope":
|
||||
from .sources.dashscope_source import (
|
||||
ProviderDashscope as ProviderDashscope,
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import io
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class CozeAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = None
|
||||
|
||||
async def _ensure_session(self):
|
||||
"""确保HTTP session存在"""
|
||||
if self.session is None:
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=False if self.api_base.startswith("http://") else True,
|
||||
limit=100,
|
||||
limit_per_host=30,
|
||||
keepalive_timeout=30,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=120, # 默认超时时间
|
||||
connect=30,
|
||||
sock_read=120,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "text/event-stream",
|
||||
}
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=headers, timeout=timeout, connector=connector
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id
|
||||
|
||||
Args:
|
||||
file_data (bytes): 文件的二进制数据
|
||||
Returns:
|
||||
str: 上传成功后返回的 file_id
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v1/files/upload"
|
||||
|
||||
try:
|
||||
file_io = io.BytesIO(file_data)
|
||||
async with session.post(
|
||||
url,
|
||||
data={
|
||||
"file": file_io,
|
||||
},
|
||||
timeout=aiohttp.ClientTimeout(total=60),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
response_text = await response.text()
|
||||
logger.debug(
|
||||
f"文件上传响应状态: {response.status}, 内容: {response_text}"
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"文件上传响应解析失败: {response_text}")
|
||||
|
||||
if result.get("code") != 0:
|
||||
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
|
||||
|
||||
file_id = result["data"]["id"]
|
||||
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
|
||||
return file_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("文件上传超时")
|
||||
raise Exception("文件上传超时")
|
||||
except Exception as e:
|
||||
logger.error(f"文件上传失败: {str(e)}")
|
||||
raise Exception(f"文件上传失败: {str(e)}")
|
||||
|
||||
async def download_image(self, image_url: str) -> bytes:
|
||||
"""下载图片并返回字节数据
|
||||
|
||||
Args:
|
||||
image_url (str): 图片的URL
|
||||
Returns:
|
||||
bytes: 图片的二进制数据
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
|
||||
try:
|
||||
async with session.get(image_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"下载图片失败,状态码: {response.status}")
|
||||
|
||||
image_data = await response.read()
|
||||
return image_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败 {image_url}: {str(e)}")
|
||||
raise Exception(f"下载图片失败: {str(e)}")
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
bot_id: str,
|
||||
user_id: str,
|
||||
additional_messages: List[Dict] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
auto_save_history: bool = True,
|
||||
stream: bool = True,
|
||||
timeout: float = 120,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""发送聊天消息并返回流式响应
|
||||
|
||||
Args:
|
||||
bot_id: Bot ID
|
||||
user_id: 用户ID
|
||||
additional_messages: 额外消息列表
|
||||
conversation_id: 会话ID
|
||||
auto_save_history: 是否自动保存历史
|
||||
stream: 是否流式响应
|
||||
timeout: 超时时间
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/chat"
|
||||
|
||||
payload = {
|
||||
"bot_id": bot_id,
|
||||
"user_id": user_id,
|
||||
"stream": stream,
|
||||
"auto_save_history": auto_save_history,
|
||||
}
|
||||
|
||||
if additional_messages:
|
||||
payload["additional_messages"] = additional_messages
|
||||
|
||||
params = {}
|
||||
if conversation_id:
|
||||
params["conversation_id"] = conversation_id
|
||||
|
||||
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
|
||||
|
||||
try:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
|
||||
|
||||
# SSE
|
||||
buffer = ""
|
||||
event_type = None
|
||||
event_data = None
|
||||
|
||||
async for chunk in response.content:
|
||||
if chunk:
|
||||
buffer += chunk.decode("utf-8", errors="ignore")
|
||||
lines = buffer.split("\n")
|
||||
buffer = lines[-1]
|
||||
|
||||
for line in lines[:-1]:
|
||||
line = line.strip()
|
||||
|
||||
if not line:
|
||||
if event_type and event_data:
|
||||
yield {"event": event_type, "data": event_data}
|
||||
event_type = None
|
||||
event_data = None
|
||||
elif line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if data_str and data_str != "[DONE]":
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
event_data = {"content": data_str}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
|
||||
except Exception as e:
|
||||
raise Exception(f"Coze API 流式请求失败: {str(e)}")
|
||||
|
||||
async def clear_context(self, conversation_id: str):
|
||||
"""清空会话上下文
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
Returns:
|
||||
dict: API响应结果
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/conversation/message/clear_context"
|
||||
payload = {"conversation_id": conversation_id}
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
if response.status == 401:
|
||||
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
|
||||
|
||||
if response.status != 200:
|
||||
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
|
||||
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Coze API 返回非JSON格式")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception("Coze API 请求超时")
|
||||
except aiohttp.ClientError as e:
|
||||
raise Exception(f"Coze API 请求失败: {str(e)}")
|
||||
|
||||
async def get_message_list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
order: str = "desc",
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
):
|
||||
"""获取消息列表
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
order: 排序方式 (asc/desc)
|
||||
limit: 限制数量
|
||||
offset: 偏移量
|
||||
Returns:
|
||||
dict: API响应结果
|
||||
"""
|
||||
session = await self._ensure_session()
|
||||
url = f"{self.api_base}/v3/conversation/message/list"
|
||||
params = {
|
||||
"conversation_id": conversation_id,
|
||||
"order": order,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
try:
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Coze消息列表失败: {str(e)}")
|
||||
raise Exception(f"获取Coze消息列表失败: {str(e)}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭会话"""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
async def test_coze_api_client():
|
||||
api_key = os.getenv("COZE_API_KEY", "")
|
||||
bot_id = os.getenv("COZE_BOT_ID", "")
|
||||
client = CozeAPIClient(api_key=api_key)
|
||||
|
||||
try:
|
||||
with open("README.md", "rb") as f:
|
||||
file_data = f.read()
|
||||
file_id = await client.upload_file(file_data)
|
||||
print(f"Uploaded file_id: {file_id}")
|
||||
async for event in client.chat_messages(
|
||||
bot_id=bot_id,
|
||||
user_id="test_user",
|
||||
additional_messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
[
|
||||
{"type": "text", "text": "这是什么"},
|
||||
{"type": "file", "file_id": file_id},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
"content_type": "object_string",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
print(f"Event: {event}")
|
||||
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(test_coze_api_client())
|
||||
@@ -0,0 +1,635 @@
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import AsyncGenerator, Dict
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from ..register import register_provider_adapter
|
||||
from .coze_api_client import CozeAPIClient
|
||||
|
||||
|
||||
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
|
||||
class ProviderCoze(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona=None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config,
|
||||
provider_settings,
|
||||
default_persona,
|
||||
)
|
||||
self.api_key = provider_config.get("coze_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Coze API Key 不能为空。")
|
||||
self.bot_id = provider_config.get("bot_id", "")
|
||||
if not self.bot_id:
|
||||
raise Exception("Coze Bot ID 不能为空。")
|
||||
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
|
||||
|
||||
if not isinstance(self.api_base, str) or not self.api_base.startswith(
|
||||
("http://", "https://")
|
||||
):
|
||||
raise Exception(
|
||||
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
|
||||
)
|
||||
|
||||
self.timeout = provider_config.get("timeout", 120)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.auto_save_history = provider_config.get("auto_save_history", True)
|
||||
self.conversation_ids: Dict[str, str] = {}
|
||||
self.file_id_cache: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
# 创建 API 客户端
|
||||
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
|
||||
|
||||
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
|
||||
"""生成统一的缓存键
|
||||
|
||||
Args:
|
||||
data: 图片数据或路径
|
||||
is_base64: 是否是 base64 数据
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
|
||||
try:
|
||||
if is_base64 and data.startswith("data:image/"):
|
||||
try:
|
||||
header, encoded = data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
cache_key = hashlib.md5(image_bytes).hexdigest()
|
||||
return cache_key
|
||||
except Exception:
|
||||
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
if data.startswith(("http://", "https://")):
|
||||
# URL图片,使用URL作为缓存键
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
clean_path = (
|
||||
data.split("_")[0]
|
||||
if "_" in data and len(data.split("_")) >= 3
|
||||
else data
|
||||
)
|
||||
|
||||
if os.path.exists(clean_path):
|
||||
with open(clean_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
cache_key = hashlib.md5(file_content).hexdigest()
|
||||
return cache_key
|
||||
else:
|
||||
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
except Exception as e:
|
||||
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
|
||||
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
|
||||
return cache_key
|
||||
|
||||
async def _upload_file(
|
||||
self,
|
||||
file_data: bytes,
|
||||
session_id: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> str:
|
||||
"""上传文件到 Coze 并返回 file_id"""
|
||||
# 使用 API 客户端上传文件
|
||||
file_id = await self.api_client.upload_file(file_data)
|
||||
|
||||
# 缓存 file_id
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
|
||||
|
||||
return file_id
|
||||
|
||||
async def _download_and_upload_image(
|
||||
self, image_url: str, session_id: str | None = None
|
||||
) -> str:
|
||||
"""下载图片并上传到 Coze,返回 file_id"""
|
||||
# 计算哈希实现缓存
|
||||
cache_key = self._generate_cache_key(image_url) if session_id else None
|
||||
|
||||
if session_id and cache_key:
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
return file_id
|
||||
|
||||
try:
|
||||
image_data = await self.api_client.download_image(image_url)
|
||||
|
||||
file_id = await self._upload_file(image_data, session_id, cache_key)
|
||||
|
||||
if session_id and cache_key:
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
|
||||
return file_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {image_url}: {str(e)}")
|
||||
raise Exception(f"处理图片失败: {str(e)}")
|
||||
|
||||
async def _process_context_images(
|
||||
self, content: str | list, session_id: str
|
||||
) -> str:
|
||||
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
processed_content = []
|
||||
if session_id not in self.file_id_cache:
|
||||
self.file_id_cache[session_id] = {}
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
processed_content.append(item)
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
processed_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# 处理图片逻辑
|
||||
if "file_id" in item:
|
||||
# 已经有 file_id
|
||||
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
|
||||
processed_content.append(item)
|
||||
else:
|
||||
# 获取图片数据
|
||||
image_data = ""
|
||||
if "image_url" in item and isinstance(item["image_url"], dict):
|
||||
image_data = item["image_url"].get("url", "")
|
||||
elif "data" in item:
|
||||
image_data = item.get("data", "")
|
||||
elif "url" in item:
|
||||
image_data = item.get("url", "")
|
||||
|
||||
if not image_data:
|
||||
continue
|
||||
# 计算哈希用于缓存
|
||||
cache_key = self._generate_cache_key(
|
||||
image_data, is_base64=image_data.startswith("data:image/")
|
||||
)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self.file_id_cache[session_id]:
|
||||
file_id = self.file_id_cache[session_id][cache_key]
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id}
|
||||
)
|
||||
else:
|
||||
# 上传图片并缓存
|
||||
if image_data.startswith("data:image/"):
|
||||
# base64 处理
|
||||
_, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
elif image_data.startswith(("http://", "https://")):
|
||||
# URL 图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
image_data, session_id
|
||||
)
|
||||
# 为URL图片也添加缓存
|
||||
self.file_id_cache[session_id][cache_key] = file_id
|
||||
elif os.path.exists(image_data):
|
||||
# 本地文件
|
||||
with open(image_data, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
file_id = await self._upload_file(
|
||||
image_bytes,
|
||||
session_id,
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法处理的图片格式: {image_data[:50]}..."
|
||||
)
|
||||
continue
|
||||
|
||||
processed_content.append(
|
||||
{"type": "image", "file_id": file_id}
|
||||
)
|
||||
|
||||
result = json.dumps(processed_content, ensure_ascii=False)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"处理上下文图片失败: {str(e)}")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
else:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""文本对话, 内部使用流式接口实现非流式
|
||||
|
||||
Args:
|
||||
prompt (str): 用户提示词
|
||||
session_id (str): 会话ID
|
||||
image_urls (List[str]): 图片URL列表
|
||||
func_tool (FuncCall): 函数调用工具(不支持)
|
||||
contexts (List): 上下文列表
|
||||
system_prompt (str): 系统提示语
|
||||
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
|
||||
model (str): 模型名称(不支持)
|
||||
Returns:
|
||||
LLMResponse: LLM响应对象
|
||||
"""
|
||||
accumulated_content = ""
|
||||
final_response = None
|
||||
|
||||
async for llm_response in self.text_chat_stream(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
model=model,
|
||||
**kwargs,
|
||||
):
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.completion_text:
|
||||
accumulated_content += llm_response.completion_text
|
||||
else:
|
||||
final_response = llm_response
|
||||
|
||||
if final_response:
|
||||
return final_response
|
||||
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
else:
|
||||
return LLMResponse(role="assistant", completion_text="")
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话接口"""
|
||||
# 用户ID参数(参考文档, 可以自定义)
|
||||
user_id = session_id or kwargs.get("user", "default_user")
|
||||
|
||||
# 获取或创建会话ID
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
# 构建消息
|
||||
additional_messages = []
|
||||
|
||||
if system_prompt:
|
||||
if not self.auto_save_history or not conversation_id:
|
||||
additional_messages.append(
|
||||
{"role": "system", "content": system_prompt, "content_type": "text"}
|
||||
)
|
||||
|
||||
if not self.auto_save_history and contexts:
|
||||
# 如果关闭了自动保存历史,传入上下文
|
||||
for ctx in contexts:
|
||||
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
|
||||
content = ctx["content"]
|
||||
content_type = ctx.get("content_type", "text")
|
||||
|
||||
# 处理可能包含图片的上下文
|
||||
if (
|
||||
content_type == "object_string"
|
||||
or (isinstance(content, str) and content.startswith("["))
|
||||
or (
|
||||
isinstance(content, list)
|
||||
and any(
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "image_url"
|
||||
for item in content
|
||||
)
|
||||
)
|
||||
):
|
||||
processed_content = await self._process_context_images(
|
||||
content, user_id
|
||||
)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": processed_content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": ctx["role"],
|
||||
"content": (
|
||||
content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content, ensure_ascii=False)
|
||||
),
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
|
||||
|
||||
if prompt or image_urls:
|
||||
if image_urls:
|
||||
# 多模态
|
||||
object_string_content = []
|
||||
if prompt:
|
||||
object_string_content.append({"type": "text", "text": prompt})
|
||||
|
||||
for url in image_urls:
|
||||
try:
|
||||
if url.startswith(("http://", "https://")):
|
||||
# 网络图片
|
||||
file_id = await self._download_and_upload_image(
|
||||
url, user_id
|
||||
)
|
||||
else:
|
||||
# 本地文件或 base64
|
||||
if url.startswith("data:image/"):
|
||||
# base64
|
||||
_, encoded = url.split(",", 1)
|
||||
image_data = base64.b64decode(encoded)
|
||||
cache_key = self._generate_cache_key(
|
||||
url, is_base64=True
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data, user_id, cache_key
|
||||
)
|
||||
else:
|
||||
# 本地文件
|
||||
if os.path.exists(url):
|
||||
with open(url, "rb") as f:
|
||||
image_data = f.read()
|
||||
# 用文件路径和修改时间来缓存
|
||||
file_stat = os.stat(url)
|
||||
cache_key = self._generate_cache_key(
|
||||
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
|
||||
is_base64=False,
|
||||
)
|
||||
file_id = await self._upload_file(
|
||||
image_data, user_id, cache_key
|
||||
)
|
||||
else:
|
||||
logger.warning(f"图片文件不存在: {url}")
|
||||
continue
|
||||
|
||||
object_string_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"file_id": file_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败 {url}: {str(e)}")
|
||||
continue
|
||||
|
||||
if object_string_content:
|
||||
content = json.dumps(object_string_content, ensure_ascii=False)
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
"content_type": "object_string",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 纯文本
|
||||
if prompt:
|
||||
additional_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"content_type": "text",
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
accumulated_content = ""
|
||||
message_started = False
|
||||
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=additional_messages,
|
||||
conversation_id=conversation_id,
|
||||
auto_save_history=self.auto_save_history,
|
||||
stream=True,
|
||||
timeout=self.timeout,
|
||||
):
|
||||
event_type = chunk.get("event")
|
||||
data = chunk.get("data", {})
|
||||
|
||||
if event_type == "conversation.chat.created":
|
||||
if isinstance(data, dict) and "conversation_id" in data:
|
||||
self.conversation_ids[user_id] = data["conversation_id"]
|
||||
|
||||
elif event_type == "conversation.message.delta":
|
||||
if isinstance(data, dict):
|
||||
content = data.get("content", "")
|
||||
if not content and "delta" in data:
|
||||
content = data["delta"].get("content", "")
|
||||
if not content and "text" in data:
|
||||
content = data.get("text", "")
|
||||
|
||||
if content:
|
||||
message_started = True
|
||||
accumulated_content += content
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=content,
|
||||
is_chunk=True,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.message.completed":
|
||||
if isinstance(data, dict):
|
||||
msg_type = data.get("type")
|
||||
if msg_type == "answer" and data.get("role") == "assistant":
|
||||
final_content = data.get("content", "")
|
||||
if not accumulated_content and final_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(final_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
elif event_type == "conversation.chat.completed":
|
||||
if accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
elif event_type == "done":
|
||||
break
|
||||
|
||||
elif event_type == "error":
|
||||
error_msg = (
|
||||
data.get("message", "未知错误")
|
||||
if isinstance(data, dict)
|
||||
else str(data)
|
||||
)
|
||||
logger.error(f"Coze 流式响应错误: {error_msg}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 错误: {error_msg}",
|
||||
is_chunk=False,
|
||||
)
|
||||
break
|
||||
|
||||
if not message_started and not accumulated_content:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="LLM 未响应任何内容。",
|
||||
is_chunk=False,
|
||||
)
|
||||
elif message_started and accumulated_content:
|
||||
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
result_chain=chain,
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Coze 流式请求失败: {str(e)}")
|
||||
yield LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"Coze 流式请求失败: {str(e)}",
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
async def forget(self, session_id: str):
|
||||
"""清空指定会话的上下文"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if user_id in self.file_id_cache:
|
||||
self.file_id_cache.pop(user_id, None)
|
||||
|
||||
if not conversation_id:
|
||||
return True
|
||||
|
||||
try:
|
||||
response = await self.api_client.clear_context(conversation_id)
|
||||
|
||||
if "code" in response and response["code"] == 0:
|
||||
self.conversation_ids.pop(user_id, None)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"清空 Coze 会话上下文失败: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空 Coze 会话失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_current_key(self):
|
||||
"""获取当前API Key"""
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key: str):
|
||||
"""设置新的API Key"""
|
||||
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
"""获取可用模型列表"""
|
||||
return [f"bot_{self.bot_id}"]
|
||||
|
||||
def get_model(self):
|
||||
"""获取当前模型"""
|
||||
return f"bot_{self.bot_id}"
|
||||
|
||||
def set_model(self, model: str):
|
||||
"""设置模型(在Coze中是Bot ID)"""
|
||||
if model.startswith("bot_"):
|
||||
self.bot_id = model[4:]
|
||||
else:
|
||||
self.bot_id = model
|
||||
|
||||
async def get_human_readable_context(
|
||||
self, session_id: str, page: int = 1, page_size: int = 10
|
||||
):
|
||||
"""获取人类可读的上下文历史"""
|
||||
user_id = session_id
|
||||
conversation_id = self.conversation_ids.get(user_id)
|
||||
|
||||
if not conversation_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await self.api_client.get_message_list(
|
||||
conversation_id=conversation_id,
|
||||
order="desc",
|
||||
limit=page_size,
|
||||
offset=(page - 1) * page_size,
|
||||
)
|
||||
|
||||
if data.get("code") != 0:
|
||||
logger.warning(f"获取 Coze 消息历史失败: {data}")
|
||||
return []
|
||||
|
||||
messages = data.get("data", {}).get("messages", [])
|
||||
|
||||
readable_history = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
msg_type = msg.get("type", "")
|
||||
|
||||
if role == "user":
|
||||
readable_history.append(f"用户: {content}")
|
||||
elif role == "assistant" and msg_type == "answer":
|
||||
readable_history.append(f"助手: {content}")
|
||||
|
||||
return readable_history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def terminate(self):
|
||||
"""清理资源"""
|
||||
await self.api_client.close()
|
||||
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
|
||||
self.init_handler_md(handler_md)
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
|
||||
# Cache for complete command names list
|
||||
self._cmpl_cmd_names: list | None = None
|
||||
|
||||
def print_types(self):
|
||||
result = ""
|
||||
for k, v in self.handler_params.items():
|
||||
@@ -136,6 +139,28 @@ class CommandFilter(HandlerFilter):
|
||||
)
|
||||
return result
|
||||
|
||||
def get_complete_command_names(self):
|
||||
if self._cmpl_cmd_names is not None:
|
||||
return self._cmpl_cmd_names
|
||||
self._cmpl_cmd_names = [
|
||||
f"{parent} {cmd}" if parent else cmd
|
||||
for cmd in [self.command_name] + list(self.alias)
|
||||
for parent in self.parent_command_names or [""]
|
||||
]
|
||||
return self._cmpl_cmd_names
|
||||
|
||||
def startswith(self, message_str: str) -> bool:
|
||||
for full_cmd in self.get_complete_command_names():
|
||||
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
|
||||
return True
|
||||
return False
|
||||
|
||||
def equals(self, message_str: str) -> bool:
|
||||
for full_cmd in self.get_complete_command_names():
|
||||
if message_str == full_cmd:
|
||||
return True
|
||||
return False
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
@@ -145,19 +170,7 @@ class CommandFilter(HandlerFilter):
|
||||
|
||||
# 检查是否以指令开头
|
||||
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
|
||||
candidates = [self.command_name] + list(self.alias)
|
||||
ok = False
|
||||
for candidate in candidates:
|
||||
for parent_command_name in self.parent_command_names:
|
||||
if parent_command_name:
|
||||
_full = f"{parent_command_name} {candidate}"
|
||||
else:
|
||||
_full = candidate
|
||||
if message_str.startswith(f"{_full} ") or message_str == _full:
|
||||
message_str = message_str[len(_full) :].strip()
|
||||
ok = True
|
||||
break
|
||||
if not ok:
|
||||
if not self.startswith(message_str):
|
||||
return False
|
||||
|
||||
# 分割为列表
|
||||
|
||||
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
|
||||
self.custom_filter_list: List[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
# Cache for complete command names list
|
||||
self._cmpl_cmd_names: list | None = None
|
||||
|
||||
def add_sub_command_filter(
|
||||
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
|
||||
):
|
||||
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
|
||||
"""遍历父节点获取完整的指令名。
|
||||
|
||||
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
|
||||
if self._cmpl_cmd_names is not None:
|
||||
return self._cmpl_cmd_names
|
||||
|
||||
parent_cmd_names = (
|
||||
self.parent_group.get_complete_command_names() if self.parent_group else []
|
||||
)
|
||||
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
for parent_cmd_name in parent_cmd_names:
|
||||
for candidate in candidates:
|
||||
result.append(parent_cmd_name + " " + candidate)
|
||||
self._cmpl_cmd_names = result
|
||||
return result
|
||||
|
||||
# 以树的形式打印出来
|
||||
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def startswith(self, message_str: str) -> bool:
|
||||
return message_str.startswith(tuple(self.get_complete_command_names()))
|
||||
|
||||
def equals(self, message_str: str) -> bool:
|
||||
return message_str in self.get_complete_command_names()
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_at_or_wake_command:
|
||||
return False
|
||||
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
if not self.custom_filter_ok(event, cfg):
|
||||
return False
|
||||
|
||||
complete_command_names = self.get_complete_command_names()
|
||||
if event.message_str.strip() in complete_command_names:
|
||||
if self.equals(event.message_str.strip()):
|
||||
tree = (
|
||||
self.group_name
|
||||
+ "\n"
|
||||
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
|
||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||
)
|
||||
|
||||
# complete_command_names = [name + " " for name in complete_command_names]
|
||||
# return event.message_str.startswith(tuple(complete_command_names))
|
||||
return False
|
||||
return self.startswith(event.message_str)
|
||||
|
||||
@@ -52,10 +52,6 @@ class SessionServiceManager:
|
||||
"session_service_config", session_config, scope="umo", scope_id=session_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
@@ -1,9 +1,33 @@
|
||||
import codecs
|
||||
import json
|
||||
from astrbot.core import logger
|
||||
from aiohttp import ClientSession
|
||||
from aiohttp import ClientSession, ClientResponse
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
|
||||
async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
|
||||
decoder = codecs.getincrementaldecoder("utf-8")()
|
||||
buffer = ""
|
||||
async for chunk in resp.content.iter_chunked(8192):
|
||||
buffer += decoder.decode(chunk)
|
||||
while "\n\n" in buffer:
|
||||
block, buffer = buffer.split("\n\n", 1)
|
||||
if block.strip().startswith("data:"):
|
||||
try:
|
||||
yield json.loads(block[5:])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Drop invalid dify json data: {block[5:]}")
|
||||
continue
|
||||
# flush any remaining text
|
||||
buffer += decoder.decode(b"", final=True)
|
||||
if buffer.strip().startswith("data:"):
|
||||
try:
|
||||
yield json.loads(buffer[5:])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
|
||||
pass
|
||||
|
||||
|
||||
class DifyAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||
self.api_key = api_key
|
||||
@@ -33,31 +57,11 @@ class DifyAPIClient:
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk.decode("utf-8")
|
||||
blocks = buffer.split("\n\n")
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith("data:"):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
raise Exception(
|
||||
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
|
||||
)
|
||||
async for event in _stream_sse(resp):
|
||||
yield event
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
@@ -77,31 +81,11 @@ class DifyAPIClient:
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
# 保持原有的8192字节限制,防止数据过大导致高水位报错
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
buffer += chunk.decode("utf-8")
|
||||
blocks = buffer.split("\n\n")
|
||||
|
||||
# 处理完整的数据块
|
||||
for block in blocks[:-1]:
|
||||
if block.strip() and block.startswith("data:"):
|
||||
try:
|
||||
json_str = block[5:] # 移除 "data:" 前缀
|
||||
json_obj = json.loads(json_str)
|
||||
yield json_obj
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}")
|
||||
logger.error(f"原始数据块: {json_str}")
|
||||
|
||||
# 保留最后一个可能不完整的块
|
||||
buffer = blocks[-1] if blocks else ""
|
||||
raise Exception(
|
||||
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
|
||||
)
|
||||
async for event in _stream_sse(resp):
|
||||
yield event
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
@@ -109,12 +93,15 @@ class DifyAPIClient:
|
||||
user: str,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.api_base}/files/upload"
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": open(file_path, "rb"),
|
||||
}
|
||||
async with self.session.post(url, data=payload, headers=self.headers) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
with open(file_path, "rb") as f:
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": f,
|
||||
}
|
||||
async with self.session.post(
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from quart import request, Response as QuartResponse, g, make_response
|
||||
from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def track_conversation(convs: dict, conv_id: str):
|
||||
convs[conv_id] = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
convs.pop(conv_id, None)
|
||||
|
||||
|
||||
class ChatRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -40,6 +50,8 @@ class ChatRoute(Route):
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
|
||||
self.running_convs: dict[str, bool] = {}
|
||||
|
||||
async def get_file(self):
|
||||
filename = request.args.get("filename")
|
||||
if not filename:
|
||||
@@ -139,42 +151,63 @@ class ChatRoute(Route):
|
||||
)
|
||||
|
||||
async def stream():
|
||||
client_disconnected = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
except Exception as e:
|
||||
logger.error(f"WebChat stream error: {e}")
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.05)
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
|
||||
if type == "end":
|
||||
break
|
||||
elif (
|
||||
(streaming and type == "complete")
|
||||
or not streaming
|
||||
or type == "break"
|
||||
):
|
||||
# append bot message
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
try:
|
||||
if not client_disconnected:
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
if not client_disconnected:
|
||||
logger.debug(
|
||||
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
|
||||
)
|
||||
client_disconnected = True
|
||||
|
||||
except BaseException as _:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||
return
|
||||
try:
|
||||
if not client_disconnected:
|
||||
await asyncio.sleep(0.05)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
|
||||
if type == "end":
|
||||
break
|
||||
elif (
|
||||
(streaming and type == "complete")
|
||||
or not streaming
|
||||
or type == "break"
|
||||
):
|
||||
# append bot message
|
||||
new_his = {"type": "bot", "message": result_text}
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
|
||||
# Put message to conversation-specific queue
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||
@@ -291,6 +324,7 @@ class ChatRoute(Route):
|
||||
.ok(
|
||||
data={
|
||||
"history": history_res,
|
||||
"is_running": self.running_convs.get(webchat_conv_id, False),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
|
||||
@@ -30,6 +30,7 @@ class SessionManagementRoute(Route):
|
||||
"/session/update_tts": ("POST", self.update_session_tts),
|
||||
"/session/update_name": ("POST", self.update_session_name),
|
||||
"/session/update_status": ("POST", self.update_session_status),
|
||||
"/session/delete": ("POST", self.delete_session),
|
||||
}
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
self.core_lifecycle = core_lifecycle
|
||||
@@ -180,60 +181,132 @@ class SessionManagementRoute(Route):
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_persona(self, session_id: str, persona_name: str):
|
||||
"""更新单个会话的 persona 的内部方法"""
|
||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
|
||||
conv = None
|
||||
if conversation_id:
|
||||
conv = await conversation_manager.get_conversation(
|
||||
unified_msg_origin=session_id,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if not conv or not conversation_id:
|
||||
conversation_id = await conversation_manager.new_conversation(session_id)
|
||||
|
||||
# 更新 persona
|
||||
await conversation_manager.update_conversation_persona_id(
|
||||
session_id, persona_name
|
||||
)
|
||||
|
||||
async def _handle_batch_operation(
|
||||
self, session_ids: list, operation_func, operation_name: str, **kwargs
|
||||
):
|
||||
"""通用的批量操作处理方法"""
|
||||
success_count = 0
|
||||
error_sessions = []
|
||||
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
await operation_func(session_id, **kwargs)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"批量{operation_name} 会话 {session_id} 失败: {str(e)}")
|
||||
error_sessions.append(session_id)
|
||||
|
||||
if error_sessions:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"批量更新完成,成功: {success_count},失败: {len(error_sessions)}",
|
||||
"success_count": success_count,
|
||||
"error_count": len(error_sessions),
|
||||
"error_sessions": error_sessions,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功批量{operation_name} {success_count} 个会话",
|
||||
"success_count": success_count,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def update_session_persona(self):
|
||||
"""更新指定会话的 persona"""
|
||||
"""更新指定会话的 persona,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
persona_name = data.get("persona_name")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if persona_name is None:
|
||||
return Response().error("缺少必要参数: persona_name").__dict__
|
||||
|
||||
# 获取会话当前的对话 ID
|
||||
conversation_manager = self.core_lifecycle.star_context.conversation_manager
|
||||
conversation_id = await conversation_manager.get_curr_conversation_id(
|
||||
session_id
|
||||
)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
if not conversation_id:
|
||||
# 如果没有对话,创建一个新的对话
|
||||
conversation_id = await conversation_manager.new_conversation(
|
||||
session_id
|
||||
return await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_persona,
|
||||
"更新人格",
|
||||
persona_name=persona_name,
|
||||
)
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 更新 persona
|
||||
await conversation_manager.update_conversation_persona_id(
|
||||
session_id, persona_name
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
|
||||
.__dict__
|
||||
)
|
||||
await self._update_single_session_persona(session_id, persona_name)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_provider(
|
||||
self, session_id: str, provider_id: str, provider_type_enum
|
||||
):
|
||||
"""更新单个会话的 provider 的内部方法"""
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_enum,
|
||||
umo=session_id,
|
||||
)
|
||||
|
||||
async def update_session_provider(self):
|
||||
"""更新指定会话的 provider"""
|
||||
"""更新指定会话的 provider,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
provider_id = data.get("provider_id")
|
||||
# "chat_completion", "speech_to_text", "text_to_speech"
|
||||
provider_type = data.get("provider_type")
|
||||
|
||||
if not session_id or not provider_id or not provider_type:
|
||||
if not provider_id or not provider_type:
|
||||
return (
|
||||
Response()
|
||||
.error("缺少必要参数: session_id, provider_id, provider_type")
|
||||
.error("缺少必要参数: provider_id, provider_type")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
@@ -251,23 +324,35 @@ class SessionManagementRoute(Route):
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 设置 provider
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
await provider_manager.set_provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type_enum,
|
||||
umo=session_id,
|
||||
)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}"
|
||||
}
|
||||
return await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_provider,
|
||||
f"更新 {provider_type} 提供商",
|
||||
provider_id=provider_id,
|
||||
provider_type_enum=provider_type_enum,
|
||||
)
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
await self._update_single_session_provider(
|
||||
session_id, provider_id, provider_type_enum
|
||||
)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}"
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}"
|
||||
@@ -376,66 +461,98 @@ class SessionManagementRoute(Route):
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_llm(self, session_id: str, enabled: bool):
|
||||
"""更新单个会话的LLM状态的内部方法"""
|
||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||
|
||||
async def update_session_llm(self):
|
||||
"""更新指定会话的LLM启停状态"""
|
||||
"""更新指定会话的LLM启停状态,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新LLM状态
|
||||
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}
|
||||
result = await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_llm,
|
||||
f"{'启用' if enabled else '禁用'}LLM",
|
||||
enabled=enabled,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
await self._update_single_session_llm(session_id, enabled)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
||||
|
||||
async def _update_single_session_tts(self, session_id: str, enabled: bool):
|
||||
"""更新单个会话的TTS状态的内部方法"""
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||
|
||||
async def update_session_tts(self):
|
||||
"""更新指定会话的TTS启停状态"""
|
||||
"""更新指定会话的TTS启停状态,支持批量操作"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
is_batch = data.get("is_batch", False)
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionServiceManager 更新TTS状态
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
|
||||
if is_batch:
|
||||
session_ids = data.get("session_ids", [])
|
||||
if not session_ids:
|
||||
return Response().error("缺少必要参数: session_ids").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"TTS已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"tts_enabled": enabled,
|
||||
}
|
||||
result = await self._handle_batch_operation(
|
||||
session_ids,
|
||||
self._update_single_session_tts,
|
||||
f"{'启用' if enabled else '禁用'}TTS",
|
||||
enabled=enabled,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
await self._update_single_session_tts(session_id, enabled)
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"TTS已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"tts_enabled": enabled,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
@@ -507,3 +624,43 @@ class SessionManagementRoute(Route):
|
||||
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
|
||||
|
||||
async def delete_session(self):
|
||||
"""删除指定会话及其所有相关数据"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 删除会话的所有相关数据
|
||||
conversation_manager = self.core_lifecycle.conversation_manager
|
||||
|
||||
# 1. 删除会话的所有对话
|
||||
try:
|
||||
await conversation_manager.delete_conversations_by_user_id(session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除会话 {session_id} 的对话失败: {str(e)}")
|
||||
|
||||
# 2. 清除会话的偏好设置数据(清空该会话的所有配置)
|
||||
try:
|
||||
await sp.clear_async("umo", session_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"清除会话 {session_id} 的偏好设置失败: {str(e)}")
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": f"会话 {session_id} 及其相关所有对话数据已成功删除",
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"删除会话失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"删除会话失败: {str(e)}").__dict__
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# What's Changed
|
||||
+22
-1
@@ -1,7 +1,28 @@
|
||||
<template>
|
||||
<RouterView></RouterView>
|
||||
|
||||
<!-- 全局唯一 snackbar -->
|
||||
<v-snackbar v-if="toastStore.current" v-model="snackbarShow" :color="toastStore.current.color"
|
||||
:timeout="toastStore.current.timeout" :multi-line="toastStore.current.multiLine"
|
||||
:location="toastStore.current.location" close-on-back>
|
||||
{{ toastStore.current.message }}
|
||||
<template #actions v-if="toastStore.current.closable">
|
||||
<v-btn variant="text" @click="snackbarShow = false">关闭</v-btn>
|
||||
</template>
|
||||
</v-snackbar>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
<script setup>
|
||||
import { RouterView } from 'vue-router';
|
||||
import { computed } from 'vue'
|
||||
import { useToastStore } from '@/stores/toast'
|
||||
|
||||
const toastStore = useToastStore()
|
||||
|
||||
const snackbarShow = computed({
|
||||
get: () => !!toastStore.current,
|
||||
set: (val) => {
|
||||
if (!val) toastStore.shift()
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;"
|
||||
v-if="chatboxMode">
|
||||
<img width="50" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
|
||||
<span v-if="!sidebarCollapsed" style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
|
||||
<span v-if="!sidebarCollapsed"
|
||||
style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -46,7 +47,7 @@
|
||||
|| tm('conversation.newConversation') }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="!sidebarCollapsed" class="timestamp">{{
|
||||
formatDate(item.updated_at)
|
||||
}}</v-list-item-subtitle>
|
||||
}}</v-list-item-subtitle>
|
||||
|
||||
<template v-if="!sidebarCollapsed" v-slot:append>
|
||||
<div class="conversation-actions">
|
||||
@@ -118,8 +119,9 @@
|
||||
</div>
|
||||
<v-divider v-if="currCid && getCurrentConversation" class="conversation-divider"></v-divider>
|
||||
|
||||
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark" :isStreaming="isStreaming"
|
||||
@openImagePreview="openImagePreview" ref="messageList" />
|
||||
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark"
|
||||
:isStreaming="isStreaming || isConvRunning" @openImagePreview="openImagePreview"
|
||||
ref="messageList" />
|
||||
<div class="welcome-container fade-in" v-else>
|
||||
<div class="welcome-title">
|
||||
<span>Hello, I'm</span>
|
||||
@@ -145,9 +147,10 @@
|
||||
<!-- 输入区域 -->
|
||||
<div class="input-area fade-in">
|
||||
<div
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px; padding: 4px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown" :disabled="isStreaming"
|
||||
@click:clear="clearMessage" placeholder="Ask AstrBot..."
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px;">
|
||||
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown"
|
||||
:disabled="isStreaming || isConvRunning" @click:clear="clearMessage"
|
||||
placeholder="Ask AstrBot..."
|
||||
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
|
||||
<div
|
||||
style="display: flex; justify-content: space-between; align-items: center; padding: 0px 8px;">
|
||||
@@ -155,18 +158,21 @@
|
||||
<!-- 选择提供商和模型 -->
|
||||
<ProviderModelSelector ref="providerModelSelector" />
|
||||
</div>
|
||||
<div style="display: flex; justify-content: flex-end; margin-top: 8px;">
|
||||
<div
|
||||
style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
|
||||
<input type="file" ref="imageInput" @change="handleFileSelect" accept="image/*"
|
||||
style="display: none" multiple />
|
||||
<v-progress-circular v-if="isStreaming || isConvRunning" indeterminate size="16"
|
||||
class="mr-1" width="1.5" />
|
||||
<v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple"
|
||||
class="add-btn" size="small" />
|
||||
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
|
||||
class="send-btn" size="small" />
|
||||
<v-btn @click="isRecording ? stopRecording() : startRecording()"
|
||||
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
|
||||
:color="isRecording ? 'error' : 'deep-purple'" class="record-btn"
|
||||
size="small" />
|
||||
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
|
||||
class="send-btn" size="small" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -235,6 +241,7 @@ import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
|
||||
import ProviderModelSelector from '@/components/chat/ProviderModelSelector.vue';
|
||||
import MessageList from '@/components/chat/MessageList.vue';
|
||||
import 'highlight.js/styles/github.css';
|
||||
import { useToast } from '@/utils/toast';
|
||||
|
||||
export default {
|
||||
name: 'ChatPage',
|
||||
@@ -301,7 +308,10 @@ export default {
|
||||
imagePreviewDialog: false,
|
||||
previewImageUrl: '',
|
||||
|
||||
isStreaming: false
|
||||
isStreaming: false,
|
||||
isConvRunning: false, // Track if the current conversation is running
|
||||
|
||||
isToastedRunningInfo: false, // To avoid multiple toasts
|
||||
}
|
||||
},
|
||||
|
||||
@@ -379,7 +389,7 @@ export default {
|
||||
} else {
|
||||
this.sidebarCollapsed = true; // 默认折叠状态
|
||||
}
|
||||
|
||||
|
||||
// 设置输入框标签
|
||||
this.inputFieldLabel = this.tm('input.chatPrompt');
|
||||
this.getConversations();
|
||||
@@ -662,6 +672,25 @@ export default {
|
||||
// Update the selected conversation in the sidebar
|
||||
this.selectedConversations = [cid[0]];
|
||||
let history = response.data.data.history;
|
||||
this.isConvRunning = response.data.data.is_running || false;
|
||||
|
||||
if (this.isConvRunning) {
|
||||
if (!this.isToastedRunningInfo) {
|
||||
useToast().info("该对话正在运行中。", { timeout: 5000 });
|
||||
this.isToastedRunningInfo = true;
|
||||
}
|
||||
|
||||
// 如果对话还在运行,3秒后重新获取消息
|
||||
setTimeout(() => {
|
||||
this.getConversationMessages([this.currCid]);
|
||||
}, 3000);
|
||||
}
|
||||
|
||||
// 滚动到底部
|
||||
this.$nextTick(() => {
|
||||
this.$refs.messageList.scrollToBottom();
|
||||
});
|
||||
|
||||
for (let i = 0; i < history.length; i++) {
|
||||
let content = history[i].content;
|
||||
if (content.message.startsWith('[IMAGE]')) {
|
||||
|
||||
@@ -29,12 +29,11 @@
|
||||
|
||||
<!-- Bot Messages -->
|
||||
<div v-else class="bot-message">
|
||||
<div v-if="isStreaming && index === messages.length - 1" style="width: 36px; height: 36px;">
|
||||
<v-progress-circular indeterminate size="28" width="2"
|
||||
style="margin-top: 12px;"></v-progress-circular>
|
||||
</div>
|
||||
<v-avatar v-else class="bot-avatar" size="36">
|
||||
<span class="text-h2">✨</span>
|
||||
|
||||
<v-avatar class="bot-avatar" size="36">
|
||||
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
|
||||
width="2"></v-progress-circular>
|
||||
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2">✨</span>
|
||||
</v-avatar>
|
||||
<div class="bot-message-content">
|
||||
<div class="message-bubble bot-bubble">
|
||||
|
||||
@@ -31,6 +31,10 @@
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="closeDialog">{{ tm('dialog.cancel') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
@@ -63,9 +63,7 @@
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn text @click="closeDialog">
|
||||
Close
|
||||
</v-btn>
|
||||
<v-btn text @click="closeDialog">{{ tm('dialogs.config.cancel') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
"apply": "Apply Batch Settings",
|
||||
"editName": "Edit Session Name",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel"
|
||||
"cancel": "Cancel",
|
||||
"delete": "Delete"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "Active Sessions",
|
||||
@@ -29,7 +30,8 @@
|
||||
"ttsProvider": "TTS Provider",
|
||||
"llmStatus": "LLM Status",
|
||||
"ttsStatus": "TTS Status",
|
||||
"pluginManagement": "Plugin Management"
|
||||
"pluginManagement": "Plugin Management",
|
||||
"actions": "Actions"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
@@ -65,6 +67,10 @@
|
||||
"fullSessionId": "Full Session ID",
|
||||
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
|
||||
},
|
||||
"deleteConfirm": {
|
||||
"message": "Are you sure you want to delete session {sessionName}?",
|
||||
"warning": "This action will permanently delete all chat history and preference settings for this session (except for data linked via plugins), and this cannot be undone. Continue?"
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "Session list refreshed",
|
||||
"personaUpdateSuccess": "Persona updated successfully",
|
||||
@@ -82,6 +88,8 @@
|
||||
"pluginStatusSuccess": "Plugin {name} {status}",
|
||||
"pluginStatusError": "Failed to update plugin status",
|
||||
"nameUpdateSuccess": "Session name updated successfully",
|
||||
"nameUpdateError": "Failed to update session name"
|
||||
"nameUpdateError": "Failed to update session name",
|
||||
"deleteSuccess": "Session deleted successfully",
|
||||
"deleteError": "Failed to delete session"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
"apply": "应用批量设置",
|
||||
"editName": "备注",
|
||||
"save": "保存",
|
||||
"cancel": "取消"
|
||||
"cancel": "取消",
|
||||
"delete": "删除"
|
||||
},
|
||||
"sessions": {
|
||||
"activeSessions": "活跃会话",
|
||||
@@ -29,7 +30,8 @@
|
||||
"ttsProvider": "语音合成模型",
|
||||
"llmStatus": "启用 LLM",
|
||||
"ttsStatus": "启用 TTS",
|
||||
"pluginManagement": "插件管理"
|
||||
"pluginManagement": "插件管理",
|
||||
"actions": "操作"
|
||||
}
|
||||
},
|
||||
"status": {
|
||||
@@ -65,6 +67,10 @@
|
||||
"fullSessionId": "完整会话ID",
|
||||
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时,会显示一个小感叹号标识(!),鼠标悬停时会显示实际的UMO。"
|
||||
},
|
||||
"deleteConfirm": {
|
||||
"message": "确定要删除会话 {sessionName} 吗?",
|
||||
"warning": "此操作将永久删除本次会话的「全部对话记录」与「偏好设置」(插件对会话的关联数据除外),且无法恢复。确认继续?"
|
||||
},
|
||||
"messages": {
|
||||
"refreshSuccess": "会话列表已刷新",
|
||||
"personaUpdateSuccess": "人格更新成功",
|
||||
@@ -82,6 +88,8 @@
|
||||
"pluginStatusSuccess": "插件 {name} {status}",
|
||||
"pluginStatusError": "插件状态更新失败",
|
||||
"nameUpdateSuccess": "会话名称更新成功",
|
||||
"nameUpdateError": "会话名称更新失败"
|
||||
"nameUpdateError": "会话名称更新失败",
|
||||
"deleteSuccess": "会话删除成功",
|
||||
"deleteError": "会话删除失败"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
|
||||
export const useToastStore = defineStore('toast', () => {
|
||||
const queue = ref([])
|
||||
const current = computed(() => queue.value[0])
|
||||
|
||||
function add({
|
||||
message,
|
||||
color = 'info', // Vuetify 颜色
|
||||
timeout = 3000,
|
||||
closable = true,
|
||||
multiLine = false,
|
||||
location = 'top center'
|
||||
}) {
|
||||
queue.value.push({
|
||||
message,
|
||||
color,
|
||||
timeout,
|
||||
closable,
|
||||
multiLine,
|
||||
location
|
||||
})
|
||||
}
|
||||
|
||||
function shift() {
|
||||
queue.value.shift()
|
||||
}
|
||||
|
||||
return { current, add, shift }
|
||||
})
|
||||
@@ -22,6 +22,7 @@ export function getProviderIcon(type) {
|
||||
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
|
||||
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
|
||||
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
|
||||
"coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg",
|
||||
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
|
||||
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
|
||||
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
import { useToastStore } from '@/stores/toast'
|
||||
|
||||
export function useToast() {
|
||||
const store = useToastStore()
|
||||
|
||||
const toast = (message, color = 'info', opts = {}) =>
|
||||
store.add({ message, color, ...opts })
|
||||
|
||||
return {
|
||||
toast,
|
||||
success: (msg, opts) => toast(msg, 'success', opts),
|
||||
error: (msg, opts) => toast(msg, 'error', opts),
|
||||
info: (msg, opts) => toast(msg, 'primary', opts),
|
||||
warning: (msg, opts) => toast(msg, 'warning', opts)
|
||||
}
|
||||
}
|
||||
@@ -601,19 +601,15 @@ export default {
|
||||
}
|
||||
|
||||
if (this.search) {
|
||||
params.search = this.search;
|
||||
params.search = this.search.trim();
|
||||
}
|
||||
|
||||
// 添加排除条件
|
||||
params.exclude_ids = 'astrbot';
|
||||
params.exclude_platforms = 'webchat';
|
||||
|
||||
console.log(`正在请求对话列表: /api/conversation/list 参数:`, params);
|
||||
|
||||
const response = await axios.get('/api/conversation/list', { params });
|
||||
|
||||
console.log('收到对话列表响应:', response.data);
|
||||
|
||||
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
|
||||
|
||||
if (response.data.status === "ok") {
|
||||
|
||||
@@ -303,6 +303,7 @@ export default {
|
||||
"googlegenai_chat_completion": "chat_completion",
|
||||
"zhipu_chat_completion": "chat_completion",
|
||||
"dify": "chat_completion",
|
||||
"coze": "chat_completion",
|
||||
"dashscope": "chat_completion",
|
||||
"openai_whisper_api": "speech_to_text",
|
||||
"openai_whisper_selfhost": "speech_to_text",
|
||||
|
||||
@@ -141,6 +141,17 @@
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<!-- 操作按钮 -->
|
||||
<template v-slot:item.actions="{ item }">
|
||||
<v-btn size="x-small" variant="tonal" color="error" @click="deleteSession(item)"
|
||||
:loading="item.deleting" icon>
|
||||
<v-icon>mdi-delete</v-icon>
|
||||
<v-tooltip activator="parent" location="top">
|
||||
{{ tm('buttons.delete') }}
|
||||
</v-tooltip>
|
||||
</v-btn>
|
||||
</template>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<template v-slot:no-data>
|
||||
<div class="text-center py-8">
|
||||
@@ -409,6 +420,7 @@ export default {
|
||||
{ title: this.tm('table.headers.llmStatus'), key: 'llm_enabled', sortable: false, minWidth: '120px' },
|
||||
{ title: this.tm('table.headers.ttsStatus'), key: 'tts_enabled', sortable: false, minWidth: '120px' },
|
||||
{ title: this.tm('table.headers.pluginManagement'), key: 'plugins', sortable: false, minWidth: '120px' },
|
||||
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, minWidth: '100px' },
|
||||
]
|
||||
},
|
||||
|
||||
@@ -418,12 +430,13 @@ export default {
|
||||
|
||||
// 搜索筛选
|
||||
if (this.searchQuery) {
|
||||
const query = this.searchQuery.toLowerCase();
|
||||
const query = this.searchQuery.toLowerCase().trim();
|
||||
filtered = filtered.filter(session =>
|
||||
session.session_name.toLowerCase().includes(query) ||
|
||||
session.platform.toLowerCase().includes(query) ||
|
||||
session.persona_name?.toLowerCase().includes(query) ||
|
||||
session.chat_provider_name?.toLowerCase().includes(query)
|
||||
session.chat_provider_name?.toLowerCase().includes(query) ||
|
||||
session.session_id.toLowerCase().includes(query)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -487,7 +500,8 @@ export default {
|
||||
this.sessions = data.sessions.map(session => ({
|
||||
...session,
|
||||
updating: false, // 添加更新状态标志
|
||||
loadingPlugins: false // 添加插件加载状态标志
|
||||
loadingPlugins: false, // 添加插件加载状态标志
|
||||
deleting: false // 添加删除状态标志
|
||||
}));
|
||||
this.availablePersonas = data.available_personas;
|
||||
this.availableChatProviders = data.available_chat_providers;
|
||||
@@ -508,60 +522,131 @@ export default {
|
||||
},
|
||||
|
||||
async updatePersona(session, personaName) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_persona', {
|
||||
session_id: session.session_id,
|
||||
persona_name: personaName
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.persona_id = personaName;
|
||||
session.persona_name = personaName === '[%None]' ? this.tm('persona.none') :
|
||||
return this._updateSession('persona', session, { persona_name: personaName }, (s, success) => {
|
||||
if (success) {
|
||||
s.persona_id = personaName;
|
||||
s.persona_name = personaName === '[%None]' ? this.tm('persona.none') :
|
||||
this.availablePersonas.find(p => p.name === personaName)?.name || personaName;
|
||||
this.showSuccess(this.tm('messages.personaUpdateSuccess'));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.personaUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.personaUpdateError'));
|
||||
}
|
||||
session.updating = false;
|
||||
});
|
||||
},
|
||||
|
||||
async updateProvider(session, providerId, providerType) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_provider', {
|
||||
session_id: session.session_id,
|
||||
provider_id: providerId,
|
||||
provider_type: providerType
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
// 更新本地数据
|
||||
return this._updateSession('provider', session, {
|
||||
provider_id: providerId,
|
||||
provider_type: providerType
|
||||
}, (s, success) => {
|
||||
if (success) {
|
||||
if (providerType === 'chat_completion') {
|
||||
session.chat_provider_id = providerId;
|
||||
s.chat_provider_id = providerId;
|
||||
const provider = this.availableChatProviders.find(p => p.id === providerId);
|
||||
session.chat_provider_name = provider?.name || providerId;
|
||||
s.chat_provider_name = provider?.name || providerId;
|
||||
} else if (providerType === 'speech_to_text') {
|
||||
session.stt_provider_id = providerId;
|
||||
s.stt_provider_id = providerId;
|
||||
const provider = this.availableSttProviders.find(p => p.id === providerId);
|
||||
session.stt_provider_name = provider?.name || providerId;
|
||||
s.stt_provider_name = provider?.name || providerId;
|
||||
} else if (providerType === 'text_to_speech') {
|
||||
session.tts_provider_id = providerId;
|
||||
s.tts_provider_id = providerId;
|
||||
const provider = this.availableTtsProviders.find(p => p.id === providerId);
|
||||
session.tts_provider_name = provider?.name || providerId;
|
||||
s.tts_provider_name = provider?.name || providerId;
|
||||
}
|
||||
this.showSuccess(this.tm('messages.providerUpdateSuccess'));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.providerUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.providerUpdateError'));
|
||||
} session.updating = false;
|
||||
});
|
||||
},
|
||||
|
||||
async updateLLM(session, enabled) {
|
||||
return this._updateSession('llm', session, { enabled }, (s, success) => {
|
||||
if (success) s.llm_enabled = enabled;
|
||||
});
|
||||
},
|
||||
|
||||
async updateTTS(session, enabled) {
|
||||
return this._updateSession('tts', session, { enabled }, (s, success) => {
|
||||
if (success) s.tts_enabled = enabled;
|
||||
});
|
||||
},
|
||||
|
||||
// 通用的更新会话方法,支持单个和批量操作
|
||||
async _updateSession(type, sessionOrSessions, params, updateLocalData) {
|
||||
const isBatch = Array.isArray(sessionOrSessions);
|
||||
|
||||
if (!isBatch) {
|
||||
// 单个操作
|
||||
const session = sessionOrSessions;
|
||||
session.updating = true;
|
||||
|
||||
try {
|
||||
const payload = {
|
||||
is_batch: false,
|
||||
session_id: session.session_id,
|
||||
...params
|
||||
};
|
||||
|
||||
const response = await axios.post(`/api/session/update_${type}`, payload);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
updateLocalData(session, true);
|
||||
this.showSuccess(this.tm(`messages.${type}UpdateSuccess`));
|
||||
return { success: true };
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm(`messages.${type}UpdateError`));
|
||||
return { success: false, error: response.data.message };
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm(`messages.${type}UpdateError`));
|
||||
return { success: false, error: error.message };
|
||||
} finally {
|
||||
session.updating = false;
|
||||
}
|
||||
} else {
|
||||
// 批量操作
|
||||
const sessions = sessionOrSessions;
|
||||
const sessionIds = sessions.map(s => s.session_id);
|
||||
|
||||
try {
|
||||
const payload = {
|
||||
is_batch: true,
|
||||
session_ids: sessionIds,
|
||||
...params
|
||||
};
|
||||
|
||||
const response = await axios.post(`/api/session/update_${type}`, payload);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data;
|
||||
|
||||
// 更新成功的会话的本地数据
|
||||
sessions.forEach(session => {
|
||||
const wasSuccessful = !data.error_sessions || !data.error_sessions.includes(session.session_id);
|
||||
updateLocalData(session, wasSuccessful);
|
||||
});
|
||||
|
||||
return {
|
||||
success: true,
|
||||
successCount: data.success_count || 0,
|
||||
errorCount: data.error_count || 0,
|
||||
errorSessions: data.error_sessions || []
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
success: false,
|
||||
error: response.data.message,
|
||||
errorCount: sessionIds.length,
|
||||
successCount: 0
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
success: false,
|
||||
error: error.response?.data?.message || error.message,
|
||||
errorCount: sessionIds.length,
|
||||
successCount: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// 单独的会话状态更新方法(不支持批量操作)
|
||||
async updateSessionStatus(session, enabled) {
|
||||
session.updating = true;
|
||||
try {
|
||||
@@ -572,47 +657,9 @@ export default {
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.session_enabled = enabled;
|
||||
this.showSuccess(this.tm('messages.sessionStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
session.updating = false;
|
||||
},
|
||||
|
||||
async updateLLM(session, enabled) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_llm', {
|
||||
session_id: session.session_id,
|
||||
enabled: enabled
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.llm_enabled = enabled;
|
||||
this.showSuccess(this.tm('messages.llmStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
session.updating = false;
|
||||
},
|
||||
|
||||
async updateTTS(session, enabled) {
|
||||
session.updating = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/update_tts', {
|
||||
session_id: session.session_id,
|
||||
enabled: enabled
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
session.tts_enabled = enabled;
|
||||
this.showSuccess(this.tm('messages.ttsStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
|
||||
this.showSuccess(this.tm('messages.sessionStatusSuccess', {
|
||||
status: enabled ? this.tm('status.enabled') : this.tm('status.disabled')
|
||||
}));
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
|
||||
}
|
||||
@@ -628,60 +675,120 @@ export default {
|
||||
}
|
||||
|
||||
this.batchUpdating = true;
|
||||
let successCount = 0;
|
||||
let errorCount = 0;
|
||||
let totalSuccessCount = 0;
|
||||
let totalErrorCount = 0;
|
||||
let allErrorSessions = [];
|
||||
|
||||
// 使用过滤后的会话数据进行批量操作
|
||||
for (const session of this.filteredSessions) {
|
||||
try {
|
||||
// 批量更新人格
|
||||
if (this.batchPersona) {
|
||||
await this.updatePersona(session, this.batchPersona);
|
||||
successCount++;
|
||||
}
|
||||
const sessions = this.filteredSessions;
|
||||
|
||||
// 批量更新 Chat Provider
|
||||
if (this.batchChatProvider) {
|
||||
await this.updateProvider(session, this.batchChatProvider, 'chat_completion');
|
||||
successCount++;
|
||||
}
|
||||
try {
|
||||
// 定义批量操作任务
|
||||
const batchTasks = [];
|
||||
|
||||
// 批量更新 STT Provider
|
||||
if (this.batchSttProvider) {
|
||||
await this.updateProvider(session, this.batchSttProvider, 'speech_to_text');
|
||||
successCount++;
|
||||
}
|
||||
|
||||
// 批量更新 TTS Provider
|
||||
if (this.batchTtsProvider) {
|
||||
await this.updateProvider(session, this.batchTtsProvider, 'text_to_speech');
|
||||
successCount++;
|
||||
}
|
||||
|
||||
// 批量更新 LLM 状态
|
||||
if (this.batchLlmStatus !== null) {
|
||||
await this.updateLLM(session, this.batchLlmStatus);
|
||||
successCount++;
|
||||
}
|
||||
|
||||
// 批量更新 TTS 状态
|
||||
if (this.batchTtsStatus !== null) {
|
||||
await this.updateTTS(session, this.batchTtsStatus);
|
||||
successCount++;
|
||||
}
|
||||
} catch (error) {
|
||||
errorCount++;
|
||||
if (this.batchPersona) {
|
||||
batchTasks.push({
|
||||
type: 'persona',
|
||||
params: { persona_name: this.batchPersona }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchChatProvider) {
|
||||
batchTasks.push({
|
||||
type: 'provider',
|
||||
params: { provider_id: this.batchChatProvider, provider_type: 'chat_completion' }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchSttProvider) {
|
||||
batchTasks.push({
|
||||
type: 'provider',
|
||||
params: { provider_id: this.batchSttProvider, provider_type: 'speech_to_text' }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchTtsProvider) {
|
||||
batchTasks.push({
|
||||
type: 'provider',
|
||||
params: { provider_id: this.batchTtsProvider, provider_type: 'text_to_speech' }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchLlmStatus !== null) {
|
||||
batchTasks.push({
|
||||
type: 'llm',
|
||||
params: { enabled: this.batchLlmStatus }
|
||||
});
|
||||
}
|
||||
|
||||
if (this.batchTtsStatus !== null) {
|
||||
batchTasks.push({
|
||||
type: 'tts',
|
||||
params: { enabled: this.batchTtsStatus }
|
||||
});
|
||||
}
|
||||
|
||||
// 执行所有批量任务
|
||||
for (const task of batchTasks) {
|
||||
let updateLocalData;
|
||||
|
||||
// 定义本地数据更新逻辑
|
||||
switch (task.type) {
|
||||
case 'persona':
|
||||
updateLocalData = (s, success) => {
|
||||
if (success) s.persona_id = task.params.persona_name;
|
||||
};
|
||||
break;
|
||||
case 'provider':
|
||||
updateLocalData = (s, success) => {
|
||||
if (!success) return;
|
||||
const { provider_id, provider_type } = task.params;
|
||||
if (provider_type === 'chat_completion') {
|
||||
s.chat_provider_id = provider_id;
|
||||
} else if (provider_type === 'speech_to_text') {
|
||||
s.stt_provider_id = provider_id;
|
||||
} else if (provider_type === 'text_to_speech') {
|
||||
s.tts_provider_id = provider_id;
|
||||
}
|
||||
};
|
||||
break;
|
||||
case 'llm':
|
||||
updateLocalData = (s, success) => {
|
||||
if (success) s.llm_enabled = task.params.enabled;
|
||||
};
|
||||
break;
|
||||
case 'tts':
|
||||
updateLocalData = (s, success) => {
|
||||
if (success) s.tts_enabled = task.params.enabled;
|
||||
};
|
||||
break;
|
||||
}
|
||||
|
||||
const result = await this._updateSession(task.type, sessions, task.params, updateLocalData);
|
||||
|
||||
totalSuccessCount += result.successCount || 0;
|
||||
totalErrorCount += result.errorCount || 0;
|
||||
if (result.errorSessions) {
|
||||
allErrorSessions.push(...result.errorSessions);
|
||||
}
|
||||
}
|
||||
|
||||
// 显示最终结果
|
||||
if (totalErrorCount === 0) {
|
||||
this.showSuccess(this.tm('messages.batchUpdateSuccess', { count: totalSuccessCount }));
|
||||
} else {
|
||||
const uniqueErrorSessions = [...new Set(allErrorSessions)];
|
||||
this.showError(this.tm('messages.batchUpdatePartial', {
|
||||
success: totalSuccessCount,
|
||||
error: uniqueErrorSessions.length
|
||||
}));
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
this.showError(this.tm('messages.batchUpdateError'));
|
||||
}
|
||||
|
||||
this.batchUpdating = false;
|
||||
|
||||
if (errorCount === 0) {
|
||||
this.showSuccess(this.tm('messages.batchUpdateSuccess', { count: successCount }));
|
||||
} else {
|
||||
this.showError(this.tm('messages.batchUpdatePartial', { success: successCount, error: errorCount }));
|
||||
}
|
||||
|
||||
// 清空批量设置
|
||||
this.batchPersona = null;
|
||||
this.batchChatProvider = null;
|
||||
@@ -797,6 +904,38 @@ export default {
|
||||
this.snackbarColor = 'error';
|
||||
this.snackbar = true;
|
||||
},
|
||||
|
||||
async deleteSession(session) {
|
||||
const confirmMessage = this.tm('deleteConfirm.message', {
|
||||
sessionName: session.session_name || session.session_id
|
||||
}) + '\n\n' + this.tm('deleteConfirm.warning');
|
||||
|
||||
if (!confirm(confirmMessage)) {
|
||||
return;
|
||||
}
|
||||
|
||||
session.deleting = true;
|
||||
try {
|
||||
const response = await axios.post('/api/session/delete', {
|
||||
session_id: session.session_id
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.data.message || this.tm('messages.deleteSuccess'));
|
||||
// 从列表中移除已删除的会话
|
||||
const index = this.sessions.findIndex(s => s.session_id === session.session_id);
|
||||
if (index > -1) {
|
||||
this.sessions.splice(index, 1);
|
||||
}
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.deleteError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.deleteError'));
|
||||
}
|
||||
|
||||
session.deleting = false;
|
||||
},
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
||||
+21
-21
@@ -527,12 +527,11 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
return
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"已重置当前 Dify 会话,新聊天将更换到新的会话。"
|
||||
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
@@ -755,8 +754,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
创建新对话
|
||||
"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||
@@ -783,8 +781,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
|
||||
"""创建新群聊对话"""
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
if provider and provider.meta().type in ["dify", "coze"]:
|
||||
await provider.forget(message.unified_msg_origin)
|
||||
message.set_result(
|
||||
MessageEventResult().message("成功,下次聊天将是新对话。")
|
||||
@@ -823,7 +820,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
|
||||
provider = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if provider and provider.meta().type == "dify":
|
||||
assert isinstance(provider, ProviderDify)
|
||||
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
|
||||
if not data["data"]:
|
||||
message.set_result(MessageEventResult().message("未找到任何对话。"))
|
||||
@@ -1348,22 +1344,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("alter_cmd")
|
||||
@filter.command("alter_cmd", alias={"alter"})
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
# token = event.message_str.split(" ")
|
||||
token = self.parse_commands(event.message_str)
|
||||
if token.len < 2:
|
||||
if token.len < 3:
|
||||
yield event.plain_result(
|
||||
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
|
||||
"该指令用于设置指令或指令组的权限。\n"
|
||||
"格式: /alter_cmd <cmd_name> <admin/member>\n"
|
||||
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
|
||||
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
|
||||
"/alter_cmd reset config 打开 reset 权限配置"
|
||||
)
|
||||
return
|
||||
|
||||
cmd_name = token.get(1)
|
||||
cmd_type = token.get(2)
|
||||
cmd_name = " ".join(token.tokens[1:-1])
|
||||
cmd_type = token.get(-1)
|
||||
|
||||
# ============================
|
||||
# 对reset权限进行特殊处理
|
||||
# ============================
|
||||
if cmd_name == "reset" and cmd_type == "config":
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get("astrbot", {})
|
||||
@@ -1413,16 +1409,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
|
||||
# 查找指令
|
||||
found_command = None
|
||||
cmd_group = False
|
||||
for handler in star_handlers_registry:
|
||||
assert isinstance(handler, StarHandlerMetadata)
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, CommandFilter):
|
||||
if filter_.command_name == cmd_name:
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
break
|
||||
elif isinstance(filter_, CommandGroupFilter):
|
||||
if cmd_name == filter_.group_name:
|
||||
if filter_.equals(cmd_name):
|
||||
found_command = handler
|
||||
cmd_group = True
|
||||
break
|
||||
|
||||
if not found_command:
|
||||
@@ -1459,8 +1457,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
else filter.PermissionType.MEMBER
|
||||
),
|
||||
)
|
||||
|
||||
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
|
||||
cmd_group_str = "指令组" if cmd_group else "指令"
|
||||
yield event.plain_result(
|
||||
f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。"
|
||||
)
|
||||
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
"""更新reset命令在特定场景下的权限设置
|
||||
|
||||
@@ -375,5 +375,3 @@ class Main(star.Star):
|
||||
tool_set.add_tool(tavily_extract_web_page)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
|
||||
print(req.func_tool)
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.1.7"
|
||||
version = "4.2.0"
|
||||
description = "易上手的多平台 LLM 聊天机器人及开发框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
+79
-43
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import os
|
||||
import asyncio
|
||||
from quart import Quart
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def core_lifecycle_td():
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def core_lifecycle_td(tmp_path_factory):
|
||||
"""Creates and initializes a core lifecycle instance with a temporary database."""
|
||||
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
|
||||
db = SQLiteDatabase(str(tmp_db_path))
|
||||
log_broker = LogBroker()
|
||||
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
|
||||
return core_lifecycle_td
|
||||
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||
await core_lifecycle.initialize()
|
||||
return core_lifecycle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app(core_lifecycle_td):
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
server = AstrBotDashboard(core_lifecycle_td, db)
|
||||
def app(core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Creates a Quart app instance for testing."""
|
||||
shutdown_event = asyncio.Event()
|
||||
# The db instance is already part of the core_lifecycle_td
|
||||
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
|
||||
return server.app
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def header():
|
||||
return {}
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Handles login and returns an authenticated header."""
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post(
|
||||
"/api/auth/login",
|
||||
json={
|
||||
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
|
||||
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
|
||||
},
|
||||
)
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
token = data["data"]["token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_core_lifecycle_td(core_lifecycle_td):
|
||||
await core_lifecycle_td.initialize()
|
||||
assert core_lifecycle_td is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(
|
||||
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
|
||||
):
|
||||
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
"""Tests the login functionality with both wrong and correct credentials."""
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post(
|
||||
"/api/auth/login", json={"username": "wrong", "password": "password"}
|
||||
@@ -55,31 +67,32 @@ async def test_auth_login(
|
||||
)
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok" and "token" in data["data"]
|
||||
header["Authorization"] = f"Bearer {data['data']['token']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stat(app: Quart, header: dict):
|
||||
async def test_get_stat(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get("/api/stat/get")
|
||||
assert response.status_code == 401
|
||||
response = await test_client.get("/api/stat/get", headers=header)
|
||||
response = await test_client.get("/api/stat/get", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok" and "platform" in data["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(app: Quart, header: dict):
|
||||
async def test_plugins(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
# 已经安装的插件
|
||||
response = await test_client.get("/api/plugin/get", headers=header)
|
||||
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
|
||||
# 插件市场
|
||||
response = await test_client.get("/api/plugin/market_list", headers=header)
|
||||
response = await test_client.get(
|
||||
"/api/plugin/market_list", headers=authenticated_header
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
|
||||
response = await test_client.post(
|
||||
"/api/plugin/install",
|
||||
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
|
||||
headers=header,
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
|
||||
|
||||
# 插件更新
|
||||
response = await test_client.post(
|
||||
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
|
||||
"/api/plugin/update",
|
||||
json={"name": "astrbot_plugin_essential"},
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
|
||||
response = await test_client.post(
|
||||
"/api/plugin/uninstall",
|
||||
json={"name": "astrbot_plugin_essential"},
|
||||
headers=header,
|
||||
headers=authenticated_header,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_update(app: Quart, header: dict):
|
||||
async def test_check_update(app: Quart, authenticated_header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get("/api/update/check", headers=header)
|
||||
response = await test_client.get("/api/update/check", headers=authenticated_header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "success"
|
||||
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update(
|
||||
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
|
||||
app: Quart,
|
||||
authenticated_header: dict,
|
||||
core_lifecycle_td: AstrBotCoreLifecycle,
|
||||
monkeypatch,
|
||||
tmp_path_factory,
|
||||
):
|
||||
global VERSION
|
||||
test_client = app.test_client()
|
||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
||||
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
||||
VERSION = "114.514.1919810"
|
||||
response = await test_client.post(
|
||||
"/api/update/do", headers=header, json={"version": "latest"}
|
||||
|
||||
# Use a temporary path for the mock update to avoid side effects
|
||||
temp_release_dir = tmp_path_factory.mktemp("release")
|
||||
release_path = temp_release_dir / "astrbot"
|
||||
|
||||
async def mock_update(*args, **kwargs):
|
||||
"""Mocks the update process by creating a directory in the temp path."""
|
||||
os.makedirs(release_path, exist_ok=True)
|
||||
return
|
||||
|
||||
async def mock_download_dashboard(*args, **kwargs):
|
||||
"""Mocks the dashboard download to prevent network access."""
|
||||
return
|
||||
|
||||
async def mock_pip_install(*args, **kwargs):
|
||||
"""Mocks pip install to prevent actual installation."""
|
||||
return
|
||||
|
||||
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "error" # 已经是最新版本
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
|
||||
"/api/update/do",
|
||||
headers=authenticated_header,
|
||||
json={"version": "v3.4.0", "reboot": False},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data["status"] == "ok"
|
||||
assert os.path.exists("data/astrbot_release/astrbot")
|
||||
assert os.path.exists(release_path)
|
||||
|
||||
+51
-18
@@ -1,5 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from main import check_env, check_dashboard_files
|
||||
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files(monkeypatch):
|
||||
async def test_check_dashboard_files_not_exists(monkeypatch):
|
||||
"""Tests dashboard download when files do not exist."""
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: False)
|
||||
|
||||
async def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
status = 200
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
await check_dashboard_files()
|
||||
mock_download.assert_called_once()
|
||||
|
||||
async def read(self):
|
||||
return b"content"
|
||||
|
||||
return MockResponse()
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
|
||||
"""Tests that dashboard is not downloaded when it exists and version matches."""
|
||||
# Mock os.path.exists to return True
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
|
||||
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
|
||||
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
|
||||
# Mock get_dashboard_version to return the current version
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
# We need to import VERSION from main's context
|
||||
from main import VERSION
|
||||
|
||||
async def mock_aenter(_):
|
||||
await check_dashboard_files()
|
||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
||||
mock_extractall.assert_called_once()
|
||||
mock_get_version.return_value = f"v{VERSION}"
|
||||
|
||||
async def mock_aexit(obj, exc_type, exc, tb):
|
||||
return
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
await check_dashboard_files()
|
||||
# Assert that download_dashboard was NOT called
|
||||
mock_download.assert_not_called()
|
||||
|
||||
mock_extractall.__aenter__ = mock_aenter
|
||||
mock_extractall.__aexit__ = mock_aexit
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
|
||||
"""Tests that a warning is logged when dashboard version mismatches."""
|
||||
monkeypatch.setattr(os.path, "exists", lambda x: True)
|
||||
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
mock_get_version.return_value = "v0.0.1" # A different version
|
||||
|
||||
with mock.patch("main.logger.warning") as mock_logger_warning:
|
||||
await check_dashboard_files()
|
||||
mock_logger_warning.assert_called_once()
|
||||
call_args, _ = mock_logger_warning.call_args
|
||||
assert "不符" in call_args[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
|
||||
"""Tests that providing a valid webui_dir skips all checks."""
|
||||
valid_dir = "/tmp/my-custom-webui"
|
||||
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
|
||||
|
||||
with mock.patch("main.download_dashboard") as mock_download:
|
||||
with mock.patch("main.get_dashboard_version") as mock_get_version:
|
||||
result = await check_dashboard_files(webui_dir=valid_dir)
|
||||
assert result == valid_dir
|
||||
mock_download.assert_not_called()
|
||||
mock_get_version.assert_not_called()
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
import pytest
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core.message.components import Plain, At
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
||||
TEST_LLM_PROVIDER = {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
}
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
["t2i", "文本转图片模式"],
|
||||
["sid", "此 ID 可用于设置会话白名单。"],
|
||||
["op test_op", "授权成功。"],
|
||||
["deop test_op", "取消授权成功。"],
|
||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
||||
["provider", "当前载入的 LLM 提供商"],
|
||||
["reset", "重置成功"],
|
||||
# ["model", "查看、切换提供商模型列表"],
|
||||
["history", "历史记录:"],
|
||||
["key", "当前 Key"],
|
||||
["persona", "[Persona]"],
|
||||
]
|
||||
|
||||
|
||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, abm: AstrBotMessage = None):
|
||||
meta = PlatformMetadata("test_platform", "test")
|
||||
super().__init__(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=meta,
|
||||
session_id=abm.session_id,
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_event(
|
||||
message_str: str,
|
||||
session_id: str = "test_sid",
|
||||
is_at: bool = False,
|
||||
is_group: bool = False,
|
||||
sender_id: str = "123456",
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
abm.message_str = message_str
|
||||
abm.group_id = "test"
|
||||
abm.message = [Plain(message_str)]
|
||||
if is_at:
|
||||
abm.message.append(At(qq="bot"))
|
||||
abm.self_id = "bot"
|
||||
abm.sender = MessageMember(sender_id, "mika")
|
||||
abm.timestamp = 1234567890
|
||||
abm.message_id = "test"
|
||||
abm.session_id = session_id
|
||||
if is_group:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return FakeAstrMessageEvent(abm)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
cfg = AstrBotConfig()
|
||||
cfg["platform_settings"]["id_whitelist"] = [
|
||||
"test_platform:FriendMessage:test_sid_wl",
|
||||
"test_platform:GroupMessage:test_sid_wl",
|
||||
]
|
||||
cfg["admins_id"] = ["123456"]
|
||||
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
|
||||
cfg["provider"] = [TEST_LLM_PROVIDER]
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_manager(event_queue, config):
|
||||
return PlatformManager(config, event_queue)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def provider_manager(config, db):
|
||||
return ProviderManager(config, db)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
||||
return star_context
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plugin_manager(star_context, config):
|
||||
plugin_manager = PluginManager(star_context, config)
|
||||
# await plugin_manager.reload()
|
||||
asyncio.run(plugin_manager.reload())
|
||||
return plugin_manager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_context(config, plugin_manager):
|
||||
return PipelineContext(config, plugin_manager)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_scheduler(pipeline_context):
|
||||
return PipelineScheduler(pipeline_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
||||
await platform_manager.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
||||
await provider_manager.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
||||
await pipeline_scheduler.initialize()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
"""测试唤醒"""
|
||||
# 群聊无 @ 无指令
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
|
||||
)
|
||||
# 群聊有 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"test", is_group=True, is_at=True
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
||||
# 群聊有指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert mock_event._has_send_oper is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(
|
||||
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
|
||||
):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"不在会话白名单中,已终止事件传播。" not in message
|
||||
for message in caplog.messages
|
||||
), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any(
|
||||
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
|
||||
), "日志中未找到预期的消息"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"色情", session_id=SESSION_ID_IN_WHITELIST
|
||||
) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
||||
"日志中未找到预期的消息"
|
||||
)
|
||||
# 测试额外屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), (
|
||||
"日志中未找到预期的消息"
|
||||
)
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
||||
# TODO: 测试 百度AI 的内容安全检查
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert mock_event.get_result() is not None
|
||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert any(
|
||||
"web_searcher - search_from_search_engine" in message
|
||||
for message in caplog.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(
|
||||
command[0], session_id=SESSION_ID_IN_WHITELIST
|
||||
)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
||||
assert any(command[1] in message for message in caplog.messages)
|
||||
+105
-43
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
@@ -8,18 +9,51 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from asyncio import Queue
|
||||
|
||||
event_queue = Queue()
|
||||
|
||||
config = AstrBotConfig()
|
||||
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
star_context = Context(event_queue, config, db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_manager_pm():
|
||||
return PluginManager(star_context, config)
|
||||
def plugin_manager_pm(tmp_path):
|
||||
"""
|
||||
Provides a fully isolated PluginManager instance for testing.
|
||||
- Uses a temporary directory for plugins.
|
||||
- Uses a temporary database.
|
||||
- Creates a fresh context for each test.
|
||||
"""
|
||||
# Create temporary resources
|
||||
temp_plugins_path = tmp_path / "plugins"
|
||||
temp_plugins_path.mkdir()
|
||||
temp_db_path = tmp_path / "test_db.db"
|
||||
|
||||
# Create fresh, isolated instances for the context
|
||||
event_queue = Queue()
|
||||
config = AstrBotConfig()
|
||||
db = SQLiteDatabase(str(temp_db_path))
|
||||
|
||||
# Set the plugin store path in the config to the temporary directory
|
||||
config.plugin_store_path = str(temp_plugins_path)
|
||||
|
||||
# Mock dependencies for the context
|
||||
provider_manager = MagicMock()
|
||||
platform_manager = MagicMock()
|
||||
conversation_manager = MagicMock()
|
||||
message_history_manager = MagicMock()
|
||||
persona_manager = MagicMock()
|
||||
astrbot_config_mgr = MagicMock()
|
||||
|
||||
star_context = Context(
|
||||
event_queue,
|
||||
config,
|
||||
db,
|
||||
provider_manager,
|
||||
platform_manager,
|
||||
conversation_manager,
|
||||
message_history_manager,
|
||||
persona_manager,
|
||||
astrbot_config_mgr,
|
||||
)
|
||||
|
||||
# Create the PluginManager instance
|
||||
manager = PluginManager(star_context, config)
|
||||
yield manager
|
||||
|
||||
|
||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
@@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
"""测试插件安装和重载"""
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
async def test_install_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests successful plugin installation in an isolated environment."""
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert plugin_path is not None
|
||||
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
|
||||
plugin_path = os.path.join(
|
||||
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
|
||||
)
|
||||
|
||||
assert plugin_info is not None
|
||||
assert os.path.exists(plugin_path)
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
# shutil.rmtree(plugin_path)
|
||||
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
||||
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
|
||||
)
|
||||
|
||||
# install plugin which is not exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that installing a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
|
||||
await plugin_manager_pm.install_plugin(
|
||||
"https://github.com/Soulter/non_existent_repo"
|
||||
)
|
||||
|
||||
# update
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests updating an existing plugin in an isolated environment."""
|
||||
# First, install the plugin
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
await plugin_manager_pm.install_plugin(test_repo)
|
||||
|
||||
# Then, update it
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# uninstall
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that updating a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("non_existent_plugin")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests successful plugin uninstallation in an isolated environment."""
|
||||
# First, install the plugin
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
await plugin_manager_pm.install_plugin(test_repo)
|
||||
plugin_path = os.path.join(
|
||||
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
|
||||
)
|
||||
assert os.path.exists(plugin_path) # Pre-condition
|
||||
|
||||
# Then, uninstall it
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
||||
|
||||
assert not os.path.exists(plugin_path)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
|
||||
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
|
||||
)
|
||||
assert not any(
|
||||
"astrbot_plugin_essential" in md.handler_module_path
|
||||
for md in star_handlers_registry
|
||||
), (
|
||||
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
|
||||
"""Tests that uninstalling a non-existent plugin raises an exception."""
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# TODO: file installation
|
||||
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")
|
||||
|
||||
Reference in New Issue
Block a user