Compare commits

...

13 Commits

Author SHA1 Message Date
Soulter b9de2aef60 chore: bump version to 4.2.1 2025-09-27 23:36:25 +08:00
Soulter 7a47598538 fix: 修复指令无法使用的问题
fixes: #2897
2025-09-27 23:35:35 +08:00
Soulter 3c8c28ebd5 chore: bump version to 4.2.0 2025-09-27 20:45:50 +08:00
Soulter 524285f767 feat: add cancel button with localized text to AddNewPlatform and update close button in AddNewProvider
fixes: #2889
2025-09-27 20:41:45 +08:00
Soulter c2a34475f1 feat: 支持删除指定会话以及部分会话管理优化 (#2895)
* feat: add toast notification system with snackbar component

* feat: add session deletion functionality

* feat: support batch operations for updating session persona, provider, LLM, and TTS statuses

fix: #2263

* feat: 修复对话状态关闭,删除对话管理库会导致对话无法恢复

fixes: #2309
2025-09-27 20:36:30 +08:00
Soulter a69195a02b fix: webchat streaming queue interrupted after user closing tab (#2892)
* feat: add toast notification system with snackbar component

* feat: enhance chat functionality with conversation running state and notifications

* fix: update bot message avatar rendering during streaming

* feat: implement conversation tracking context manager for webchat

* fix: update conversation tracking to remove conversation ID on exit
2025-09-27 17:57:12 +08:00
RC-CHN 19d7438499 fix: unit tests (#2760)
* fix:修复了main和plugin_manager部分单元测试

* fix: 修复了dashboard部分测试

* remove: 删除暂无用的配置测试脚本

* perf:拆分插件增查删改为独立的单元测试

* refactor: 重构插件管理器测试,使用临时环境隔离测试实例

* test: 增加对仪表板文件检查的单元测试,涵盖不同情况

* style: format code

* remove: 删除未使用的导入语句

* delete: remove unused test file for pipeline

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:43:04 +08:00
anka ccb380ce06 feat: 支持接入 Coze (#2858)
* feat: 适配 coze 供应商
1. 支持文件上传
2. 支持多模态
3. 支持流式传输
4. 支持 API 端的上下文保存历史记录
5. 支持类似 dify 的 forget 接口

* style: format code

* fix: type checking error

* fix: 修复:
1. 使用coze api端的上下文时, 现在不会重复传递上下文
2. 使用 AstrBot 的上下文时, 正确处理其中的图片信息
3. 上传图片时, 提供一个非持久化的缓存避免重复上传(在解析上下文并将文件转化为file_id传递给coze api时, 如果没有缓存会导致很多的网络资源浪费)
4. 修复reset等指令不能正确重置上下文的问题

* fix: 移除某些地方多余的针对 dify 的断言, 以兼容 Coze

* style: 修改配置项显示/webchat平台对于非预期的类型的处理

* fix: 让conversation_id放到请求中正确的位置

* refactor: extract coze api client

* refactor: improve image processing logic in ProviderCoze

* chore: remove file ext guessing

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-27 14:23:29 +08:00
Ding Jiatong a35c439bbd fix: 使用增量解码器修复 Dify 流式返回结果偶现的解码错误 (#2888)
* fix: 修复linux下utf-8解码错误的问题

* feat: use incremental decoder

* fix: add type hint for response parameter in _stream_sse and refactor file upload method

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 23:04:58 +08:00
Soulter 09d1f96603 fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)
* fix: revert changes in command_group.py at 782c036 to fix command group permission check

* fix: 不传递 GroupCommand handler

* perf: alter_cmd 指令支持对子指令、指令组进行配置

* chore: remove test commands and subcommands from test_group

* chore: add cache for complete command names list in CommandFilter and CommandGroupFilter

---------

Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-09-26 14:16:50 +08:00
鸦羽 26aa18d980 Merge pull request #2881 from Raven95676/fix/2879
fix: add missing id field
2025-09-26 11:31:28 +08:00
Raven95676 d10b542797 chore: format 2025-09-26 11:05:32 +08:00
Raven95676 ce4e4fb8dd fix: add missing id field 2025-09-26 10:59:11 +08:00
41 changed files with 2123 additions and 787 deletions
+33 -1
View File
@@ -6,7 +6,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.1.7"
VERSION = "4.2.1"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
# 默认配置
@@ -869,6 +869,18 @@ CONFIG_METADATA_2 = {
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
},
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
@@ -1735,6 +1747,26 @@ CONFIG_METADATA_2 = {
"hint": "发送的消息文本内容对应的输入变量名。默认为 astrbot_text_query。",
"obvious": True,
},
"coze_api_key": {
"description": "Coze API Key",
"type": "string",
"hint": "Coze API 密钥,用于访问 Coze 服务。",
},
"bot_id": {
"description": "Bot ID",
"type": "string",
"hint": "Coze 机器人的 ID,在 Coze 平台上创建机器人后获得。",
},
"coze_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Coze API 的基础 URL 地址,默认为 https://api.coze.cn",
},
"auto_save_history": {
"description": "由 Coze 管理对话记录",
"type": "bool",
"hint": "启用后,将由 Coze 进行对话历史记录管理, 此时 AstrBot 本地保存的上下文不会生效(仅供浏览), 对 AstrBot 的上下文进行的操作也不会生效。如果为禁用, 则使用 AstrBot 管理上下文。",
},
},
},
"provider_settings": {
+12 -4
View File
@@ -87,17 +87,25 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
f = False
if not conversation_id:
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
f = True
if conversation_id:
await self.db.delete_conversation(cid=conversation_id)
if f:
curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
if curr_cid == conversation_id:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
"""删除会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
"""
await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
"""获取会话当前的对话 ID
+5
View File
@@ -154,6 +154,11 @@ class BaseDatabase(abc.ABC):
"""Delete a conversation by its ID."""
...
@abc.abstractmethod
async def delete_conversations_by_user_id(self, user_id: str) -> None:
"""Delete all conversations for a specific user."""
...
@abc.abstractmethod
async def insert_platform_message_history(
self,
+8
View File
@@ -249,6 +249,14 @@ class SQLiteDatabase(BaseDatabase):
delete(ConversationV2).where(ConversationV2.conversation_id == cid)
)
async def delete_conversations_by_user_id(self, user_id: str) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(ConversationV2).where(ConversationV2.user_id == user_id)
)
async def insert_platform_message_history(
self,
platform_id,
@@ -291,13 +291,6 @@ async def run_agent(
else:
astr_event.set_result(MessageEventResult().message(err_msg))
return
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
class LLMRequestSubStage(Stage):
@@ -524,6 +517,14 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
)
)
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -536,7 +537,23 @@ class LLMRequestSubStage(Stage):
latest_pair = messages[-2:]
if not latest_pair:
return
cleaned_text = "User: " + latest_pair[0].get("content", "").strip()
content = latest_pair[0].get("content", "")
if isinstance(content, list):
# 多模态
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "image":
text_parts.append("[图片]")
elif isinstance(item, str):
text_parts.append(item)
cleaned_text = "User: " + " ".join(text_parts).strip()
elif isinstance(content, str):
cleaned_text = "User: " + content.strip()
else:
return
logger.debug(f"WebChat 对话标题生成请求,清理后的文本: {cleaned_text}")
llm_resp = await prov.text_chat(
system_prompt="You are expert in summarizing user's query.",
@@ -11,7 +11,8 @@ class SessionStatusCheckStage(Stage):
"""检查会话是否整体启用"""
async def initialize(self, ctx: PipelineContext) -> None:
pass
self.ctx = ctx
self.conv_mgr = ctx.plugin_manager.context.conversation_manager
async def process(
self, event: AstrMessageEvent
@@ -19,4 +20,14 @@ class SessionStatusCheckStage(Stage):
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
conv_id = await self.conv_mgr.get_curr_conversation_id(
event.unified_msg_origin
)
if not conv_id:
await self.conv_mgr.new_conversation(
event.unified_msg_origin, platform_id=event.get_platform_id()
)
event.stop_event()
+10 -5
View File
@@ -5,6 +5,7 @@ from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -170,11 +171,15 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
activated_handlers.append(handler)
if "parsed_params" in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
"parsed_params"
)
is_group_cmd_handler = any(
isinstance(f, CommandGroupFilter) for f in handler.event_filters
)
if not is_group_cmd_handler:
activated_handlers.append(handler)
if "parsed_params" in event.get_extra(default={}):
handlers_parsed_params[handler.handler_full_name] = (
event.get_extra("parsed_params")
)
event._extras.pop("parsed_params", None)
+9 -5
View File
@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -26,6 +26,8 @@ from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
class AstrMessageEvent(abc.ABC):
def __init__(
@@ -49,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
"""是否唤醒(是否通过 WakingStage)"""
self.is_at_or_wake_command = False
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self._extras: dict[str, Any] = {}
self.session = MessageSesion(
platform_name=platform_meta.id,
message_type=message_obj.type,
@@ -57,7 +59,7 @@ class AstrMessageEvent(abc.ABC):
)
self.unified_msg_origin = str(self.session)
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self._result: MessageEventResult = None
self._result: MessageEventResult | None = None
"""消息事件的结果"""
self._has_send_oper = False
@@ -173,13 +175,15 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(self, key=None):
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
"""
获取额外的信息。
"""
if key is None:
return self._extras
return self._extras.get(key, None)
return self._extras.get(key, default)
def clear_extra(self):
"""
@@ -185,6 +185,7 @@ class WecomPlatformAdapter(Platform):
return PlatformMetadata(
"wecom",
"wecom 适配器",
id=self.config.get("id", "wecom"),
)
@override
@@ -184,6 +184,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
return PlatformMetadata(
"weixin_official_account",
"微信公众平台 适配器",
id=self.config.get("id", "weixin_official_account"),
)
@override
+2
View File
@@ -234,6 +234,8 @@ class ProviderManager:
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
@@ -0,0 +1,314 @@
import json
import asyncio
import aiohttp
import io
from typing import Dict, List, Any, AsyncGenerator
from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
self.api_key = api_key
self.api_base = api_base
self.session = None
async def _ensure_session(self):
"""确保HTTP session存在"""
if self.session is None:
connector = aiohttp.TCPConnector(
ssl=False if self.api_base.startswith("http://") else True,
limit=100,
limit_per_host=30,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
timeout = aiohttp.ClientTimeout(
total=120, # 默认超时时间
connect=30,
sock_read=120,
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "text/event-stream",
}
self.session = aiohttp.ClientSession(
headers=headers, timeout=timeout, connector=connector
)
return self.session
async def upload_file(
self,
file_data: bytes,
) -> str:
"""上传文件到 Coze 并返回 file_id
Args:
file_data (bytes): 文件的二进制数据
Returns:
str: 上传成功后返回的 file_id
"""
session = await self._ensure_session()
url = f"{self.api_base}/v1/files/upload"
try:
file_io = io.BytesIO(file_data)
async with session.post(
url,
data={
"file": file_io,
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
response_text = await response.text()
logger.debug(
f"文件上传响应状态: {response.status}, 内容: {response_text}"
)
if response.status != 200:
raise Exception(
f"文件上传失败,状态码: {response.status}, 响应: {response_text}"
)
try:
result = await response.json()
except json.JSONDecodeError:
raise Exception(f"文件上传响应解析失败: {response_text}")
if result.get("code") != 0:
raise Exception(f"文件上传失败: {result.get('msg', '未知错误')}")
file_id = result["data"]["id"]
logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}")
return file_id
except asyncio.TimeoutError:
logger.error("文件上传超时")
raise Exception("文件上传超时")
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
raise Exception(f"文件上传失败: {str(e)}")
async def download_image(self, image_url: str) -> bytes:
"""下载图片并返回字节数据
Args:
image_url (str): 图片的URL
Returns:
bytes: 图片的二进制数据
"""
session = await self._ensure_session()
try:
async with session.get(image_url) as response:
if response.status != 200:
raise Exception(f"下载图片失败,状态码: {response.status}")
image_data = await response.read()
return image_data
except Exception as e:
logger.error(f"下载图片失败 {image_url}: {str(e)}")
raise Exception(f"下载图片失败: {str(e)}")
async def chat_messages(
self,
bot_id: str,
user_id: str,
additional_messages: List[Dict] | None = None,
conversation_id: str | None = None,
auto_save_history: bool = True,
stream: bool = True,
timeout: float = 120,
) -> AsyncGenerator[Dict[str, Any], None]:
"""发送聊天消息并返回流式响应
Args:
bot_id: Bot ID
user_id: 用户ID
additional_messages: 额外消息列表
conversation_id: 会话ID
auto_save_history: 是否自动保存历史
stream: 是否流式响应
timeout: 超时时间
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/chat"
payload = {
"bot_id": bot_id,
"user_id": user_id,
"stream": stream,
"auto_save_history": auto_save_history,
}
if additional_messages:
payload["additional_messages"] = additional_messages
params = {}
if conversation_id:
params["conversation_id"] = conversation_id
logger.debug(f"Coze chat_messages payload: {payload}, params: {params}")
try:
async with session.post(
url,
json=payload,
params=params,
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 流式请求失败,状态码: {response.status}")
# SSE
buffer = ""
event_type = None
event_data = None
async for chunk in response.content:
if chunk:
buffer += chunk.decode("utf-8", errors="ignore")
lines = buffer.split("\n")
buffer = lines[-1]
for line in lines[:-1]:
line = line.strip()
if not line:
if event_type and event_data:
yield {"event": event_type, "data": event_data}
event_type = None
event_data = None
elif line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_str = line[5:].strip()
if data_str and data_str != "[DONE]":
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
event_data = {"content": data_str}
except asyncio.TimeoutError:
raise Exception(f"Coze API 流式请求超时 ({timeout}秒)")
except Exception as e:
raise Exception(f"Coze API 流式请求失败: {str(e)}")
async def clear_context(self, conversation_id: str):
"""清空会话上下文
Args:
conversation_id: 会话ID
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/clear_context"
payload = {"conversation_id": conversation_id}
try:
async with session.post(url, json=payload) as response:
response_text = await response.text()
if response.status == 401:
raise Exception("Coze API 认证失败,请检查 API Key 是否正确")
if response.status != 200:
raise Exception(f"Coze API 请求失败,状态码: {response.status}")
try:
return json.loads(response_text)
except json.JSONDecodeError:
raise Exception("Coze API 返回非JSON格式")
except asyncio.TimeoutError:
raise Exception("Coze API 请求超时")
except aiohttp.ClientError as e:
raise Exception(f"Coze API 请求失败: {str(e)}")
async def get_message_list(
self,
conversation_id: str,
order: str = "desc",
limit: int = 10,
offset: int = 0,
):
"""获取消息列表
Args:
conversation_id: 会话ID
order: 排序方式 (asc/desc)
limit: 限制数量
offset: 偏移量
Returns:
dict: API响应结果
"""
session = await self._ensure_session()
url = f"{self.api_base}/v3/conversation/message/list"
params = {
"conversation_id": conversation_id,
"order": order,
"limit": limit,
"offset": offset,
}
try:
async with session.get(url, params=params) as response:
response.raise_for_status()
return await response.json()
except Exception as e:
logger.error(f"获取Coze消息列表失败: {str(e)}")
raise Exception(f"获取Coze消息列表失败: {str(e)}")
async def close(self):
"""关闭会话"""
if self.session:
await self.session.close()
self.session = None
if __name__ == "__main__":
import os
import asyncio
async def test_coze_api_client():
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
try:
with open("README.md", "rb") as f:
file_data = f.read()
file_id = await client.upload_file(file_data)
print(f"Uploaded file_id: {file_id}")
async for event in client.chat_messages(
bot_id=bot_id,
user_id="test_user",
additional_messages=[
{
"role": "user",
"content": json.dumps(
[
{"type": "text", "text": "这是什么"},
{"type": "file", "file_id": file_id},
],
ensure_ascii=False,
),
"content_type": "object_string",
},
],
stream=True,
):
print(f"Event: {event}")
finally:
await client.close()
asyncio.run(test_coze_api_client())
@@ -0,0 +1,635 @@
import json
import os
import base64
import hashlib
from typing import AsyncGenerator, Dict
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
default_persona=None,
) -> None:
super().__init__(
provider_config,
provider_settings,
default_persona,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://")
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。"
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: Dict[str, str] = {}
self.file_id_cache: Dict[str, Dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
else:
if data.startswith(("http://", "https://")):
# URL图片,使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
else:
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self, image_url: str, session_id: str | None = None
) -> str:
"""下载图片并上传到 Coze,返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {str(e)}")
raise Exception(f"处理图片失败: {str(e)}")
async def _process_context_images(
self, content: str | list, session_id: str
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data, is_base64=image_data.startswith("data:image/")
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id}
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data, session_id
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}..."
)
continue
processed_content.append(
{"type": "image", "file_id": file_id}
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {str(e)}")
if isinstance(content, str):
return content
else:
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
else:
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{"role": "system", "content": system_prompt, "content_type": "text"}
)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content, user_id
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
}
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url, user_id
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url, is_base64=True
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
# 本地文件
if os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data, user_id, cache_key
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {str(e)}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
else:
# 纯文本
if prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
}
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {str(e)}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {str(e)}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
else:
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {str(e)}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型(在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self, session_id: str, page: int = 1, page_size: int = 10
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {str(e)}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()
+23 -11
View File
@@ -32,6 +32,9 @@ class CommandFilter(HandlerFilter):
self.init_handler_md(handler_md)
self.custom_filter_list: List[CustomFilter] = []
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def print_types(self):
result = ""
for k, v in self.handler_params.items():
@@ -136,6 +139,22 @@ class CommandFilter(HandlerFilter):
)
return result
def get_complete_command_names(self):
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
self._cmpl_cmd_names = [
f"{parent} {cmd}" if parent else cmd
for cmd in [self.command_name] + list(self.alias)
for parent in self.parent_command_names or [""]
]
return self._cmpl_cmd_names
def equals(self, message_str: str) -> bool:
for full_cmd in self.get_complete_command_names():
if message_str == full_cmd:
return True
return False
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -145,18 +164,11 @@ class CommandFilter(HandlerFilter):
# 检查是否以指令开头
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
candidates = [self.command_name] + list(self.alias)
ok = False
for candidate in candidates:
for parent_command_name in self.parent_command_names:
if parent_command_name:
_full = f"{parent_command_name} {candidate}"
else:
_full = candidate
if message_str.startswith(f"{_full} ") or message_str == _full:
message_str = message_str[len(_full) :].strip()
ok = True
break
for full_cmd in self.get_complete_command_names():
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
ok = True
message_str = message_str[len(full_cmd) :].strip()
if not ok:
return False
+15 -5
View File
@@ -22,6 +22,9 @@ class CommandGroupFilter(HandlerFilter):
self.custom_filter_list: List[CustomFilter] = []
self.parent_group = parent_group
# Cache for complete command names list
self._cmpl_cmd_names: list | None = None
def add_sub_command_filter(
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
):
@@ -34,6 +37,9 @@ class CommandGroupFilter(HandlerFilter):
"""遍历父节点获取完整的指令名。
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
if self._cmpl_cmd_names is not None:
return self._cmpl_cmd_names
parent_cmd_names = (
self.parent_group.get_complete_command_names() if self.parent_group else []
)
@@ -47,6 +53,7 @@ class CommandGroupFilter(HandlerFilter):
for parent_cmd_name in parent_cmd_names:
for candidate in candidates:
result.append(parent_cmd_name + " " + candidate)
self._cmpl_cmd_names = result
return result
# 以树的形式打印出来
@@ -97,6 +104,12 @@ class CommandGroupFilter(HandlerFilter):
return False
return True
def startswith(self, message_str: str) -> bool:
return message_str.startswith(tuple(self.get_complete_command_names()))
def equals(self, message_str: str) -> bool:
return message_str in self.get_complete_command_names()
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_at_or_wake_command:
return False
@@ -105,8 +118,7 @@ class CommandGroupFilter(HandlerFilter):
if not self.custom_filter_ok(event, cfg):
return False
complete_command_names = self.get_complete_command_names()
if event.message_str.strip() in complete_command_names:
if self.equals(event.message_str.strip()):
tree = (
self.group_name
+ "\n"
@@ -116,6 +128,4 @@ class CommandGroupFilter(HandlerFilter):
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
)
# complete_command_names = [name + " " for name in complete_command_names]
# return event.message_str.startswith(tuple(complete_command_names))
return False
return self.startswith(event.message_str)
-4
View File
@@ -52,10 +52,6 @@ class SessionServiceManager:
"session_service_config", session_config, scope="umo", scope_id=session_id
)
logger.info(
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
)
@staticmethod
def should_process_llm_request(event: AstrMessageEvent) -> bool:
"""检查是否应该处理LLM请求
+44 -57
View File
@@ -1,9 +1,33 @@
import codecs
import json
from astrbot.core import logger
from aiohttp import ClientSession
from aiohttp import ClientSession, ClientResponse
from typing import Dict, List, Any, AsyncGenerator
async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
decoder = codecs.getincrementaldecoder("utf-8")()
buffer = ""
async for chunk in resp.content.iter_chunked(8192):
buffer += decoder.decode(chunk)
while "\n\n" in buffer:
block, buffer = buffer.split("\n\n", 1)
if block.strip().startswith("data:"):
try:
yield json.loads(block[5:])
except json.JSONDecodeError:
logger.warning(f"Drop invalid dify json data: {block[5:]}")
continue
# flush any remaining text
buffer += decoder.decode(b"", final=True)
if buffer.strip().startswith("data:"):
try:
yield json.loads(buffer[5:])
except json.JSONDecodeError:
logger.warning(f"Drop invalid dify json data: {buffer[5:]}")
pass
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
@@ -33,31 +57,11 @@ class DifyAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制,防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode("utf-8")
blocks = buffer.split("\n\n")
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith("data:"):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
raise Exception(
f"Dify /chat-messages 接口请求失败:{resp.status}. {text}"
)
async for event in _stream_sse(resp):
yield event
async def workflow_run(
self,
@@ -77,31 +81,11 @@ class DifyAPIClient:
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制,防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode("utf-8")
blocks = buffer.split("\n\n")
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith("data:"):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
raise Exception(
f"Dify /workflows/run 接口请求失败:{resp.status}. {text}"
)
async for event in _stream_sse(resp):
yield event
async def file_upload(
self,
@@ -109,12 +93,15 @@ class DifyAPIClient:
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(url, data=payload, headers=self.headers) as resp:
return await resp.json() # {"id": "xxx", ...}
with open(file_path, "rb") as f:
payload = {
"user": user,
"file": f,
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+66 -32
View File
@@ -1,17 +1,27 @@
import uuid
import json
import os
import asyncio
from contextlib import asynccontextmanager
from .route import Route, Response, RouteContext
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from quart import request, Response as QuartResponse, g, make_response
from astrbot.core.db import BaseDatabase
import asyncio
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.platform.astr_message_event import MessageSession
@asynccontextmanager
async def track_conversation(convs: dict, conv_id: str):
convs[conv_id] = True
try:
yield
finally:
convs.pop(conv_id, None)
class ChatRoute(Route):
def __init__(
self,
@@ -40,6 +50,8 @@ class ChatRoute(Route):
self.conv_mgr = core_lifecycle.conversation_manager
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.running_convs: dict[str, bool] = {}
async def get_file(self):
filename = request.args.get("filename")
if not filename:
@@ -139,42 +151,63 @@ class ChatRoute(Route):
)
async def stream():
client_disconnected = False
try:
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=10)
except asyncio.TimeoutError:
continue
async with track_conversation(self.running_convs, webchat_conv_id):
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
except Exception as e:
logger.error(f"WebChat stream error: {e}")
if not result:
continue
if not result:
continue
result_text = result["data"]
type = result.get("type")
streaming = result.get("streaming", False)
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
await asyncio.sleep(0.05)
result_text = result["data"]
type = result.get("type")
streaming = result.get("streaming", False)
if type == "end":
break
elif (
(streaming and type == "complete")
or not streaming
or type == "break"
):
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
try:
if not client_disconnected:
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
except Exception as e:
if not client_disconnected:
logger.debug(
f"[WebChat] 用户 {username} 断开聊天长连接。 {e}"
)
client_disconnected = True
except BaseException as _:
logger.debug(f"用户 {username} 断开聊天长连接。")
return
try:
if not client_disconnected:
await asyncio.sleep(0.05)
except asyncio.CancelledError:
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
client_disconnected = True
if type == "end":
break
elif (
(streaming and type == "complete")
or not streaming
or type == "break"
):
# append bot message
new_his = {"type": "bot", "message": result_text}
await self.platform_history_mgr.insert(
platform_id="webchat",
user_id=webchat_conv_id,
content=new_his,
sender_id="bot",
sender_name="bot",
)
except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
# Put message to conversation-specific queue
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
@@ -291,6 +324,7 @@ class ChatRoute(Route):
.ok(
data={
"history": history_res,
"is_running": self.running_convs.get(webchat_conv_id, False),
}
)
.__dict__
+235 -78
View File
@@ -30,6 +30,7 @@ class SessionManagementRoute(Route):
"/session/update_tts": ("POST", self.update_session_tts),
"/session/update_name": ("POST", self.update_session_name),
"/session/update_status": ("POST", self.update_session_status),
"/session/delete": ("POST", self.delete_session),
}
self.conv_mgr = core_lifecycle.conversation_manager
self.core_lifecycle = core_lifecycle
@@ -180,60 +181,132 @@ class SessionManagementRoute(Route):
logger.error(error_msg)
return Response().error(f"获取会话列表失败: {str(e)}").__dict__
async def _update_single_session_persona(self, session_id: str, persona_name: str):
"""更新单个会话的 persona 的内部方法"""
conversation_manager = self.core_lifecycle.star_context.conversation_manager
conversation_id = await conversation_manager.get_curr_conversation_id(
session_id
)
conv = None
if conversation_id:
conv = await conversation_manager.get_conversation(
unified_msg_origin=session_id,
conversation_id=conversation_id,
)
if not conv or not conversation_id:
conversation_id = await conversation_manager.new_conversation(session_id)
# 更新 persona
await conversation_manager.update_conversation_persona_id(
session_id, persona_name
)
async def _handle_batch_operation(
self, session_ids: list, operation_func, operation_name: str, **kwargs
):
"""通用的批量操作处理方法"""
success_count = 0
error_sessions = []
for session_id in session_ids:
try:
await operation_func(session_id, **kwargs)
success_count += 1
except Exception as e:
logger.error(f"批量{operation_name} 会话 {session_id} 失败: {str(e)}")
error_sessions.append(session_id)
if error_sessions:
return (
Response()
.ok(
{
"message": f"批量更新完成,成功: {success_count},失败: {len(error_sessions)}",
"success_count": success_count,
"error_count": len(error_sessions),
"error_sessions": error_sessions,
}
)
.__dict__
)
else:
return (
Response()
.ok(
{
"message": f"成功批量{operation_name} {success_count} 个会话",
"success_count": success_count,
}
)
.__dict__
)
async def update_session_persona(self):
"""更新指定会话的 persona"""
"""更新指定会话的 persona,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
persona_name = data.get("persona_name")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
if persona_name is None:
return Response().error("缺少必要参数: persona_name").__dict__
# 获取会话当前的对话 ID
conversation_manager = self.core_lifecycle.star_context.conversation_manager
conversation_id = await conversation_manager.get_curr_conversation_id(
session_id
)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
if not conversation_id:
# 如果没有对话,创建一个新的对话
conversation_id = await conversation_manager.new_conversation(
session_id
return await self._handle_batch_operation(
session_ids,
self._update_single_session_persona,
"更新人格",
persona_name=persona_name,
)
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
# 更新 persona
await conversation_manager.update_conversation_persona_id(
session_id, persona_name
)
return (
Response()
.ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"})
.__dict__
)
await self._update_single_session_persona(session_id, persona_name)
return (
Response()
.ok(
{
"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"
}
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"更新会话人格失败: {str(e)}").__dict__
async def _update_single_session_provider(
self, session_id: str, provider_id: str, provider_type_enum
):
"""更新单个会话的 provider 的内部方法"""
provider_manager = self.core_lifecycle.star_context.provider_manager
await provider_manager.set_provider(
provider_id=provider_id,
provider_type=provider_type_enum,
umo=session_id,
)
async def update_session_provider(self):
"""更新指定会话的 provider"""
"""更新指定会话的 provider,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
provider_id = data.get("provider_id")
# "chat_completion", "speech_to_text", "text_to_speech"
provider_type = data.get("provider_type")
if not session_id or not provider_id or not provider_type:
if not provider_id or not provider_type:
return (
Response()
.error("缺少必要参数: session_id, provider_id, provider_type")
.error("缺少必要参数: provider_id, provider_type")
.__dict__
)
@@ -251,23 +324,35 @@ class SessionManagementRoute(Route):
.__dict__
)
# 设置 provider
provider_manager = self.core_lifecycle.star_context.provider_manager
await provider_manager.set_provider(
provider_id=provider_id,
provider_type=provider_type_enum,
umo=session_id,
)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
return (
Response()
.ok(
{
"message": f"成功更新会话 {session_id}{provider_type} 提供商为 {provider_id}"
}
return await self._handle_batch_operation(
session_ids,
self._update_single_session_provider,
f"更新 {provider_type} 提供商",
provider_id=provider_id,
provider_type_enum=provider_type_enum,
)
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
await self._update_single_session_provider(
session_id, provider_id, provider_type_enum
)
return (
Response()
.ok(
{
"message": f"成功更新会话 {session_id}{provider_type} 提供商为 {provider_id}"
}
)
.__dict__
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}"
@@ -376,66 +461,98 @@ class SessionManagementRoute(Route):
logger.error(error_msg)
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
async def _update_single_session_llm(self, session_id: str, enabled: bool):
"""更新单个会话的LLM状态的内部方法"""
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
async def update_session_llm(self):
"""更新指定会话的LLM启停状态"""
"""更新指定会话的LLM启停状态,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
enabled = data.get("enabled")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
if enabled is None:
return Response().error("缺少必要参数: enabled").__dict__
# 使用 SessionServiceManager 更新LLM状态
SessionServiceManager.set_llm_status_for_session(session_id, enabled)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
return (
Response()
.ok(
{
"message": f"LLM已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"llm_enabled": enabled,
}
result = await self._handle_batch_operation(
session_ids,
self._update_single_session_llm,
f"{'启用' if enabled else '禁用'}LLM",
enabled=enabled,
)
return result
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
await self._update_single_session_llm(session_id, enabled)
return (
Response()
.ok(
{
"message": f"LLM已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"llm_enabled": enabled,
}
)
.__dict__
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
async def _update_single_session_tts(self, session_id: str, enabled: bool):
"""更新单个会话的TTS状态的内部方法"""
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
async def update_session_tts(self):
"""更新指定会话的TTS启停状态"""
"""更新指定会话的TTS启停状态,支持批量操作"""
try:
data = await request.get_json()
session_id = data.get("session_id")
is_batch = data.get("is_batch", False)
enabled = data.get("enabled")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
if enabled is None:
return Response().error("缺少必要参数: enabled").__dict__
# 使用 SessionServiceManager 更新TTS状态
SessionServiceManager.set_tts_status_for_session(session_id, enabled)
if is_batch:
session_ids = data.get("session_ids", [])
if not session_ids:
return Response().error("缺少必要参数: session_ids").__dict__
return (
Response()
.ok(
{
"message": f"TTS已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"tts_enabled": enabled,
}
result = await self._handle_batch_operation(
session_ids,
self._update_single_session_tts,
f"{'启用' if enabled else '禁用'}TTS",
enabled=enabled,
)
return result
else:
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
await self._update_single_session_tts(session_id, enabled)
return (
Response()
.ok(
{
"message": f"TTS已{'启用' if enabled else '禁用'}",
"session_id": session_id,
"tts_enabled": enabled,
}
)
.__dict__
)
.__dict__
)
except Exception as e:
error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}"
@@ -507,3 +624,43 @@ class SessionManagementRoute(Route):
error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__
async def delete_session(self):
"""删除指定会话及其所有相关数据"""
try:
data = await request.get_json()
session_id = data.get("session_id")
if not session_id:
return Response().error("缺少必要参数: session_id").__dict__
# 删除会话的所有相关数据
conversation_manager = self.core_lifecycle.conversation_manager
# 1. 删除会话的所有对话
try:
await conversation_manager.delete_conversations_by_user_id(session_id)
except Exception as e:
logger.warning(f"删除会话 {session_id} 的对话失败: {str(e)}")
# 2. 清除会话的偏好设置数据(清空该会话的所有配置)
try:
await sp.clear_async("umo", session_id)
except Exception as e:
logger.warning(f"清除会话 {session_id} 的偏好设置失败: {str(e)}")
return (
Response()
.ok(
{
"message": f"会话 {session_id} 及其相关所有对话数据已成功删除",
"session_id": session_id,
}
)
.__dict__
)
except Exception as e:
error_msg = f"删除会话失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"删除会话失败: {str(e)}").__dict__
+1
View File
@@ -0,0 +1 @@
# What's Changed
+1
View File
@@ -0,0 +1 @@
# What's Changed
+22 -1
View File
@@ -1,7 +1,28 @@
<template>
<RouterView></RouterView>
<!-- 全局唯一 snackbar -->
<v-snackbar v-if="toastStore.current" v-model="snackbarShow" :color="toastStore.current.color"
:timeout="toastStore.current.timeout" :multi-line="toastStore.current.multiLine"
:location="toastStore.current.location" close-on-back>
{{ toastStore.current.message }}
<template #actions v-if="toastStore.current.closable">
<v-btn variant="text" @click="snackbarShow = false">关闭</v-btn>
</template>
</v-snackbar>
</template>
<script setup lang="ts">
<script setup>
import { RouterView } from 'vue-router';
import { computed } from 'vue'
import { useToastStore } from '@/stores/toast'
const toastStore = useToastStore()
const snackbarShow = computed({
get: () => !!toastStore.current,
set: (val) => {
if (!val) toastStore.shift()
}
})
</script>
+42 -13
View File
@@ -9,7 +9,8 @@
<div style="display: flex; align-items: center; justify-content: center; padding: 16px; padding-bottom: 0px;"
v-if="chatboxMode">
<img width="50" src="@/assets/images/astrbot_logo_mini.webp" alt="AstrBot Logo">
<span v-if="!sidebarCollapsed" style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
<span v-if="!sidebarCollapsed"
style="font-weight: 1000; font-size: 26px; margin-left: 8px;">AstrBot</span>
</div>
@@ -46,7 +47,7 @@
|| tm('conversation.newConversation') }}</v-list-item-title>
<v-list-item-subtitle v-if="!sidebarCollapsed" class="timestamp">{{
formatDate(item.updated_at)
}}</v-list-item-subtitle>
}}</v-list-item-subtitle>
<template v-if="!sidebarCollapsed" v-slot:append>
<div class="conversation-actions">
@@ -118,8 +119,9 @@
</div>
<v-divider v-if="currCid && getCurrentConversation" class="conversation-divider"></v-divider>
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark" :isStreaming="isStreaming"
@openImagePreview="openImagePreview" ref="messageList" />
<MessageList v-if="messages && messages.length > 0" :messages="messages" :isDark="isDark"
:isStreaming="isStreaming || isConvRunning" @openImagePreview="openImagePreview"
ref="messageList" />
<div class="welcome-container fade-in" v-else>
<div class="welcome-title">
<span>Hello, I'm</span>
@@ -145,9 +147,10 @@
<!-- 输入区域 -->
<div class="input-area fade-in">
<div
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px; padding: 4px;">
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown" :disabled="isStreaming"
@click:clear="clearMessage" placeholder="Ask AstrBot..."
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px;">
<textarea id="input-field" v-model="prompt" @keydown="handleInputKeyDown"
:disabled="isStreaming || isConvRunning" @click:clear="clearMessage"
placeholder="Ask AstrBot..."
style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);"></textarea>
<div
style="display: flex; justify-content: space-between; align-items: center; padding: 0px 8px;">
@@ -155,18 +158,21 @@
<!-- 选择提供商和模型 -->
<ProviderModelSelector ref="providerModelSelector" />
</div>
<div style="display: flex; justify-content: flex-end; margin-top: 8px;">
<div
style="display: flex; justify-content: flex-end; margin-top: 8px; align-items: center;">
<input type="file" ref="imageInput" @change="handleFileSelect" accept="image/*"
style="display: none" multiple />
<v-progress-circular v-if="isStreaming || isConvRunning" indeterminate size="16"
class="mr-1" width="1.5" />
<v-btn @click="triggerImageInput" icon="mdi-plus" variant="text" color="deep-purple"
class="add-btn" size="small" />
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
class="send-btn" size="small" />
<v-btn @click="isRecording ? stopRecording() : startRecording()"
:icon="isRecording ? 'mdi-stop-circle' : 'mdi-microphone'" variant="text"
:color="isRecording ? 'error' : 'deep-purple'" class="record-btn"
size="small" />
<v-btn @click="sendMessage" icon="mdi-send" variant="text" color="deep-purple"
:disabled="!prompt && stagedImagesName.length === 0 && !stagedAudioUrl"
class="send-btn" size="small" />
</div>
</div>
@@ -235,6 +241,7 @@ import LanguageSwitcher from '@/components/shared/LanguageSwitcher.vue';
import ProviderModelSelector from '@/components/chat/ProviderModelSelector.vue';
import MessageList from '@/components/chat/MessageList.vue';
import 'highlight.js/styles/github.css';
import { useToast } from '@/utils/toast';
export default {
name: 'ChatPage',
@@ -301,7 +308,10 @@ export default {
imagePreviewDialog: false,
previewImageUrl: '',
isStreaming: false
isStreaming: false,
isConvRunning: false, // Track if the current conversation is running
isToastedRunningInfo: false, // To avoid multiple toasts
}
},
@@ -379,7 +389,7 @@ export default {
} else {
this.sidebarCollapsed = true; // 默认折叠状态
}
// 设置输入框标签
this.inputFieldLabel = this.tm('input.chatPrompt');
this.getConversations();
@@ -662,6 +672,25 @@ export default {
// Update the selected conversation in the sidebar
this.selectedConversations = [cid[0]];
let history = response.data.data.history;
this.isConvRunning = response.data.data.is_running || false;
if (this.isConvRunning) {
if (!this.isToastedRunningInfo) {
useToast().info("该对话正在运行中。", { timeout: 5000 });
this.isToastedRunningInfo = true;
}
// 如果对话还在运行,3秒后重新获取消息
setTimeout(() => {
this.getConversationMessages([this.currCid]);
}, 3000);
}
// 滚动到底部
this.$nextTick(() => {
this.$refs.messageList.scrollToBottom();
});
for (let i = 0; i < history.length; i++) {
let content = history[i].content;
if (content.message.startsWith('[IMAGE]')) {
@@ -29,12 +29,11 @@
<!-- Bot Messages -->
<div v-else class="bot-message">
<div v-if="isStreaming && index === messages.length - 1" style="width: 36px; height: 36px;">
<v-progress-circular indeterminate size="28" width="2"
style="margin-top: 12px;"></v-progress-circular>
</div>
<v-avatar v-else class="bot-avatar" size="36">
<span class="text-h2"></span>
<v-avatar class="bot-avatar" size="36">
<v-progress-circular :index="index" v-if="isStreaming && index === messages.length - 1" indeterminate size="28"
width="2"></v-progress-circular>
<span v-else-if="messages[index - 1]?.content.type !== 'bot'" class="text-h2"></span>
</v-avatar>
<div class="bot-message-content">
<div class="message-bubble bot-bubble">
@@ -31,6 +31,10 @@
</v-col>
</v-row>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="closeDialog">{{ tm('dialog.cancel') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
@@ -63,9 +63,7 @@
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn text @click="closeDialog">
Close
</v-btn>
<v-btn text @click="closeDialog">{{ tm('dialogs.config.cancel') }}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -7,7 +7,8 @@
"apply": "Apply Batch Settings",
"editName": "Edit Session Name",
"save": "Save",
"cancel": "Cancel"
"cancel": "Cancel",
"delete": "Delete"
},
"sessions": {
"activeSessions": "Active Sessions",
@@ -29,7 +30,8 @@
"ttsProvider": "TTS Provider",
"llmStatus": "LLM Status",
"ttsStatus": "TTS Status",
"pluginManagement": "Plugin Management"
"pluginManagement": "Plugin Management",
"actions": "Actions"
}
},
"status": {
@@ -65,6 +67,10 @@
"fullSessionId": "Full Session ID",
"hint": "Custom names help you easily identify sessions. The small information icon (!) will show the actual UMO when hovering."
},
"deleteConfirm": {
"message": "Are you sure you want to delete session {sessionName}?",
"warning": "This action will permanently delete all chat history and preference settings for this session (except for data linked via plugins), and this cannot be undone. Continue?"
},
"messages": {
"refreshSuccess": "Session list refreshed",
"personaUpdateSuccess": "Persona updated successfully",
@@ -82,6 +88,8 @@
"pluginStatusSuccess": "Plugin {name} {status}",
"pluginStatusError": "Failed to update plugin status",
"nameUpdateSuccess": "Session name updated successfully",
"nameUpdateError": "Failed to update session name"
"nameUpdateError": "Failed to update session name",
"deleteSuccess": "Session deleted successfully",
"deleteError": "Failed to delete session"
}
}
@@ -7,7 +7,8 @@
"apply": "应用批量设置",
"editName": "备注",
"save": "保存",
"cancel": "取消"
"cancel": "取消",
"delete": "删除"
},
"sessions": {
"activeSessions": "活跃会话",
@@ -29,7 +30,8 @@
"ttsProvider": "语音合成模型",
"llmStatus": "启用 LLM",
"ttsStatus": "启用 TTS",
"pluginManagement": "插件管理"
"pluginManagement": "插件管理",
"actions": "操作"
}
},
"status": {
@@ -65,6 +67,10 @@
"fullSessionId": "完整会话ID",
"hint": "自定义名称帮助您轻松识别会话。当设置了自定义名称时,会显示一个小感叹号标识(!),鼠标悬停时会显示实际的UMO。"
},
"deleteConfirm": {
"message": "确定要删除会话 {sessionName} 吗?",
"warning": "此操作将永久删除本次会话的「全部对话记录」与「偏好设置」(插件对会话的关联数据除外),且无法恢复。确认继续?"
},
"messages": {
"refreshSuccess": "会话列表已刷新",
"personaUpdateSuccess": "人格更新成功",
@@ -82,6 +88,8 @@
"pluginStatusSuccess": "插件 {name} {status}",
"pluginStatusError": "插件状态更新失败",
"nameUpdateSuccess": "会话名称更新成功",
"nameUpdateError": "会话名称更新失败"
"nameUpdateError": "会话名称更新失败",
"deleteSuccess": "会话删除成功",
"deleteError": "会话删除失败"
}
}
+31
View File
@@ -0,0 +1,31 @@
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
export const useToastStore = defineStore('toast', () => {
const queue = ref([])
const current = computed(() => queue.value[0])
function add({
message,
color = 'info', // Vuetify 颜色
timeout = 3000,
closable = true,
multiLine = false,
location = 'top center'
}) {
queue.value.push({
message,
color,
timeout,
closable,
multiLine,
location
})
}
function shift() {
queue.value.shift()
}
return { current, add, shift }
})
+1
View File
@@ -22,6 +22,7 @@ export function getProviderIcon(type) {
'moonshot': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/kimi.svg',
'ppio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/ppio.svg',
'dify': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/dify-color.svg',
"coze": "https://registry.npmmirror.com/@lobehub/icons-static-svg/1.66.0/files/icons/coze.svg",
'dashscope': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/alibabacloud-color.svg',
'fastgpt': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/fastgpt-color.svg',
'lm_studio': 'https://registry.npmmirror.com/@lobehub/icons-static-svg/latest/files/icons/lmstudio.svg',
+16
View File
@@ -0,0 +1,16 @@
import { useToastStore } from '@/stores/toast'
export function useToast() {
const store = useToastStore()
const toast = (message, color = 'info', opts = {}) =>
store.add({ message, color, ...opts })
return {
toast,
success: (msg, opts) => toast(msg, 'success', opts),
error: (msg, opts) => toast(msg, 'error', opts),
info: (msg, opts) => toast(msg, 'primary', opts),
warning: (msg, opts) => toast(msg, 'warning', opts)
}
}
+1 -5
View File
@@ -601,19 +601,15 @@ export default {
}
if (this.search) {
params.search = this.search;
params.search = this.search.trim();
}
// 添加排除条件
params.exclude_ids = 'astrbot';
params.exclude_platforms = 'webchat';
console.log(`正在请求对话列表: /api/conversation/list 参数:`, params);
const response = await axios.get('/api/conversation/list', { params });
console.log('收到对话列表响应:', response.data);
this.lastAppliedFilters = { ...this.currentFilters }; // 记录已应用的筛选条件
if (response.data.status === "ok") {
+1
View File
@@ -303,6 +303,7 @@ export default {
"googlegenai_chat_completion": "chat_completion",
"zhipu_chat_completion": "chat_completion",
"dify": "chat_completion",
"coze": "chat_completion",
"dashscope": "chat_completion",
"openai_whisper_api": "speech_to_text",
"openai_whisper_selfhost": "speech_to_text",
+268 -129
View File
@@ -141,6 +141,17 @@
</v-btn>
</template>
<!-- 操作按钮 -->
<template v-slot:item.actions="{ item }">
<v-btn size="x-small" variant="tonal" color="error" @click="deleteSession(item)"
:loading="item.deleting" icon>
<v-icon>mdi-delete</v-icon>
<v-tooltip activator="parent" location="top">
{{ tm('buttons.delete') }}
</v-tooltip>
</v-btn>
</template>
<!-- 空状态 -->
<template v-slot:no-data>
<div class="text-center py-8">
@@ -409,6 +420,7 @@ export default {
{ title: this.tm('table.headers.llmStatus'), key: 'llm_enabled', sortable: false, minWidth: '120px' },
{ title: this.tm('table.headers.ttsStatus'), key: 'tts_enabled', sortable: false, minWidth: '120px' },
{ title: this.tm('table.headers.pluginManagement'), key: 'plugins', sortable: false, minWidth: '120px' },
{ title: this.tm('table.headers.actions'), key: 'actions', sortable: false, minWidth: '100px' },
]
},
@@ -418,12 +430,13 @@ export default {
//
if (this.searchQuery) {
const query = this.searchQuery.toLowerCase();
const query = this.searchQuery.toLowerCase().trim();
filtered = filtered.filter(session =>
session.session_name.toLowerCase().includes(query) ||
session.platform.toLowerCase().includes(query) ||
session.persona_name?.toLowerCase().includes(query) ||
session.chat_provider_name?.toLowerCase().includes(query)
session.chat_provider_name?.toLowerCase().includes(query) ||
session.session_id.toLowerCase().includes(query)
);
}
@@ -487,7 +500,8 @@ export default {
this.sessions = data.sessions.map(session => ({
...session,
updating: false, //
loadingPlugins: false //
loadingPlugins: false, //
deleting: false //
}));
this.availablePersonas = data.available_personas;
this.availableChatProviders = data.available_chat_providers;
@@ -508,60 +522,131 @@ export default {
},
async updatePersona(session, personaName) {
session.updating = true;
try {
const response = await axios.post('/api/session/update_persona', {
session_id: session.session_id,
persona_name: personaName
});
if (response.data.status === 'ok') {
session.persona_id = personaName;
session.persona_name = personaName === '[%None]' ? this.tm('persona.none') :
return this._updateSession('persona', session, { persona_name: personaName }, (s, success) => {
if (success) {
s.persona_id = personaName;
s.persona_name = personaName === '[%None]' ? this.tm('persona.none') :
this.availablePersonas.find(p => p.name === personaName)?.name || personaName;
this.showSuccess(this.tm('messages.personaUpdateSuccess'));
} else {
this.showError(response.data.message || this.tm('messages.personaUpdateError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.personaUpdateError'));
}
session.updating = false;
});
},
async updateProvider(session, providerId, providerType) {
session.updating = true;
try {
const response = await axios.post('/api/session/update_provider', {
session_id: session.session_id,
provider_id: providerId,
provider_type: providerType
});
if (response.data.status === 'ok') {
//
return this._updateSession('provider', session, {
provider_id: providerId,
provider_type: providerType
}, (s, success) => {
if (success) {
if (providerType === 'chat_completion') {
session.chat_provider_id = providerId;
s.chat_provider_id = providerId;
const provider = this.availableChatProviders.find(p => p.id === providerId);
session.chat_provider_name = provider?.name || providerId;
s.chat_provider_name = provider?.name || providerId;
} else if (providerType === 'speech_to_text') {
session.stt_provider_id = providerId;
s.stt_provider_id = providerId;
const provider = this.availableSttProviders.find(p => p.id === providerId);
session.stt_provider_name = provider?.name || providerId;
s.stt_provider_name = provider?.name || providerId;
} else if (providerType === 'text_to_speech') {
session.tts_provider_id = providerId;
s.tts_provider_id = providerId;
const provider = this.availableTtsProviders.find(p => p.id === providerId);
session.tts_provider_name = provider?.name || providerId;
s.tts_provider_name = provider?.name || providerId;
}
this.showSuccess(this.tm('messages.providerUpdateSuccess'));
} else {
this.showError(response.data.message || this.tm('messages.providerUpdateError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.providerUpdateError'));
} session.updating = false;
});
},
async updateLLM(session, enabled) {
return this._updateSession('llm', session, { enabled }, (s, success) => {
if (success) s.llm_enabled = enabled;
});
},
async updateTTS(session, enabled) {
return this._updateSession('tts', session, { enabled }, (s, success) => {
if (success) s.tts_enabled = enabled;
});
},
//
async _updateSession(type, sessionOrSessions, params, updateLocalData) {
const isBatch = Array.isArray(sessionOrSessions);
if (!isBatch) {
//
const session = sessionOrSessions;
session.updating = true;
try {
const payload = {
is_batch: false,
session_id: session.session_id,
...params
};
const response = await axios.post(`/api/session/update_${type}`, payload);
if (response.data.status === 'ok') {
updateLocalData(session, true);
this.showSuccess(this.tm(`messages.${type}UpdateSuccess`));
return { success: true };
} else {
this.showError(response.data.message || this.tm(`messages.${type}UpdateError`));
return { success: false, error: response.data.message };
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm(`messages.${type}UpdateError`));
return { success: false, error: error.message };
} finally {
session.updating = false;
}
} else {
//
const sessions = sessionOrSessions;
const sessionIds = sessions.map(s => s.session_id);
try {
const payload = {
is_batch: true,
session_ids: sessionIds,
...params
};
const response = await axios.post(`/api/session/update_${type}`, payload);
if (response.data.status === 'ok') {
const data = response.data.data;
//
sessions.forEach(session => {
const wasSuccessful = !data.error_sessions || !data.error_sessions.includes(session.session_id);
updateLocalData(session, wasSuccessful);
});
return {
success: true,
successCount: data.success_count || 0,
errorCount: data.error_count || 0,
errorSessions: data.error_sessions || []
};
} else {
return {
success: false,
error: response.data.message,
errorCount: sessionIds.length,
successCount: 0
};
}
} catch (error) {
return {
success: false,
error: error.response?.data?.message || error.message,
errorCount: sessionIds.length,
successCount: 0
};
}
}
},
//
async updateSessionStatus(session, enabled) {
session.updating = true;
try {
@@ -572,47 +657,9 @@ export default {
if (response.data.status === 'ok') {
session.session_enabled = enabled;
this.showSuccess(this.tm('messages.sessionStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
} else {
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.statusUpdateError'));
}
session.updating = false;
},
async updateLLM(session, enabled) {
session.updating = true;
try {
const response = await axios.post('/api/session/update_llm', {
session_id: session.session_id,
enabled: enabled
});
if (response.data.status === 'ok') {
session.llm_enabled = enabled;
this.showSuccess(this.tm('messages.llmStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
} else {
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.statusUpdateError'));
}
session.updating = false;
},
async updateTTS(session, enabled) {
session.updating = true;
try {
const response = await axios.post('/api/session/update_tts', {
session_id: session.session_id,
enabled: enabled
});
if (response.data.status === 'ok') {
session.tts_enabled = enabled;
this.showSuccess(this.tm('messages.ttsStatusSuccess', { status: enabled ? this.tm('status.enabled') : this.tm('status.disabled') }));
this.showSuccess(this.tm('messages.sessionStatusSuccess', {
status: enabled ? this.tm('status.enabled') : this.tm('status.disabled')
}));
} else {
this.showError(response.data.message || this.tm('messages.statusUpdateError'));
}
@@ -628,60 +675,120 @@ export default {
}
this.batchUpdating = true;
let successCount = 0;
let errorCount = 0;
let totalSuccessCount = 0;
let totalErrorCount = 0;
let allErrorSessions = [];
// 使
for (const session of this.filteredSessions) {
try {
//
if (this.batchPersona) {
await this.updatePersona(session, this.batchPersona);
successCount++;
}
const sessions = this.filteredSessions;
// Chat Provider
if (this.batchChatProvider) {
await this.updateProvider(session, this.batchChatProvider, 'chat_completion');
successCount++;
}
try {
//
const batchTasks = [];
// STT Provider
if (this.batchSttProvider) {
await this.updateProvider(session, this.batchSttProvider, 'speech_to_text');
successCount++;
}
// TTS Provider
if (this.batchTtsProvider) {
await this.updateProvider(session, this.batchTtsProvider, 'text_to_speech');
successCount++;
}
// LLM
if (this.batchLlmStatus !== null) {
await this.updateLLM(session, this.batchLlmStatus);
successCount++;
}
// TTS
if (this.batchTtsStatus !== null) {
await this.updateTTS(session, this.batchTtsStatus);
successCount++;
}
} catch (error) {
errorCount++;
if (this.batchPersona) {
batchTasks.push({
type: 'persona',
params: { persona_name: this.batchPersona }
});
}
if (this.batchChatProvider) {
batchTasks.push({
type: 'provider',
params: { provider_id: this.batchChatProvider, provider_type: 'chat_completion' }
});
}
if (this.batchSttProvider) {
batchTasks.push({
type: 'provider',
params: { provider_id: this.batchSttProvider, provider_type: 'speech_to_text' }
});
}
if (this.batchTtsProvider) {
batchTasks.push({
type: 'provider',
params: { provider_id: this.batchTtsProvider, provider_type: 'text_to_speech' }
});
}
if (this.batchLlmStatus !== null) {
batchTasks.push({
type: 'llm',
params: { enabled: this.batchLlmStatus }
});
}
if (this.batchTtsStatus !== null) {
batchTasks.push({
type: 'tts',
params: { enabled: this.batchTtsStatus }
});
}
//
for (const task of batchTasks) {
let updateLocalData;
//
switch (task.type) {
case 'persona':
updateLocalData = (s, success) => {
if (success) s.persona_id = task.params.persona_name;
};
break;
case 'provider':
updateLocalData = (s, success) => {
if (!success) return;
const { provider_id, provider_type } = task.params;
if (provider_type === 'chat_completion') {
s.chat_provider_id = provider_id;
} else if (provider_type === 'speech_to_text') {
s.stt_provider_id = provider_id;
} else if (provider_type === 'text_to_speech') {
s.tts_provider_id = provider_id;
}
};
break;
case 'llm':
updateLocalData = (s, success) => {
if (success) s.llm_enabled = task.params.enabled;
};
break;
case 'tts':
updateLocalData = (s, success) => {
if (success) s.tts_enabled = task.params.enabled;
};
break;
}
const result = await this._updateSession(task.type, sessions, task.params, updateLocalData);
totalSuccessCount += result.successCount || 0;
totalErrorCount += result.errorCount || 0;
if (result.errorSessions) {
allErrorSessions.push(...result.errorSessions);
}
}
//
if (totalErrorCount === 0) {
this.showSuccess(this.tm('messages.batchUpdateSuccess', { count: totalSuccessCount }));
} else {
const uniqueErrorSessions = [...new Set(allErrorSessions)];
this.showError(this.tm('messages.batchUpdatePartial', {
success: totalSuccessCount,
error: uniqueErrorSessions.length
}));
}
} catch (error) {
this.showError(this.tm('messages.batchUpdateError'));
}
this.batchUpdating = false;
if (errorCount === 0) {
this.showSuccess(this.tm('messages.batchUpdateSuccess', { count: successCount }));
} else {
this.showError(this.tm('messages.batchUpdatePartial', { success: successCount, error: errorCount }));
}
//
this.batchPersona = null;
this.batchChatProvider = null;
@@ -797,6 +904,38 @@ export default {
this.snackbarColor = 'error';
this.snackbar = true;
},
async deleteSession(session) {
const confirmMessage = this.tm('deleteConfirm.message', {
sessionName: session.session_name || session.session_id
}) + '\n\n' + this.tm('deleteConfirm.warning');
if (!confirm(confirmMessage)) {
return;
}
session.deleting = true;
try {
const response = await axios.post('/api/session/delete', {
session_id: session.session_id
});
if (response.data.status === 'ok') {
this.showSuccess(response.data.data.message || this.tm('messages.deleteSuccess'));
//
const index = this.sessions.findIndex(s => s.session_id === session.session_id);
if (index > -1) {
this.sessions.splice(index, 1);
}
} else {
this.showError(response.data.message || this.tm('messages.deleteError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.deleteError'));
}
session.deleting = false;
},
},
}
</script>
+21 -21
View File
@@ -527,12 +527,11 @@ UID: {user_id} 此 ID 可用于设置管理员。
return
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
if provider and provider.meta().type in ["dify", "coze"]:
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message(
"已重置当前 Dify 会话,新聊天将更换到新的会话。"
"已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。"
)
)
return
@@ -755,8 +754,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
创建新对话
"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
if provider and provider.meta().type in ["dify", "coze"]:
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message("成功,下次聊天将是新对话。")
@@ -783,8 +781,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
"""创建新群聊对话"""
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
if provider and provider.meta().type in ["dify", "coze"]:
await provider.forget(message.unified_msg_origin)
message.set_result(
MessageEventResult().message("成功,下次聊天将是新对话。")
@@ -823,7 +820,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
provider = self.context.get_using_provider(message.unified_msg_origin)
if provider and provider.meta().type == "dify":
assert isinstance(provider, ProviderDify)
data = await provider.api_client.get_chat_convs(message.unified_msg_origin)
if not data["data"]:
message.set_result(MessageEventResult().message("未找到任何对话。"))
@@ -1348,22 +1344,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
logger.error(f"ltm: {e}")
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd")
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent):
# token = event.message_str.split(" ")
token = self.parse_commands(event.message_str)
if token.len < 2:
if token.len < 3:
yield event.plain_result(
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
"该指令用于设置指令或指令组的权限。\n"
"格式: /alter_cmd <cmd_name> <admin/member>\n"
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
"/alter_cmd reset config 打开 reset 权限配置"
)
return
cmd_name = token.get(1)
cmd_type = token.get(2)
cmd_name = " ".join(token.tokens[1:-1])
cmd_type = token.get(-1)
# ============================
# 对reset权限进行特殊处理
# ============================
if cmd_name == "reset" and cmd_type == "config":
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get("astrbot", {})
@@ -1413,16 +1409,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
# 查找指令
found_command = None
cmd_group = False
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.command_name == cmd_name:
if filter_.equals(cmd_name):
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if cmd_name == filter_.group_name:
if filter_.equals(cmd_name):
found_command = handler
cmd_group = True
break
if not found_command:
@@ -1459,8 +1457,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
else filter.PermissionType.MEMBER
),
)
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
cmd_group_str = "指令组" if cmd_group else "指令"
yield event.plain_result(
f"已将「{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}"
)
async def update_reset_permission(self, scene_key: str, perm_type: str):
"""更新reset命令在特定场景下的权限设置
-2
View File
@@ -375,5 +375,3 @@ class Main(star.Star):
tool_set.add_tool(tavily_extract_web_page)
tool_set.remove_tool("web_search")
tool_set.remove_tool("fetch_url")
print(req.func_tool)
+1 -1
View File
@@ -1,6 +1,6 @@
[project]
name = "AstrBot"
version = "4.1.7"
version = "4.2.1"
description = "易上手的多平台 LLM 聊天机器人及开发框架"
readme = "README.md"
requires-python = ">=3.10"
+79 -43
View File
@@ -1,5 +1,7 @@
import pytest
import pytest_asyncio
import os
import asyncio
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
@@ -9,36 +11,46 @@ from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
@pytest_asyncio.fixture(scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
db = SQLiteDatabase(str(tmp_db_path))
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
await core_lifecycle.initialize()
return core_lifecycle
@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
def app(core_lifecycle_td: AstrBotCoreLifecycle):
"""Creates a Quart app instance for testing."""
shutdown_event = asyncio.Event()
# The db instance is already part of the core_lifecycle_td
server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest_asyncio.fixture(scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login",
json={
"username": core_lifecycle_td.astrbot_config["dashboard"]["username"],
"password": core_lifecycle_td.astrbot_config["dashboard"]["password"],
},
)
data = await response.get_json()
assert data["status"] == "ok"
token = data["data"]["token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest.mark.asyncio
async def test_auth_login(
app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict
):
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Tests the login functionality with both wrong and correct credentials."""
test_client = app.test_client()
response = await test_client.post(
"/api/auth/login", json={"username": "wrong", "password": "password"}
@@ -55,31 +67,32 @@ async def test_auth_login(
)
data = await response.get_json()
assert data["status"] == "ok" and "token" in data["data"]
header["Authorization"] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
async def test_get_stat(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/stat/get")
assert response.status_code == 401
response = await test_client.get("/api/stat/get", headers=header)
response = await test_client.get("/api/stat/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok" and "platform" in data["data"]
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
async def test_plugins(app: Quart, authenticated_header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get("/api/plugin/get", headers=header)
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
# 插件市场
response = await test_client.get("/api/plugin/market_list", headers=header)
response = await test_client.get(
"/api/plugin/market_list", headers=authenticated_header
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
@@ -88,7 +101,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/install",
json={"url": "https://github.com/Soulter/astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -102,7 +115,9 @@ async def test_plugins(app: Quart, header: dict):
# 插件更新
response = await test_client.post(
"/api/plugin/update", json={"name": "astrbot_plugin_essential"}, headers=header
"/api/plugin/update",
json={"name": "astrbot_plugin_essential"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -112,7 +127,7 @@ async def test_plugins(app: Quart, header: dict):
response = await test_client.post(
"/api/plugin/uninstall",
json={"name": "astrbot_plugin_essential"},
headers=header,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
@@ -132,9 +147,9 @@ async def test_plugins(app: Quart, header: dict):
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
async def test_check_update(app: Quart, authenticated_header: dict):
test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=header)
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"
@@ -142,24 +157,45 @@ async def test_check_update(app: Quart, header: dict):
@pytest.mark.asyncio
async def test_do_update(
app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
tmp_path_factory,
):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "latest"}
# Use a temporary path for the mock update to avoid side effects
temp_release_dir = tmp_path_factory.mktemp("release")
release_path = temp_release_dir / "astrbot"
async def mock_update(*args, **kwargs):
"""Mocks the update process by creating a directory in the temp path."""
os.makedirs(release_path, exist_ok=True)
return
async def mock_download_dashboard(*args, **kwargs):
"""Mocks the dashboard download to prevent network access."""
return
async def mock_pip_install(*args, **kwargs):
"""Mocks pip install to prevent actual installation."""
return
monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard
)
monkeypatch.setattr(
"astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error" # 已经是最新版本
response = await test_client.post(
"/api/update/do", headers=header, json={"version": "v3.4.0", "reboot": False}
"/api/update/do",
headers=authenticated_header,
json={"version": "v3.4.0", "reboot": False},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert os.path.exists("data/astrbot_release/astrbot")
assert os.path.exists(release_path)
+51 -18
View File
@@ -1,5 +1,9 @@
import os
import sys
# 将项目根目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import pytest
from unittest import mock
from main import check_env, check_dashboard_files
@@ -27,29 +31,58 @@ def test_check_env(monkeypatch):
@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
async def test_check_dashboard_files_not_exists(monkeypatch):
"""Tests dashboard download when files do not exist."""
monkeypatch.setattr(os.path, "exists", lambda x: False)
async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
mock_download.assert_called_once()
async def read(self):
return b"content"
return MockResponse()
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_and_version_match(monkeypatch):
"""Tests that dashboard is not downloaded when it exists and version matches."""
# Mock os.path.exists to return True
monkeypatch.setattr(os.path, "exists", lambda x: True)
with mock.patch("aiohttp.ClientSession.get", new=mock_get):
with mock.patch("builtins.open", mock.mock_open()) as mock_file:
with mock.patch("zipfile.ZipFile.extractall") as mock_extractall:
# Mock get_dashboard_version to return the current version
with mock.patch("main.get_dashboard_version") as mock_get_version:
# We need to import VERSION from main's context
from main import VERSION
async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()
mock_get_version.return_value = f"v{VERSION}"
async def mock_aexit(obj, exc_type, exc, tb):
return
with mock.patch("main.download_dashboard") as mock_download:
await check_dashboard_files()
# Assert that download_dashboard was NOT called
mock_download.assert_not_called()
mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
@pytest.mark.asyncio
async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch):
"""Tests that a warning is logged when dashboard version mismatches."""
monkeypatch.setattr(os.path, "exists", lambda x: True)
with mock.patch("main.get_dashboard_version") as mock_get_version:
mock_get_version.return_value = "v0.0.1" # A different version
with mock.patch("main.logger.warning") as mock_logger_warning:
await check_dashboard_files()
mock_logger_warning.assert_called_once()
call_args, _ = mock_logger_warning.call_args
assert "不符" in call_args[0]
@pytest.mark.asyncio
async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch):
"""Tests that providing a valid webui_dir skips all checks."""
valid_dir = "/tmp/my-custom-webui"
monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir)
with mock.patch("main.download_dashboard") as mock_download:
with mock.patch("main.get_dashboard_version") as mock_get_version:
result = await check_dashboard_files(webui_dir=valid_dir)
assert result == valid_dir
mock_download.assert_not_called()
mock_get_version.assert_not_called()
-285
View File
@@ -1,285 +0,0 @@
import pytest
import logging
import os
import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import (
AstrBotMessage,
MessageMember,
MessageType,
)
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.components import Plain, At
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
TEST_LLM_PROVIDER = {
"id": "zhipu_default",
"type": "openai_chat_completion",
"enable": True,
"key": [os.getenv("ZHIPU_API_KEY")],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
}
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
["t2i", "文本转图片模式"],
["sid", "此 ID 可用于设置会话白名单。"],
["op test_op", "授权成功。"],
["deop test_op", "取消授权成功。"],
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
["provider", "当前载入的 LLM 提供商"],
["reset", "重置成功"],
# ["model", "查看、切换提供商模型列表"],
["history", "历史记录:"],
["key", "当前 Key"],
["persona", "[Persona]"],
]
class FakeAstrMessageEvent(AstrMessageEvent):
def __init__(self, abm: AstrBotMessage = None):
meta = PlatformMetadata("test_platform", "test")
super().__init__(
message_str=abm.message_str,
message_obj=abm,
platform_meta=meta,
session_id=abm.session_id,
)
async def send(self, message: MessageChain):
await super().send(message)
@staticmethod
def create_fake_event(
message_str: str,
session_id: str = "test_sid",
is_at: bool = False,
is_group: bool = False,
sender_id: str = "123456",
):
abm = AstrBotMessage()
abm.message_str = message_str
abm.group_id = "test"
abm.message = [Plain(message_str)]
if is_at:
abm.message.append(At(qq="bot"))
abm.self_id = "bot"
abm.sender = MessageMember(sender_id, "mika")
abm.timestamp = 1234567890
abm.message_id = "test"
abm.session_id = session_id
if is_group:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
return FakeAstrMessageEvent(abm)
@pytest.fixture(scope="module")
def event_queue():
return Queue()
@pytest.fixture(scope="module")
def config():
cfg = AstrBotConfig()
cfg["platform_settings"]["id_whitelist"] = [
"test_platform:FriendMessage:test_sid_wl",
"test_platform:GroupMessage:test_sid_wl",
]
cfg["admins_id"] = ["123456"]
cfg["content_safety"]["internal_keywords"]["extra_keywords"] = ["^TEST_NEGATIVE"]
cfg["provider"] = [TEST_LLM_PROVIDER]
return cfg
@pytest.fixture(scope="module")
def db():
return SQLiteDatabase("data/data_v3.db")
@pytest.fixture(scope="module")
def platform_manager(event_queue, config):
return PlatformManager(config, event_queue)
@pytest.fixture(scope="module")
def provider_manager(config, db):
return ProviderManager(config, db)
@pytest.fixture(scope="module")
def star_context(event_queue, config, db, platform_manager, provider_manager):
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
# await plugin_manager.reload()
asyncio.run(plugin_manager.reload())
return plugin_manager
@pytest.fixture(scope="module")
def pipeline_context(config, plugin_manager):
return PipelineContext(config, plugin_manager)
@pytest.fixture(scope="module")
def pipeline_scheduler(pipeline_context):
return PipelineScheduler(pipeline_context)
@pytest.mark.asyncio
async def test_platform_initialization(platform_manager: PlatformManager):
await platform_manager.initialize()
@pytest.mark.asyncio
async def test_provider_initialization(provider_manager: ProviderManager):
await provider_manager.initialize()
@pytest.mark.asyncio
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
await pipeline_scheduler.initialize()
@pytest.mark.asyncio
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
"""测试唤醒"""
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any(
"执行阶段 WhitelistCheckStage" not in message for message in caplog.messages
)
# 群聊有 @ 无指令
mock_event = FakeAstrMessageEvent.create_fake_event(
"test", is_group=True, is_at=True
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
# 群聊有指令
mock_event = FakeAstrMessageEvent.create_fake_event(
"/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST
)
await pipeline_scheduler.execute(mock_event)
assert mock_event._has_send_oper is True
@pytest.mark.asyncio
async def test_pipeline_wl(
pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog
):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"test", SESSION_ID_IN_WHITELIST, sender_id="123"
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any(
"不在会话白名单中,已终止事件传播。" not in message
for message in caplog.messages
), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any(
"不在会话白名单中,已终止事件传播。" in message for message in caplog.messages
), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"色情", session_id=SESSION_ID_IN_WHITELIST
) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), (
"日志中未找到预期的消息"
)
# 测试额外屏蔽词
mock_event = FakeAstrMessageEvent.create_fake_event(
"TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), (
"日志中未找到预期的消息"
)
mock_event = FakeAstrMessageEvent.create_fake_event(
"_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" not in message for message in caplog.messages)
# TODO: 测试 百度AI 的内容安全检查
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert mock_event.get_result() is not None
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
"help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert any(
"web_searcher - search_from_search_engine" in message
for message in caplog.messages
)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(
command[0], session_id=SESSION_ID_IN_WHITELIST
)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
assert any(command[1] in message for message in caplog.messages)
+105 -43
View File
@@ -1,5 +1,6 @@
import pytest
import os
from unittest.mock import MagicMock
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@@ -8,18 +9,51 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from asyncio import Queue
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase("data/data_v3.db")
star_context = Context(event_queue, config, db)
@pytest.fixture
def plugin_manager_pm():
return PluginManager(star_context, config)
def plugin_manager_pm(tmp_path):
"""
Provides a fully isolated PluginManager instance for testing.
- Uses a temporary directory for plugins.
- Uses a temporary database.
- Creates a fresh context for each test.
"""
# Create temporary resources
temp_plugins_path = tmp_path / "plugins"
temp_plugins_path.mkdir()
temp_db_path = tmp_path / "test_db.db"
# Create fresh, isolated instances for the context
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase(str(temp_db_path))
# Set the plugin store path in the config to the temporary directory
config.plugin_store_path = str(temp_plugins_path)
# Mock dependencies for the context
provider_manager = MagicMock()
platform_manager = MagicMock()
conversation_manager = MagicMock()
message_history_manager = MagicMock()
persona_manager = MagicMock()
astrbot_config_mgr = MagicMock()
star_context = Context(
event_queue,
config,
db,
provider_manager,
platform_manager,
conversation_manager,
message_history_manager,
persona_manager,
astrbot_config_mgr,
)
# Create the PluginManager instance
manager = PluginManager(star_context, config)
yield manager
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
@@ -36,48 +70,76 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):
"""测试插件安装和重载"""
os.makedirs("data/plugins", exist_ok=True)
async def test_install_plugin(plugin_manager_pm: PluginManager):
"""Tests successful plugin installation in an isolated environment."""
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert plugin_path is not None
plugin_info = await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
)
assert plugin_info is not None
assert os.path.exists(plugin_path)
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# shutil.rmtree(plugin_path)
assert any(md.name == "astrbot_plugin_essential" for md in star_registry), (
"Plugin 'astrbot_plugin_essential' was not loaded into star_registry."
)
# install plugin which is not exists
@pytest.mark.asyncio
async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that installing a non-existent plugin raises an exception."""
with pytest.raises(Exception):
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
await plugin_manager_pm.install_plugin(
"https://github.com/Soulter/non_existent_repo"
)
# update
@pytest.mark.asyncio
async def test_update_plugin(plugin_manager_pm: PluginManager):
"""Tests updating an existing plugin in an isolated environment."""
# First, install the plugin
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
await plugin_manager_pm.install_plugin(test_repo)
# Then, update it
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
# uninstall
@pytest.mark.asyncio
async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that updating a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("non_existent_plugin")
@pytest.mark.asyncio
async def test_uninstall_plugin(plugin_manager_pm: PluginManager):
"""Tests successful plugin uninstallation in an isolated environment."""
# First, install the plugin
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
await plugin_manager_pm.install_plugin(test_repo)
plugin_path = os.path.join(
plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential"
)
assert os.path.exists(plugin_path) # Pre-condition
# Then, uninstall it
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), (
"Plugin 'astrbot_plugin_essential' was not unloaded from star_registry."
)
assert not any(
"astrbot_plugin_essential" in md.handler_module_path
for md in star_handlers_registry
), (
"Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry."
)
@pytest.mark.asyncio
async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager):
"""Tests that uninstalling a non-existent plugin raises an exception."""
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation
await plugin_manager_pm.uninstall_plugin("non_existent_plugin")