diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index 540171f1d..5d15dedc2 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -1,20 +1,19 @@ -from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger -from astrbot.core import html_renderer -from astrbot.core import sp -from astrbot.core.star.register import register_llm_tool as llm_tool -from astrbot.core.star.register import register_agent as agent -from astrbot.core.agent.tool import ToolSet, FunctionTool +from astrbot.core import html_renderer, sp +from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.star.register import register_agent as agent +from astrbot.core.star.register import register_llm_tool as llm_tool __all__ = [ "AstrBotConfig", - "logger", + "BaseFunctionToolExecutor", + "FunctionTool", + "ToolSet", + "agent", "html_renderer", "llm_tool", - "agent", + "logger", "sp", - "ToolSet", - "FunctionTool", - "BaseFunctionToolExecutor", ] diff --git a/astrbot/api/event/__init__.py b/astrbot/api/event/__init__.py index 1f2fce640..2b8dd5a9b 100644 --- a/astrbot/api/event/__init__.py +++ b/astrbot/api/event/__init__.py @@ -1,18 +1,17 @@ from astrbot.core.message.message_event_result import ( - MessageEventResult, - MessageChain, CommandResult, EventResultType, + MessageChain, + MessageEventResult, ResultContentType, ) - from astrbot.core.platform import AstrMessageEvent __all__ = [ - "MessageEventResult", - "MessageChain", + "AstrMessageEvent", "CommandResult", "EventResultType", - "AstrMessageEvent", + "MessageChain", + "MessageEventResult", "ResultContentType", ] diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index d63850e4e..a8d2b4269 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -1,51 +1,52 @@ -from astrbot.core.star.register import ( - register_command as command, - register_command_group as command_group, - register_event_message_type as event_message_type, - register_regex as regex, - register_platform_adapter_type as platform_adapter_type, - register_permission_type as permission_type, - register_custom_filter as custom_filter, - register_on_astrbot_loaded as on_astrbot_loaded, - register_on_platform_loaded as on_platform_loaded, - register_on_llm_request as on_llm_request, - register_on_llm_response as on_llm_response, - register_llm_tool as llm_tool, - register_on_decorating_result as on_decorating_result, - register_after_message_sent as after_message_sent, -) - -from astrbot.core.star.filter.event_message_type import ( - EventMessageTypeFilter, - EventMessageType, -) -from astrbot.core.star.filter.platform_adapter_type import ( - PlatformAdapterTypeFilter, - PlatformAdapterType, -) -from astrbot.core.star.filter.permission import PermissionTypeFilter, PermissionType from astrbot.core.star.filter.custom_filter import CustomFilter +from astrbot.core.star.filter.event_message_type import ( + EventMessageType, + EventMessageTypeFilter, +) +from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter +from astrbot.core.star.filter.platform_adapter_type import ( + PlatformAdapterType, + PlatformAdapterTypeFilter, +) +from astrbot.core.star.register import register_after_message_sent as after_message_sent +from astrbot.core.star.register import register_command as command +from astrbot.core.star.register import register_command_group as command_group +from astrbot.core.star.register import register_custom_filter as custom_filter +from astrbot.core.star.register import register_event_message_type as event_message_type +from astrbot.core.star.register import register_llm_tool as llm_tool +from astrbot.core.star.register import register_on_astrbot_loaded as on_astrbot_loaded +from astrbot.core.star.register import ( + register_on_decorating_result as on_decorating_result, +) +from astrbot.core.star.register import register_on_llm_request as on_llm_request +from astrbot.core.star.register import register_on_llm_response as on_llm_response +from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded +from astrbot.core.star.register import register_permission_type as permission_type +from astrbot.core.star.register import ( + register_platform_adapter_type as platform_adapter_type, +) +from astrbot.core.star.register import register_regex as regex __all__ = [ + "CustomFilter", + "EventMessageType", + "EventMessageTypeFilter", + "PermissionType", + "PermissionTypeFilter", + "PlatformAdapterType", + "PlatformAdapterTypeFilter", + "after_message_sent", "command", "command_group", - "event_message_type", - "regex", - "platform_adapter_type", - "permission_type", - "EventMessageTypeFilter", - "EventMessageType", - "PlatformAdapterTypeFilter", - "PlatformAdapterType", - "PermissionTypeFilter", - "CustomFilter", "custom_filter", - "PermissionType", - "on_astrbot_loaded", - "on_platform_loaded", - "on_llm_request", + "event_message_type", "llm_tool", + "on_astrbot_loaded", "on_decorating_result", - "after_message_sent", + "on_llm_request", "on_llm_response", + "on_platform_loaded", + "permission_type", + "platform_adapter_type", + "regex", ] diff --git a/astrbot/api/platform/__init__.py b/astrbot/api/platform/__init__.py index 5a98c5903..6a182c32b 100644 --- a/astrbot/api/platform/__init__.py +++ b/astrbot/api/platform/__init__.py @@ -1,23 +1,22 @@ +from astrbot.core.message.components import * from astrbot.core.platform import ( - AstrMessageEvent, - Platform, AstrBotMessage, + AstrMessageEvent, + Group, MessageMember, MessageType, + Platform, PlatformMetadata, - Group, ) - from astrbot.core.platform.register import register_platform_adapter -from astrbot.core.message.components import * __all__ = [ - "AstrMessageEvent", - "Platform", "AstrBotMessage", + "AstrMessageEvent", + "Group", "MessageMember", "MessageType", + "Platform", "PlatformMetadata", "register_platform_adapter", - "Group", ] diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 9b1ade50a..2008c7bcf 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,17 +1,17 @@ -from astrbot.core.provider import Provider, STTProvider, Personality +from astrbot.core.provider import Personality, Provider, STTProvider from astrbot.core.provider.entities import ( + LLMResponse, + ProviderMetaData, ProviderRequest, ProviderType, - ProviderMetaData, - LLMResponse, ) __all__ = [ - "Provider", - "STTProvider", + "LLMResponse", "Personality", + "Provider", + "ProviderMetaData", "ProviderRequest", "ProviderType", - "ProviderMetaData", - "LLMResponse", + "STTProvider", ] diff --git a/astrbot/api/star/__init__.py b/astrbot/api/star/__init__.py index 1b33923fe..63db07a72 100644 --- a/astrbot/api/star/__init__.py +++ b/astrbot/api/star/__init__.py @@ -1,8 +1,7 @@ +from astrbot.core.star import Context, Star, StarTools +from astrbot.core.star.config import * from astrbot.core.star.register import ( register_star as register, # 注册插件(Star) ) -from astrbot.core.star import Context, Star, StarTools -from astrbot.core.star.config import * - -__all__ = ["register", "Context", "Star", "StarTools"] +__all__ = ["Context", "Star", "StarTools", "register"] diff --git a/astrbot/api/util/__init__.py b/astrbot/api/util/__init__.py index a66206e05..1be3152d0 100644 --- a/astrbot/api/util/__init__.py +++ b/astrbot/api/util/__init__.py @@ -1,7 +1,7 @@ from astrbot.core.utils.session_waiter import ( - SessionWaiter, SessionController, + SessionWaiter, session_waiter, ) -__all__ = ["SessionWaiter", "SessionController", "session_waiter"] +__all__ = ["SessionController", "SessionWaiter", "session_waiter"] diff --git a/astrbot/cli/__main__.py b/astrbot/cli/__main__.py index f2b6651f5..40c46de79 100644 --- a/astrbot/cli/__main__.py +++ b/astrbot/cli/__main__.py @@ -1,11 +1,11 @@ -""" -AstrBot CLI入口 -""" +"""AstrBot CLI入口""" + +import sys import click -import sys + from . import __version__ -from .commands import init, run, plug, conf +from .commands import conf, init, plug, run logo_tmpl = r""" ___ _______.___________..______ .______ ______ .___________. diff --git a/astrbot/cli/commands/__init__.py b/astrbot/cli/commands/__init__.py index 9fa9149e2..1d3e0bca2 100644 --- a/astrbot/cli/commands/__init__.py +++ b/astrbot/cli/commands/__init__.py @@ -1,6 +1,6 @@ -from .cmd_init import init -from .cmd_run import run -from .cmd_plug import plug from .cmd_conf import conf +from .cmd_init import init +from .cmd_plug import plug +from .cmd_run import run -__all__ = ["init", "run", "plug", "conf"] +__all__ = ["conf", "init", "plug", "run"] diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index fea654f20..86f78cbaa 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -1,9 +1,12 @@ -import json -import click import hashlib +import json import zoneinfo -from typing import Any, Callable -from ..utils import get_astrbot_root, check_astrbot_root +from collections.abc import Callable +from typing import Any + +import click + +from ..utils import check_astrbot_root, get_astrbot_root def _validate_log_level(value: str) -> str: @@ -11,7 +14,7 @@ def _validate_log_level(value: str) -> str: value = value.upper() if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: raise click.ClickException( - "日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一" + "日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一", ) return value @@ -73,7 +76,7 @@ def _load_config() -> dict[str, Any]: root = get_astrbot_root() if not check_astrbot_root(root): raise click.ClickException( - f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init" + f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", ) config_path = root / "data" / "cmd_config.json" @@ -88,7 +91,7 @@ def _load_config() -> dict[str, Any]: try: return json.loads(config_path.read_text(encoding="utf-8-sig")) except json.JSONDecodeError as e: - raise click.ClickException(f"配置文件解析失败: {str(e)}") + raise click.ClickException(f"配置文件解析失败: {e!s}") def _save_config(config: dict[str, Any]) -> None: @@ -96,7 +99,8 @@ def _save_config(config: dict[str, Any]) -> None: config_path = get_astrbot_root() / "data" / "cmd_config.json" config_path.write_text( - json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8-sig" + json.dumps(config, ensure_ascii=False, indent=2), + encoding="utf-8-sig", ) @@ -108,7 +112,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: obj[part] = {} elif not isinstance(obj[part], dict): raise click.ClickException( - f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典" + f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典", ) obj = obj[part] obj[parts[-1]] = value @@ -140,7 +144,6 @@ def conf(): - callback_api_base: 回调接口基址 """ - pass @conf.command(name="set") @@ -148,7 +151,7 @@ def conf(): @click.argument("value") def set_config(key: str, value: str): """设置配置项的值""" - if key not in CONFIG_VALIDATORS.keys(): + if key not in CONFIG_VALIDATORS: raise click.ClickException(f"不支持的配置项: {key}") config = _load_config() @@ -170,7 +173,7 @@ def set_config(key: str, value: str): except KeyError: raise click.ClickException(f"未知的配置项: {key}") except Exception as e: - raise click.UsageError(f"设置配置失败: {str(e)}") + raise click.UsageError(f"设置配置失败: {e!s}") @conf.command(name="get") @@ -180,7 +183,7 @@ def get_config(key: str = None): config = _load_config() if key: - if key not in CONFIG_VALIDATORS.keys(): + if key not in CONFIG_VALIDATORS: raise click.ClickException(f"不支持的配置项: {key}") try: @@ -191,10 +194,10 @@ def get_config(key: str = None): except KeyError: raise click.ClickException(f"未知的配置项: {key}") except Exception as e: - raise click.UsageError(f"获取配置失败: {str(e)}") + raise click.UsageError(f"获取配置失败: {e!s}") else: click.echo("当前配置:") - for key in CONFIG_VALIDATORS.keys(): + for key in CONFIG_VALIDATORS: try: value = ( "********" diff --git a/astrbot/cli/commands/cmd_init.py b/astrbot/cli/commands/cmd_init.py index d9a42f822..993995a66 100644 --- a/astrbot/cli/commands/cmd_init.py +++ b/astrbot/cli/commands/cmd_init.py @@ -13,7 +13,7 @@ async def initialize_astrbot(astrbot_root) -> None: if not dot_astrbot.exists(): click.echo(f"Current Directory: {astrbot_root}") click.echo( - "如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。" + "如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。", ) if click.confirm( f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}", diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index b250ede4b..a1099de1d 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -1,31 +1,29 @@ import re +import shutil from pathlib import Path import click -import shutil - from ..utils import ( - get_git_repo, - build_plug_list, - manage_plugin, PluginStatus, + build_plug_list, check_astrbot_root, get_astrbot_root, + get_git_repo, + manage_plugin, ) @click.group() def plug(): """插件管理""" - pass def _get_data_path() -> Path: base = get_astrbot_root() if not check_astrbot_root(base): raise click.ClickException( - f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init" + f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", ) return (base / "data").resolve() @@ -41,7 +39,7 @@ def display_plugins(plugins, title=None, color=None): desc = p["desc"][:30] + ("..." if len(p["desc"]) > 30 else "") click.echo( f"{p['name']:<20} {p['version']:<10} {p['status']:<10} " - f"{p['author']:<15} {desc:<30}" + f"{p['author']:<15} {desc:<30}", ) @@ -78,7 +76,7 @@ def new(name: str): f"desc: {desc}\n" f"version: {version}\n" f"author: {author}\n" - f"repo: {repo}\n" + f"repo: {repo}\n", ) # 重写 README.md @@ -86,7 +84,7 @@ def new(name: str): f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n") # 重写 main.py - with open(plug_path / "main.py", "r", encoding="utf-8") as f: + with open(plug_path / "main.py", encoding="utf-8") as f: content = f.read() new_content = content.replace( diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 38113744f..9333f1b87 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -1,19 +1,18 @@ +import asyncio import os import sys +import traceback from pathlib import Path import click -import asyncio -import traceback - from filelock import FileLock, Timeout -from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root +from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root async def run_astrbot(astrbot_root: Path): """运行 AstrBot""" - from astrbot.core import logger, LogManager, LogBroker, db_helper + from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.initial_loader import InitialLoader await check_dashboard(astrbot_root / "data") @@ -38,7 +37,7 @@ def run(reload: bool, port: str) -> None: if not check_astrbot_root(astrbot_root): raise click.ClickException( - f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init" + f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", ) os.environ["ASTRBOT_ROOT"] = str(astrbot_root) diff --git a/astrbot/cli/utils/__init__.py b/astrbot/cli/utils/__init__.py index 9989dcf26..3830682f0 100644 --- a/astrbot/cli/utils/__init__.py +++ b/astrbot/cli/utils/__init__.py @@ -1,18 +1,18 @@ from .basic import ( - get_astrbot_root, check_astrbot_root, check_dashboard, + get_astrbot_root, ) -from .plugin import get_git_repo, manage_plugin, build_plug_list, PluginStatus +from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin from .version_comparator import VersionComparator __all__ = [ - "get_astrbot_root", + "PluginStatus", + "VersionComparator", + "build_plug_list", "check_astrbot_root", "check_dashboard", + "get_astrbot_root", "get_git_repo", "manage_plugin", - "build_plug_list", - "VersionComparator", - "PluginStatus", ] diff --git a/astrbot/cli/utils/basic.py b/astrbot/cli/utils/basic.py index fabced48a..5dbe29006 100644 --- a/astrbot/cli/utils/basic.py +++ b/astrbot/cli/utils/basic.py @@ -21,8 +21,9 @@ def get_astrbot_root() -> Path: async def check_dashboard(astrbot_root: Path) -> None: """检查是否安装了dashboard""" - from astrbot.core.utils.io import get_dashboard_version, download_dashboard from astrbot.core.config.default import VERSION + from astrbot.core.utils.io import download_dashboard, get_dashboard_version + from .version_comparator import VersionComparator try: @@ -48,19 +49,18 @@ async def check_dashboard(astrbot_root: Path) -> None: if VersionComparator.compare_version(VERSION, dashboard_version) <= 0: click.echo("管理面板已是最新版本") return - else: - try: - version = dashboard_version.split("v")[1] - click.echo(f"管理面板版本: {version}") - await download_dashboard( - path="data/dashboard.zip", - extract_path=str(astrbot_root), - version=f"v{VERSION}", - latest=False, - ) - except Exception as e: - click.echo(f"下载管理面板失败: {e}") - return + try: + version = dashboard_version.split("v")[1] + click.echo(f"管理面板版本: {version}") + await download_dashboard( + path="data/dashboard.zip", + extract_path=str(astrbot_root), + version=f"v{VERSION}", + latest=False, + ) + except Exception as e: + click.echo(f"下载管理面板失败: {e}") + return except FileNotFoundError: click.echo("初始化管理面板目录...") try: diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd1fcd97b..55edf4de2 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -1,14 +1,14 @@ import shutil import tempfile - -import httpx -import yaml from enum import Enum from io import BytesIO from pathlib import Path from zipfile import ZipFile import click +import httpx +import yaml + from .version_comparator import VersionComparator @@ -32,7 +32,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None): release_url = f"https://api.github.com/repos/{author}/{repo}/releases" try: with httpx.Client( - proxy=proxy if proxy else None, follow_redirects=True + proxy=proxy if proxy else None, + follow_redirects=True, ) as client: resp = client.get(release_url) resp.raise_for_status() @@ -55,7 +56,8 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None): # 下载并解压 with httpx.Client( - proxy=proxy if proxy else None, follow_redirects=True + proxy=proxy if proxy else None, + follow_redirects=True, ) as client: resp = client.get(download_url) if ( @@ -89,6 +91,7 @@ def load_yaml_metadata(plugin_dir: Path) -> dict: Returns: dict: 包含元数据的字典,如果读取失败则返回空字典 + """ yaml_path = plugin_dir / "metadata.yaml" if yaml_path.exists(): @@ -107,6 +110,7 @@ def build_plug_list(plugins_dir: Path) -> list: Returns: list: 包含插件信息的字典列表 + """ # 获取本地插件信息 result = [] @@ -133,7 +137,7 @@ def build_plug_list(plugins_dir: Path) -> list: "repo": str(metadata.get("repo", "")), "status": PluginStatus.INSTALLED, "local_path": str(plugin_dir), - } + }, ) # 获取在线插件列表 @@ -153,7 +157,7 @@ def build_plug_list(plugins_dir: Path) -> list: "repo": str(plugin_info.get("repo", "")), "status": PluginStatus.NOT_INSTALLED, "local_path": None, - } + }, ) except Exception as e: click.echo(f"获取在线插件列表失败: {e}", err=True) @@ -168,7 +172,8 @@ def build_plug_list(plugins_dir: Path) -> list: ) if ( VersionComparator.compare_version( - local_plugin["version"], online_plugin["version"] + local_plugin["version"], + online_plugin["version"], ) < 0 ): @@ -186,7 +191,10 @@ def build_plug_list(plugins_dir: Path) -> list: def manage_plugin( - plugin: dict, plugins_dir: Path, is_update: bool = False, proxy: str | None = None + plugin: dict, + plugins_dir: Path, + is_update: bool = False, + proxy: str | None = None, ) -> None: """安装或更新插件 @@ -195,6 +203,7 @@ def manage_plugin( plugins_dir (Path): 插件目录 is_update (bool, optional): 是否为更新操作. 默认为 False proxy (str, optional): 代理服务器地址 + """ plugin_name = plugin["name"] repo_url = plugin["repo"] @@ -219,7 +228,7 @@ def manage_plugin( try: click.echo( - f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}..." + f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...", ) get_git_repo(repo_url, target_path, proxy) @@ -233,5 +242,5 @@ def manage_plugin( if is_update and backup_path.exists(): shutil.move(backup_path, target_path) raise click.ClickException( - f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}" + f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}", ) diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py index fecab885e..99d71d34d 100644 --- a/astrbot/cli/utils/version_comparator.py +++ b/astrbot/cli/utils/version_comparator.py @@ -1,6 +1,4 @@ -""" -拷贝自 astrbot.core.utils.version_comparator -""" +"""拷贝自 astrbot.core.utils.version_comparator""" import re @@ -42,15 +40,15 @@ class VersionComparator: for i in range(length): if v1_parts[i] > v2_parts[i]: return 1 - elif v1_parts[i] < v2_parts[i]: + if v1_parts[i] < v2_parts[i]: return -1 # 比较预发布标签 if v1_prerelease is None and v2_prerelease is not None: return 1 # 没有预发布标签的版本高于有预发布标签的版本 - elif v1_prerelease is not None and v2_prerelease is None: + if v1_prerelease is not None and v2_prerelease is None: return -1 # 有预发布标签的版本低于没有预发布标签的版本 - elif v1_prerelease is not None and v2_prerelease is not None: + if v1_prerelease is not None and v2_prerelease is not None: len_pre = max(len(v1_prerelease), len(v2_prerelease)) for i in range(len_pre): p1 = v1_prerelease[i] if i < len(v1_prerelease) else None @@ -58,21 +56,18 @@ class VersionComparator: if p1 is None and p2 is not None: return -1 - elif p1 is not None and p2 is None: + if p1 is not None and p2 is None: return 1 - elif isinstance(p1, int) and isinstance(p2, str): + if isinstance(p1, int) and isinstance(p2, str): return -1 - elif isinstance(p1, str) and isinstance(p2, int): + if isinstance(p1, str) and isinstance(p2, int): return 1 - elif isinstance(p1, int) and isinstance(p2, int): + if (isinstance(p1, int) and isinstance(p2, int)) or ( + isinstance(p1, str) and isinstance(p2, str) + ): if p1 > p2: return 1 - elif p1 < p2: - return -1 - elif isinstance(p1, str) and isinstance(p2, str): - if p1 > p2: - return 1 - elif p1 < p2: + if p1 < p2: return -1 return 0 # 预发布标签完全相同 diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 235a8284b..30b81af60 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,12 +1,14 @@ import os -from .log import LogManager, LogBroker # noqa -from astrbot.core.utils.t2i.renderer import HtmlRenderer -from astrbot.core.utils.shared_preferences import SharedPreferences -from astrbot.core.utils.pip_installer import PipInstaller -from astrbot.core.db.sqlite import SQLiteDatabase -from astrbot.core.config.default import DB_PATH + from astrbot.core.config import AstrBotConfig +from astrbot.core.config.default import DB_PATH +from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.file_token_service import FileTokenService +from astrbot.core.utils.pip_installer import PipInstaller +from astrbot.core.utils.shared_preferences import SharedPreferences +from astrbot.core.utils.t2i.renderer import HtmlRenderer + +from .log import LogBroker, LogManager # noqa from .utils.astrbot_path import get_astrbot_data_path # 初始化数据存储文件夹 diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py index 061ffde09..e2206829e 100644 --- a/astrbot/core/agent/agent.py +++ b/astrbot/core/agent/agent.py @@ -1,8 +1,9 @@ from dataclasses import dataclass -from .tool import FunctionTool from typing import Generic -from .run_context import TContext + from .hooks import BaseAgentRunHooks +from .run_context import TContext +from .tool import FunctionTool @dataclass diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index d26463147..85276540b 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,14 +1,18 @@ from typing import Generic -from .tool import FunctionTool + from .agent import Agent from .run_context import TContext +from .tool import FunctionTool class HandoffTool(FunctionTool, Generic[TContext]): """Handoff tool for delegating tasks to another agent.""" def __init__( - self, agent: Agent[TContext], parameters: dict | None = None, **kwargs + self, + agent: Agent[TContext], + parameters: dict | None = None, + **kwargs, ): self.agent = agent super().__init__( diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 884fe6bd4..949ebd3fe 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -1,9 +1,12 @@ -import mcp from dataclasses import dataclass -from .run_context import ContextWrapper, TContext from typing import Generic -from astrbot.core.provider.entities import LLMResponse + +import mcp + from astrbot.core.agent.tool import FunctionTool +from astrbot.core.provider.entities import LLMResponse + +from .run_context import ContextWrapper, TContext @dataclass @@ -23,5 +26,7 @@ class BaseAgentRunHooks(Generic[TContext]): tool_result: mcp.types.CallToolResult | None, ): ... async def on_agent_done( - self, run_context: ContextWrapper[TContext], llm_response: LLMResponse + self, + run_context: ContextWrapper[TContext], + llm_response: LLMResponse, ): ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 8db9d6f26..303973a0d 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -1,8 +1,8 @@ import asyncio import logging -from datetime import timedelta -from typing import Optional from contextlib import AsyncExitStack +from datetime import timedelta + from astrbot import logger from astrbot.core.utils.log_pipe import LogPipe @@ -16,13 +16,13 @@ try: from mcp.client.streamable_http import streamablehttp_client except (ModuleNotFoundError, ImportError): logger.warning( - "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。" + "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。", ) def _prepare_config(config: dict) -> dict: """准备配置,处理嵌套格式""" - if "mcpServers" in config and config["mcpServers"]: + if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] config.pop("active", None) @@ -71,8 +71,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: ) as response: if response.status == 200: return True, "" - else: - return False, f"HTTP {response.status}: {response.reason}" + return False, f"HTTP {response.status}: {response.reason}" else: async with session.get( url, @@ -84,8 +83,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: ) as response: if response.status == 200: return True, "" - else: - return False, f"HTTP {response.status}: {response.reason}" + return False, f"HTTP {response.status}: {response.reason}" except asyncio.TimeoutError: return False, f"连接超时: {timeout}秒" @@ -96,7 +94,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: def __init__(self): # Initialize session and client objects - self.session: Optional[mcp.ClientSession] = None + self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() self.name: str | None = None @@ -115,6 +113,7 @@ class MCPClient: Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server + """ cfg = _prepare_config(mcp_server_config.copy()) @@ -144,7 +143,7 @@ class MCPClient: sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), ) streams = await self.exit_stack.enter_async_context( - self._streams_context + self._streams_context, ) # Create a new client session @@ -154,12 +153,12 @@ class MCPClient: *streams, read_timeout_seconds=read_timeout, logging_callback=logging_callback, # type: ignore - ) + ), ) else: timeout = timedelta(seconds=cfg.get("timeout", 30)) sse_read_timeout = timedelta( - seconds=cfg.get("sse_read_timeout", 60 * 5) + seconds=cfg.get("sse_read_timeout", 60 * 5), ) self._streams_context = streamablehttp_client( url=cfg["url"], @@ -169,7 +168,7 @@ class MCPClient: terminate_on_close=cfg.get("terminate_on_close", True), ) read_s, write_s, _ = await self.exit_stack.enter_async_context( - self._streams_context + self._streams_context, ) # Create a new client session @@ -180,7 +179,7 @@ class MCPClient: write_stream=write_s, read_timeout_seconds=read_timeout, logging_callback=logging_callback, # type: ignore - ) + ), ) else: @@ -206,7 +205,7 @@ class MCPClient: # Create a new client session self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(*stdio_transport) + mcp.ClientSession(*stdio_transport), ) await self.session.initialize() diff --git a/astrbot/core/agent/response.py b/astrbot/core/agent/response.py index 8eb1854f6..3f3430c87 100644 --- a/astrbot/core/agent/response.py +++ b/astrbot/core/agent/response.py @@ -1,5 +1,6 @@ -from dataclasses import dataclass import typing as T +from dataclasses import dataclass + from astrbot.core.message.message_event_result import MessageChain diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index a0febf8c9..634735ccc 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Any, Generic + from typing_extensions import TypeVar from astrbot.core.platform.astr_message_event import AstrMessageEvent diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 83821ae29..c7cd36d96 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -1,13 +1,15 @@ import abc import typing as T from enum import Enum, auto -from ..run_context import ContextWrapper, TContext -from ..response import AgentResponse -from ..hooks import BaseAgentRunHooks -from ..tool_executor import BaseFunctionToolExecutor + from astrbot.core.provider import Provider from astrbot.core.provider.entities import LLMResponse +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponse +from ..run_context import ContextWrapper, TContext +from ..tool_executor import BaseFunctionToolExecutor + class AgentState(Enum): """Defines the state of the agent.""" @@ -28,31 +30,26 @@ class BaseAgentRunner(T.Generic[TContext]): agent_hooks: BaseAgentRunHooks[TContext], **kwargs: T.Any, ) -> None: - """ - Reset the agent to its initial state. + """Reset the agent to its initial state. This method should be called before starting a new run. """ ... @abc.abstractmethod async def step(self) -> T.AsyncGenerator[AgentResponse, None]: - """ - Process a single step of the agent. - """ + """Process a single step of the agent.""" ... @abc.abstractmethod def done(self) -> bool: - """ - Check if the agent has completed its task. + """Check if the agent has completed its task. Returns True if the agent is done, False otherwise. """ ... @abc.abstractmethod def get_final_llm_resp(self) -> LLMResponse | None: - """ - Get the final observation from the agent. + """Get the final observation from the agent. This method should be called after the agent is done. """ ... diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 33298e895..cb89fb612 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,31 +1,34 @@ import sys import traceback import typing as T -from .base import BaseAgentRunner, AgentResponse, AgentState -from ..hooks import BaseAgentRunHooks -from ..tool_executor import BaseFunctionToolExecutor -from ..run_context import ContextWrapper, TContext -from ..response import AgentResponseData -from astrbot.core.provider.provider import Provider + +from mcp.types import ( + BlobResourceContents, + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + TextResourceContents, +) + +from astrbot import logger from astrbot.core.message.message_event_result import ( MessageChain, ) from astrbot.core.provider.entities import ( - ProviderRequest, - LLMResponse, - ToolCallMessageSegment, AssistantMessageSegment, + LLMResponse, + ProviderRequest, + ToolCallMessageSegment, ToolCallsResult, ) -from mcp.types import ( - TextContent, - ImageContent, - EmbeddedResource, - TextResourceContents, - BlobResourceContents, - CallToolResult, -) -from astrbot import logger +from astrbot.core.provider.provider import Provider + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from ..tool_executor import BaseFunctionToolExecutor +from .base import AgentResponse, AgentState, BaseAgentRunner if sys.version_info >= (3, 12): from typing import override @@ -70,8 +73,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): @override async def step(self): - """ - Process a single step of the agent. + """Process a single step of the agent. This method should return the result of the step. """ if not self.req: @@ -99,7 +101,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): yield AgentResponse( type="streaming_delta", data=AgentResponseData( - chain=MessageChain().message(llm_response.completion_text) + chain=MessageChain().message(llm_response.completion_text), ), ) continue @@ -120,8 +122,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): type="err", data=AgentResponseData( chain=MessageChain().message( - f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}" - ) + f"LLM 响应错误: {llm_resp.completion_text or '未知错误'}", + ), ), ) @@ -144,7 +146,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): yield AgentResponse( type="llm_result", data=AgentResponseData( - chain=MessageChain().message(llm_resp.completion_text) + chain=MessageChain().message(llm_resp.completion_text), ), ) @@ -155,7 +157,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): yield AgentResponse( type="tool_call", data=AgentResponseData( - chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}") + chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"), ), ) async for result in self._handle_function_tools(self.req, llm_resp): @@ -205,7 +207,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): role="tool", tool_call_id=func_tool_id, content=f"error: 未找到工具 {func_tool_name}", - ) + ), ) continue @@ -214,7 +216,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): # 获取实际的 handler 函数 if func_tool.handler: logger.debug( - f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}" + f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}", ) if func_tool.parameters and func_tool.parameters.get("properties"): expected_params = set(func_tool.parameters["properties"].keys()) @@ -227,11 +229,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): # 记录被忽略的参数 ignored_params = set(func_tool_args.keys()) - set( - valid_params.keys() + valid_params.keys(), ) if ignored_params: logger.warning( - f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}" + f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}", ) else: # 如果没有 handler(如 MCP 工具),使用所有参数 @@ -240,7 +242,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): try: await self.agent_hooks.on_tool_start( - self.run_context, func_tool, valid_params + self.run_context, + func_tool, + valid_params, ) except Exception as e: logger.error(f"Error in on_tool_start hook: {e}", exc_info=True) @@ -262,7 +266,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): role="tool", tool_call_id=func_tool_id, content=res.content[0].text, - ) + ), ) yield MessageChain().message(res.content[0].text) elif isinstance(res.content[0], ImageContent): @@ -271,10 +275,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): role="tool", tool_call_id=func_tool_id, content="返回了图片(已直接发送给用户)", - ) + ), ) yield MessageChain(type="tool_direct_result").base64_image( - res.content[0].data + res.content[0].data, ) elif isinstance(res.content[0], EmbeddedResource): resource = res.content[0].resource @@ -284,7 +288,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): role="tool", tool_call_id=func_tool_id, content=resource.text, - ) + ), ) yield MessageChain().message(resource.text) elif ( @@ -297,10 +301,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): role="tool", tool_call_id=func_tool_id, content="返回了图片(已直接发送给用户)", - ) + ), ) yield MessageChain( - type="tool_direct_result" + type="tool_direct_result", ).base64_image(resource.blob) else: tool_call_result_blocks.append( @@ -308,7 +312,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): role="tool", tool_call_id=func_tool_id, content="返回的数据类型不受支持", - ) + ), ) yield MessageChain().message("返回的数据类型不受支持。") @@ -319,17 +323,21 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): if res := self.run_context.event.get_result(): if res.chain: yield MessageChain( - chain=res.chain, type="tool_direct_result" + chain=res.chain, + type="tool_direct_result", ) else: # 不应该出现其他类型 logger.warning( - f"Tool 返回了不支持的类型: {type(resp)},将忽略。" + f"Tool 返回了不支持的类型: {type(resp)},将忽略。", ) try: await self.agent_hooks.on_tool_end( - self.run_context, func_tool, func_tool_args, _final_resp + self.run_context, + func_tool, + func_tool_args, + _final_resp, ) except Exception as e: logger.error(f"Error in on_tool_end hook: {e}", exc_info=True) @@ -341,8 +349,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content=f"error: {str(e)}", - ) + content=f"error: {e!s}", + ), ) # 处理函数调用响应 diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index ae0ab761c..3c36def63 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,9 @@ +from collections.abc import Awaitable, Callable from dataclasses import dataclass +from typing import Any, Literal + from deprecated import deprecated -from typing import Awaitable, Callable, Literal, Any, Optional + from .mcp_client import MCPClient @@ -49,7 +52,8 @@ class ToolSet: """A set of function tools that can be used in function calling. This class provides methods to add, remove, and retrieve tools, as well as - convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).""" + convert the tools to different API formats (OpenAI, Anthropic, Google GenAI). + """ def __init__(self, tools: list[FunctionTool] | None = None): self.tools: list[FunctionTool] = tools or [] @@ -71,7 +75,7 @@ class ToolSet: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] - def get_tool(self, name: str) -> Optional[FunctionTool]: + def get_tool(self, name: str) -> FunctionTool | None: """Get a tool by its name.""" for tool in self.tools: if tool.name == name: @@ -132,10 +136,8 @@ class ToolSet: } if ( - tool.parameters - and tool.parameters.get("properties") - or not omit_empty_parameter_field - ): + tool.parameters and tool.parameters.get("properties") + ) or not omit_empty_parameter_field: func_def["function"]["parameters"] = tool.parameters result.append(func_def) @@ -185,7 +187,8 @@ class ToolSet: if "type" in schema and schema["type"] in supported_types: result["type"] = schema["type"] if "format" in schema and schema["format"] in supported_formats.get( - result["type"], set() + result["type"], + set(), ): result["format"] = schema["format"] else: diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 34a2f5e77..2704119d4 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,11 +1,17 @@ +from collections.abc import AsyncGenerator +from typing import Any, Generic + import mcp -from typing import Any, Generic, AsyncGenerator -from .run_context import TContext, ContextWrapper + +from .run_context import ContextWrapper, TContext from .tool import FunctionTool class BaseFunctionToolExecutor(Generic[TContext]): @classmethod async def execute( - cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args + cls, + tool: FunctionTool, + run_context: ContextWrapper[TContext], + **tool_args, ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... diff --git a/astrbot/core/astr_agent_context.py b/astrbot/core/astr_agent_context.py index 008c3a435..e21ddb9c6 100644 --- a/astrbot/core/astr_agent_context.py +++ b/astrbot/core/astr_agent_context.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 0ee3f4fe6..3a1353ce5 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -1,13 +1,14 @@ import os import uuid +from typing import TypedDict, TypeVar + from astrbot.core import AstrBotConfig, logger -from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH from astrbot.core.config.default import DEFAULT_CONFIG from astrbot.core.platform.message_session import MessageSession from astrbot.core.umop_config_router import UmopConfigRouter from astrbot.core.utils.astrbot_path import get_astrbot_config_path -from typing import TypeVar, TypedDict +from astrbot.core.utils.shared_preferences import SharedPreferences _VT = TypeVar("_VT") @@ -48,7 +49,10 @@ class AstrBotConfigManager: """获取所有的 abconf 数据""" if self.abconf_data is None: self.abconf_data = self.sp.get( - "abconf_mapping", {}, scope="global", scope_id="global" + "abconf_mapping", + {}, + scope="global", + scope_id="global", ) return self.abconf_data @@ -64,7 +68,7 @@ class AstrBotConfigManager: self.confs[uuid_] = conf else: logger.warning( - f"Config file {conf_path} for UUID {uuid_} does not exist, skipping." + f"Config file {conf_path} for UUID {uuid_} does not exist, skipping.", ) continue @@ -73,6 +77,7 @@ class AstrBotConfigManager: Returns: ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型 + """ # uuid -> { "path": str, "name": str } abconf_data = self._get_abconf_data() @@ -103,7 +108,10 @@ class AstrBotConfigManager: ) -> None: """保存配置文件的映射关系""" abconf_data = self.sp.get( - "abconf_mapping", {}, scope="global", scope_id="global" + "abconf_mapping", + {}, + scope="global", + scope_id="global", ) random_word = abconf_name or uuid.uuid4().hex[:8] abconf_data[abconf_id] = { @@ -177,13 +185,17 @@ class AstrBotConfigManager: Raises: ValueError: 如果试图删除默认配置文件 + """ if conf_id == "default": raise ValueError("不能删除默认配置文件") # 从映射中移除 abconf_data = self.sp.get( - "abconf_mapping", {}, scope="global", scope_id="global" + "abconf_mapping", + {}, + scope="global", + scope_id="global", ) if conf_id not in abconf_data: logger.warning(f"配置文件 {conf_id} 不存在于映射中") @@ -191,7 +203,8 @@ class AstrBotConfigManager: # 获取配置文件路径 conf_path = os.path.join( - get_astrbot_config_path(), abconf_data[conf_id]["path"] + get_astrbot_config_path(), + abconf_data[conf_id]["path"], ) # 删除配置文件 @@ -224,12 +237,16 @@ class AstrBotConfigManager: Returns: bool: 更新是否成功 + """ if conf_id == "default": raise ValueError("不能更新默认配置文件的信息") abconf_data = self.sp.get( - "abconf_mapping", {}, scope="global", scope_id="global" + "abconf_mapping", + {}, + scope="global", + scope_id="global", ) if conf_id not in abconf_data: logger.warning(f"配置文件 {conf_id} 不存在于映射中") @@ -246,7 +263,10 @@ class AstrBotConfigManager: return True def g( - self, umo: str | None = None, key: str | None = None, default: _VT = None + self, + umo: str | None = None, + key: str | None = None, + default: _VT = None, ) -> _VT: """获取配置项。umo 为 None 时使用默认配置""" if umo is None: diff --git a/astrbot/core/config/__init__.py b/astrbot/core/config/__init__.py index e49ac88a5..839aeef3e 100644 --- a/astrbot/core/config/__init__.py +++ b/astrbot/core/config/__init__.py @@ -1,9 +1,9 @@ -from .default import DEFAULT_CONFIG, VERSION, DB_PATH from .astrbot_config import * +from .default import DB_PATH, DEFAULT_CONFIG, VERSION __all__ = [ + "DB_PATH", "DEFAULT_CONFIG", "VERSION", - "DB_PATH", "AstrBotConfig", ] diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 5d1f6fbe7..68b73cd29 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -1,11 +1,12 @@ -import os +import enum import json import logging -import enum -from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP -from typing import Dict +import os + from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP + ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") logger = logging.getLogger("astrbot") @@ -45,7 +46,7 @@ class AstrBotConfig(dict): json.dump(default_config, f, indent=4, ensure_ascii=False) object.__setattr__(self, "first_deploy", True) # 标记第一次部署 - with open(config_path, "r", encoding="utf-8-sig") as f: + with open(config_path, encoding="utf-8-sig") as f: conf_str = f.read() conf = json.loads(conf_str) @@ -65,7 +66,7 @@ class AstrBotConfig(dict): for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( - f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}" + f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}", ) if "default" in v: default = v["default"] @@ -82,7 +83,7 @@ class AstrBotConfig(dict): return conf - def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): + def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False @@ -97,27 +98,28 @@ class AstrBotConfig(dict): logger.info(f"检查到配置项 {path_} 不存在,已插入默认值 {value}") new_conf[key] = value has_new = True - else: - if conf[key] is None: - # 配置项为 None,使用默认值 + elif conf[key] is None: + # 配置项为 None,使用默认值 + new_conf[key] = value + has_new = True + elif isinstance(value, dict): + # 递归检查子配置项 + if not isinstance(conf[key], dict): + # 类型不匹配,使用默认值 new_conf[key] = value has_new = True - elif isinstance(value, dict): - # 递归检查子配置项 - if not isinstance(conf[key], dict): - # 类型不匹配,使用默认值 - new_conf[key] = value - has_new = True - else: - # 递归检查并同步顺序 - child_has_new = self.check_config_integrity( - value, conf[key], path + "." + key if path else key - ) - new_conf[key] = conf[key] - has_new |= child_has_new else: - # 直接使用现有配置 + # 递归检查并同步顺序 + child_has_new = self.check_config_integrity( + value, + conf[key], + path + "." + key if path else key, + ) new_conf[key] = conf[key] + has_new |= child_has_new + else: + # 直接使用现有配置 + new_conf[key] = conf[key] # 检查是否存在参考配置中没有的配置项 for key in list(conf.keys()): @@ -140,7 +142,7 @@ class AstrBotConfig(dict): return has_new - def save_config(self, replace_config: Dict = None): + def save_config(self, replace_config: dict = None): """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index da8e2a732..d8ccb1a22 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1,6 +1,4 @@ -""" -如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 -""" +"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。""" import os @@ -2707,9 +2705,9 @@ CONFIG_METADATA_3_SYSTEM = { "items": {"type": "string"}, }, }, - } + }, }, - } + }, } diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 8f8e2e0e9..2be406100 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -1,13 +1,13 @@ -""" -AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库 +"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库. 在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话, 在一个会话中可以建立多个对话, 并且支持对话的切换和删除 """ import json +from collections.abc import Awaitable, Callable + from astrbot.core import sp -from typing import Dict, List, Callable, Awaitable from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 @@ -16,31 +16,34 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): - self.session_conversations: Dict[str, str] = {} + self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 # 会话删除回调函数列表(用于级联清理,如知识库配置) - self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = [] + self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = [] def register_on_session_deleted( - self, callback: Callable[[str], Awaitable[None]] + self, + callback: Callable[[str], Awaitable[None]], ) -> None: - """注册会话删除回调函数 + """注册会话删除回调函数. 其他模块可以注册回调来响应会话删除事件,实现级联清理。 例如:知识库模块可以注册回调来清理会话的知识库配置。 Args: callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数 + """ self._on_session_deleted_callbacks.append(callback) async def _trigger_session_deleted(self, unified_msg_origin: str) -> None: - """触发会话删除回调 + """触发会话删除回调. Args: unified_msg_origin: 会话ID + """ for callback in self._on_session_deleted_callbacks: try: @@ -49,7 +52,7 @@ class ConversationManager: from astrbot.core import logger logger.error( - f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}" + f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}", ) def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: @@ -75,12 +78,13 @@ class ConversationManager: title: str | None = None, persona_id: str | None = None, ) -> str: - """新建对话,并将当前会话的对话转移到新对话 + """新建对话,并将当前会话的对话转移到新对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ if not platform_id: # 如果没有提供 platform_id,则从 unified_msg_origin 中解析 @@ -106,18 +110,22 @@ class ConversationManager: Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ self.session_conversations[unified_msg_origin] = conversation_id await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id) async def delete_conversation( - self, unified_msg_origin: str, conversation_id: str | None = None + self, + unified_msg_origin: str, + conversation_id: str | None = None, ): """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ if not conversation_id: conversation_id = self.session_conversations.get(unified_msg_origin) @@ -133,6 +141,7 @@ class ConversationManager: Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + """ await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin) self.session_conversations.pop(unified_msg_origin, None) @@ -148,6 +157,7 @@ class ConversationManager: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ ret = self.session_conversations.get(unified_msg_origin, None) if not ret: @@ -162,13 +172,15 @@ class ConversationManager: conversation_id: str, create_if_not_exists: bool = False, ) -> Conversation | None: - """获取会话的对话 + """获取会话的对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话 Returns: conversation (Conversation): 对话对象 + """ conv = await self.db.get_conversation_by_id(cid=conversation_id) if not conv and create_if_not_exists: @@ -181,18 +193,22 @@ class ConversationManager: return conv_res async def get_conversations( - self, unified_msg_origin: str | None = None, platform_id: str | None = None - ) -> List[Conversation]: - """获取对话列表 + self, + unified_msg_origin: str | None = None, + platform_id: str | None = None, + ) -> list[Conversation]: + """获取对话列表. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选 platform_id (str): 平台 ID, 可选参数, 用于过滤对话 Returns: conversations (List[Conversation]): 对话对象列表 + """ convs = await self.db.get_conversations( - user_id=unified_msg_origin, platform_id=platform_id + user_id=unified_msg_origin, + platform_id=platform_id, ) convs_res = [] for conv in convs: @@ -208,7 +224,7 @@ class ConversationManager: search_query: str = "", **kwargs, ) -> tuple[list[Conversation], int]: - """获取过滤后的对话列表 + """获取过滤后的对话列表. Args: page (int): 页码, 默认为 1 @@ -217,6 +233,7 @@ class ConversationManager: search_query (str): 搜索查询字符串, 可选 Returns: conversations (list[Conversation]): 对话对象列表 + """ convs, cnt = await self.db.get_filtered_conversations( page=page, @@ -238,13 +255,14 @@ class ConversationManager: history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, - ): - """更新会话的对话 + ) -> None: + """更新会话的对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 + """ if not conversation_id: # 如果没有提供 conversation_id,则获取当前的 @@ -258,16 +276,20 @@ class ConversationManager: ) async def update_conversation_title( - self, unified_msg_origin: str, title: str, conversation_id: str | None = None - ): - """更新会话的对话标题 + self, + unified_msg_origin: str, + title: str, + conversation_id: str | None = None, + ) -> None: + """更新会话的对话标题. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id title (str): 对话标题 - + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 Deprecated: Use `update_conversation` with `title` parameter instead. + """ await self.update_conversation( unified_msg_origin=unified_msg_origin, @@ -280,15 +302,16 @@ class ConversationManager: unified_msg_origin: str, persona_id: str, conversation_id: str | None = None, - ): - """更新会话的对话 Persona ID + ) -> None: + """更新会话的对话 Persona ID. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id persona_id (str): 对话 Persona ID - + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 Deprecated: Use `update_conversation` with `persona_id` parameter instead. + """ await self.update_conversation( unified_msg_origin=unified_msg_origin, @@ -297,39 +320,49 @@ class ConversationManager: ) async def get_human_readable_context( - self, unified_msg_origin, conversation_id, page=1, page_size=10 - ): - """获取人类可读的上下文 + self, + unified_msg_origin: str, + conversation_id: str, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[str], int]: + """获取人类可读的上下文. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 page (int): 页码 page_size (int): 每页大小 + """ conversation = await self.get_conversation(unified_msg_origin, conversation_id) + if not conversation: + return [], 0 history = json.loads(conversation.history) - contexts = [] - temp_contexts = [] + # contexts_groups 存放按顺序的段落(每个段落是一个 str 列表), + # 之后会被展平成一个扁平的 str 列表返回。 + contexts_groups: list[list[str]] = [] + temp_contexts: list[str] = [] for record in history: if record["role"] == "user": temp_contexts.append(f"User: {record['content']}") elif record["role"] == "assistant": - if "content" in record and record["content"]: + if record.get("content"): temp_contexts.append(f"Assistant: {record['content']}") elif "tool_calls" in record: tool_calls_str = json.dumps( - record["tool_calls"], ensure_ascii=False + record["tool_calls"], + ensure_ascii=False, ) temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}") else: temp_contexts.append("Assistant: [未知的内容]") - contexts.insert(0, temp_contexts) + contexts_groups.insert(0, temp_contexts) temp_contexts = [] - # 展平 contexts 列表 - contexts = [item for sublist in contexts for item in sublist] + # 展平分组后的 contexts 列表为单层字符串列表 + contexts = [item for sublist in contexts_groups for item in sublist] # 计算分页 paged_contexts = contexts[(page - 1) * page_size : page * page_size] diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 3d4b28c03..2a6ac4273 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -1,5 +1,5 @@ -""" -Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 +"""Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作. + 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。 该类还负责加载和执行插件, 以及处理事件总线的分发。 @@ -9,44 +9,44 @@ Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、 3. 执行启动完成事件钩子 """ -import traceback import asyncio -import time -import threading import os -from .event_bus import EventBus -from . import astrbot_config, html_renderer +import threading +import time +import traceback from asyncio import Queue -from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext -from astrbot.core.star import PluginManager -from astrbot.core.platform.manager import PlatformManager -from astrbot.core.star.context import Context -from astrbot.core.persona_mgr import PersonaManager -from astrbot.core.provider.manager import ProviderManager -from astrbot.core import LogBroker -from astrbot.core.db import BaseDatabase -from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 -from astrbot.core.updator import AstrBotUpdator -from astrbot.core import logger, sp + +from astrbot.core import LogBroker, logger, sp +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager -from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager -from astrbot.core.umop_config_router import UmopConfigRouter -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star_handler import star_map +from astrbot.core.db import BaseDatabase +from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager +from astrbot.core.persona_mgr import PersonaManager +from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler +from astrbot.core.platform.manager import PlatformManager +from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager +from astrbot.core.provider.manager import ProviderManager +from astrbot.core.star import PluginManager +from astrbot.core.star.context import Context +from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map +from astrbot.core.umop_config_router import UmopConfigRouter +from astrbot.core.updator import AstrBotUpdator + +from . import astrbot_config, html_renderer +from .event_bus import EventBus class AstrBotCoreLifecycle: - """ - AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 + """AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作. + 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、 EventBus 等。 该类还负责加载和执行插件, 以及处理事件总线的分发。 """ - def __init__(self, log_broker: LogBroker, db: BaseDatabase): + def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: self.log_broker = log_broker # 初始化日志代理 self.astrbot_config = astrbot_config # 初始化配置 self.db = db # 初始化数据库 @@ -70,11 +70,11 @@ class AstrBotCoreLifecycle: del os.environ["no_proxy"] logger.debug("HTTP proxy cleared") - async def initialize(self): - """ - 初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 - """ + async def initialize(self) -> None: + """初始化 AstrBot 核心生命周期管理类. + 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 + """ # 初始化日志代理 logger.info("AstrBot v" + VERSION) if os.environ.get("TESTING", ""): @@ -91,7 +91,9 @@ class AstrBotCoreLifecycle: # 初始化 AstrBot 配置管理器 self.astrbot_config_mgr = AstrBotConfigManager( - default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp + default_config=self.astrbot_config, + ucr=self.umop_config_router, + sp=sp, ) # 4.5 to 4.6 migration for umop_config_router @@ -110,7 +112,9 @@ class AstrBotCoreLifecycle: # 初始化供应商管理器 self.provider_manager = ProviderManager( - self.astrbot_config_mgr, self.db, self.persona_mgr + self.astrbot_config_mgr, + self.db, + self.persona_mgr, ) # 初始化平台管理器 @@ -158,7 +162,9 @@ class AstrBotCoreLifecycle: # 初始化事件总线 self.event_bus = EventBus( - self.event_queue, self.pipeline_scheduler_mapping, self.astrbot_config_mgr + self.event_queue, + self.pipeline_scheduler_mapping, + self.astrbot_config_mgr, ) # 记录启动时间 @@ -173,13 +179,13 @@ class AstrBotCoreLifecycle: # 初始化关闭控制面板的事件 self.dashboard_shutdown_event = asyncio.Event() - def _load(self): - """加载事件总线和任务并初始化""" - + def _load(self) -> None: + """加载事件总线和任务并初始化.""" # 创建一个异步任务来执行事件总线的 dispatch() 方法 # dispatch是一个无限循环的协程, 从事件队列中获取事件并处理 event_bus_task = asyncio.create_task( - self.event_bus.dispatch(), name="event_bus" + self.event_bus.dispatch(), + name="event_bus", ) # 把插件中注册的所有协程函数注册到事件总线中并执行 @@ -190,16 +196,17 @@ class AstrBotCoreLifecycle: tasks_ = [event_bus_task, *extra_tasks] for task in tasks_: self.curr_tasks.append( - asyncio.create_task(self._task_wrapper(task), name=task.get_name()) + asyncio.create_task(self._task_wrapper(task), name=task.get_name()), ) self.start_time = int(time.time()) - async def _task_wrapper(self, task: asyncio.Task): - """异步任务包装器, 用于处理异步任务执行中出现的各种异常 + async def _task_wrapper(self, task: asyncio.Task) -> None: + """异步任务包装器, 用于处理异步任务执行中出现的各种异常. Args: task (asyncio.Task): 要执行的异步任务 + """ try: await task @@ -212,19 +219,22 @@ class AstrBotCoreLifecycle: logger.error(f"| {line}") logger.error("-------") - async def start(self): - """启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子""" + async def start(self) -> None: + """启动 AstrBot 核心生命周期管理类. + + 用load加载事件总线和任务并初始化, 执行启动完成事件钩子 + """ self._load() logger.info("AstrBot 启动完成。") # 执行启动完成事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnAstrBotLoadedEvent + EventType.OnAstrBotLoadedEvent, ) for handler in handlers: try: logger.info( - f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + f"hook(on_astrbot_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) await handler.handler() except BaseException: @@ -233,8 +243,8 @@ class AstrBotCoreLifecycle: # 同时运行curr_tasks中的所有任务 await asyncio.gather(*self.curr_tasks, return_exceptions=True) - async def stop(self): - """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器""" + async def stop(self) -> None: + """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器.""" # 请求停止所有正在运行的异步任务 for task in self.curr_tasks: task.cancel() @@ -245,7 +255,7 @@ class AstrBotCoreLifecycle: except Exception as e: logger.warning(traceback.format_exc()) logger.warning( - f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。" + f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。", ) await self.provider_manager.terminate() @@ -262,14 +272,16 @@ class AstrBotCoreLifecycle: except Exception as e: logger.error(f"任务 {task.get_name()} 发生错误: {e}") - async def restart(self): + async def restart(self) -> None: """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" await self.provider_manager.terminate() await self.platform_manager.terminate() await self.kb_manager.terminate() self.dashboard_shutdown_event.set() threading.Thread( - target=self.astrbot_updator._reboot, name="restart", daemon=True + target=self.astrbot_updator._reboot, + name="restart", + daemon=True, ).start() def load_platform(self) -> list[asyncio.Task]: @@ -281,36 +293,38 @@ class AstrBotCoreLifecycle: asyncio.create_task( platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})", - ) + ), ) return tasks async def load_pipeline_scheduler(self) -> dict[str, PipelineScheduler]: - """加载消息事件流水线调度器 + """加载消息事件流水线调度器. Returns: dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 + """ mapping = {} for conf_id, ab_config in self.astrbot_config_mgr.confs.items(): scheduler = PipelineScheduler( - PipelineContext(ab_config, self.plugin_manager, conf_id) + PipelineContext(ab_config, self.plugin_manager, conf_id), ) await scheduler.initialize() mapping[conf_id] = scheduler return mapping - async def reload_pipeline_scheduler(self, conf_id: str): - """重新加载消息事件流水线调度器 + async def reload_pipeline_scheduler(self, conf_id: str) -> None: + """重新加载消息事件流水线调度器. Returns: dict[str, PipelineScheduler]: 平台 ID 到流水线调度器的映射 + """ ab_config = self.astrbot_config_mgr.confs.get(conf_id) if not ab_config: raise ValueError(f"配置文件 {conf_id} 不存在") scheduler = PipelineScheduler( - PipelineContext(ab_config, self.plugin_manager, conf_id) + PipelineContext(ab_config, self.plugin_manager, conf_id), ) await scheduler.initialize() self.pipeline_scheduler_mapping[conf_id] = scheduler diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 0abd3ad49..c62e49289 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,27 +1,27 @@ import abc import datetime import typing as T -from deprecated import deprecated -from dataclasses import dataclass -from astrbot.core.db.po import ( - Stats, - PlatformStat, - ConversationV2, - PlatformMessageHistory, - Attachment, - Persona, - Preference, -) from contextlib import asynccontextmanager +from dataclasses import dataclass + +from deprecated import deprecated from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker +from astrbot.core.db.po import ( + Attachment, + ConversationV2, + Persona, + PlatformMessageHistory, + PlatformStat, + Preference, + Stats, +) + @dataclass class BaseDatabase(abc.ABC): - """ - 数据库基类 - """ + """数据库基类""" DATABASE_URL = "" @@ -32,12 +32,13 @@ class BaseDatabase(abc.ABC): future=True, ) self.AsyncSessionLocal = sessionmaker( - self.engine, class_=AsyncSession, expire_on_commit=False + self.engine, + class_=AsyncSession, + expire_on_commit=False, ) async def initialize(self): """初始化数据库连接""" - pass @asynccontextmanager async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]: @@ -91,7 +92,9 @@ class BaseDatabase(abc.ABC): @abc.abstractmethod async def get_conversations( - self, user_id: str | None = None, platform_id: str | None = None + self, + user_id: str | None = None, + platform_id: str | None = None, ) -> list[ConversationV2]: """Get all conversations for a specific user and platform_id(optional). @@ -106,7 +109,9 @@ class BaseDatabase(abc.ABC): @abc.abstractmethod async def get_all_conversations( - self, page: int = 1, page_size: int = 20 + self, + page: int = 1, + page_size: int = 20, ) -> list[ConversationV2]: """Get all conversations with pagination.""" ... @@ -173,7 +178,10 @@ class BaseDatabase(abc.ABC): @abc.abstractmethod async def delete_platform_message_offset( - self, platform_id: str, user_id: str, offset_sec: int = 86400 + self, + platform_id: str, + user_id: str, + offset_sec: int = 86400, ) -> None: """Delete platform message history records older than the specified offset.""" ... @@ -243,7 +251,11 @@ class BaseDatabase(abc.ABC): @abc.abstractmethod async def insert_preference_or_update( - self, scope: str, scope_id: str, key: str, value: dict + self, + scope: str, + scope_id: str, + key: str, + value: dict, ) -> Preference: """Insert a new preference record.""" ... @@ -255,7 +267,10 @@ class BaseDatabase(abc.ABC): @abc.abstractmethod async def get_preferences( - self, scope: str, scope_id: str | None = None, key: str | None = None + self, + scope: str, + scope_id: str | None = None, + key: str | None = None, ) -> list[Preference]: """Get all preferences for a specific scope ID or key.""" ... diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index 0b7888548..d7bca3067 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -1,20 +1,21 @@ import os -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.db import BaseDatabase -from astrbot.core.config import AstrBotConfig + from astrbot.api import logger, sp +from astrbot.core.config import AstrBotConfig +from astrbot.core.db import BaseDatabase +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + from .migra_3_to_4 import ( migration_conversation_table, - migration_platform_table, - migration_webchat_data, migration_persona_data, + migration_platform_table, migration_preferences, + migration_webchat_data, ) async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: - """ - 检查是否需要进行数据库迁移 + """检查是否需要进行数据库迁移 如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。 """ # 仅当 data 目录下存在旧版本数据(data_v3.db 文件)时才考虑迁移 @@ -24,7 +25,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: if not os.path.exists(data_v3_db): return False migration_done = await db_helper.get_preference( - "global", "global", "migration_done_v4" + "global", + "global", + "migration_done_v4", ) if migration_done: return False @@ -36,8 +39,7 @@ async def do_migration_v4( platform_id_map: dict[str, dict[str, str]], astrbot_config: AstrBotConfig, ) -> None: - """ - 执行数据库迁移 + """执行数据库迁移 迁移旧的 webchat_conversation 表到新的 conversation 表。 迁移旧的 platform 到新的 platform_stats 表。 """ diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 4aa5082db..13a14c327 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -1,15 +1,18 @@ -import json import datetime -from .. import BaseDatabase -from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3 -from .shared_preferences_v3 import sp as sp_v3 -from astrbot.core.config.default import DB_PATH +import json + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + from astrbot.api import logger, sp from astrbot.core.config import AstrBotConfig -from astrbot.core.platform.astr_message_event import MessageSesion -from sqlalchemy.ext.asyncio import AsyncSession +from astrbot.core.config.default import DB_PATH from astrbot.core.db.po import ConversationV2, PlatformMessageHistory -from sqlalchemy import text +from astrbot.core.platform.astr_message_event import MessageSesion + +from .. import BaseDatabase +from .shared_preferences_v3 import sp as sp_v3 +from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3 """ 1. 迁移旧的 webchat_conversation 表到新的 conversation 表。 @@ -18,7 +21,8 @@ from sqlalchemy import text def get_platform_id( - platform_id_map: dict[str, dict[str, str]], old_platform_name: str + platform_id_map: dict[str, dict[str, str]], + old_platform_name: str, ) -> str: return platform_id_map.get( old_platform_name, @@ -27,7 +31,8 @@ def get_platform_id( def get_platform_type( - platform_id_map: dict[str, dict[str, str]], old_platform_name: str + platform_id_map: dict[str, dict[str, str]], + old_platform_name: str, ) -> str: return platform_id_map.get( old_platform_name, @@ -36,13 +41,15 @@ def get_platform_type( async def migration_conversation_table( - db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] + db_helper: BaseDatabase, + platform_id_map: dict[str, dict[str, str]], ): db_helper_v3 = SQLiteV3DatabaseV3( - db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) conversations, total_cnt = db_helper_v3.get_all_conversations( - page=1, page_size=10000000 + page=1, + page_size=10000000, ) logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...") @@ -61,13 +68,14 @@ async def migration_conversation_table( ) if not conv: logger.info( - f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" + f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", ) if ":" not in conv.user_id: continue session = MessageSesion.from_str(session_str=conv.user_id) platform_id = get_platform_id( - platform_id_map, session.platform_name + platform_id_map, + session.platform_name, ) session.platform_id = platform_id # 更新平台名称为新的 ID conv_v2 = ConversationV2( @@ -90,10 +98,11 @@ async def migration_conversation_table( async def migration_platform_table( - db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] + db_helper: BaseDatabase, + platform_id_map: dict[str, dict[str, str]], ): db_helper_v3 = SQLiteV3DatabaseV3( - db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) secs_from_2023_4_10_to_now = ( datetime.datetime.now(datetime.timezone.utc) @@ -134,10 +143,12 @@ async def migration_platform_table( if cnt == 0: continue platform_id = get_platform_id( - platform_id_map, platform_stats_v3[idx].name + platform_id_map, + platform_stats_v3[idx].name, ) platform_type = get_platform_type( - platform_id_map, platform_stats_v3[idx].name + platform_id_map, + platform_stats_v3[idx].name, ) try: await dbsession.execute( @@ -149,7 +160,8 @@ async def migration_platform_table( """), { "timestamp": datetime.datetime.fromtimestamp( - bucket_end, tz=datetime.timezone.utc + bucket_end, + tz=datetime.timezone.utc, ), "platform_id": platform_id, "platform_type": platform_type, @@ -165,14 +177,16 @@ async def migration_platform_table( async def migration_webchat_data( - db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] + db_helper: BaseDatabase, + platform_id_map: dict[str, dict[str, str]], ): """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( - db_path=DB_PATH.replace("data_v4.db", "data_v3.db") + db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) conversations, total_cnt = db_helper_v3.get_all_conversations( - page=1, page_size=10000000 + page=1, + page_size=10000000, ) logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...") @@ -191,7 +205,7 @@ async def migration_webchat_data( ) if not conv: logger.info( - f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" + f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", ) if ":" in conv.user_id: continue @@ -218,10 +232,10 @@ async def migration_webchat_data( async def migration_persona_data( - db_helper: BaseDatabase, astrbot_config: AstrBotConfig + db_helper: BaseDatabase, + astrbot_config: AstrBotConfig, ): - """ - 迁移 Persona 数据到新的表中。 + """迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 """ v3_persona_config: list[dict] = astrbot_config.get("persona", []) @@ -253,14 +267,15 @@ async def migration_persona_data( begin_dialogs=begin_dialogs, ) logger.info( - f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。" + f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。", ) except Exception as e: logger.error(f"解析 Persona 配置失败:{e}") async def migration_preferences( - db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] + db_helper: BaseDatabase, + platform_id_map: dict[str, dict[str, str]], ): # 1. global scope migration keys = [ @@ -329,10 +344,13 @@ async def migration_preferences( for provider_type, provider_id in perf.items(): await sp.put_async( - "umo", str(session), f"provider_perf_{provider_type}", provider_id + "umo", + str(session), + f"provider_perf_{provider_type}", + provider_id, ) logger.info( - f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}" + f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}", ) except Exception as e: logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True) diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index 8a1dc5de7..dc70026f9 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -9,7 +9,7 @@ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): if not isinstance(abconf_data, dict): # should be unreachable logger.warning( - f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}" + f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}", ) return diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 6a661bd3d..3abcb1a66 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -1,6 +1,7 @@ import json import os from typing import TypeVar + from astrbot.core.utils.astrbot_path import get_astrbot_data_path _VT = TypeVar("_VT") @@ -16,7 +17,7 @@ class SharedPreferences: def _load_preferences(self): if os.path.exists(self.path): try: - with open(self.path, "r") as f: + with open(self.path) as f: return json.load(f) except json.JSONDecodeError: os.remove(self.path) diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index ad86c51f3..7b341c316 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -1,8 +1,9 @@ import sqlite3 import time -from astrbot.core.db.po import Platform, Stats -from typing import Tuple, List, Dict, Any from dataclasses import dataclass +from typing import Any + +from astrbot.core.db.po import Platform, Stats @dataclass @@ -94,7 +95,7 @@ class SQLiteDatabase: c.execute( """ PRAGMA table_info(webchat_conversation) - """ + """, ) res = c.fetchall() has_title = False @@ -108,14 +109,14 @@ class SQLiteDatabase: c.execute( """ ALTER TABLE webchat_conversation ADD COLUMN title TEXT; - """ + """, ) self.conn.commit() if not has_persona_id: c.execute( """ ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; - """ + """, ) self.conn.commit() @@ -126,7 +127,7 @@ class SQLiteDatabase: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: Tuple = None): + def _exec_sql(self, sql: str, params: tuple = None): conn = self.conn try: c = self.conn.cursor() @@ -174,7 +175,7 @@ class SQLiteDatabase: """ SELECT * FROM platform """ - + where_clause + + where_clause, ) platform = [] @@ -194,7 +195,7 @@ class SQLiteDatabase: c.execute( """ SELECT SUM(count) FROM platform - """ + """, ) res = c.fetchone() c.close() @@ -214,7 +215,7 @@ class SQLiteDatabase: SELECT name, SUM(count), timestamp FROM platform """ + where_clause - + " GROUP BY name" + + " GROUP BY name", ) platform = [] @@ -242,7 +243,7 @@ class SQLiteDatabase: c.close() if not res: - return + return None return Conversation(*res) @@ -257,7 +258,7 @@ class SQLiteDatabase: (user_id, cid, history, updated_at, created_at), ) - def get_conversations(self, user_id: str) -> Tuple: + def get_conversations(self, user_id: str) -> tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -280,7 +281,7 @@ class SQLiteDatabase: title = row[3] persona_id = row[4] conversations.append( - Conversation("", cid, "[]", created_at, updated_at, title, persona_id) + Conversation("", cid, "[]", created_at, updated_at, title, persona_id), ) return conversations @@ -319,8 +320,10 @@ class SQLiteDatabase: ) def get_all_conversations( - self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: + self, + page: int = 1, + page_size: int = 20, + ) -> tuple[list[dict[str, Any]], int]: """获取所有对话,支持分页,按更新时间降序排序""" try: c = self.conn.cursor() @@ -366,7 +369,7 @@ class SQLiteDatabase: "persona_id": persona_id or "", "created_at": created_at or 0, "updated_at": updated_at or 0, - } + }, ) return conversations, total_count @@ -381,12 +384,12 @@ class SQLiteDatabase: self, page: int = 1, page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, + platforms: list[str] = None, + message_types: list[str] = None, search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: + exclude_ids: list[str] = None, + exclude_platforms: list[str] = None, + ) -> tuple[list[dict[str, Any]], int]: """获取筛选后的对话列表""" try: c = self.conn.cursor() @@ -422,7 +425,7 @@ class SQLiteDatabase: if search_query: search_query = search_query.encode("unicode_escape").decode("utf-8") where_clauses.append( - "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)" + "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)", ) search_param = f"%{search_query}%" params.extend([search_param, search_param, search_param, search_param]) @@ -482,7 +485,7 @@ class SQLiteDatabase: "persona_id": persona_id or "", "created_at": created_at or 0, "updated_at": updated_at or 0, - } + }, ) return conversations, total_count diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 24a05f947..1e7245976 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,15 +1,15 @@ import uuid - -from datetime import datetime, timezone from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import TypedDict + from sqlmodel import ( + JSON, + Field, SQLModel, Text, - JSON, UniqueConstraint, - Field, ) -from typing import Optional, TypedDict class PlatformStat(SQLModel, table=True): @@ -40,7 +40,8 @@ class ConversationV2(SQLModel, table=True): __tablename__ = "conversations" inner_conversation_id: int = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True} + primary_key=True, + sa_column_kwargs={"autoincrement": True}, ) conversation_id: str = Field( max_length=36, @@ -50,14 +51,14 @@ class ConversationV2(SQLModel, table=True): ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) - content: Optional[list] = Field(default=None, sa_type=JSON) + content: list | None = Field(default=None, sa_type=JSON) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) - title: Optional[str] = Field(default=None, max_length=255) - persona_id: Optional[str] = Field(default=None) + title: str | None = Field(default=None, max_length=255) + persona_id: str | None = Field(default=None) __table_args__ = ( UniqueConstraint( @@ -76,13 +77,15 @@ class Persona(SQLModel, table=True): __tablename__ = "personas" id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, ) persona_id: str = Field(max_length=255, nullable=False) system_prompt: str = Field(sa_type=Text, nullable=False) - begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) + begin_dialogs: list | None = Field(default=None, sa_type=JSON) """a list of strings, each representing a dialog to start with""" - tools: Optional[list] = Field(default=None, sa_type=JSON) + tools: list | None = Field(default=None, sa_type=JSON) """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names.""" created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( @@ -104,7 +107,9 @@ class Preference(SQLModel, table=True): __tablename__ = "preferences" id: int | None = Field( - default=None, primary_key=True, sa_column_kwargs={"autoincrement": True} + default=None, + primary_key=True, + sa_column_kwargs={"autoincrement": True}, ) scope: str = Field(nullable=False) """Scope of the preference, such as 'global', 'umo', 'plugin'.""" @@ -138,13 +143,15 @@ class PlatformMessageHistory(SQLModel, table=True): __tablename__ = "platform_message_history" id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) # An id of group, user in platform - sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform - sender_name: Optional[str] = Field( - default=None + sender_id: str | None = Field(default=None) # ID of the sender in the platform + sender_name: str | None = Field( + default=None, ) # Name of the sender in the platform content: dict = Field(sa_type=JSON, nullable=False) # a message chain list created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @@ -163,7 +170,9 @@ class Attachment(SQLModel, table=True): __tablename__ = "attachments" inner_attachment_id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, ) attachment_id: str = Field( max_length=36, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index f9faede19..457a4ab3f 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,22 +1,27 @@ import asyncio -import typing as T import threading +import typing as T from datetime import datetime, timedelta + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import col, delete, desc, func, or_, select, text, update + from astrbot.core.db import BaseDatabase from astrbot.core.db.po import ( - ConversationV2, - PlatformStat, - PlatformMessageHistory, Attachment, + ConversationV2, Persona, + PlatformMessageHistory, + PlatformStat, Preference, - Stats as DeprecatedStats, - Platform as DeprecatedPlatformStat, SQLModel, ) - -from sqlmodel import select, update, delete, text, func, or_, desc, col -from sqlalchemy.ext.asyncio import AsyncSession +from astrbot.core.db.po import ( + Platform as DeprecatedPlatformStat, +) +from astrbot.core.db.po import ( + Stats as DeprecatedStats, +) NOT_GIVEN = T.TypeVar("NOT_GIVEN") @@ -57,7 +62,9 @@ class SQLiteDatabase(BaseDatabase): async with session.begin(): if timestamp is None: timestamp = datetime.now().replace( - minute=0, second=0, microsecond=0 + minute=0, + second=0, + microsecond=0, ) current_hour = timestamp await session.execute( @@ -81,13 +88,13 @@ class SQLiteDatabase(BaseDatabase): session: AsyncSession result = await session.execute( select(func.count(col(PlatformStat.platform_id))).select_from( - PlatformStat - ) + PlatformStat, + ), ) count = result.scalar_one_or_none() return count if count is not None else 0 - async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]: + async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: """Get platform statistics within the specified offset in seconds and group by platform_id.""" async with self.get_db() as session: session: AsyncSession @@ -138,7 +145,7 @@ class SQLiteDatabase(BaseDatabase): select(ConversationV2) .order_by(desc(ConversationV2.created_at)) .offset(offset) - .limit(page_size) + .limit(page_size), ) return result.scalars().all() @@ -157,7 +164,7 @@ class SQLiteDatabase(BaseDatabase): if platform_ids: base_query = base_query.where( - col(ConversationV2.platform_id).in_(platform_ids) + col(ConversationV2.platform_id).in_(platform_ids), ) if search_query: search_query = search_query.encode("unicode_escape").decode("utf-8") @@ -167,16 +174,16 @@ class SQLiteDatabase(BaseDatabase): col(ConversationV2.content).ilike(f"%{search_query}%"), col(ConversationV2.user_id).ilike(f"%{search_query}%"), col(ConversationV2.conversation_id).ilike(f"%{search_query}%"), - ) + ), ) if "message_types" in kwargs and len(kwargs["message_types"]) > 0: for msg_type in kwargs["message_types"]: base_query = base_query.where( - col(ConversationV2.user_id).ilike(f"%:{msg_type}:%") + col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"), ) if "platforms" in kwargs and len(kwargs["platforms"]) > 0: base_query = base_query.where( - col(ConversationV2.platform_id).in_(kwargs["platforms"]) + col(ConversationV2.platform_id).in_(kwargs["platforms"]), ) # Get total count matching the filters @@ -233,7 +240,7 @@ class SQLiteDatabase(BaseDatabase): session: AsyncSession async with session.begin(): query = update(ConversationV2).where( - col(ConversationV2.conversation_id) == cid + col(ConversationV2.conversation_id) == cid, ) values = {} if title is not None: @@ -243,7 +250,7 @@ class SQLiteDatabase(BaseDatabase): if content is not None: values["content"] = content if not values: - return + return None query = query.values(**values) await session.execute(query) return await self.get_conversation_by_id(cid) @@ -254,8 +261,8 @@ class SQLiteDatabase(BaseDatabase): async with session.begin(): await session.execute( delete(ConversationV2).where( - col(ConversationV2.conversation_id) == cid - ) + col(ConversationV2.conversation_id) == cid, + ), ) async def delete_conversations_by_user_id(self, user_id: str) -> None: @@ -263,7 +270,9 @@ class SQLiteDatabase(BaseDatabase): session: AsyncSession async with session.begin(): await session.execute( - delete(ConversationV2).where(col(ConversationV2.user_id) == user_id) + delete(ConversationV2).where( + col(ConversationV2.user_id) == user_id + ), ) async def get_session_conversations( @@ -282,7 +291,7 @@ class SQLiteDatabase(BaseDatabase): select( col(Preference.scope_id).label("session_id"), func.json_extract(Preference.value, "$.val").label( - "conversation_id" + "conversation_id", ), # type: ignore col(ConversationV2.persona_id).label("persona_id"), col(ConversationV2.title).label("title"), @@ -295,7 +304,8 @@ class SQLiteDatabase(BaseDatabase): == ConversationV2.conversation_id, ) .outerjoin( - Persona, col(ConversationV2.persona_id) == Persona.persona_id + Persona, + col(ConversationV2.persona_id) == Persona.persona_id, ) .where(Preference.scope == "umo", Preference.key == "sel_conv_id") ) @@ -308,14 +318,14 @@ class SQLiteDatabase(BaseDatabase): col(Preference.scope_id).ilike(search_pattern), col(ConversationV2.title).ilike(search_pattern), col(Persona.persona_id).ilike(search_pattern), - ) + ), ) # 平台筛选 if platform: platform_pattern = f"{platform}:%" base_query = base_query.where( - col(Preference.scope_id).like(platform_pattern) + col(Preference.scope_id).like(platform_pattern), ) # 排序 @@ -336,7 +346,8 @@ class SQLiteDatabase(BaseDatabase): == ConversationV2.conversation_id, ) .outerjoin( - Persona, col(ConversationV2.persona_id) == Persona.persona_id + Persona, + col(ConversationV2.persona_id) == Persona.persona_id, ) .where(Preference.scope == "umo", Preference.key == "sel_conv_id") ) @@ -349,13 +360,13 @@ class SQLiteDatabase(BaseDatabase): col(Preference.scope_id).ilike(search_pattern), col(ConversationV2.title).ilike(search_pattern), col(Persona.persona_id).ilike(search_pattern), - ) + ), ) if platform: platform_pattern = f"{platform}:%" count_base_query = count_base_query.where( - col(Preference.scope_id).like(platform_pattern) + col(Preference.scope_id).like(platform_pattern), ) total_result = await session.execute(count_base_query) @@ -396,7 +407,10 @@ class SQLiteDatabase(BaseDatabase): return new_history async def delete_platform_message_offset( - self, platform_id, user_id, offset_sec=86400 + self, + platform_id, + user_id, + offset_sec=86400, ): """Delete platform message history records older than the specified offset.""" async with self.get_db() as session: @@ -409,11 +423,15 @@ class SQLiteDatabase(BaseDatabase): col(PlatformMessageHistory.platform_id) == platform_id, col(PlatformMessageHistory.user_id) == user_id, col(PlatformMessageHistory.created_at) < cutoff_time, - ) + ), ) async def get_platform_message_history( - self, platform_id, user_id, page=1, page_size=20 + self, + platform_id, + user_id, + page=1, + page_size=20, ): """Get platform message history records.""" async with self.get_db() as session: @@ -452,7 +470,11 @@ class SQLiteDatabase(BaseDatabase): return result.scalar_one_or_none() async def insert_persona( - self, persona_id, system_prompt, begin_dialogs=None, tools=None + self, + persona_id, + system_prompt, + begin_dialogs=None, + tools=None, ): """Insert a new persona record.""" async with self.get_db() as session: @@ -484,7 +506,11 @@ class SQLiteDatabase(BaseDatabase): return result.scalars().all() async def update_persona( - self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN + self, + persona_id, + system_prompt=None, + begin_dialogs=None, + tools=NOT_GIVEN, ): """Update a persona's system prompt or begin dialogs.""" async with self.get_db() as session: @@ -499,7 +525,7 @@ class SQLiteDatabase(BaseDatabase): if tools is not NOT_GIVEN: values["tools"] = tools if not values: - return + return None query = query.values(**values) await session.execute(query) return await self.get_persona_by_id(persona_id) @@ -510,7 +536,7 @@ class SQLiteDatabase(BaseDatabase): session: AsyncSession async with session.begin(): await session.execute( - delete(Persona).where(col(Persona.persona_id) == persona_id) + delete(Persona).where(col(Persona.persona_id) == persona_id), ) async def insert_preference_or_update(self, scope, scope_id, key, value): @@ -529,7 +555,10 @@ class SQLiteDatabase(BaseDatabase): existing_preference.value = value else: new_preference = Preference( - scope=scope, scope_id=scope_id, key=key, value=value + scope=scope, + scope_id=scope_id, + key=key, + value=value, ) session.add(new_preference) return existing_preference or new_preference @@ -568,7 +597,7 @@ class SQLiteDatabase(BaseDatabase): col(Preference.scope) == scope, col(Preference.scope_id) == scope_id, col(Preference.key) == key, - ) + ), ) await session.commit() @@ -581,7 +610,7 @@ class SQLiteDatabase(BaseDatabase): delete(Preference).where( col(Preference.scope) == scope, col(Preference.scope_id) == scope_id, - ) + ), ) await session.commit() @@ -598,7 +627,7 @@ class SQLiteDatabase(BaseDatabase): now = datetime.now() start_time = now - timedelta(seconds=offset_sec) result = await session.execute( - select(PlatformStat).where(PlatformStat.timestamp >= start_time) + select(PlatformStat).where(PlatformStat.timestamp >= start_time), ) all_datas = result.scalars().all() deprecated_stats = DeprecatedStats() @@ -608,7 +637,7 @@ class SQLiteDatabase(BaseDatabase): name=data.platform_id, count=data.count, timestamp=int(data.timestamp.timestamp()), - ) + ), ) return deprecated_stats @@ -630,7 +659,7 @@ class SQLiteDatabase(BaseDatabase): async with self.get_db() as session: session: AsyncSession result = await session.execute( - select(func.sum(PlatformStat.count)).select_from(PlatformStat) + select(func.sum(PlatformStat.count)).select_from(PlatformStat), ) total_count = result.scalar_one_or_none() return total_count if total_count is not None else 0 @@ -656,7 +685,7 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute( select(PlatformStat.platform_id, func.sum(PlatformStat.count)) .where(PlatformStat.timestamp >= start_time) - .group_by(PlatformStat.platform_id) + .group_by(PlatformStat.platform_id), ) grouped_stats = result.all() deprecated_stats = DeprecatedStats() @@ -666,7 +695,7 @@ class SQLiteDatabase(BaseDatabase): name=platform_id, count=count, timestamp=int(start_time.timestamp()), - ) + ), ) return deprecated_stats diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 27fc9f3fb..7440b6f2a 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -10,18 +10,16 @@ class Result: class BaseVecDB: async def initialize(self): - """ - 初始化向量数据库 - """ - pass + """初始化向量数据库""" @abc.abstractmethod async def insert( - self, content: str, metadata: dict | None = None, id: str | None = None + self, + content: str, + metadata: dict | None = None, + id: str | None = None, ) -> int: - """ - 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 - """ + """插入一条文本和其对应向量,自动生成 ID 并保持一致性。""" ... @abc.abstractmethod @@ -35,11 +33,11 @@ class BaseVecDB: max_retries: int = 3, progress_callback=None, ) -> int: - """ - 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 + """批量插入文本和其对应向量,自动生成 ID 并保持一致性。 Args: progress_callback: 进度回调函数,接收参数 (current, total) + """ ... @@ -52,8 +50,7 @@ class BaseVecDB: rerank: bool = False, metadata_filters: dict | None = None, ) -> list[Result]: - """ - 搜索最相似的文档。 + """搜索最相似的文档。 Args: query (str): 查询文本 top_k (int): 返回的最相似文档的数量 @@ -64,8 +61,7 @@ class BaseVecDB: @abc.abstractmethod async def delete(self, doc_id: str) -> bool: - """ - 删除指定文档。 + """删除指定文档。 Args: doc_id (str): 要删除的文档 ID Returns: diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 265c0cc43..e27eb6fe8 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -1,12 +1,13 @@ -import os import json -from datetime import datetime +import os from contextlib import asynccontextmanager +from datetime import datetime -from sqlalchemy import Text, Column +from sqlalchemy import Column, Text from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker -from sqlmodel import Field, SQLModel, select, col, func, text, MetaData +from sqlmodel import Field, MetaData, SQLModel, col, func, select, text + from astrbot.core import logger @@ -20,7 +21,9 @@ class Document(BaseDocModel, table=True): __tablename__ = "documents" # type: ignore id: int | None = Field( - default=None, primary_key=True, sa_column_kwargs={"autoincrement": True} + default=None, + primary_key=True, + sa_column_kwargs={"autoincrement": True}, ) doc_id: str = Field(nullable=False) text: str = Field(nullable=False) @@ -36,7 +39,8 @@ class DocumentStorage: self.engine: AsyncEngine | None = None self.async_session_maker: sessionmaker | None = None self.sqlite_init_path = os.path.join( - os.path.dirname(__file__), "sqlite_init.sql" + os.path.dirname(__file__), + "sqlite_init.sql", ) async def initialize(self): @@ -50,26 +54,26 @@ class DocumentStorage: await conn.execute( text( "ALTER TABLE documents ADD COLUMN kb_doc_id TEXT " - "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED" - ) + "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED", + ), ) await conn.execute( text( "ALTER TABLE documents ADD COLUMN user_id TEXT " - "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED" - ) + "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED", + ), ) # Create indexes await conn.execute( text( - "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)" - ) + "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)", + ), ) await conn.execute( text( - "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)" - ) + "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)", + ), ) except BaseException: pass @@ -113,10 +117,11 @@ class DocumentStorage: Returns: list: The list of documents that match the filters. + """ if self.engine is None: logger.warning( - "Database connection is not initialized, returning empty result" + "Database connection is not initialized, returning empty result", ) return [] @@ -125,7 +130,7 @@ class DocumentStorage: for key, val in metadata_filters.items(): query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}") + text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), ).params(**{f"filter_{key}": val}) if ids is not None and len(ids) > 0: @@ -153,24 +158,27 @@ class DocumentStorage: Returns: int: The integer ID of the inserted document. + """ assert self.engine is not None, "Database connection is not initialized." - async with self.get_session() as session: - async with session.begin(): - document = Document( - doc_id=doc_id, - text=text, - metadata_=json.dumps(metadata), - created_at=datetime.now(), - updated_at=datetime.now(), - ) - session.add(document) - await session.flush() # Flush to get the ID - return document.id # type: ignore + async with self.get_session() as session, session.begin(): + document = Document( + doc_id=doc_id, + text=text, + metadata_=json.dumps(metadata), + created_at=datetime.now(), + updated_at=datetime.now(), + ) + session.add(document) + await session.flush() # Flush to get the ID + return document.id # type: ignore async def insert_documents_batch( - self, doc_ids: list[str], texts: list[str], metadatas: list[dict] + self, + doc_ids: list[str], + texts: list[str], + metadatas: list[dict], ) -> list[int]: """Batch insert documents and return their integer IDs. @@ -181,44 +189,44 @@ class DocumentStorage: Returns: list[int]: List of integer IDs of the inserted documents. + """ assert self.engine is not None, "Database connection is not initialized." - async with self.get_session() as session: - async with session.begin(): - import json + async with self.get_session() as session, session.begin(): + import json - documents = [] - for doc_id, text, metadata in zip(doc_ids, texts, metadatas): - document = Document( - doc_id=doc_id, - text=text, - metadata_=json.dumps(metadata), - created_at=datetime.now(), - updated_at=datetime.now(), - ) - documents.append(document) - session.add(document) + documents = [] + for doc_id, text, metadata in zip(doc_ids, texts, metadatas): + document = Document( + doc_id=doc_id, + text=text, + metadata_=json.dumps(metadata), + created_at=datetime.now(), + updated_at=datetime.now(), + ) + documents.append(document) + session.add(document) - await session.flush() # Flush to get all IDs - return [doc.id for doc in documents] # type: ignore + await session.flush() # Flush to get all IDs + return [doc.id for doc in documents] # type: ignore async def delete_document_by_doc_id(self, doc_id: str): """Delete a document by its doc_id. Args: doc_id (str): The doc_id of the document to delete. + """ assert self.engine is not None, "Database connection is not initialized." - async with self.get_session() as session: - async with session.begin(): - query = select(Document).where(col(Document.doc_id) == doc_id) - result = await session.execute(query) - document = result.scalar_one_or_none() + async with self.get_session() as session, session.begin(): + query = select(Document).where(col(Document.doc_id) == doc_id) + result = await session.execute(query) + document = result.scalar_one_or_none() - if document: - await session.delete(document) + if document: + await session.delete(document) async def get_document_by_doc_id(self, doc_id: str): """Retrieve a document by its doc_id. @@ -228,6 +236,7 @@ class DocumentStorage: Returns: dict: The document data or None if not found. + """ assert self.engine is not None, "Database connection is not initialized." @@ -246,46 +255,46 @@ class DocumentStorage: Args: doc_id (str): The doc_id. new_text (str): The new text to update the document with. + """ assert self.engine is not None, "Database connection is not initialized." - async with self.get_session() as session: - async with session.begin(): - query = select(Document).where(col(Document.doc_id) == doc_id) - result = await session.execute(query) - document = result.scalar_one_or_none() + async with self.get_session() as session, session.begin(): + query = select(Document).where(col(Document.doc_id) == doc_id) + result = await session.execute(query) + document = result.scalar_one_or_none() - if document: - document.text = new_text - document.updated_at = datetime.now() - session.add(document) + if document: + document.text = new_text + document.updated_at = datetime.now() + session.add(document) async def delete_documents(self, metadata_filters: dict): """Delete documents by their metadata filters. Args: metadata_filters (dict): The metadata filters to apply. + """ if self.engine is None: logger.warning( - "Database connection is not initialized, skipping delete operation" + "Database connection is not initialized, skipping delete operation", ) return - async with self.get_session() as session: - async with session.begin(): - query = select(Document) + async with self.get_session() as session, session.begin(): + query = select(Document) - for key, val in metadata_filters.items(): - query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}") - ).params(**{f"filter_{key}": val}) + for key, val in metadata_filters.items(): + query = query.where( + text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), + ).params(**{f"filter_{key}": val}) - result = await session.execute(query) - documents = result.scalars().all() + result = await session.execute(query) + documents = result.scalars().all() - for doc in documents: - await session.delete(doc) + for doc in documents: + await session.delete(doc) async def count_documents(self, metadata_filters: dict | None = None) -> int: """Count documents in the database. @@ -295,6 +304,7 @@ class DocumentStorage: Returns: int: The count of documents. + """ if self.engine is None: logger.warning("Database connection is not initialized, returning 0") @@ -306,7 +316,7 @@ class DocumentStorage: if metadata_filters: for key, val in metadata_filters.items(): query = query.where( - text(f"json_extract(metadata, '$.{key}') = :filter_{key}") + text(f"json_extract(metadata, '$.{key}') = :filter_{key}"), ).params(**{f"filter_{key}": val}) result = await session.execute(query) @@ -318,12 +328,13 @@ class DocumentStorage: Returns: list: A list of user IDs. + """ assert self.engine is not None, "Database connection is not initialized." async with self.get_session() as session: query = text( - "SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL" + "SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL", ) result = await session.execute(query) rows = result.fetchall() @@ -337,6 +348,7 @@ class DocumentStorage: Returns: dict: The converted dictionary. + """ return { "id": document.id, @@ -361,6 +373,7 @@ class DocumentStorage: dict: The converted dictionary. Note: This method is kept for backward compatibility but is no longer used internally. + """ return { "id": row[0], diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 2c0cc8dfe..24f1c323c 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -2,9 +2,10 @@ try: import faiss except ModuleNotFoundError: raise ImportError( - "faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。" + "faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。", ) import os + import numpy as np @@ -27,11 +28,12 @@ class EmbeddingStorage: id (int): 向量的ID Raises: ValueError: 如果向量的维度与存储的维度不匹配 + """ assert self.index is not None, "FAISS index is not initialized." if vector.shape[0] != self.dimension: raise ValueError( - f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}" + f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}", ) self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() @@ -44,11 +46,12 @@ class EmbeddingStorage: ids (list[int]): 向量的ID列表 Raises: ValueError: 如果向量的维度与存储的维度不匹配 + """ assert self.index is not None, "FAISS index is not initialized." if vectors.shape[1] != self.dimension: raise ValueError( - f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}" + f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}", ) self.index.add_with_ids(vectors, np.array(ids)) await self.save_index() @@ -61,6 +64,7 @@ class EmbeddingStorage: k (int): 返回的最相似向量的数量 Returns: tuple: (距离, 索引) + """ assert self.index is not None, "FAISS index is not initialized." faiss.normalize_L2(vector) @@ -72,6 +76,7 @@ class EmbeddingStorage: Args: ids (list[int]): 要删除的向量ID列表 + """ assert self.index is not None, "FAISS index is not initialized." id_array = np.array(ids, dtype=np.int64) @@ -83,5 +88,6 @@ class EmbeddingStorage: Args: path (str): 保存索引的路径 + """ faiss.write_index(self.index, self.path) diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 8a21538ec..14221f1e8 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,18 +1,18 @@ -import uuid import time +import uuid + import numpy as np + +from astrbot import logger +from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider + +from ..base import BaseVecDB, Result from .document_storage import DocumentStorage from .embedding_storage import EmbeddingStorage -from ..base import Result, BaseVecDB -from astrbot.core.provider.provider import EmbeddingProvider -from astrbot.core.provider.provider import RerankProvider -from astrbot import logger class FaissVecDB(BaseVecDB): - """ - A class to represent a vector database. - """ + """A class to represent a vector database.""" def __init__( self, @@ -26,7 +26,8 @@ class FaissVecDB(BaseVecDB): self.embedding_provider = embedding_provider self.document_storage = DocumentStorage(doc_store_path) self.embedding_storage = EmbeddingStorage( - embedding_provider.get_dim(), index_store_path + embedding_provider.get_dim(), + index_store_path, ) self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider @@ -35,11 +36,12 @@ class FaissVecDB(BaseVecDB): await self.document_storage.initialize() async def insert( - self, content: str, metadata: dict | None = None, id: str | None = None + self, + content: str, + metadata: dict | None = None, + id: str | None = None, ) -> int: - """ - 插入一条文本和其对应向量,自动生成 ID 并保持一致性。 - """ + """插入一条文本和其对应向量,自动生成 ID 并保持一致性。""" metadata = metadata or {} str_id = id or str(uuid.uuid4()) # 使用 UUID 作为原始 ID @@ -63,11 +65,11 @@ class FaissVecDB(BaseVecDB): max_retries: int = 3, progress_callback=None, ) -> list[int]: - """ - 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 + """批量插入文本和其对应向量,自动生成 ID 并保持一致性。 Args: progress_callback: 进度回调函数,接收参数 (current, total) + """ metadatas = metadatas or [{} for _ in contents] ids = ids or [str(uuid.uuid4()) for _ in contents] @@ -83,12 +85,14 @@ class FaissVecDB(BaseVecDB): ) end = time.time() logger.debug( - f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds." + f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.", ) # 使用 DocumentStorage 的批量插入方法 int_ids = await self.document_storage.insert_documents_batch( - ids, contents, metadatas + ids, + contents, + metadatas, ) # 批量插入向量到 FAISS @@ -104,8 +108,7 @@ class FaissVecDB(BaseVecDB): rerank: bool = False, metadata_filters: dict | None = None, ) -> list[Result]: - """ - 搜索最相似的文档。 + """搜索最相似的文档。 Args: query (str): 查询文本 @@ -116,6 +119,7 @@ class FaissVecDB(BaseVecDB): Returns: List[Result]: 查询结果 + """ embedding = await self.embedding_provider.get_embedding(query) scores, indices = await self.embedding_storage.search( @@ -128,7 +132,8 @@ class FaissVecDB(BaseVecDB): scores[0] = 1.0 - (scores[0] / 2.0) # NOTE: maybe the size is less than k. fetched_docs = await self.document_storage.get_documents( - metadata_filters=metadata_filters or {}, ids=indices[0] + metadata_filters=metadata_filters or {}, + ids=indices[0], ) if not fetched_docs: return [] @@ -149,7 +154,9 @@ class FaissVecDB(BaseVecDB): documents = [doc.data["text"] for doc in top_k_results] reranked_results = await self.rerank_provider.rerank(query, documents) reranked_results = sorted( - reranked_results, key=lambda x: x.relevance_score, reverse=True + reranked_results, + key=lambda x: x.relevance_score, + reverse=True, ) top_k_results = [ top_k_results[reranked_result.index] @@ -159,9 +166,7 @@ class FaissVecDB(BaseVecDB): return top_k_results async def delete(self, doc_id: str): - """ - 删除一条文档块(chunk) - """ + """删除一条文档块(chunk)""" # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) int_id = result["id"] if result else None @@ -176,23 +181,23 @@ class FaissVecDB(BaseVecDB): await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: - """ - 计算文档数量 + """计算文档数量 Args: metadata_filter (dict | None): 元数据过滤器 + """ count = await self.document_storage.count_documents( - metadata_filters=metadata_filter or {} + metadata_filters=metadata_filter or {}, ) return count async def delete_documents(self, metadata_filters: dict): - """ - 根据元数据过滤器删除文档 - """ + """根据元数据过滤器删除文档""" docs = await self.document_storage.get_documents( - metadata_filters=metadata_filters, offset=None, limit=None + metadata_filters=metadata_filters, + offset=None, + limit=None, ) doc_ids: list[int] = [doc["id"] for doc in docs] await self.embedding_storage.delete(doc_ids) diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 2ae709396..749df753e 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -1,5 +1,4 @@ -""" -事件总线, 用于处理事件的分发和处理 +"""事件总线, 用于处理事件的分发和处理 事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理 其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑 @@ -13,10 +12,12 @@ class: import asyncio from asyncio import Queue -from astrbot.core.pipeline.scheduler import PipelineScheduler + from astrbot.core import logger -from .platform import AstrMessageEvent from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.pipeline.scheduler import PipelineScheduler + +from .platform import AstrMessageEvent class EventBus: @@ -46,14 +47,15 @@ class EventBus: Args: event (AstrMessageEvent): 事件对象 + """ # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 if event.get_sender_name(): logger.info( - f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" + f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}", ) # 没有发送者名称: [平台名] 发送者ID: 消息概要 else: logger.info( - f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}" + f"[{conf_name}] [{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}", ) diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index 56fe7ea10..ea97759c1 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -1,9 +1,9 @@ import asyncio import os -import uuid -import time -from urllib.parse import urlparse, unquote import platform +import time +import uuid +from urllib.parse import unquote, urlparse class FileTokenService: @@ -40,8 +40,8 @@ class FileTokenService: Raises: FileNotFoundError: 当路径不存在时抛出 - """ + """ # 处理 file:/// try: parsed_uri = urlparse(file_path) @@ -61,7 +61,7 @@ class FileTokenService: if not os.path.exists(local_path): raise FileNotFoundError( - f"文件不存在: {local_path} (原始输入: {file_path})" + f"文件不存在: {local_path} (原始输入: {file_path})", ) file_token = str(uuid.uuid4()) @@ -84,6 +84,7 @@ class FileTokenService: Raises: KeyError: 当令牌不存在或已过期时抛出 FileNotFoundError: 当文件本身已被删除时抛出 + """ async with self.lock: await self._cleanup_expired_tokens() diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index c6c01a304..f54d18641 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -1,5 +1,4 @@ -""" -AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。 +"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。 工作流程: 1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期 @@ -8,10 +7,10 @@ AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。 import asyncio import traceback -from astrbot.core import logger + +from astrbot.core import LogBroker, logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core import LogBroker from astrbot.dashboard.server import AstrBotDashboard @@ -39,7 +38,10 @@ class InitialLoader: webui_dir = self.webui_dir self.dashboard_server = AstrBotDashboard( - core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir + core_lifecycle, + self.db, + core_lifecycle.dashboard_shutdown_event, + webui_dir, ) coro = self.dashboard_server.run() diff --git a/astrbot/core/knowledge_base/chunking/__init__.py b/astrbot/core/knowledge_base/chunking/__init__.py index 3124afe81..805ddc242 100644 --- a/astrbot/core/knowledge_base/chunking/__init__.py +++ b/astrbot/core/knowledge_base/chunking/__init__.py @@ -1,6 +1,4 @@ -""" -文档分块模块 -""" +"""文档分块模块""" from .base import BaseChunker from .fixed_size import FixedSizeChunker diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index 5aaf84ba1..a45d86ad1 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -21,4 +21,5 @@ class BaseChunker(ABC): Returns: list[str]: 分块后的文本列表 + """ diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index c9b35d7d8..5439f070f 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -18,6 +18,7 @@ class FixedSizeChunker(BaseChunker): Args: chunk_size: 块的大小(字符数) chunk_overlap: 块之间的重叠字符数 + """ self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap @@ -32,6 +33,7 @@ class FixedSizeChunker(BaseChunker): Returns: list[str]: 分块后的文本列表 + """ chunk_size = kwargs.get("chunk_size", self.chunk_size) chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap) diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 21b76cba5..3f4aabb57 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -1,4 +1,5 @@ from collections.abc import Callable + from .base import BaseChunker @@ -11,8 +12,7 @@ class RecursiveCharacterChunker(BaseChunker): is_separator_regex: bool = False, separators: list[str] | None = None, ): - """ - 初始化递归字符文本分割器 + """初始化递归字符文本分割器 Args: chunk_size: 每个文本块的最大大小 @@ -20,6 +20,7 @@ class RecursiveCharacterChunker(BaseChunker): length_function: 计算文本长度的函数 is_separator_regex: 分隔符是否为正则表达式 separators: 用于分割文本的分隔符列表,按优先级排序 + """ self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap @@ -39,8 +40,7 @@ class RecursiveCharacterChunker(BaseChunker): ] async def chunk(self, text: str, **kwargs) -> list[str]: - """ - 递归地将文本分割成块 + """递归地将文本分割成块 Args: text: 要分割的文本 @@ -49,6 +49,7 @@ class RecursiveCharacterChunker(BaseChunker): Returns: 分割后的文本块列表 + """ if not text: return [] @@ -90,7 +91,7 @@ class RecursiveCharacterChunker(BaseChunker): combined_text, chunk_size=chunk_size, chunk_overlap=overlap, - ) + ), ) current_chunk = [] current_chunk_length = 0 @@ -98,8 +99,10 @@ class RecursiveCharacterChunker(BaseChunker): # 递归分割过大的部分 final_chunks.extend( await self.chunk( - split, chunk_size=chunk_size, chunk_overlap=overlap - ) + split, + chunk_size=chunk_size, + chunk_overlap=overlap, + ), ) # 如果添加这部分会使当前块超过chunk_size elif current_chunk_length + split_length > chunk_size: @@ -132,16 +135,19 @@ class RecursiveCharacterChunker(BaseChunker): return [text] def _split_by_character( - self, text: str, chunk_size: int | None = None, overlap: int | None = None + self, + text: str, + chunk_size: int | None = None, + overlap: int | None = None, ) -> list[str]: - """ - 按字符级别分割文本 + """按字符级别分割文本 Args: text: 要分割的文本 Returns: 分割后的文本块列表 + """ chunk_size = chunk_size or self.chunk_size overlap = overlap or self.chunk_overlap diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 827d621d3..5e1db842f 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,18 +1,18 @@ from contextlib import asynccontextmanager from pathlib import Path -from sqlmodel import col, desc -from sqlalchemy import text, func, select, update, delete +from sqlalchemy import delete, func, select, text, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlmodel import col, desc from astrbot.core import logger +from astrbot.core.db.vec_db.faiss_impl import FaissVecDB from astrbot.core.knowledge_base.models import ( BaseKBModel, KBDocument, KBMedia, KnowledgeBase, ) -from astrbot.core.db.vec_db.faiss_impl import FaissVecDB class KBSQLiteDatabase: @@ -21,6 +21,7 @@ class KBSQLiteDatabase: Args: db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db + """ self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" @@ -85,77 +86,77 @@ class KBSQLiteDatabase: await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_kb_kb_id " - "ON knowledge_bases(kb_id)" - ) + "ON knowledge_bases(kb_id)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_kb_name " - "ON knowledge_bases(kb_name)" - ) + "ON knowledge_bases(kb_name)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_kb_created_at " - "ON knowledge_bases(created_at)" - ) + "ON knowledge_bases(created_at)", + ), ) # 创建文档表索引 await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_doc_doc_id " - "ON kb_documents(doc_id)" - ) + "ON kb_documents(doc_id)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_doc_kb_id " - "ON kb_documents(kb_id)" - ) + "ON kb_documents(kb_id)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_doc_name " - "ON kb_documents(doc_name)" - ) + "ON kb_documents(doc_name)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_doc_type " - "ON kb_documents(file_type)" - ) + "ON kb_documents(file_type)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_doc_created_at " - "ON kb_documents(created_at)" - ) + "ON kb_documents(created_at)", + ), ) # 创建多媒体表索引 await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_media_media_id " - "ON kb_media(media_id)" - ) + "ON kb_media(media_id)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_media_doc_id " - "ON kb_media(doc_id)" - ) + "ON kb_media(doc_id)", + ), ) await session.execute( text( - "CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)" - ) + "CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)", + ), ) await session.execute( text( "CREATE INDEX IF NOT EXISTS idx_media_type " - "ON kb_media(media_type)" - ) + "ON kb_media(media_type)", + ), ) await session.commit() @@ -208,7 +209,10 @@ class KBSQLiteDatabase: return result.scalar_one_or_none() async def list_documents_by_kb( - self, kb_id: str, offset: int = 0, limit: int = 100 + self, + kb_id: str, + offset: int = 0, + limit: int = 100, ) -> list[KBDocument]: """列出知识库的所有文档""" async with self.get_db() as session: @@ -226,7 +230,7 @@ class KBSQLiteDatabase: """统计知识库的文档数量""" async with self.get_db() as session: stmt = select(func.count(col(KBDocument.id))).where( - col(KBDocument.kb_id) == kb_id + col(KBDocument.kb_id) == kb_id, ) result = await session.execute(stmt) return result.scalar() or 0 @@ -252,12 +256,11 @@ class KBSQLiteDatabase: async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): """删除单个文档及其相关数据""" # 在知识库表中删除 - async with self.get_db() as session: - async with session.begin(): - # 删除文档记录 - delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id) - await session.execute(delete_stmt) - await session.commit() + async with self.get_db() as session, session.begin(): + # 删除文档记录 + delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id) + await session.execute(delete_stmt) + await session.commit() # 在 vec db 中删除相关向量 await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id}) @@ -282,18 +285,17 @@ class KBSQLiteDatabase: """更新知识库统计信息""" chunk_cnt = await vec_db.count_documents() - async with self.get_db() as session: - async with session.begin(): - update_stmt = ( - update(KnowledgeBase) - .where(col(KnowledgeBase.kb_id) == kb_id) - .values( - doc_count=select(func.count(col(KBDocument.id))) - .where(col(KBDocument.kb_id) == kb_id) - .scalar_subquery(), - chunk_count=chunk_cnt, - ) + async with self.get_db() as session, session.begin(): + update_stmt = ( + update(KnowledgeBase) + .where(col(KnowledgeBase.kb_id) == kb_id) + .values( + doc_count=select(func.count(col(KBDocument.id))) + .where(col(KBDocument.kb_id) == kb_id) + .scalar_subquery(), + chunk_count=chunk_cnt, ) + ) - await session.execute(update_stmt) - await session.commit() + await session.execute(update_stmt) + await session.commit() diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 09b9c9fc8..b03b00369 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -1,16 +1,19 @@ -import uuid -import aiofiles import json +import uuid from pathlib import Path -from .models import KnowledgeBase, KBDocument, KBMedia -from .kb_db_sqlite import KBSQLiteDatabase + +import aiofiles + +from astrbot.core import logger from astrbot.core.db.vec_db.base import BaseVecDB from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB -from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider from astrbot.core.provider.manager import ProviderManager -from .parsers.util import select_parser +from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider + from .chunking.base import BaseChunker -from astrbot.core import logger +from .kb_db_sqlite import KBSQLiteDatabase +from .models import KBDocument, KBMedia, KnowledgeBase +from .parsers.util import select_parser class KBHelper: @@ -45,11 +48,11 @@ class KBHelper: if not self.kb.embedding_provider_id: raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id( - self.kb.embedding_provider_id + self.kb.embedding_provider_id, ) # type: ignore if not ep: raise ValueError( - f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider" + f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider", ) return ep @@ -57,11 +60,11 @@ class KBHelper: if not self.kb.rerank_provider_id: return None rp: RerankProvider = await self.prov_mgr.get_provider_by_id( - self.kb.rerank_provider_id + self.kb.rerank_provider_id, ) # type: ignore if not rp: raise ValueError( - f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider" + f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider", ) return rp @@ -122,6 +125,7 @@ class KBHelper: - stage: 当前阶段 ('parsing', 'chunking', 'embedding') - current: 当前进度 - total: 总数 + """ await self._ensure_vec_db() doc_id = str(uuid.uuid4()) @@ -162,7 +166,9 @@ class KBHelper: await progress_callback("chunking", 0, 100) chunks_text = await self.chunker.chunk( - text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap + text_content, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, ) contents = [] metadatas = [] @@ -173,7 +179,7 @@ class KBHelper: "kb_id": self.kb.kb_id, "kb_doc_id": doc_id, "chunk_index": idx, - } + }, ) if progress_callback: @@ -234,7 +240,9 @@ class KBHelper: raise e async def list_documents( - self, offset: int = 0, limit: int = 100 + self, + offset: int = 0, + limit: int = 100, ) -> list[KBDocument]: """列出知识库的所有文档""" docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit) @@ -288,12 +296,17 @@ class KBHelper: await session.refresh(doc) async def get_chunks_by_doc_id( - self, doc_id: str, offset: int = 0, limit: int = 100 + self, + doc_id: str, + offset: int = 0, + limit: int = 100, ) -> list[dict]: """获取文档的所有块及其元数据""" vec_db: FaissVecDB = self.vec_db # type: ignore chunks = await vec_db.document_storage.get_documents( - metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit + metadata_filters={"kb_doc_id": doc_id}, + offset=offset, + limit=limit, ) result = [] for chunk in chunks: @@ -306,7 +319,7 @@ class KBHelper: "chunk_index": chunk_md["chunk_index"], "content": chunk["text"], "char_count": len(chunk["text"]), - } + }, ) return result diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index c1c63d08a..f7e07fe15 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -1,19 +1,17 @@ import traceback from pathlib import Path + from astrbot.core import logger from astrbot.core.provider.manager import ProviderManager -from .retrieval.manager import RetrievalManager, RetrievalResult -from .retrieval.sparse_retriever import SparseRetriever -from .retrieval.rank_fusion import RankFusion -from .kb_db_sqlite import KBSQLiteDatabase - # from .chunking.fixed_size import FixedSizeChunker from .chunking.recursive import RecursiveCharacterChunker +from .kb_db_sqlite import KBSQLiteDatabase from .kb_helper import KBHelper - from .models import KnowledgeBase - +from .retrieval.manager import RetrievalManager, RetrievalResult +from .retrieval.rank_fusion import RankFusion +from .retrieval.sparse_retriever import SparseRetriever FILES_PATH = "data/knowledge_base" DB_PATH = Path(FILES_PATH) / "kb.db" @@ -257,6 +255,7 @@ class KnowledgeBaseManager: Returns: str: 格式化的上下文文本 + """ lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"] diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index 010d6113c..da919a384 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime, timezone -from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData +from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint class BaseKBModel(SQLModel, table=False): @@ -17,7 +17,9 @@ class KnowledgeBase(BaseKBModel, table=True): __tablename__ = "knowledge_bases" # type: ignore id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, ) kb_id: str = Field( max_length=36, @@ -63,7 +65,9 @@ class KBDocument(BaseKBModel, table=True): __tablename__ = "kb_documents" # type: ignore id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, ) doc_id: str = Field( max_length=36, @@ -95,7 +99,9 @@ class KBMedia(BaseKBModel, table=True): __tablename__ = "kb_media" # type: ignore id: int | None = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + primary_key=True, + sa_column_kwargs={"autoincrement": True}, + default=None, ) media_id: str = Field( max_length=36, diff --git a/astrbot/core/knowledge_base/parsers/__init__.py b/astrbot/core/knowledge_base/parsers/__init__.py index 6851edebd..184f2fd41 100644 --- a/astrbot/core/knowledge_base/parsers/__init__.py +++ b/astrbot/core/knowledge_base/parsers/__init__.py @@ -1,15 +1,13 @@ -""" -文档解析器模块 -""" +"""文档解析器模块""" from .base import BaseParser, MediaItem, ParseResult -from .text_parser import TextParser from .pdf_parser import PDFParser +from .text_parser import TextParser __all__ = [ "BaseParser", "MediaItem", + "PDFParser", "ParseResult", "TextParser", - "PDFParser", ] diff --git a/astrbot/core/knowledge_base/parsers/base.py b/astrbot/core/knowledge_base/parsers/base.py index 1c571db2e..4ffca9c6f 100644 --- a/astrbot/core/knowledge_base/parsers/base.py +++ b/astrbot/core/knowledge_base/parsers/base.py @@ -47,4 +47,5 @@ class BaseParser(ABC): Returns: ParseResult: 解析结果 + """ diff --git a/astrbot/core/knowledge_base/parsers/markitdown_parser.py b/astrbot/core/knowledge_base/parsers/markitdown_parser.py index 50af984e0..9ef347933 100644 --- a/astrbot/core/knowledge_base/parsers/markitdown_parser.py +++ b/astrbot/core/knowledge_base/parsers/markitdown_parser.py @@ -1,11 +1,12 @@ import io import os +from markitdown_no_magika import MarkItDown, StreamInfo + from astrbot.core.knowledge_base.parsers.base import ( BaseParser, ParseResult, ) -from markitdown_no_magika import MarkItDown, StreamInfo class MarkitdownParser(BaseParser): diff --git a/astrbot/core/knowledge_base/parsers/pdf_parser.py b/astrbot/core/knowledge_base/parsers/pdf_parser.py index fca626871..aeeea930a 100644 --- a/astrbot/core/knowledge_base/parsers/pdf_parser.py +++ b/astrbot/core/knowledge_base/parsers/pdf_parser.py @@ -29,6 +29,7 @@ class PDFParser(BaseParser): Returns: ParseResult: 包含文本和图片的解析结果 + """ pdf_file = io.BytesIO(file_content) reader = PdfReader(pdf_file) @@ -87,7 +88,7 @@ class PDFParser(BaseParser): file_name=f"page_{page_num}_img_{image_counter}.{ext}", content=image_data, mime_type=mime_type, - ) + ), ) except Exception: # 单个图片提取失败不影响整体 diff --git a/astrbot/core/knowledge_base/parsers/text_parser.py b/astrbot/core/knowledge_base/parsers/text_parser.py index 49a95a95c..bed2d09b8 100644 --- a/astrbot/core/knowledge_base/parsers/text_parser.py +++ b/astrbot/core/knowledge_base/parsers/text_parser.py @@ -26,6 +26,7 @@ class TextParser(BaseParser): Raises: ValueError: 如果无法解码文件 + """ # 尝试多种编码 for encoding in ["utf-8", "gbk", "gb2312", "gb18030"]: diff --git a/astrbot/core/knowledge_base/parsers/util.py b/astrbot/core/knowledge_base/parsers/util.py index 41cc5e4de..7a4463202 100644 --- a/astrbot/core/knowledge_base/parsers/util.py +++ b/astrbot/core/knowledge_base/parsers/util.py @@ -6,7 +6,7 @@ async def select_parser(ext: str) -> BaseParser: from .markitdown_parser import MarkitdownParser return MarkitdownParser() - elif ext == ".pdf": + if ext == ".pdf": from .pdf_parser import PDFParser return PDFParser() diff --git a/astrbot/core/knowledge_base/retrieval/__init__.py b/astrbot/core/knowledge_base/retrieval/__init__.py index 16a5e6645..f5d196cb9 100644 --- a/astrbot/core/knowledge_base/retrieval/__init__.py +++ b/astrbot/core/knowledge_base/retrieval/__init__.py @@ -1,16 +1,14 @@ -""" -检索模块 -""" +"""检索模块""" from .manager import RetrievalManager, RetrievalResult -from .sparse_retriever import SparseRetriever, SparseResult -from .rank_fusion import RankFusion, FusedResult +from .rank_fusion import FusedResult, RankFusion +from .sparse_retriever import SparseResult, SparseRetriever __all__ = [ + "FusedResult", + "RankFusion", "RetrievalManager", "RetrievalResult", - "SparseRetriever", "SparseResult", - "RankFusion", - "FusedResult", + "SparseRetriever", ] diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 278e4da20..9a42cd6cd 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -4,18 +4,17 @@ """ import time - from dataclasses import dataclass -from typing import List +from astrbot import logger +from astrbot.core.db.vec_db.base import Result +from astrbot.core.db.vec_db.faiss_impl import FaissVecDB from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever from astrbot.core.provider.provider import RerankProvider -from astrbot.core.db.vec_db.base import Result -from astrbot.core.db.vec_db.faiss_impl import FaissVecDB + from ..kb_helper import KBHelper -from astrbot import logger @dataclass @@ -53,6 +52,7 @@ class RetrievalManager: sparse_retriever: 稀疏检索器 rank_fusion: 结果融合器 kb_db: 知识库数据库实例 + """ self.sparse_retriever = sparse_retriever self.rank_fusion = rank_fusion @@ -61,11 +61,11 @@ class RetrievalManager: async def retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_id_helper_map: dict[str, KBHelper], top_k_fusion: int = 20, top_m_final: int = 5, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """混合检索 流程: @@ -82,6 +82,7 @@ class RetrievalManager: Returns: List[RetrievalResult]: 检索结果列表 + """ if not kb_ids: return [] @@ -114,7 +115,7 @@ class RetrievalManager: ) time_end = time.time() logger.debug( - f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results." + f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.", ) # 2. 稀疏检索 @@ -126,7 +127,7 @@ class RetrievalManager: ) time_end = time.time() logger.debug( - f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results." + f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.", ) # 3. 结果融合 @@ -138,7 +139,7 @@ class RetrievalManager: ) time_end = time.time() logger.debug( - f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results." + f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.", ) # 4. 转换为 RetrievalResult (获取元数据) @@ -159,7 +160,7 @@ class RetrievalManager: "chunk_index": fr.chunk_index, "char_count": len(fr.content), }, - ) + ), ) # 5. Rerank @@ -188,7 +189,7 @@ class RetrievalManager: async def _dense_retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_options: dict, ): """稠密检索 (向量相似度) @@ -202,6 +203,7 @@ class RetrievalManager: Returns: List[Result]: 检索结果列表 + """ all_results: list[Result] = [] for kb_id in kb_ids: @@ -233,10 +235,10 @@ class RetrievalManager: async def _rerank( self, query: str, - results: List[RetrievalResult], + results: list[RetrievalResult], top_k: int, rerank_provider: RerankProvider, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """Rerank 重排序 Args: @@ -246,6 +248,7 @@ class RetrievalManager: Returns: List[RetrievalResult]: 重排序后的结果列表 + """ if not results: return [] diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 3ceba4ff8..26203f94b 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -37,6 +37,7 @@ class RankFusion: Args: kb_db: 知识库数据库实例 k: RRF 参数,用于平滑排名 + """ self.kb_db = kb_db self.k = k @@ -59,6 +60,7 @@ class RankFusion: Returns: List[FusedResult]: 融合后的结果列表 + """ # 1. 构建排名映射 dense_ranks = { @@ -101,7 +103,9 @@ class RankFusion: # 4. 排序 sorted_ids = sorted( - rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True + rrf_scores.keys(), + key=lambda cid: rrf_scores[cid], + reverse=True, )[:top_k] # 5. 构建融合结果 @@ -118,7 +122,7 @@ class RankFusion: kb_id=sr.kb_id, content=sr.content, score=rrf_scores[identifier], - ) + ), ) elif identifier in vec_doc_id_to_dense: # 从向量检索获取信息,需要从数据库获取块的详细信息 @@ -132,7 +136,7 @@ class RankFusion: kb_id=chunk_md["kb_id"], content=vec_result.data["text"], score=rrf_scores[identifier], - ) + ), ) return fused_results diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 315930b3e..ea5da1c9e 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -3,13 +3,15 @@ 使用 BM25 算法进行基于关键词的文档检索 """ -import jieba -import os import json +import os from dataclasses import dataclass + +import jieba from rank_bm25 import BM25Okapi -from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase + from astrbot.core.db.vec_db.faiss_impl import FaissVecDB +from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase @dataclass @@ -37,6 +39,7 @@ class SparseRetriever: Args: kb_db: 知识库数据库实例 + """ self.kb_db = kb_db self._index_cache = {} # 缓存 BM25 索引 @@ -64,6 +67,7 @@ class SparseRetriever: Returns: List[SparseResult]: 检索结果列表 + """ # 1. 获取所有相关块 top_k_sparse = 0 @@ -73,7 +77,9 @@ class SparseRetriever: if not vec_db: continue result = await vec_db.document_storage.get_documents( - metadata_filters={}, limit=None, offset=None + metadata_filters={}, + limit=None, + offset=None, ) chunk_mds = [json.loads(doc["metadata"]) for doc in result] result = [ @@ -122,7 +128,7 @@ class SparseRetriever: kb_id=chunk["kb_id"], content=chunk["text"], score=float(score), - ) + ), ) results.sort(key=lambda x: x.score, reverse=True) diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 3a1c50371..376f5ffd6 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -1,5 +1,4 @@ -""" -日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能 +"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能 const: CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量 @@ -21,14 +20,14 @@ function: 4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流 """ -import logging -import colorlog import asyncio +import logging import os import sys -from collections import deque from asyncio import Queue -from typing import List +from collections import deque + +import colorlog # 日志缓存大小 CACHED_SIZE = 200 @@ -52,6 +51,7 @@ def is_plugin_path(pathname): Returns: bool: 如果路径来自插件目录,则返回 True,否则返回 False + """ if not pathname: return False @@ -68,6 +68,7 @@ def get_short_level_name(level_name): Returns: str: 四个字母的日志级别缩写 + """ level_map = { "DEBUG": "DBUG", @@ -87,13 +88,14 @@ class LogBroker: def __init__(self): self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 - self.subscribers: List[Queue] = [] # 订阅者列表 + self.subscribers: list[Queue] = [] # 订阅者列表 def register(self) -> Queue: """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列 Returns: Queue: 订阅者的队列, 可用于接收日志消息 + """ q = Queue(maxsize=CACHED_SIZE + 10) self.subscribers.append(q) @@ -104,6 +106,7 @@ class LogBroker: Args: q (Queue): 需要取消订阅的队列 + """ self.subscribers.remove(q) @@ -113,6 +116,7 @@ class LogBroker: Args: log_entry (dict): 日志消息, 包含日志级别和日志内容. example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"} + """ self.log_cache.append(log_entry) for q in self.subscribers: @@ -138,6 +142,7 @@ class LogQueueHandler(logging.Handler): Args: record (logging.LogRecord): 日志记录对象, 包含日志信息 + """ log_entry = self.format(record) self.log_broker.publish( @@ -145,7 +150,7 @@ class LogQueueHandler(logging.Handler): "level": record.levelname, "time": record.asctime, "data": log_entry, - } + }, ) @@ -164,6 +169,7 @@ class LogManager: Returns: logging.Logger: 返回配置好的日志记录器 + """ logger = logging.getLogger(log_name) # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置 @@ -171,10 +177,10 @@ class LogManager: return logger # 如果logger没有处理器 console_handler = logging.StreamHandler( - sys.stdout + sys.stdout, ) # 创建一个StreamHandler用于控制台输出 console_handler.setLevel( - logging.DEBUG + logging.DEBUG, ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG # 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息 @@ -195,7 +201,8 @@ class LogManager: class FileNameFilter(logging.Filter): """文件名过滤器类, 用于修改日志记录的文件名格式 - 例如: 将文件路径 /path/to/file.py 转换为 file. 格式""" + 例如: 将文件路径 /path/to/file.py 转换为 file. 格式 + """ # 获取这个文件和父文件夹的名字:. 并且去除 .py def filter(self, record): @@ -231,6 +238,7 @@ class LogManager: Args: logger (logging.Logger): 日志记录器 log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息 + """ handler = LogQueueHandler(log_broker) handler.setLevel(logging.DEBUG) @@ -240,7 +248,7 @@ class LogManager: # 为队列处理器设置相同格式的formatter handler.setFormatter( logging.Formatter( - "[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s" - ) + "[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s", + ), ) logger.addHandler(handler) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index d9ec4b41b..bdab0b6e3 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -1,5 +1,4 @@ -""" -MIT License +"""MIT License Copyright (c) 2021 Lxns-Network @@ -26,7 +25,6 @@ import asyncio import base64 import json import os -import typing as T import uuid from enum import Enum @@ -81,7 +79,7 @@ class BaseMessageComponent(BaseModel): k = "type" if isinstance(v, bool): v = 1 if v else 0 - output += ",%s=%s" % ( + output += ",{}={}".format( k, str(v) .replace("&", "&") @@ -110,7 +108,7 @@ class BaseMessageComponent(BaseModel): class Plain(BaseMessageComponent): type = ComponentType.Plain text: str - convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息 + convert: bool | None = True # 若为 False 则直接发送未转换 CQ 码的消息 def __init__(self, text: str, convert: bool = True, **_): super().__init__(text=text, convert=convert, **_) @@ -139,17 +137,17 @@ class Face(BaseMessageComponent): class Record(BaseMessageComponent): type = ComponentType.Record - file: T.Optional[str] = "" - magic: T.Optional[bool] = False - url: T.Optional[str] = "" - cache: T.Optional[bool] = True - proxy: T.Optional[bool] = True - timeout: T.Optional[int] = 0 + file: str | None = "" + magic: bool | None = False + url: str | None = "" + cache: bool | None = True + proxy: bool | None = True + timeout: int | None = 0 # 额外 - path: T.Optional[str] + path: str | None - def __init__(self, file: T.Optional[str], **_): - for k in _.keys(): + def __init__(self, file: str | None, **_): + for k in _: if k == "url": pass # Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}") @@ -174,15 +172,16 @@ class Record(BaseMessageComponent): Returns: str: 语音的本地路径,以绝对路径表示。 + """ if not self.file: raise Exception(f"not a valid file: {self.file}") if self.file.startswith("file:///"): return self.file[8:] - elif self.file.startswith("http"): + if self.file.startswith("http"): file_path = await download_image_by_url(self.file) return os.path.abspath(file_path) - elif self.file.startswith("base64://"): + if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") @@ -190,16 +189,16 @@ class Record(BaseMessageComponent): with open(file_path, "wb") as f: f.write(image_bytes) return os.path.abspath(file_path) - elif os.path.exists(self.file): + if os.path.exists(self.file): return os.path.abspath(self.file) - else: - raise Exception(f"not a valid file: {self.file}") + raise Exception(f"not a valid file: {self.file}") async def convert_to_base64(self) -> str: """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 Returns: str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ # convert to base64 if not self.file: @@ -219,14 +218,14 @@ class Record(BaseMessageComponent): return bs64_data async def register_to_file_service(self) -> str: - """ - 将语音注册到文件服务。 + """将语音注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base + """ callback_host = astrbot_config.get("callback_api_base") @@ -245,10 +244,10 @@ class Record(BaseMessageComponent): class Video(BaseMessageComponent): type = ComponentType.Video file: str - cover: T.Optional[str] = "" - c: T.Optional[int] = 2 + cover: str | None = "" + c: int | None = 2 # 额外 - path: T.Optional[str] = "" + path: str | None = "" def __init__(self, file: str, **_): super().__init__(file=file, **_) @@ -268,32 +267,31 @@ class Video(BaseMessageComponent): Returns: str: 视频的本地路径,以绝对路径表示。 + """ url = self.file if url and url.startswith("file:///"): return url[8:] - elif url and url.startswith("http"): + if url and url.startswith("http"): download_dir = os.path.join(get_astrbot_data_path(), "temp") video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") await download_file(url, video_file_path) if os.path.exists(video_file_path): return os.path.abspath(video_file_path) - else: - raise Exception(f"download failed: {url}") - elif os.path.exists(url): + raise Exception(f"download failed: {url}") + if os.path.exists(url): return os.path.abspath(url) - else: - raise Exception(f"not a valid file: {url}") + raise Exception(f"not a valid file: {url}") async def register_to_file_service(self): - """ - 将视频注册到文件服务。 + """将视频注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base + """ callback_host = astrbot_config.get("callback_api_base") @@ -330,8 +328,8 @@ class Video(BaseMessageComponent): class At(BaseMessageComponent): type = ComponentType.At - qq: T.Union[int, str] # 此处str为all时代表所有人 - name: T.Optional[str] = "" + qq: int | str # 此处str为all时代表所有人 + name: str | None = "" def __init__(self, **_): super().__init__(**_) @@ -373,7 +371,7 @@ class Shake(BaseMessageComponent): # TODO class Anonymous(BaseMessageComponent): # TODO type = ComponentType.Anonymous - ignore: T.Optional[bool] = False + ignore: bool | None = False def __init__(self, **_): super().__init__(**_) @@ -383,8 +381,8 @@ class Share(BaseMessageComponent): type = ComponentType.Share url: str title: str - content: T.Optional[str] = "" - image: T.Optional[str] = "" + content: str | None = "" + image: str | None = "" def __init__(self, **_): super().__init__(**_) @@ -393,7 +391,7 @@ class Share(BaseMessageComponent): class Contact(BaseMessageComponent): # TODO type = ComponentType.Contact _type: str # type 字段冲突 - id: T.Optional[int] = 0 + id: int | None = 0 def __init__(self, **_): super().__init__(**_) @@ -403,8 +401,8 @@ class Location(BaseMessageComponent): # TODO type = ComponentType.Location lat: float lon: float - title: T.Optional[str] = "" - content: T.Optional[str] = "" + title: str | None = "" + content: str | None = "" def __init__(self, **_): super().__init__(**_) @@ -413,12 +411,12 @@ class Location(BaseMessageComponent): # TODO class Music(BaseMessageComponent): type = ComponentType.Music _type: str - id: T.Optional[int] = 0 - url: T.Optional[str] = "" - audio: T.Optional[str] = "" - title: T.Optional[str] = "" - content: T.Optional[str] = "" - image: T.Optional[str] = "" + id: int | None = 0 + url: str | None = "" + audio: str | None = "" + title: str | None = "" + content: str | None = "" + image: str | None = "" def __init__(self, **_): # for k in _.keys(): @@ -429,18 +427,18 @@ class Music(BaseMessageComponent): class Image(BaseMessageComponent): type = ComponentType.Image - file: T.Optional[str] = "" - _type: T.Optional[str] = "" - subType: T.Optional[int] = 0 - url: T.Optional[str] = "" - cache: T.Optional[bool] = True - id: T.Optional[int] = 40000 - c: T.Optional[int] = 2 + file: str | None = "" + _type: str | None = "" + subType: int | None = 0 + url: str | None = "" + cache: bool | None = True + id: int | None = 40000 + c: int | None = 2 # 额外 - path: T.Optional[str] = "" - file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 + path: str | None = "" + file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: T.Optional[str], **_): + def __init__(self, file: str | None, **_): super().__init__(file=file, **_) @staticmethod @@ -470,16 +468,17 @@ class Image(BaseMessageComponent): Returns: str: 图片的本地路径,以绝对路径表示。 + """ url = self.url or self.file if not url: raise ValueError("No valid file or URL provided") if url.startswith("file:///"): return url[8:] - elif url.startswith("http"): + if url.startswith("http"): image_file_path = await download_image_by_url(url) return os.path.abspath(image_file_path) - elif url.startswith("base64://"): + if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") @@ -487,16 +486,16 @@ class Image(BaseMessageComponent): with open(image_file_path, "wb") as f: f.write(image_bytes) return os.path.abspath(image_file_path) - elif os.path.exists(url): + if os.path.exists(url): return os.path.abspath(url) - else: - raise Exception(f"not a valid file: {url}") + raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 Returns: str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 + """ # convert to base64 url = self.url or self.file @@ -517,14 +516,14 @@ class Image(BaseMessageComponent): return bs64_data async def register_to_file_service(self) -> str: - """ - 将图片注册到文件服务。 + """将图片注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base + """ callback_host = astrbot_config.get("callback_api_base") @@ -542,24 +541,24 @@ class Image(BaseMessageComponent): class Reply(BaseMessageComponent): type = ComponentType.Reply - id: T.Union[str, int] + id: str | int """所引用的消息 ID""" - chain: T.Optional[T.List["BaseMessageComponent"]] = [] + chain: list["BaseMessageComponent"] | None = [] """被引用的消息段列表""" - sender_id: T.Optional[int] | T.Optional[str] = 0 + sender_id: int | None | str = 0 """被引用的消息对应的发送者的 ID""" - sender_nickname: T.Optional[str] = "" + sender_nickname: str | None = "" """被引用的消息对应的发送者的昵称""" - time: T.Optional[int] = 0 + time: int | None = 0 """被引用的消息发送时间""" - message_str: T.Optional[str] = "" + message_str: str | None = "" """被引用的消息解析后的纯文本消息字符串""" - text: T.Optional[str] = "" + text: str | None = "" """deprecated""" - qq: T.Optional[int] = 0 + qq: int | None = 0 """deprecated""" - seq: T.Optional[int] = 0 + seq: int | None = 0 """deprecated""" def __init__(self, **_): @@ -576,8 +575,8 @@ class RedBag(BaseMessageComponent): class Poke(BaseMessageComponent): type: str = ComponentType.Poke - id: T.Optional[int] = 0 - qq: T.Optional[int] = 0 + id: int | None = 0 + qq: int | None = 0 def __init__(self, type: str, **_): type = f"Poke:{type}" @@ -596,12 +595,12 @@ class Node(BaseMessageComponent): """群合并转发消息""" type = ComponentType.Node - id: T.Optional[int] = 0 # 忽略 - name: T.Optional[str] = "" # qq昵称 - uin: T.Optional[str] = "0" # qq号 - content: T.Optional[list[BaseMessageComponent]] = [] - seq: T.Optional[T.Union[str, list]] = "" # 忽略 - time: T.Optional[int] = 0 # 忽略 + id: int | None = 0 # 忽略 + name: str | None = "" # qq昵称 + uin: str | None = "0" # qq号 + content: list[BaseMessageComponent] | None = [] + seq: str | list | None = "" # 忽略 + time: int | None = 0 # 忽略 def __init__(self, content: list[BaseMessageComponent], **_): if isinstance(content, Node): @@ -619,7 +618,7 @@ class Node(BaseMessageComponent): { "type": comp.type.lower(), "data": {"file": f"base64://{bs64}"}, - } + }, ) elif isinstance(comp, Plain): # For Plain segments, we need to handle the plain differently @@ -648,9 +647,9 @@ class Node(BaseMessageComponent): class Nodes(BaseMessageComponent): type = ComponentType.Nodes - nodes: T.List[Node] + nodes: list[Node] - def __init__(self, nodes: T.List[Node], **_): + def __init__(self, nodes: list[Node], **_): super().__init__(nodes=nodes, **_) def toDict(self): @@ -675,7 +674,7 @@ class Nodes(BaseMessageComponent): class Xml(BaseMessageComponent): type = ComponentType.Xml data: str - resid: T.Optional[int] = 0 + resid: int | None = 0 def __init__(self, **_): super().__init__(**_) @@ -683,8 +682,8 @@ class Xml(BaseMessageComponent): class Json(BaseMessageComponent): type = ComponentType.Json - data: T.Union[str, dict] - resid: T.Optional[int] = 0 + data: str | dict + resid: int | None = 0 def __init__(self, data, **_): if isinstance(data, dict): @@ -695,13 +694,13 @@ class Json(BaseMessageComponent): class CardImage(BaseMessageComponent): type = ComponentType.CardImage file: str - cache: T.Optional[bool] = True - minwidth: T.Optional[int] = 400 - minheight: T.Optional[int] = 400 - maxwidth: T.Optional[int] = 500 - maxheight: T.Optional[int] = 500 - source: T.Optional[str] = "" - icon: T.Optional[str] = "" + cache: bool | None = True + minwidth: int | None = 400 + minheight: int | None = 400 + maxwidth: int | None = 500 + maxheight: int | None = 500 + source: str | None = "" + icon: str | None = "" def __init__(self, **_): super().__init__(**_) @@ -728,14 +727,12 @@ class Unknown(BaseMessageComponent): class File(BaseMessageComponent): - """ - 文件消息段 - """ + """文件消息段""" type = ComponentType.File - name: T.Optional[str] = "" # 名字 - file_: T.Optional[str] = "" # 本地路径 - url: T.Optional[str] = "" # url + name: str | None = "" # 名字 + file_: str | None = "" # 本地路径 + url: str | None = "" # url def __init__(self, name: str, file: str = "", url: str = ""): """文件消息段。""" @@ -743,11 +740,11 @@ class File(BaseMessageComponent): @property def file(self) -> str: - """ - 获取文件路径,如果文件不存在但有URL,则同步下载文件 + """获取文件路径,如果文件不存在但有URL,则同步下载文件 Returns: str: 文件路径 + """ if self.file_ and os.path.exists(self.file_): return os.path.abspath(self.file_) @@ -757,19 +754,16 @@ class File(BaseMessageComponent): loop = asyncio.get_event_loop() if loop.is_running(): logger.warning( - ( - "不可以在异步上下文中同步等待下载! " - "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" - "请使用 await get_file() 代替直接获取 .file 字段" - ) + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段", ) return "" - else: - # 等待下载完成 - loop.run_until_complete(self._download_file()) + # 等待下载完成 + loop.run_until_complete(self._download_file()) - if self.file_ and os.path.exists(self.file_): - return os.path.abspath(self.file_) + if self.file_ and os.path.exists(self.file_): + return os.path.abspath(self.file_) except Exception as e: logger.error(f"文件下载失败: {e}") @@ -777,11 +771,11 @@ class File(BaseMessageComponent): @file.setter def file(self, value: str): - """ - 向前兼容, 设置file属性, 传入的参数可能是文件路径或URL + """向前兼容, 设置file属性, 传入的参数可能是文件路径或URL Args: value (str): 文件路径或URL + """ if value.startswith("http://") or value.startswith("https://"): self.url = value @@ -796,6 +790,7 @@ class File(BaseMessageComponent): 注意,如果为 True,也可能返回文件路径。 Returns: str: 文件路径或者 http 下载链接 + """ if allow_return_url and self.url: return self.url @@ -818,14 +813,14 @@ class File(BaseMessageComponent): self.file_ = os.path.abspath(file_path) async def register_to_file_service(self): - """ - 将文件注册到文件服务。 + """将文件注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base + """ callback_host = astrbot_config.get("callback_api_base") @@ -863,9 +858,9 @@ class File(BaseMessageComponent): class WechatEmoji(BaseMessageComponent): type = ComponentType.WechatEmoji - md5: T.Optional[str] = "" - md5_len: T.Optional[int] = 0 - cdnurl: T.Optional[str] = "" + md5: str | None = "" + md5_len: int | None = 0 + cdnurl: str | None = "" def __init__(self, **_): super().__init__(**_) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 7bfdd34c8..ed4e25f43 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,15 +1,16 @@ import enum - -from typing import List, Optional, Union, AsyncGenerator +from collections.abc import AsyncGenerator from dataclasses import dataclass, field + +from typing_extensions import deprecated + from astrbot.core.message.components import ( - BaseMessageComponent, - Plain, - Image, At, AtAll, + BaseMessageComponent, + Image, + Plain, ) -from typing_extensions import deprecated @dataclass @@ -20,18 +21,18 @@ class MessageChain: Attributes: `chain` (list): 用于顺序存储各个组件。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 + """ - chain: List[BaseMessageComponent] = field(default_factory=list) - use_t2i_: Optional[bool] = None # None 为跟随用户设置 - type: Optional[str] = None + chain: list[BaseMessageComponent] = field(default_factory=list) + use_t2i_: bool | None = None # None 为跟随用户设置 + type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" def message(self, message: str): """添加一条文本消息到消息链 `chain` 中。 Example: - CommandResult().message("Hello ").message("world!") # 输出 Hello world! @@ -39,11 +40,10 @@ class MessageChain: self.chain.append(Plain(message)) return self - def at(self, name: str, qq: Union[str, int]): + def at(self, name: str, qq: str | int): """添加一条 At 消息到消息链 `chain` 中。 Example: - CommandResult().at("张三", "12345678910") # 输出 @张三 @@ -55,7 +55,6 @@ class MessageChain: """添加一条 AtAll 消息到消息链 `chain` 中。 Example: - CommandResult().at_all() # 输出 @所有人 @@ -68,7 +67,6 @@ class MessageChain: """添加一条错误消息到消息链 `chain` 中 Example: - CommandResult().error("解析失败") """ @@ -82,7 +80,6 @@ class MessageChain: 如果需要发送本地图片,请使用 `file_image` 方法。 Example: - CommandResult().image("https://example.com/image.jpg") """ @@ -96,6 +93,7 @@ class MessageChain: 如果需要发送网络图片,请使用 `url_image` 方法。 CommandResult().image("image.jpg") + """ self.chain.append(Image.fromFileSystem(path)) return self @@ -114,6 +112,7 @@ class MessageChain: Args: use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 + """ self.use_t2i_ = use_t2i return self @@ -125,7 +124,7 @@ class MessageChain: def squash_plain(self): """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" if not self.chain: - return + return None new_chain = [] first_plain = None @@ -153,6 +152,7 @@ class EventResultType(enum.Enum): Attributes: CONTINUE: 事件将会继续传播 STOP: 事件将会终止传播 + """ CONTINUE = enum.auto() @@ -181,17 +181,18 @@ class MessageEventResult(MessageChain): `chain` (list): 用于顺序存储各个组件。 `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 `result_type` (EventResultType): 事件处理的结果类型。 + """ - result_type: Optional[EventResultType] = field( - default_factory=lambda: EventResultType.CONTINUE + result_type: EventResultType | None = field( + default_factory=lambda: EventResultType.CONTINUE, ) - result_content_type: Optional[ResultContentType] = field( - default_factory=lambda: ResultContentType.GENERAL_RESULT + result_content_type: ResultContentType | None = field( + default_factory=lambda: ResultContentType.GENERAL_RESULT, ) - async_stream: Optional[AsyncGenerator] = None + async_stream: AsyncGenerator | None = None """异步流""" def stop_event(self) -> "MessageEventResult": @@ -205,9 +206,7 @@ class MessageEventResult(MessageChain): return self def is_stopped(self) -> bool: - """ - 是否终止事件传播。 - """ + """是否终止事件传播。""" return self.result_type == EventResultType.STOP def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": @@ -220,6 +219,7 @@ class MessageEventResult(MessageChain): Args: result_type (EventResultType): 事件处理的结果类型。 + """ self.result_content_type = typ return self diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index add3c74bc..482b5887c 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -1,8 +1,8 @@ +from astrbot import logger +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Persona, Personality -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.platform.message_session import MessageSession -from astrbot import logger DEFAULT_PERSONALITY = Personality( prompt="You are a helpful and friendly assistant.", @@ -41,12 +41,14 @@ class PersonaManager: return persona async def get_default_persona_v3( - self, umo: str | MessageSession | None = None + self, + umo: str | MessageSession | None = None, ) -> Personality: """获取默认 persona""" cfg = self.acm.get_conf(umo) default_persona_id = cfg.get("provider_settings", {}).get( - "default_personality", "default" + "default_personality", + "default", ) if not default_persona_id or default_persona_id == "default": return DEFAULT_PERSONALITY @@ -75,7 +77,10 @@ class PersonaManager: if not existing_persona: raise ValueError(f"Persona with ID {persona_id} does not exist.") persona = await self.db.update_persona( - persona_id, system_prompt, begin_dialogs, tools=tools + persona_id, + system_prompt, + begin_dialogs, + tools=tools, ) if persona: for i, p in enumerate(self.personas): @@ -100,7 +105,10 @@ class PersonaManager: if await self.db.get_persona_by_id(persona_id): raise ValueError(f"Persona with ID {persona_id} already exists.") new_persona = await self.db.insert_persona( - persona_id, system_prompt, begin_dialogs, tools=tools + persona_id, + system_prompt, + begin_dialogs, + tools=tools, ) self.personas.append(new_persona) self.get_v3_persona_data() @@ -115,6 +123,7 @@ class PersonaManager: - list[dict]: 包含 persona 配置的字典列表。 - list[Personality]: 包含 Personality 对象的列表。 - Personality: 默认选择的 Personality 对象。 + """ v3_persona_config = [ { @@ -136,7 +145,7 @@ class PersonaManager: if begin_dialogs: if len(begin_dialogs) % 2 != 0: logger.error( - f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。" + f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。", ) begin_dialogs = [] user_turn = True @@ -146,7 +155,7 @@ class PersonaManager: "role": "user" if user_turn else "assistant", "content": dialog, "_no_save": None, # 不持久化到 db - } + }, ) user_turn = not user_turn diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 29a324a1d..75fef84d3 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -27,15 +27,15 @@ STAGES_ORDER = [ ] __all__ = [ - "WakingCheckStage", - "WhitelistCheckStage", - "SessionStatusCheckStage", - "RateLimitStage", "ContentSafetyCheckStage", + "EventResultType", + "MessageEventResult", "PreProcessStage", "ProcessStage", - "ResultDecorateStage", + "RateLimitStage", "RespondStage", - "MessageEventResult", - "EventResultType", + "ResultDecorateStage", + "SessionStatusCheckStage", + "WakingCheckStage", + "WhitelistCheckStage", ] diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index e6ecd995c..c477cc23a 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -1,9 +1,11 @@ -from typing import Union, AsyncGenerator -from ..stage import Stage, register_stage -from ..context import PipelineContext -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult +from collections.abc import AsyncGenerator + from astrbot.core import logger +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from ..context import PipelineContext +from ..stage import Stage, register_stage from .strategies.strategy import StrategySelector @@ -19,8 +21,10 @@ class ContentSafetyCheckStage(Stage): self.strategy_selector = StrategySelector(config) async def process( - self, event: AstrMessageEvent, check_text: str | None = None - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + check_text: str | None = None, + ) -> None | AsyncGenerator[None, None]: """检查内容安全""" text = check_text if check_text else event.get_message_str() ok, info = self.strategy_selector.check(text) @@ -28,8 +32,8 @@ class ContentSafetyCheckStage(Stage): if event.is_at_or_wake_command: event.set_result( MessageEventResult().message( - "你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。" - ) + "你的消息或者大模型的响应中包含不适当的内容,已被屏蔽。", + ), ) yield event.stop_event() diff --git a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py index 5701f0634..f0a34e73f 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py @@ -1,8 +1,7 @@ import abc -from typing import Tuple class ContentSafetyStrategy(abc.ABC): @abc.abstractmethod - def check(self, content: str) -> Tuple[bool, str]: + def check(self, content: str) -> tuple[bool, str]: raise NotImplementedError diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py index 26284e1a1..c11822896 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py @@ -1,9 +1,8 @@ -""" -使用此功能应该先 pip install baidu-aip -""" +"""使用此功能应该先 pip install baidu-aip""" + +from aip import AipContentCensor from . import ContentSafetyStrategy -from aip import AipContentCensor class BaiduAipStrategy(ContentSafetyStrategy): @@ -19,12 +18,11 @@ class BaiduAipStrategy(ContentSafetyStrategy): return False, "" if res["conclusionType"] == 1: return True, "" - else: - if "data" not in res: - return False, "" - count = len(res["data"]) - info = f"百度审核服务发现 {count} 处违规:\n" - for i in res["data"]: - info += f"{i['msg']};\n" - info += "\n判断结果:" + res["conclusion"] - return False, info + if "data" not in res: + return False, "" + count = len(res["data"]) + info = f"百度审核服务发现 {count} 处违规:\n" + for i in res["data"]: + info += f"{i['msg']};\n" + info += "\n判断结果:" + res["conclusion"] + return False, info diff --git a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py index c65faa000..53ad900f7 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py @@ -1,4 +1,5 @@ import re + from . import ContentSafetyStrategy diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py index af960328f..c971ef26f 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -1,16 +1,16 @@ -from . import ContentSafetyStrategy -from typing import List, Tuple from astrbot import logger +from . import ContentSafetyStrategy + class StrategySelector: def __init__(self, config: dict) -> None: - self.enabled_strategies: List[ContentSafetyStrategy] = [] + self.enabled_strategies: list[ContentSafetyStrategy] = [] if config["internal_keywords"]["enable"]: from .keywords import KeywordsStrategy self.enabled_strategies.append( - KeywordsStrategy(config["internal_keywords"]["extra_keywords"]) + KeywordsStrategy(config["internal_keywords"]["extra_keywords"]), ) if config["baidu_aip"]["enable"]: try: @@ -23,10 +23,10 @@ class StrategySelector: config["baidu_aip"]["app_id"], config["baidu_aip"]["api_key"], config["baidu_aip"]["secret_key"], - ) + ), ) - def check(self, content: str) -> Tuple[bool, str]: + def check(self, content: str) -> tuple[bool, str]: for strategy in self.enabled_strategies: ok, info = strategy.check(content) if not ok: diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 803626aaa..a6cd567e0 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,7 +1,9 @@ from dataclasses import dataclass + from astrbot.core.config import AstrBotConfig from astrbot.core.star import PluginManager -from .context_utils import call_handler, call_event_hook + +from .context_utils import call_event_hook, call_handler @dataclass diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index e7ac120b7..73d28c5d1 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -1,11 +1,12 @@ import inspect import traceback import typing as T + from astrbot import logger -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star import star_map -from astrbot.core.message.message_event_result import MessageEventResult, CommandResult +from astrbot.core.message.message_event_result import CommandResult, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import EventType, star_handlers_registry async def call_handler( @@ -26,6 +27,7 @@ async def call_handler( Returns: AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 + """ ready_to_call = None # 一个协程或者异步生成器 @@ -80,14 +82,17 @@ async def call_event_hook( Returns: bool: 如果事件被终止,返回 True - #""" + # + + """ handlers = star_handlers_registry.get_handlers_by_event_type( - hook_type, plugins_name=event.plugins_name + hook_type, + plugins_name=event.plugins_name, ) for handler in handlers: try: logger.debug( - f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) await handler.handler(event, *args, **kwargs) except BaseException: @@ -95,7 +100,7 @@ async def call_event_hook( if event.is_stopped(): logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。", ) return True diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 5c075687f..a69d07ffb 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -1,12 +1,14 @@ -import traceback import asyncio import random -from typing import Union, AsyncGenerator -from ..stage import Stage, register_stage -from ..context import PipelineContext -from astrbot.core.platform.astr_message_event import AstrMessageEvent +import traceback +from collections.abc import AsyncGenerator + from astrbot.core import logger -from astrbot.core.message.components import Plain, Record, Image +from astrbot.core.message.components import Image, Plain, Record +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from ..context import PipelineContext +from ..stage import Stage, register_stage @register_stage @@ -20,8 +22,9 @@ class PreProcessStage(Stage): self.platform_settings: dict = self.config.get("platform_settings", {}) async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: """在处理事件之前的预处理""" # 平台特异配置:platform_specific..pre_ack_emoji supported = {"telegram", "lark"} @@ -68,7 +71,7 @@ class PreProcessStage(Stage): stt_provider = ctx.get_using_stt_provider(event.unified_msg_origin) if not stt_provider: logger.warning( - f"会话 {event.unified_msg_origin} 未配置语音转文本模型。" + f"会话 {event.unified_msg_origin} 未配置语音转文本模型。", ) return message_chain = event.get_messages() diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 703b3681c..d1cffc43f 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -1,15 +1,21 @@ -""" -本地 Agent 模式的 LLM 调用 Stage -""" +"""本地 Agent 模式的 LLM 调用 Stage""" import asyncio import copy import json import traceback -from datetime import timedelta from collections.abc import AsyncGenerator -from astrbot.core.conversation_mgr import Conversation +from datetime import timedelta + from astrbot.core import logger +from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( MessageChain, @@ -22,21 +28,14 @@ from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) -from astrbot.core.agent.hooks import BaseAgentRunHooks -from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolSet, FunctionTool -from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor -from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.provider.register import llm_tools from astrbot.core.star.session_llm_manager import SessionServiceManager -from astrbot.core.star.star_handler import EventType +from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric + from ...context import PipelineContext, call_event_hook, call_handler from ..stage import Stage from ..utils import inject_kb_context -from astrbot.core.provider.register import llm_tools -from astrbot.core.star.star_handler import star_map -from astrbot.core.astr_agent_context import AstrAgentContext try: import mcp @@ -59,6 +58,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): Returns: AsyncGenerator[None | mcp.types.CallToolResult, None] + """ if isinstance(tool, HandoffTool): async for r in cls._execute_handoff(tool, run_context, **tool_args): @@ -117,14 +117,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}") await run_context.event.send( - MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name) + MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name), ) await agent_runner.reset( provider=run_context.context.provider, request=request, run_context=AgentContextWrapper( - context=astr_agent_ctx, event=run_context.event + context=astr_agent_ctx, + event=run_context.event, ), tool_executor=FunctionToolExecutor(), agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](), @@ -146,7 +147,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): return logger.debug( - f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}" + f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}", ) result = ( @@ -180,7 +181,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): # 检查 tool 下有没有 run 方法 if not tool.handler and not hasattr(tool, "run"): raise ValueError("Tool must have a valid handler or 'run' method.") - awaitable = tool.handler or getattr(tool, "run") + awaitable = tool.handler or tool.run wrapper = call_handler( event=run_context.event, @@ -210,7 +211,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): yield None except asyncio.TimeoutError: raise Exception( - f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds." + f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds.", ) except StopAsyncIteration: break @@ -232,7 +233,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): name=tool.name, arguments=tool_args, read_timeout_seconds=timedelta( - seconds=run_context.context.tool_call_timeout + seconds=run_context.context.tool_call_timeout, ), ) if not res: @@ -244,7 +245,9 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): async def on_agent_done(self, run_context, llm_response): # 执行事件钩子 await call_event_hook( - run_context.event, EventType.OnLLMResponseEvent, llm_response + run_context.event, + EventType.OnLLMResponseEvent, + llm_response, ) @@ -252,7 +255,9 @@ MAIN_AGENT_HOOKS = MainAgentHooks() async def run_agent( - agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True + agent_runner: AgentRunner, + max_step: int = 30, + show_tool_use: bool = True, ) -> AsyncGenerator[MessageChain, None]: step_idx = 0 astr_event = agent_runner.run_context.event @@ -290,19 +295,18 @@ async def run_agent( MessageEventResult( chain=resp.data["chain"].chain, result_content_type=content_typ, - ) + ), ) yield astr_event.clear_result() - else: - if resp.type == "streaming_delta": - yield resp.data["chain"] # MessageChain + elif resp.type == "streaming_delta": + yield resp.data["chain"] # MessageChain if agent_runner.done(): break except Exception as e: logger.error(traceback.format_exc()) - err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n" + err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" if agent_runner.streaming: yield MessageChain().message(err_msg) else: @@ -332,7 +336,7 @@ class LLMRequestSubStage(Stage): for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): logger.info( - f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。" + f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", ) self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :] @@ -367,7 +371,9 @@ class LLMRequestSubStage(Stage): return conversation async def process( - self, event: AstrMessageEvent, _nested: bool = False + self, + event: AstrMessageEvent, + _nested: bool = False, ) -> None | AsyncGenerator[None, None]: req: ProviderRequest | None = None @@ -423,7 +429,9 @@ class LLMRequestSubStage(Stage): # 应用知识库 try: await inject_kb_context( - umo=event.unified_msg_origin, p_ctx=self.ctx, req=req + umo=event.unified_msg_origin, + p_ctx=self.ctx, + req=req, ) except Exception as e: logger.error(f"调用知识库时遇到问题: {e}") @@ -475,7 +483,7 @@ class LLMRequestSubStage(Stage): # 如果模型不支持工具使用,但请求中包含工具列表,则清空。 if "tool_use" not in provider_cfg: logger.debug( - f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。" + f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。", ) req.func_tool = None # 插件可用性设置 @@ -498,7 +506,7 @@ class LLMRequestSubStage(Stage): # run agent agent_runner = AgentRunner() logger.debug( - f"handle provider[id: {provider.provider_config['id']}] request: {req}" + f"handle provider[id: {provider.provider_config['id']}] request: {req}", ) astr_agent_ctx = AstrAgentContext( provider=provider, @@ -522,8 +530,8 @@ class LLMRequestSubStage(Stage): MessageEventResult() .set_result_content_type(ResultContentType.STREAMING_RESULT) .set_async_stream( - run_agent(agent_runner, self.max_step, self.show_tool_use) - ) + run_agent(agent_runner, self.max_step, self.show_tool_use), + ), ) yield if agent_runner.done(): @@ -540,7 +548,7 @@ class LLMRequestSubStage(Stage): MessageEventResult( chain=chain, result_content_type=ResultContentType.STREAMING_FINISH, - ) + ), ) else: async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use): @@ -560,17 +568,21 @@ class LLMRequestSubStage(Stage): llm_tick=1, model_name=agent_runner.provider.get_model(), provider_type=agent_runner.provider.meta().type, - ) + ), ) async def _handle_webchat( - self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider + self, + event: AstrMessageEvent, + req: ProviderRequest, + prov: Provider, ): """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" if not req.conversation: return conversation = await self.conv_manager.get_conversation( - event.unified_msg_origin, req.conversation.cid + event.unified_msg_origin, + req.conversation.cid, ) if conversation and not req.conversation.title: messages = json.loads(conversation.history) @@ -607,7 +619,7 @@ class LLMRequestSubStage(Stage): ) if llm_resp and llm_resp.completion_text: logger.debug( - f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}" + f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}", ) title = llm_resp.completion_text.strip() if not title or "" in title: @@ -650,7 +662,9 @@ class LLMRequestSubStage(Stage): messages.append({"role": "assistant", "content": llm_response.completion_text}) messages = list(filter(lambda item: "_no_save" not in item, messages)) await self.conv_manager.update_conversation( - event.unified_msg_origin, req.conversation.cid, history=messages + event.unified_msg_origin, + req.conversation.cid, + history=messages, ) def fix_messages(self, messages: list[dict]) -> list[dict]: diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 42990aae5..ff8120b16 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -1,16 +1,17 @@ -""" -本地 Agent 模式的 AstrBot 插件调用 Stage -""" +"""本地 Agent 模式的 AstrBot 插件调用 Stage""" + +import traceback +from collections.abc import AsyncGenerator +from typing import Any + +from astrbot.core import logger +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import StarHandlerMetadata from ...context import PipelineContext, call_handler from ..stage import Stage -from typing import Dict, Any, List, AsyncGenerator, Union -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult -from astrbot.core import logger -from astrbot.core.star.star_handler import StarHandlerMetadata -from astrbot.core.star.star import star_map -import traceback class StarRequestSubStage(Stage): @@ -21,13 +22,14 @@ class StarRequestSubStage(Stage): self.ctx = ctx async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: - activated_handlers: List[StarHandlerMetadata] = event.get_extra( - "activated_handlers" + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: + activated_handlers: list[StarHandlerMetadata] = event.get_extra( + "activated_handlers", ) - handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra( - "handlers_parsed_params" + handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra( + "handlers_parsed_params", ) if not handlers_parsed_params: handlers_parsed_params = {} @@ -37,7 +39,7 @@ class StarRequestSubStage(Stage): md = star_map.get(handler.handler_module_path) if not md: logger.warning( - f"Cannot find plugin for given handler module path: {handler.handler_module_path}" + f"Cannot find plugin for given handler module path: {handler.handler_module_path}", ) continue logger.debug(f"plugin -> {md.name} - {handler.handler_name}") diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index f653a9fb9..9f0b5f92a 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -1,12 +1,14 @@ -from typing import List, Union, AsyncGenerator -from ..stage import Stage, register_stage +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import ProviderRequest +from astrbot.core.star.star_handler import StarHandlerMetadata + from ..context import PipelineContext +from ..stage import Stage, register_stage from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.star_handler import StarHandlerMetadata -from astrbot.core.provider.entities import ProviderRequest -from astrbot.core import logger @register_stage @@ -22,11 +24,12 @@ class ProcessStage(Stage): await self.star_request_sub_stage.initialize(ctx) async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: """处理事件""" - activated_handlers: List[StarHandlerMetadata] = event.get_extra( - "activated_handlers" + activated_handlers: list[StarHandlerMetadata] = event.get_extra( + "activated_handlers", ) # 有插件 Handler 被激活 if activated_handlers: diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index e799ad4d0..b1168aa0a 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -1,6 +1,7 @@ -from ..context import PipelineContext -from astrbot.core.provider.entities import ProviderRequest from astrbot.api import logger, sp +from astrbot.core.provider.entities import ProviderRequest + +from ..context import PipelineContext async def inject_kb_context( @@ -8,14 +9,14 @@ async def inject_kb_context( p_ctx: PipelineContext, req: ProviderRequest, ) -> None: - """inject knowledge base context into the provider request + """Inject knowledge base context into the provider request Args: umo: Unique message object (session ID) p_ctx: Pipeline context req: Provider request - """ + """ kb_mgr = p_ctx.plugin_manager.context.kb_manager # 1. 优先读取会话级配置 @@ -45,7 +46,7 @@ async def inject_kb_context( if invalid_kb_ids: logger.warning( - f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}" + f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}", ) if not kb_names: diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index b36a2fbd0..64e21dd7e 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -1,18 +1,19 @@ import asyncio -from datetime import datetime, timedelta from collections import defaultdict, deque -from typing import DefaultDict, Deque, Union, AsyncGenerator -from ..stage import Stage, register_stage -from ..context import PipelineContext -from astrbot.core.platform.astr_message_event import AstrMessageEvent +from collections.abc import AsyncGenerator +from datetime import datetime, timedelta + from astrbot.core import logger from astrbot.core.config.astrbot_config import RateLimitStrategy +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from ..context import PipelineContext +from ..stage import Stage, register_stage @register_stage class RateLimitStage(Stage): - """ - 检查是否需要限制消息发送的限流器。 + """检查是否需要限制消息发送的限流器。 使用 Fixed Window 算法。 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 @@ -20,32 +21,30 @@ class RateLimitStage(Stage): def __init__(self): # 存储每个会话的请求时间队列 - self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque) + self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 - self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock) # 限流参数 self.rate_limit_count: int = 0 self.rate_limit_time: timedelta = timedelta(0) async def initialize(self, ctx: PipelineContext) -> None: - """ - 初始化限流器,根据配置设置限流参数。 - """ + """初始化限流器,根据配置设置限流参数。""" self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][ "count" ] self.rate_limit_time = timedelta( - seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"] + seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"], ) self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][ "strategy" ] # stall or discard async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: - """ - 检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: + """检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 Args: event (AstrMessageEvent): 当前消息事件。 @@ -53,6 +52,7 @@ class RateLimitStage(Stage): Returns: MessageEventResult: 继续或停止事件处理的结果。 + """ session_id = event.session_id now = datetime.now() @@ -66,32 +66,33 @@ class RateLimitStage(Stage): if len(timestamps) < self.rate_limit_count: timestamps.append(now) break - else: - next_window_time = timestamps[0] + self.rate_limit_time - stall_duration = (next_window_time - now).total_seconds() + 0.3 + next_window_time = timestamps[0] + self.rate_limit_time + stall_duration = (next_window_time - now).total_seconds() + 0.3 - match self.rl_strategy: - case RateLimitStrategy.STALL.value: - logger.info( - f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。" - ) - await asyncio.sleep(stall_duration) - now = datetime.now() - case RateLimitStrategy.DISCARD.value: - logger.info( - f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。" - ) - return event.stop_event() + match self.rl_strategy: + case RateLimitStrategy.STALL.value: + logger.info( + f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。", + ) + await asyncio.sleep(stall_duration) + now = datetime.now() + case RateLimitStrategy.DISCARD.value: + logger.info( + f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。", + ) + return event.stop_event() def _remove_expired_timestamps( - self, timestamps: Deque[datetime], now: datetime + self, + timestamps: deque[datetime], + now: datetime, ) -> None: - """ - 移除时间窗口外的时间戳。 + """移除时间窗口外的时间戳。 Args: timestamps (Deque[datetime]): 当前会话的时间戳队列。 now (datetime): 当前时间,用于计算过期时间。 + """ expiry_threshold: datetime = now - self.rate_limit_time while timestamps and timestamps[0] < expiry_threshold: diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index dc6a67e2f..f20445594 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -1,25 +1,27 @@ -import random import asyncio import math +import random +from collections.abc import AsyncGenerator + import astrbot.core.message.components as Comp -from typing import Union, AsyncGenerator -from ..stage import register_stage, Stage -from ..context import PipelineContext, call_event_hook -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core import logger from astrbot.core.message.components import BaseMessageComponent, ComponentType +from astrbot.core.message.message_event_result import MessageChain, ResultContentType +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import EventType from astrbot.core.utils.path_util import path_Mapping from astrbot.core.utils.session_lock import session_lock_manager +from ..context import PipelineContext, call_event_hook +from ..stage import Stage, register_stage + @register_stage class RespondStage(Stage): # 组件类型到其非空判断函数的映射 _component_validators = { Comp.Plain: lambda comp: bool( - comp.text and comp.text.strip() + comp.text and comp.text.strip(), ), # 纯文本消息需要strip Comp.Face: lambda comp: comp.id is not None, # QQ表情 Comp.Record: lambda comp: bool(comp.file), # 语音 @@ -58,7 +60,7 @@ class RespondStage(Stage): "segmented_reply" ]["interval_method"] self.log_base = float( - ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"] + ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"], ) interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][ "interval" @@ -86,17 +88,16 @@ class RespondStage(Stage): wc = await self._word_cnt(comp.text) i = math.log(wc + 1, self.log_base) return random.uniform(i, i + 0.5) - else: - return random.uniform(1, 1.75) - else: - # random - return random.uniform(self.interval[0], self.interval[1]) + return random.uniform(1, 1.75) + # random + return random.uniform(self.interval[0], self.interval[1]) async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): """检查消息链是否为空 Args: chain (list[BaseMessageComponent]): 包含消息对象的列表 + """ if not chain: return True @@ -150,8 +151,9 @@ class RespondStage(Stage): return extracted async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: result = event.get_result() if result is None: return @@ -159,7 +161,7 @@ class RespondStage(Stage): return logger.info( - f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}" + f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}", ) if result.result_content_type == ResultContentType.STREAMING_RESULT: @@ -168,12 +170,13 @@ class RespondStage(Stage): return # 流式结果直接交付平台适配器处理 use_fallback = self.config.get("provider_settings", {}).get( - "streaming_segmented", False + "streaming_segmented", + False, ) logger.info(f"应用流式输出({event.get_platform_id()})") await event.send_streaming(result.async_stream, use_fallback) return - elif len(result.chain) > 0: + if len(result.chain) > 0: # 检查路径映射 if mappings := self.platform_settings.get("path_mapping", []): for idx, component in enumerate(result.chain): @@ -212,7 +215,7 @@ class RespondStage(Stage): if not result.chain or len(result.chain) == 0: # may fix #2670 logger.warning( - f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}" + f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}", ) return async with session_lock_manager.acquire_lock(event.unified_msg_origin): @@ -237,7 +240,7 @@ class RespondStage(Stage): ): # may fix #2670 logger.warning( - f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}" + f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}", ) return sep_comps = self._extract_comp( diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index c1f893baf..08661a367 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -1,7 +1,7 @@ import re import time import traceback -from typing import AsyncGenerator, Union +from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply @@ -30,8 +30,7 @@ class ResultDecorateStage(Stage): self.t2i_word_threshold = ctx.astrbot_config["t2i_word_threshold"] try: self.t2i_word_threshold = int(self.t2i_word_threshold) - if self.t2i_word_threshold < 50: - self.t2i_word_threshold = 50 + self.t2i_word_threshold = max(self.t2i_word_threshold, 50) except BaseException: self.t2i_word_threshold = 150 self.t2i_strategy = ctx.astrbot_config["t2i_strategy"] @@ -46,7 +45,7 @@ class ResultDecorateStage(Stage): self.words_count_threshold = int( ctx.astrbot_config["platform_settings"]["segmented_reply"][ "words_count_threshold" - ] + ], ) self.enable_segmented_reply = ctx.astrbot_config["platform_settings"][ "segmented_reply" @@ -71,8 +70,9 @@ class ResultDecorateStage(Stage): await self.content_safe_check_stage.initialize(ctx) async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: result = event.get_result() if result is None or not result.chain: return @@ -94,34 +94,36 @@ class ResultDecorateStage(Stage): if isinstance(comp, Plain): text += comp.text async for _ in self.content_safe_check_stage.process( - event, check_text=text + event, + check_text=text, ): yield # 发送消息前事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnDecoratingResultEvent, plugins_name=event.plugins_name + EventType.OnDecoratingResultEvent, + plugins_name=event.plugins_name, ) for handler in handlers: try: logger.debug( - f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) if is_stream: logger.warning( - "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作" + "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", ) await handler.handler(event) if event.get_result() is None or not event.get_result().chain: logger.debug( - f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。" + f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", ) except BaseException: logger.error(traceback.format_exc()) if event.is_stopped(): logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。", ) return @@ -160,7 +162,9 @@ class ResultDecorateStage(Stage): new_chain.append(comp) continue split_response = re.findall( - self.regex, comp.text, re.DOTALL | re.MULTILINE + self.regex, + comp.text, + re.DOTALL | re.MULTILINE, ) if not split_response: new_chain.append(comp) @@ -177,7 +181,7 @@ class ResultDecorateStage(Stage): # TTS tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider( - event.unified_msg_origin + event.unified_msg_origin, ) if ( @@ -187,7 +191,7 @@ class ResultDecorateStage(Stage): ): if not tts_provider: logger.warning( - f"会话 {event.unified_msg_origin} 未配置文本转语音模型。" + f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", ) else: new_chain = [] @@ -199,7 +203,7 @@ class ResultDecorateStage(Stage): logger.info(f"TTS 结果: {audio_path}") if not audio_path: logger.error( - f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}" + f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}", ) new_chain.append(comp) continue @@ -217,7 +221,7 @@ class ResultDecorateStage(Stage): url = None if use_file_service and callback_api_base: token = await file_token_service.register_file( - audio_path + audio_path, ) url = f"{callback_api_base}/api/file/{token}" logger.debug(f"已注册:{url}") @@ -226,7 +230,7 @@ class ResultDecorateStage(Stage): Record( file=url or audio_path, url=url or audio_path, - ) + ), ) if dual_output: new_chain.append(comp) @@ -262,7 +266,7 @@ class ResultDecorateStage(Stage): return if time.time() - render_start > 3: logger.warning( - "文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。" + "文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。", ) if url: if url.startswith("http"): @@ -286,7 +290,9 @@ class ResultDecorateStage(Stage): word_cnt += len(comp.text) if word_cnt > self.forward_threshold: node = Node( - uin=event.get_self_id(), name="AstrBot", content=[*result.chain] + uin=event.get_self_id(), + name="AstrBot", + content=[*result.chain], ) result.chain = [node] @@ -298,7 +304,8 @@ class ResultDecorateStage(Stage): and event.get_message_type() != MessageType.FRIEND_MESSAGE ): result.chain.insert( - 0, At(qq=event.get_sender_id(), name=event.get_sender_name()) + 0, + At(qq=event.get_sender_id(), name=event.get_sender_name()), ) if len(result.chain) > 1 and isinstance(result.chain[1], Plain): result.chain[1].text = "\n" + result.chain[1].text diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 7a38ec03f..5c461a1e1 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -1,9 +1,11 @@ -from . import STAGES_ORDER -from .stage import registered_stages -from .context import PipelineContext -from typing import AsyncGenerator -from astrbot.core.platform import AstrMessageEvent +from collections.abc import AsyncGenerator + from astrbot.core import logger +from astrbot.core.platform import AstrMessageEvent + +from . import STAGES_ORDER +from .context import PipelineContext +from .stage import registered_stages class PipelineScheduler: @@ -11,7 +13,7 @@ class PipelineScheduler: def __init__(self, context: PipelineContext): registered_stages.sort( - key=lambda x: STAGES_ORDER.index(x.__name__) + key=lambda x: STAGES_ORDER.index(x.__name__), ) # 按照顺序排序 self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 @@ -29,12 +31,13 @@ class PipelineScheduler: Args: event (AstrMessageEvent): 事件对象 from_stage (int): 从第几个阶段开始执行, 默认从0开始 + """ for i in range(from_stage, len(self.stages)): stage = self.stages[i] # 获取当前要执行的阶段 # logger.debug(f"执行阶段 {stage.__class__.__name__}") coroutine = stage.process( - event + event, ) # 调用阶段的process方法, 返回协程或者异步生成器 if isinstance(coroutine, AsyncGenerator): @@ -43,7 +46,7 @@ class PipelineScheduler: # 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段 if event.is_stopped(): logger.debug( - f"阶段 {stage.__class__.__name__} 已终止事件传播。" + f"阶段 {stage.__class__.__name__} 已终止事件传播。", ) break @@ -53,7 +56,7 @@ class PipelineScheduler: # 此处是后续所有阶段处理完毕后返回的点, 执行后置处理 if event.is_stopped(): logger.debug( - f"阶段 {stage.__class__.__name__} 已终止事件传播。" + f"阶段 {stage.__class__.__name__} 已终止事件传播。", ) break else: @@ -70,6 +73,7 @@ class PipelineScheduler: Args: event (AstrMessageEvent): 事件对象 + """ await self._process_stages(event) diff --git a/astrbot/core/pipeline/session_status_check/stage.py b/astrbot/core/pipeline/session_status_check/stage.py index 3c451e26a..7feeeb86a 100644 --- a/astrbot/core/pipeline/session_status_check/stage.py +++ b/astrbot/core/pipeline/session_status_check/stage.py @@ -1,9 +1,11 @@ -from ..stage import Stage, register_stage -from ..context import PipelineContext -from typing import AsyncGenerator, Union +from collections.abc import AsyncGenerator + +from astrbot.core import logger from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.session_llm_manager import SessionServiceManager -from astrbot.core import logger + +from ..context import PipelineContext +from ..stage import Stage, register_stage @register_stage @@ -15,19 +17,21 @@ class SessionStatusCheckStage(Stage): self.conv_mgr = ctx.plugin_manager.context.conversation_manager async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: # 检查会话是否整体启用 if not SessionServiceManager.is_session_enabled(event.unified_msg_origin): logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。") # workaround for #2309 conv_id = await self.conv_mgr.get_curr_conversation_id( - event.unified_msg_origin + event.unified_msg_origin, ) if not conv_id: await self.conv_mgr.new_conversation( - event.unified_msg_origin, platform_id=event.get_platform_id() + event.unified_msg_origin, + platform_id=event.get_platform_id(), ) event.stop_event() diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index c4550495a..74aca4ef1 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,10 +1,13 @@ from __future__ import annotations + import abc -from typing import List, AsyncGenerator, Union, Type +from collections.abc import AsyncGenerator + from astrbot.core.platform.astr_message_event import AstrMessageEvent + from .context import PipelineContext -registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 +registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 def register_stage(cls): @@ -22,18 +25,21 @@ class Stage(abc.ABC): Args: ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 + """ raise NotImplementedError @abc.abstractmethod async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: """处理事件 Args: event (AstrMessageEvent): 事件对象,包含事件的相关信息 Returns: Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段) + """ raise NotImplementedError diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index de6ad5e35..814919115 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -1,11 +1,11 @@ -from typing import AsyncGenerator, Union +from collections.abc import AsyncGenerator from astrbot import logger from astrbot.core.message.components import At, AtAll, Reply from astrbot.core.message.message_event_result import MessageChain, MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.star.filter.permission import PermissionTypeFilter from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.filter.permission import PermissionTypeFilter from astrbot.core.star.session_plugin_manager import SessionPluginManager from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry @@ -30,10 +30,12 @@ class WakingCheckStage(Stage): Args: ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 + """ self.ctx = ctx self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get( - "no_permission_reply", True + "no_permission_reply", + True, ) # 私聊是否需要 wake_prefix 才能唤醒机器人 self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[ @@ -41,15 +43,18 @@ class WakingCheckStage(Stage): ].get("friend_message_needs_wake_prefix", False) # 是否忽略机器人自己发送的消息 self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get( - "ignore_bot_self_message", False + "ignore_bot_self_message", + False, ) self.ignore_at_all = self.ctx.astrbot_config["platform_settings"].get( - "ignore_at_all", False + "ignore_at_all", + False, ) async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: if ( self.ignore_bot_self_message and event.get_self_id() == event.get_sender_id() @@ -123,7 +128,8 @@ class WakingCheckStage(Stage): logger.debug(f"enabled_plugins_name: {enabled_plugins_name}") for handler in star_handlers_registry.get_handlers_by_event_type( - EventType.AdapterMessageEvent, plugins_name=event.plugins_name + EventType.AdapterMessageEvent, + plugins_name=event.plugins_name, ): # filter 需满足 AND 逻辑关系 passed = True @@ -138,15 +144,14 @@ class WakingCheckStage(Stage): if not filter.filter(event, self.ctx.astrbot_config): permission_not_pass = True permission_filter_raise_error = filter.raise_error - else: - if not filter.filter(event, self.ctx.astrbot_config): - passed = False - break + elif not filter.filter(event, self.ctx.astrbot_config): + passed = False + break except Exception as e: await event.send( MessageEventResult().message( - f"插件 {star_map[handler.handler_module_path].name}: {e}" - ) + f"插件 {star_map[handler.handler_module_path].name}: {e}", + ), ) event.stop_event() passed = False @@ -159,11 +164,11 @@ class WakingCheckStage(Stage): if self.no_permission_reply: await event.send( MessageChain().message( - f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。" - ) + f"您(ID: {event.get_sender_id()})的权限不足以使用此指令。通过 /sid 获取 ID 并请管理员添加。", + ), ) logger.info( - f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。" + f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。", ) event.stop_event() return @@ -185,7 +190,8 @@ class WakingCheckStage(Stage): # 根据会话配置过滤插件处理器 activated_handlers = SessionPluginManager.filter_handlers_by_session( - event, activated_handlers + event, + activated_handlers, ) event.set_extra("activated_handlers", activated_handlers) diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py index b140d23ba..ea9c55228 100644 --- a/astrbot/core/pipeline/whitelist_check/stage.py +++ b/astrbot/core/pipeline/whitelist_check/stage.py @@ -1,9 +1,11 @@ -from ..stage import Stage, register_stage -from ..context import PipelineContext -from typing import AsyncGenerator, Union +from collections.abc import AsyncGenerator + +from astrbot.core import logger from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType -from astrbot.core import logger + +from ..context import PipelineContext +from ..stage import Stage, register_stage @register_stage @@ -27,8 +29,9 @@ class WhitelistCheckStage(Stage): self.wl_log = ctx.astrbot_config["platform_settings"]["id_whitelist_log"] async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + self, + event: AstrMessageEvent, + ) -> None | AsyncGenerator[None, None]: if not self.enable_whitelist_check: # 白名单检查未启用 return @@ -60,6 +63,6 @@ class WhitelistCheckStage(Stage): ): if self.wl_log: logger.info( - f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。" + f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。", ) event.stop_event() diff --git a/astrbot/core/platform/__init__.py b/astrbot/core/platform/__init__.py index 4007b2d90..30b94723e 100644 --- a/astrbot/core/platform/__init__.py +++ b/astrbot/core/platform/__init__.py @@ -1,14 +1,14 @@ -from .platform import Platform from .astr_message_event import AstrMessageEvent +from .astrbot_message import AstrBotMessage, Group, MessageMember, MessageType +from .platform import Platform from .platform_metadata import PlatformMetadata -from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group __all__ = [ - "Platform", - "AstrMessageEvent", - "PlatformMetadata", "AstrBotMessage", + "AstrMessageEvent", + "Group", "MessageMember", "MessageType", - "Group", + "Platform", + "PlatformMetadata", ] diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 3a4b8c128..9605eaffb 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,30 +1,31 @@ import abc import asyncio -import re import hashlib +import re import uuid - -from typing import List, Union, Optional, AsyncGenerator, Any +from collections.abc import AsyncGenerator +from typing import Any from astrbot import logger from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( - Plain, - Image, - BaseMessageComponent, - Face, At, AtAll, + BaseMessageComponent, + Face, Forward, + Image, + Plain, Reply, ) -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain +from astrbot.core.message.message_event_result import MessageChain, MessageEventResult from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric + from .astrbot_message import AstrBotMessage, Group +from .message_session import MessageSesion, MessageSession # noqa from .platform_metadata import PlatformMetadata -from .message_session import MessageSession, MessageSesion # noqa class AstrMessageEvent(abc.ABC): @@ -74,7 +75,8 @@ class AstrMessageEvent(abc.ABC): def get_platform_name(self): """获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。 - NOTE: 用户可能会同时运行多个相同类型的平台适配器。""" + NOTE: 用户可能会同时运行多个相同类型的平台适配器。 + """ return self.platform_meta.name def get_platform_id(self): @@ -85,12 +87,10 @@ class AstrMessageEvent(abc.ABC): return self.platform_meta.id def get_message_str(self) -> str: - """ - 获取消息字符串。 - """ + """获取消息字符串。""" return self.message_str - def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str: + def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str: outline = "" if not chain: return outline @@ -120,98 +120,69 @@ class AstrMessageEvent(abc.ABC): return outline def get_message_outline(self) -> str: - """ - 获取消息概要。 + """获取消息概要。 除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。 """ return self._outline_chain(self.message_obj.message) - def get_messages(self) -> List[BaseMessageComponent]: - """ - 获取消息链。 - """ + def get_messages(self) -> list[BaseMessageComponent]: + """获取消息链。""" return self.message_obj.message def get_message_type(self) -> MessageType: - """ - 获取消息类型。 - """ + """获取消息类型。""" return self.message_obj.type def get_session_id(self) -> str: - """ - 获取会话id。 - """ + """获取会话id。""" return self.session_id def get_group_id(self) -> str: - """ - 获取群组id。如果不是群组消息,返回空字符串。 - """ + """获取群组id。如果不是群组消息,返回空字符串。""" return self.message_obj.group_id def get_self_id(self) -> str: - """ - 获取机器人自身的id。 - """ + """获取机器人自身的id。""" return self.message_obj.self_id def get_sender_id(self) -> str: - """ - 获取消息发送者的id。 - """ + """获取消息发送者的id。""" return self.message_obj.sender.user_id def get_sender_name(self) -> str: - """ - 获取消息发送者的名称。(可能会返回空字符串) - """ + """获取消息发送者的名称。(可能会返回空字符串)""" return self.message_obj.sender.nickname def set_extra(self, key, value): - """ - 设置额外的信息。 - """ + """设置额外的信息。""" self._extras[key] = value def get_extra(self, key: str | None = None, default=None) -> Any: - """ - 获取额外的信息。 - """ + """获取额外的信息。""" if key is None: return self._extras return self._extras.get(key, default) def clear_extra(self): - """ - 清除额外的信息。 - """ + """清除额外的信息。""" logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}") self._extras.clear() def is_private_chat(self) -> bool: - """ - 是否是私聊。 - """ + """是否是私聊。""" return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value def is_wake_up(self) -> bool: - """ - 是否是唤醒机器人的事件。 - """ + """是否是唤醒机器人的事件。""" return self.is_wake def is_admin(self) -> bool: - """ - 是否是管理员。 - """ + """是否是管理员。""" return self.role == "admin" async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str: - """ - 将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。 - """ + """将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。""" while True: match = re.search(pattern, buffer) if not match: @@ -223,14 +194,16 @@ class AstrMessageEvent(abc.ABC): return buffer async def send_streaming( - self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False + self, + generator: AsyncGenerator[MessageChain, None], + use_fallback: bool = False, ): """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram,qq official 私聊。 Fallback仅支持 aiocqhttp。 """ asyncio.create_task( - Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name), ) self._has_send_oper = True @@ -240,7 +213,7 @@ class AstrMessageEvent(abc.ABC): async def _post_send(self): """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: Union[MessageEventResult, str]): + def set_result(self, result: MessageEventResult | str): """设置消息事件的结果。 Note: @@ -260,6 +233,7 @@ class AstrMessageEvent(abc.ABC): event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE)) return ``` + """ if isinstance(result, str): result = MessageEventResult().message(result) @@ -283,41 +257,32 @@ class AstrMessageEvent(abc.ABC): self._result.continue_event() def is_stopped(self) -> bool: - """ - 是否终止事件传播。 - """ + """是否终止事件传播。""" if self._result is None: return False # 默认是继续传播 return self._result.is_stopped() def should_call_llm(self, call_llm: bool): - """ - 是否在此消息事件中禁止默认的 LLM 请求。 + """是否在此消息事件中禁止默认的 LLM 请求。 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 """ self.call_llm = call_llm def get_result(self) -> MessageEventResult: - """ - 获取消息事件的结果。 - """ + """获取消息事件的结果。""" return self._result def clear_result(self): - """ - 清除消息事件的结果。 - """ + """清除消息事件的结果。""" self._result = None """消息链相关""" def make_result(self) -> MessageEventResult: - """ - 创建一个空的消息事件结果。 + """创建一个空的消息事件结果。 Example: - ```python # 纯文本回复 yield event.make_result().message("Hi") @@ -325,18 +290,16 @@ class AstrMessageEvent(abc.ABC): yield event.make_result().url_image("https://example.com/image.jpg") yield event.make_result().file_image("image.jpg") ``` + """ return MessageEventResult() def plain_result(self, text: str) -> MessageEventResult: - """ - 创建一个空的消息事件结果,只包含一条文本消息。 - """ + """创建一个空的消息事件结果,只包含一条文本消息。""" return MessageEventResult().message(text) def image_result(self, url_or_path: str) -> MessageEventResult: - """ - 创建一个空的消息事件结果,只包含一条图片消息。 + """创建一个空的消息事件结果,只包含一条图片消息。 根据开头是否包含 http 来判断是网络图片还是本地图片。 """ @@ -344,10 +307,8 @@ class AstrMessageEvent(abc.ABC): return MessageEventResult().url_image(url_or_path) return MessageEventResult().file_image(url_or_path) - def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult: - """ - 创建一个空的消息事件结果,包含指定的消息链。 - """ + def chain_result(self, chain: list[BaseMessageComponent]) -> MessageEventResult: + """创建一个空的消息事件结果,包含指定的消息链。""" mer = MessageEventResult() mer.chain = chain return mer @@ -359,13 +320,12 @@ class AstrMessageEvent(abc.ABC): prompt: str, func_tool_manager=None, session_id: str = None, - image_urls: List[str] = [], - contexts: List = [], + image_urls: list[str] = [], + contexts: list = [], system_prompt: str = "", conversation: Conversation = None, ) -> ProviderRequest: - """ - 创建一个 LLM 请求。 + """创建一个 LLM 请求。 Examples: ```py @@ -384,8 +344,8 @@ class AstrMessageEvent(abc.ABC): func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。 conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。 - """ + """ if len(contexts) > 0 and conversation: conversation = None @@ -406,20 +366,22 @@ class AstrMessageEvent(abc.ABC): Args: message (MessageChain): 消息链,具体使用方式请参考文档。 + """ # Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy. hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16) sid = str(uuid.UUID(bytes=hash_obj.digest())) asyncio.create_task( Metric.upload( - msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid - ) + msg_event_tick=1, + adapter_name=self.platform_meta.name, + sid=sid, + ), ) self._has_send_oper = True async def react(self, emoji: str): - """ - 对消息添加表情回应。 + """对消息添加表情回应。 默认实现为发送一条包含该表情的消息。 注意:此实现并不一定符合所有平台的原生“表情回应”行为。 @@ -427,11 +389,10 @@ class AstrMessageEvent(abc.ABC): """ await self.send(MessageChain([Plain(emoji)])) - async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]: + async def get_group(self, group_id: str = None, **kwargs) -> Group | None: """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 适配情况: - aiocqhttp(OneBotv11) """ - ... diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 1808c2911..dcc70b0f2 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -1,7 +1,8 @@ import time -from typing import List from dataclasses import dataclass + from astrbot.core.message.components import BaseMessageComponent + from .message_type import MessageType @@ -28,9 +29,9 @@ class Group: """群头像""" group_owner: str = None """群主 id""" - group_admins: List[str] = None + group_admins: list[str] = None """群管理员 id""" - members: List[MessageMember] = None + members: list[MessageMember] = None """所有群成员""" def __str__(self): @@ -47,9 +48,7 @@ class Group: class AstrBotMessage: - """ - AstrBot 的消息对象 - """ + """AstrBot 的消息对象""" type: MessageType # 消息类型 self_id: str # 机器人的识别id @@ -57,7 +56,7 @@ class AstrBotMessage: message_id: str # 消息id group: Group # 群组 sender: MessageMember # 发送者 - message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 + message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 raw_message: object timestamp: int # 消息时间戳 @@ -71,8 +70,7 @@ class AstrBotMessage: @property def group_id(self) -> str: - """ - 向后兼容的 group_id 属性 + """向后兼容的 group_id 属性 群组id,如果为私聊,则为空 """ if self.group: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 7090c669c..9ff892025 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -1,18 +1,19 @@ -import traceback import asyncio -from astrbot.core.config.astrbot_config import AstrBotConfig -from .platform import Platform -from typing import List +import traceback from asyncio import Queue -from .register import platform_cls_map + from astrbot.core import logger -from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map + +from .platform import Platform +from .register import platform_cls_map from .sources.webchat.webchat_adapter import WebChatAdapter class PlatformManager: def __init__(self, config: AstrBotConfig, event_queue: Queue): - self.platform_insts: List[Platform] = [] + self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" self._inst_map = {} @@ -36,7 +37,7 @@ class PlatformManager: webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) self.platform_insts.append(webchat_inst) asyncio.create_task( - self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")) + self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")), ) async def load_platform(self, platform_config: dict): @@ -47,7 +48,7 @@ class PlatformManager: return logger.info( - f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ..." + f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...", ) match platform_config["type"]: case "aiocqhttp": @@ -106,14 +107,14 @@ class PlatformManager: ) except (ImportError, ModuleNotFoundError) as e: logger.error( - f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。" + f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。", ) except Exception as e: logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") if platform_config["type"] not in platform_cls_map: logger.error( - f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误" + f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误", ) return cls_type = platform_cls_map[platform_config["type"]] @@ -129,16 +130,16 @@ class PlatformManager: asyncio.create_task( inst.run(), name=f"platform_{platform_config['type']}_{platform_config['id']}", - ) - ) + ), + ), ) handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnPlatformLoadedEvent + EventType.OnPlatformLoadedEvent, ) for handler in handlers: try: logger.info( - f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) await handler.handler() except Exception: @@ -180,7 +181,7 @@ class PlatformManager: inst for inst in self.platform_insts if inst.client_self_id == client_id - ) + ), ) except Exception: logger.warning(f"可能未完全移除 {platform_id} 平台适配器") diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index bf5a72a9a..62240b621 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -1,11 +1,13 @@ -from astrbot.core.platform.message_type import MessageType from dataclasses import dataclass +from astrbot.core.platform.message_type import MessageType + @dataclass class MessageSession: """描述一条消息在 AstrBot 中对应的会话的唯一标识。 - 如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。""" + 如果您需要实例化 MessageSession,请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。 + """ platform_name: str """平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。""" diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c109f29b4..3f36e17f3 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -1,13 +1,16 @@ import abc import uuid -from typing import Awaitable, Any from asyncio import Queue -from .platform_metadata import PlatformMetadata -from .astr_message_event import AstrMessageEvent +from collections.abc import Awaitable +from typing import Any + from astrbot.core.message.message_event_result import MessageChain -from .message_session import MessageSesion from astrbot.core.utils.metrics import Metric +from .astr_message_event import AstrMessageEvent +from .message_session import MessageSesion +from .platform_metadata import PlatformMetadata + class Platform(abc.ABC): def __init__(self, event_queue: Queue): @@ -18,42 +21,31 @@ class Platform(abc.ABC): @abc.abstractmethod def run(self) -> Awaitable[Any]: - """ - 得到一个平台的运行实例,需要返回一个协程对象。 - """ + """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError async def terminate(self): - """ - 终止一个平台的运行实例。 - """ - ... + """终止一个平台的运行实例。""" @abc.abstractmethod def meta(self) -> PlatformMetadata: - """ - 得到一个平台的元数据。 - """ + """得到一个平台的元数据。""" raise NotImplementedError async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ) -> Awaitable[Any]: - """ - 通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 + """通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 异步方法。 """ await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name) def commit_event(self, event: AstrMessageEvent): - """ - 提交一个事件到事件队列。 - """ + """提交一个事件到事件队列。""" self._event_queue.put_nowait(event) def get_client(self): - """ - 获取平台的客户端对象。 - """ - pass + """获取平台的客户端对象。""" diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index 97c33a43e..4cd62ede0 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -1,10 +1,10 @@ -from typing import List, Dict, Type -from .platform_metadata import PlatformMetadata from astrbot.core import logger -platform_registry: List[PlatformMetadata] = [] +from .platform_metadata import PlatformMetadata + +platform_registry: list[PlatformMetadata] = [] """维护了通过装饰器注册的平台适配器""" -platform_cls_map: Dict[str, Type] = {} +platform_cls_map: dict[str, type] = {} """维护了平台适配器名称和适配器类的映射""" @@ -24,7 +24,7 @@ def register_platform_adapter( def decorator(cls): if adapter_name in platform_cls_map: raise ValueError( - f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。" + f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。", ) # 添加必备选项 diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index b8bb723d5..5fe605a74 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,24 +1,31 @@ import asyncio import re -from typing import AsyncGenerator, Dict, List +from collections.abc import AsyncGenerator + from aiocqhttp import CQHttp, Event + from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( + BaseMessageComponent, + File, Image, Node, Nodes, Plain, Record, Video, - File, - BaseMessageComponent, ) from astrbot.api.platform import Group, MessageMember class AiocqhttpMessageEvent(AstrMessageEvent): def __init__( - self, message_str, message_obj, platform_meta, session_id, bot: CQHttp + self, + message_str, + message_obj, + platform_meta, + session_id, + bot: CQHttp, ): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -35,16 +42,15 @@ class AiocqhttpMessageEvent(AstrMessageEvent): "file": f"base64://{bs64}", }, } - elif isinstance(segment, File): + if isinstance(segment, File): # For File segments, we need to handle the file differently d = await segment.to_dict() return d - elif isinstance(segment, Video): + if isinstance(segment, Video): d = await segment.to_dict() return d - else: - # For other segments, we simply convert them to a dict by calling toDict - return segment.toDict() + # For other segments, we simply convert them to a dict by calling toDict + return segment.toDict() @staticmethod async def _parse_onebot_json(message_chain: MessageChain): @@ -78,7 +84,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await bot.send(event=event, message=messages) else: raise ValueError( - f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})" + f"无法发送消息:缺少有效的数字 session_id({session_id}) 或 event({event})", ) @classmethod @@ -98,8 +104,8 @@ class AiocqhttpMessageEvent(AstrMessageEvent): event (Event | None, optional): aiocqhttp 事件对象. is_group (bool, optional): 是否为群消息. session_id (str | None, optional): 会话 ID(群号或 QQ 号 - """ + """ # 转发消息、文件消息不能和普通消息混在一起发送 send_one_by_one = any( isinstance(seg, (Node, Nodes, File)) for seg in message_chain.chain @@ -152,7 +158,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await super().send(message) async def send_streaming( - self, generator: AsyncGenerator, use_fallback: bool = False + self, + generator: AsyncGenerator, + use_fallback: bool = False, ): if not use_fallback: buffer = None @@ -162,7 +170,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) @@ -198,7 +206,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent): group_id=group_id, ) - members: List[Dict] = await self.bot.call_action( + members: list[dict] = await self.bot.call_action( "get_group_member_list", group_id=group_id, ) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index d1992b6c3..bb9c4474a 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -1,33 +1,41 @@ -import time import asyncio -import logging -import uuid import itertools -from typing import Awaitable, Any +import logging +import time +import uuid +from collections.abc import Awaitable +from typing import Any + from aiocqhttp import CQHttp, Event +from aiocqhttp.exceptions import ActionFailed + +from astrbot.api import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import * from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot.api.event import MessageChain -from .aiocqhttp_message_event import * # noqa: F403 -from astrbot.api.message_components import * # noqa: F403 -from astrbot.api import logger -from .aiocqhttp_message_event import AiocqhttpMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion + from ...register import register_platform_adapter -from aiocqhttp.exceptions import ActionFailed +from .aiocqhttp_message_event import * +from .aiocqhttp_message_event import AiocqhttpMessageEvent @register_platform_adapter( - "aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。" + "aiocqhttp", + "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。", ) class AiocqhttpAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -48,7 +56,7 @@ class AiocqhttpAdapter(Platform): import_name="aiocqhttp", api_timeout_sec=180, access_token=platform_config.get( - "ws_reverse_token" + "ws_reverse_token", ), # 以防旧版本配置不存在 ) @@ -81,7 +89,9 @@ class AiocqhttpAdapter(Platform): logger.info("aiocqhttp(OneBot v11) 适配器已连接。") async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): is_group = session.message_type == MessageType.GROUP_MESSAGE if is_group: @@ -104,7 +114,7 @@ class AiocqhttpAdapter(Platform): abm = await self._convert_handle_message_event(event) if abm.sender.user_id == "2854196310": # 屏蔽 QQ 管家的消息 - return + return None elif event["post_type"] == "notice": abm = await self._convert_handle_notice_event(event) elif event["post_type"] == "request": @@ -118,7 +128,7 @@ class AiocqhttpAdapter(Platform): abm.self_id = str(event.self_id) abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) abm.type = MessageType.OTHER_MESSAGE - if "group_id" in event and event["group_id"]: + if event.get("group_id"): abm.type = MessageType.GROUP_MESSAGE abm.group_id = str(event.group_id) else: @@ -144,7 +154,7 @@ class AiocqhttpAdapter(Platform): abm.self_id = str(event.self_id) abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) abm.type = MessageType.OTHER_MESSAGE - if "group_id" in event and event["group_id"]: + if event.get("group_id"): abm.group_id = str(event.group_id) abm.type = MessageType.GROUP_MESSAGE else: @@ -167,12 +177,14 @@ class AiocqhttpAdapter(Platform): if "sub_type" in event: if event["sub_type"] == "poke" and "target_id" in event: - abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405 + abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) return abm async def _convert_handle_message_event( - self, event: Event, get_reply=True + self, + event: Event, + get_reply=True, ) -> AstrBotMessage: """OneBot V11 消息类事件 @@ -207,13 +219,13 @@ class AiocqhttpAdapter(Platform): message_str = "" if not isinstance(event.message, list): - err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。" + err = f"aiocqhttp: 无法识别的消息类型: {event.message!s},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。" logger.critical(err) try: self.bot.send(event, err) except BaseException as e: logger.error(f"回复消息失败: {e}") - return + return None # 按消息段类型类型适配 for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]): @@ -224,7 +236,7 @@ class AiocqhttpAdapter(Platform): # 如果文本段为空,则跳过 continue message_str += current_text - a = ComponentTypes[t](text=current_text) # noqa: F405 + a = ComponentTypes[t](text=current_text) abm.message.append(a) elif t == "file": @@ -264,7 +276,7 @@ class AiocqhttpAdapter(Platform): elif t == "reply": for m in m_group: if not get_reply: - a = ComponentTypes[t](**m["data"]) # noqa: F405 + a = ComponentTypes[t](**m["data"]) abm.message.append(a) else: try: @@ -277,11 +289,12 @@ class AiocqhttpAdapter(Platform): new_event = Event.from_payload(reply_event_data) if not new_event: logger.error( - f"无法从回复消息数据构造 Event 对象: {reply_event_data}" + f"无法从回复消息数据构造 Event 对象: {reply_event_data}", ) continue abm_reply = await self._convert_handle_message_event( - new_event, get_reply=False + new_event, + get_reply=False, ) reply_seg = Reply( @@ -298,7 +311,7 @@ class AiocqhttpAdapter(Platform): abm.message.append(reply_seg) except BaseException as e: logger.error(f"获取引用消息失败: {e}。") - a = ComponentTypes[t](**m["data"]) # noqa: F405 + a = ComponentTypes[t](**m["data"]) abm.message.append(a) elif t == "at": first_at_self_processed = False @@ -324,7 +337,8 @@ class AiocqhttpAdapter(Platform): no_cache=False, ) nickname = at_info.get("nick", "") or at_info.get( - "nickname", "" + "nickname", + "", ) is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"} @@ -332,7 +346,7 @@ class AiocqhttpAdapter(Platform): At( qq=m["data"]["qq"], name=nickname, - ) + ), ) if is_at_self and not first_at_self_processed: @@ -349,7 +363,7 @@ class AiocqhttpAdapter(Platform): logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。") else: for m in m_group: - a = ComponentTypes[t](**m["data"]) # noqa: F405 + a = ComponentTypes[t](**m["data"]) abm.message.append(a) abm.timestamp = int(time.time()) @@ -361,7 +375,7 @@ class AiocqhttpAdapter(Platform): def run(self) -> Awaitable[Any]: if not self.host or not self.port: logger.warning( - "aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199" + "aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199", ) self.host = "0.0.0.0" self.port = 6199 diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index ec79c89ae..43d231771 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -1,26 +1,28 @@ import asyncio import os +import threading import uuid + import aiohttp import dingtalk_stream -import threading +from dingtalk_stream import AckMessage +from astrbot import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import At, Image, Plain from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot.api.event import MessageChain -from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion -from .dingtalk_event import DingtalkMessageEvent -from ...register import register_platform_adapter -from astrbot import logger -from dingtalk_stream import AckMessage -from astrbot.core.utils.io import download_file from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_file + +from ...register import register_platform_adapter +from .dingtalk_event import DingtalkMessageEvent class MyEventHandler(dingtalk_stream.EventHandler): @@ -38,7 +40,10 @@ class MyEventHandler(dingtalk_stream.EventHandler): @register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器") class DingtalkPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -64,12 +69,15 @@ class DingtalkPlatformAdapter(Platform): client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger) client.register_all_event_handler(MyEventHandler()) client.register_callback_handler( - dingtalk_stream.ChatbotMessage.TOPIC, self.client + dingtalk_stream.ChatbotMessage.TOPIC, + self.client, ) self.client_ = client # 用于 websockets 的 client async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): raise NotImplementedError("钉钉机器人适配器不支持 send_by_session") @@ -81,7 +89,8 @@ class DingtalkPlatformAdapter(Platform): ) async def convert_msg( - self, message: dingtalk_stream.ChatbotMessage + self, + message: dingtalk_stream.ChatbotMessage, ) -> AstrBotMessage: abm = AstrBotMessage() abm.message = [] @@ -93,7 +102,8 @@ class DingtalkPlatformAdapter(Platform): else MessageType.FRIEND_MESSAGE ) abm.sender = MessageMember( - user_id=message.sender_id, nickname=message.sender_nick + user_id=message.sender_id, + nickname=message.sender_nick, ) abm.self_id = message.chatbot_user_id abm.message_id = message.message_id @@ -139,7 +149,10 @@ class DingtalkPlatformAdapter(Platform): return abm # 别忘了返回转换后的消息对象 async def download_ding_file( - self, download_code: str, robot_code: str, ext: str + self, + download_code: str, + robot_code: str, + ext: str, ) -> str: """下载钉钉文件 @@ -159,20 +172,22 @@ class DingtalkPlatformAdapter(Platform): } temp_dir = os.path.join(get_astrbot_data_path(), "temp") f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}") - async with aiohttp.ClientSession() as session: - async with session.post( + async with ( + aiohttp.ClientSession() as session, + session.post( "https://api.dingtalk.com/v1.0/robot/messageFiles/download", headers=headers, json=payload, - ) as resp: - if resp.status != 200: - logger.error( - f"下载钉钉文件失败: {resp.status}, {await resp.text()}" - ) - return None - resp_data = await resp.json() - download_url = resp_data["data"]["downloadUrl"] - await download_file(download_url, f_path) + ) as resp, + ): + if resp.status != 200: + logger.error( + f"下载钉钉文件失败: {resp.status}, {await resp.text()}", + ) + return None + resp_data = await resp.json() + download_url = resp_data["data"]["downloadUrl"] + await download_file(download_url, f_path) return f_path async def get_access_token(self) -> str: @@ -187,7 +202,7 @@ class DingtalkPlatformAdapter(Platform): ) as resp: if resp.status != 200: logger.error( - f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}" + f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}", ) return None return (await resp.json())["data"]["accessToken"] diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index 1e6ddd49f..a1cd9c1aa 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -1,8 +1,10 @@ import asyncio + import dingtalk_stream + import astrbot.api.message_components as Comp -from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot import logger +from astrbot.api.event import AstrMessageEvent, MessageChain class DingtalkMessageEvent(AstrMessageEvent): @@ -18,7 +20,9 @@ class DingtalkMessageEvent(AstrMessageEvent): self.client = client async def send_with_client( - self, client: dingtalk_stream.ChatbotHandler, message: MessageChain + self, + client: dingtalk_stream.ChatbotHandler, + message: MessageChain, ): for segment in message.chain: if isinstance(segment, Comp.Plain): @@ -69,7 +73,7 @@ class DingtalkMessageEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index 78894491f..0a2982ce8 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,7 +1,9 @@ -import discord -from astrbot import logger import sys +import discord + +from astrbot import logger + if sys.version_info >= (3, 12): from typing import override else: @@ -41,7 +43,8 @@ class DiscordBotClient(discord.Bot): await self.on_ready_once_callback() except Exception as e: logger.error( - f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True + f"[Discord] on_ready_once_callback 执行失败: {e}", + exc_info=True, ) def _create_message_data(self, message: discord.Message) -> dict: @@ -84,7 +87,7 @@ class DiscordBotClient(discord.Bot): return logger.debug( - f"[Discord] 收到原始消息 from {message.author.name}: {message.content}" + f"[Discord] 收到原始消息 from {message.author.name}: {message.content}", ) if self.on_message_received: @@ -103,12 +106,12 @@ class DiscordBotClient(discord.Bot): command_name = interaction_data.get("name", "") if options := interaction_data.get("options", []): params = " ".join( - [f"{opt['name']}:{opt.get('value', '')}" for opt in options] + [f"{opt['name']}:{opt.get('value', '')}" for opt in options], ) return f"/{command_name} {params}" return f"/{command_name}" - elif interaction_type == discord.InteractionType.component: + if interaction_type == discord.InteractionType.component: custom_id = interaction_data.get("custom_id", "") component_type = interaction_data.get("component_type", "") return f"component:{custom_id}:{component_type}" diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 07e712161..dbddd1686 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -1,5 +1,5 @@ import discord -from typing import List + from astrbot.api.message_components import BaseMessageComponent @@ -18,7 +18,7 @@ class DiscordEmbed(BaseMessageComponent): thumbnail: str = None, image: str = None, footer: str = None, - fields: List[dict] = None, + fields: list[dict] = None, ): self.title = title self.description = description @@ -96,7 +96,9 @@ class DiscordView(BaseMessageComponent): type: str = "discord_view" def __init__( - self, components: List[BaseMessageComponent] = None, timeout: float = None + self, + components: list[BaseMessageComponent] = None, + timeout: float = None, ): self.components = components or [] self.timeout = timeout @@ -108,7 +110,9 @@ class DiscordView(BaseMessageComponent): for component in self.components: if isinstance(component, DiscordButton): button_style = getattr( - discord.ButtonStyle, component.style, discord.ButtonStyle.primary + discord.ButtonStyle, + component.style, + discord.ButtonStyle.primary, ) if component.url: diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 6764eda61..5dc1fd8a6 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,30 +1,32 @@ import asyncio -import discord -import sys import re +import sys +from typing import Any + +import discord from discord.abc import Messageable from discord.channel import DMChannel + +from astrbot import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import File, Image, Plain from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, - PlatformMetadata, MessageType, + Platform, + PlatformMetadata, + register_platform_adapter, ) -from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image, File from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.api.platform import register_platform_adapter -from astrbot import logger -from .client import DiscordBotClient -from .discord_platform_event import DiscordPlatformEvent - -from typing import Any, Tuple from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.star import star_map from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry +from .client import DiscordBotClient +from .discord_platform_event import DiscordPlatformEvent + if sys.version_info >= (3, 12): from typing import override else: @@ -35,7 +37,10 @@ else: @register_platform_adapter("discord", "Discord 适配器 (基于 Pycord)") class DiscordPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self.config = platform_config @@ -51,7 +56,9 @@ class DiscordPlatformAdapter(Platform): @override async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): """通过会话发送消息""" # 创建一个 message_obj 以便在 event 中使用 @@ -71,14 +78,15 @@ class DiscordPlatformAdapter(Platform): message_obj.group_id = self._get_channel_id(channel) else: logger.warning( - f"[Discord] Can't get channel info for {channel_id_str}, will guess message type." + f"[Discord] Can't get channel info for {channel_id_str}, will guess message type.", ) message_obj.type = MessageType.GROUP_MESSAGE message_obj.group_id = session.session_id message_obj.message_str = message_chain.get_plain_text() message_obj.sender = MessageMember( - user_id=str(self.client_self_id), nickname=self.client.user.display_name + user_id=str(self.client_self_id), + nickname=self.client.user.display_name, ) message_obj.self_id = self.client_self_id message_obj.session_id = session.session_id @@ -149,7 +157,9 @@ class DiscordPlatformAdapter(Platform): logger.error(f"[Discord] 适配器运行时发生意外错误: {e}", exc_info=True) def _get_message_type( - self, channel: Messageable, guild_id: int | None = None + self, + channel: Messageable, + guild_id: int | None = None, ) -> MessageType: """根据 channel 对象和 guild_id 判断消息类型""" if guild_id is not None: @@ -201,7 +211,8 @@ class DiscordPlatformAdapter(Platform): abm.group_id = self._get_channel_id(message.channel) abm.message_str = content abm.sender = MessageMember( - user_id=str(message.author.id), nickname=message.author.display_name + user_id=str(message.author.id), + nickname=message.author.display_name, ) message_chain = [] if abm.message_str: @@ -209,14 +220,14 @@ class DiscordPlatformAdapter(Platform): if message.attachments: for attachment in message.attachments: if attachment.content_type and attachment.content_type.startswith( - "image/" + "image/", ): message_chain.append( - Image(file=attachment.url, filename=attachment.filename) + Image(file=attachment.url, filename=attachment.filename), ) else: message_chain.append( - File(name=attachment.filename, url=attachment.url) + File(name=attachment.filename, url=attachment.url), ) abm.message = message_chain abm.raw_message = message @@ -260,7 +271,7 @@ class DiscordPlatformAdapter(Platform): if hasattr(message.raw_message, "guild") and message.raw_message.guild: try: bot_member = message.raw_message.guild.get_member( - self.client.user.id + self.client.user.id, ) except Exception: bot_member = None @@ -346,7 +357,7 @@ class DiscordPlatformAdapter(Platform): description="指令的所有参数", type=discord.SlashCommandOptionType.string, required=False, - ) + ), ] # 创建SlashCommand @@ -362,7 +373,7 @@ class DiscordPlatformAdapter(Platform): if registered_commands: logger.info( - f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}" + f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}", ) else: logger.info("[Discord] 没有发现可注册的指令。") @@ -387,7 +398,7 @@ class DiscordPlatformAdapter(Platform): logger.debug( f"[Discord] 斜杠指令 '{cmd_name}' 被触发。 " f"原始参数: '{params}'. " - f"构建的指令字符串: '{message_str_for_filter}'" + f"构建的指令字符串: '{message_str_for_filter}'", ) # 尝试立即响应,防止超时 @@ -404,7 +415,8 @@ class DiscordPlatformAdapter(Platform): abm.group_id = self._get_channel_id(ctx.channel) abm.message_str = message_str_for_filter abm.sender = MessageMember( - user_id=str(ctx.author.id), nickname=ctx.author.display_name + user_id=str(ctx.author.id), + nickname=ctx.author.display_name, ) abm.message = [Plain(text=message_str_for_filter)] abm.raw_message = ctx.interaction @@ -419,8 +431,9 @@ class DiscordPlatformAdapter(Platform): @staticmethod def _extract_command_info( - event_filter: Any, handler_metadata: StarHandlerMetadata - ) -> Tuple[str, str, CommandFilter] | None: + event_filter: Any, + handler_metadata: StarHandlerMetadata, + ) -> tuple[str, str, CommandFilter] | None: """从事件过滤器中提取指令信息""" cmd_name = None # is_group = False diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 2c8d055fc..3c701c4ce 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -1,21 +1,22 @@ import asyncio -import discord import base64 +import sys from io import BytesIO from pathlib import Path -from typing import Optional -import sys +import discord + +from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata, At from astrbot.api.message_components import ( - Plain, - Image, - File, BaseMessageComponent, + File, + Image, + Plain, Reply, ) -from astrbot import logger +from astrbot.api.platform import AstrBotMessage, At, PlatformMetadata + from .client import DiscordBotClient from .components import DiscordEmbed, DiscordView @@ -41,7 +42,7 @@ class DiscordPlatformEvent(AstrMessageEvent): platform_meta: PlatformMetadata, session_id: str, client: DiscordBotClient, - interaction_followup_webhook: Optional[discord.Webhook] = None, + interaction_followup_webhook: discord.Webhook | None = None, ): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -50,7 +51,6 @@ class DiscordPlatformEvent(AstrMessageEvent): @override async def send(self, message: MessageChain): """发送消息到Discord平台""" - # 解析消息链为 Discord 所需的对象 try: ( @@ -90,20 +90,19 @@ class DiscordPlatformEvent(AstrMessageEvent): channel = await self._get_channel() if not channel: return - else: - await channel.send(**kwargs) + await channel.send(**kwargs) except Exception as e: logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True) await super().send(message) - async def _get_channel(self) -> Optional[discord.abc.Messageable]: + async def _get_channel(self) -> discord.abc.Messageable | None: """获取当前事件对应的频道对象""" try: channel_id = int(self.session_id) return self.client.get_channel( - channel_id + channel_id, ) or await self.client.fetch_channel(channel_id) except (ValueError, discord.errors.NotFound, discord.errors.Forbidden): logger.error(f"[Discord] 无法获取频道 {self.session_id}") @@ -112,7 +111,7 @@ class DiscordPlatformEvent(AstrMessageEvent): async def _parse_to_discord( self, message: MessageChain, - ) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]: + ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: """将 MessageChain 解析为 Discord 发送所需的内容""" content = "" files = [] @@ -146,13 +145,14 @@ class DiscordPlatformEvent(AstrMessageEvent): continue # 2. File URI - elif file_content.startswith("file:///"): + if file_content.startswith("file:///"): logger.debug(f"[Discord] 处理 File URI: {file_content}") path = Path(file_content[8:]) if await asyncio.to_thread(path.exists): file_bytes = await asyncio.to_thread(path.read_bytes) discord_file = discord.File( - BytesIO(file_bytes), filename=filename or path.name + BytesIO(file_bytes), + filename=filename or path.name, ) else: logger.warning(f"[Discord] 图片文件不存在: {path}") @@ -166,7 +166,8 @@ class DiscordPlatformEvent(AstrMessageEvent): b64_data += "=" * (4 - missing_padding) img_bytes = base64.b64decode(b64_data) discord_file = discord.File( - BytesIO(img_bytes), filename=filename or "image.png" + BytesIO(img_bytes), + filename=filename or "image.png", ) # 4. 裸 Base64 或本地路径 @@ -179,17 +180,19 @@ class DiscordPlatformEvent(AstrMessageEvent): b64_data += "=" * (4 - missing_padding) img_bytes = base64.b64decode(b64_data) discord_file = discord.File( - BytesIO(img_bytes), filename=filename or "image.png" + BytesIO(img_bytes), + filename=filename or "image.png", ) except (ValueError, TypeError, base64.binascii.Error): logger.debug( - f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}" + f"[Discord] 裸 Base64 解码失败,作为本地路径处理: {file_content}", ) path = Path(file_content) if await asyncio.to_thread(path.exists): file_bytes = await asyncio.to_thread(path.read_bytes) discord_file = discord.File( - BytesIO(file_bytes), filename=filename or path.name + BytesIO(file_bytes), + filename=filename or path.name, ) else: logger.warning(f"[Discord] 图片文件不存在: {path}") @@ -212,11 +215,11 @@ class DiscordPlatformEvent(AstrMessageEvent): if await asyncio.to_thread(path.exists): file_bytes = await asyncio.to_thread(path.read_bytes) files.append( - discord.File(BytesIO(file_bytes), filename=i.name) + discord.File(BytesIO(file_bytes), filename=i.name), ) else: logger.warning( - f"[Discord] 获取文件失败,路径不存在: {file_path_str}" + f"[Discord] 获取文件失败,路径不存在: {file_path_str}", ) else: logger.warning(f"[Discord] 获取文件失败: {i.name}") @@ -244,7 +247,8 @@ class DiscordPlatformEvent(AstrMessageEvent): """对原消息添加反应""" try: if hasattr(self.message_obj, "raw_message") and hasattr( - self.message_obj.raw_message, "add_reaction" + self.message_obj.raw_message, + "add_reaction", ): await self.message_obj.raw_message.add_reaction(emoji) except Exception as e: @@ -279,7 +283,8 @@ class DiscordPlatformEvent(AstrMessageEvent): def is_mentioned(self) -> bool: """判断机器人是否被@""" if hasattr(self.message_obj, "raw_message") and hasattr( - self.message_obj.raw_message, "mentions" + self.message_obj.raw_message, + "mentions", ): return any( mention.id == int(self.message_obj.self_id) @@ -290,7 +295,8 @@ class DiscordPlatformEvent(AstrMessageEvent): def get_mention_clean_content(self) -> str: """获取去除@后的清洁内容""" if hasattr(self.message_obj, "raw_message") and hasattr( - self.message_obj.raw_message, "clean_content" + self.message_obj.raw_message, + "clean_content", ): return self.message_obj.raw_message.clean_content return self.message_str diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 4a7ca0966..b59dbaca4 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -1,30 +1,35 @@ -import base64 import asyncio +import base64 import json import re import uuid -import astrbot.api.message_components as Comp +import lark_oapi as lark +from lark_oapi.api.im.v1 import * + +import astrbot.api.message_components as Comp +from astrbot import logger +from astrbot.api.event import MessageChain from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot.api.event import MessageChain from astrbot.core.platform.astr_message_event import MessageSesion -from .lark_event import LarkMessageEvent + from ...register import register_platform_adapter -from astrbot import logger -import lark_oapi as lark -from lark_oapi.api.im.v1 import * +from .lark_event import LarkMessageEvent @register_platform_adapter("lark", "飞书机器人官方 API 适配器") class LarkPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -65,14 +70,16 @@ class LarkPlatformAdapter(Platform): ) async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api) wrapped = { "zh_cn": { "title": "", "content": res, - } + }, } if session.message_type == MessageType.GROUP_MESSAGE: @@ -91,7 +98,7 @@ class LarkPlatformAdapter(Platform): .content(json.dumps(wrapped)) .msg_type("post") .uuid(str(uuid.uuid4())) - .build() + .build(), ) .build() ) @@ -160,7 +167,7 @@ class LarkPlatformAdapter(Platform): content_json_b = _ls elif message.message_type == "image": content_json_b = [ - {"tag": "img", "image_key": content_json_b["image_key"], "style": []} + {"tag": "img", "image_key": content_json_b["image_key"], "style": []}, ] if message.message_type in ("post", "image"): @@ -200,11 +207,10 @@ class LarkPlatformAdapter(Platform): abm.session_id = abm.group_id else: abm.session_id = abm.sender.user_id + elif abm.type == MessageType.GROUP_MESSAGE: + abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id else: - if abm.type == MessageType.GROUP_MESSAGE: - abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id - else: - abm.session_id = abm.sender.user_id + abm.session_id = abm.sender.user_id logger.debug(abm) await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 2174c497c..04204d35e 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,27 +1,34 @@ +import base64 import json import os import uuid -import base64 -import lark_oapi as lark from io import BytesIO -from typing import List -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Plain, Image as AstrBotImage, At -from astrbot.core.utils.io import download_image_by_url + +import lark_oapi as lark from lark_oapi.api.im.v1 import * + from astrbot import logger +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import At, Plain +from astrbot.api.message_components import Image as AstrBotImage from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_image_by_url class LarkMessageEvent(AstrMessageEvent): def __init__( - self, message_str, message_obj, platform_meta, session_id, bot: lark.Client + self, + message_str, + message_obj, + platform_meta, + session_id, + bot: lark.Client, ): super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @staticmethod - async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List: + async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list: ret = [] _stage = [] for comp in message.chain: @@ -58,7 +65,7 @@ class LarkMessageEvent(AstrMessageEvent): CreateImageRequestBody.builder() .image_type("message") .image(image_file) - .build() + .build(), ) .build() ) @@ -83,7 +90,7 @@ class LarkMessageEvent(AstrMessageEvent): "zh_cn": { "title": "", "content": res, - } + }, } request = ( @@ -95,7 +102,7 @@ class LarkMessageEvent(AstrMessageEvent): .msg_type("post") .uuid(str(uuid.uuid4())) .reply_in_thread(False) - .build() + .build(), ) .build() ) @@ -114,14 +121,14 @@ class LarkMessageEvent(AstrMessageEvent): .request_body( CreateMessageReactionRequestBody.builder() .reaction_type(Emoji.builder().emoji_type(emoji).build()) - .build() + .build(), ) .build() ) response = await self.bot.im.v1.message_reaction.acreate(request) if not response.success(): logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") - return None + return async def send_streaming(self, generator, use_fallback: bool = False): buffer = None @@ -131,7 +138,7 @@ class LarkMessageEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 981d05c82..0a553dc6f 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -1,7 +1,10 @@ import asyncio +import os import random -from typing import Dict, Any, Optional, Awaitable, List +from collections.abc import Awaitable +from typing import Any +import astrbot.api.message_components as Comp from astrbot.api import logger from astrbot.api.event import MessageChain from astrbot.api.platform import ( @@ -11,32 +14,31 @@ from astrbot.api.platform import ( register_platform_adapter, ) from astrbot.core.platform.astr_message_event import MessageSession -import astrbot.api.message_components as Comp from .misskey_api import MisskeyAPI -import os try: import magic # type: ignore except Exception: magic = None +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + from .misskey_event import MisskeyPlatformEvent from .misskey_utils import ( - serialize_message_chain, - resolve_message_visibility, - is_valid_user_session_id, - is_valid_room_session_id, add_at_mention_if_needed, - process_files, - extract_sender_info, - create_base_message, - process_at_mention, - format_poll, - cache_user_info, cache_room_info, + cache_user_info, + create_base_message, + extract_sender_info, + format_poll, + is_valid_room_session_id, + is_valid_user_session_id, + process_at_mention, + process_files, + resolve_message_visibility, + serialize_message_chain, ) -from astrbot.core.utils.astrbot_path import get_astrbot_data_path # Constants MAX_FILE_UPLOAD_COUNT = 16 @@ -46,7 +48,10 @@ DEFAULT_UPLOAD_CONCURRENCY = 3 @register_platform_adapter("misskey", "Misskey 平台适配器") class MisskeyPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self.config = platform_config or {} @@ -55,7 +60,8 @@ class MisskeyPlatformAdapter(Platform): self.access_token = self.config.get("misskey_token", "") self.max_message_length = self.config.get("max_message_length", 3000) self.default_visibility = self.config.get( - "misskey_default_visibility", "public" + "misskey_default_visibility", + "public", ) self.local_only = self.config.get("misskey_local_only", False) self.enable_chat = self.config.get("misskey_enable_chat", True) @@ -64,7 +70,7 @@ class MisskeyPlatformAdapter(Platform): # download / security related options (exposed to platform_config) self.allow_insecure_downloads = bool( - self.config.get("misskey_allow_insecure_downloads", False) + self.config.get("misskey_allow_insecure_downloads", False), ) # parse download timeout and chunk size safely _dt = self.config.get("misskey_download_timeout") @@ -87,7 +93,7 @@ class MisskeyPlatformAdapter(Platform): self.unique_session = platform_settings["unique_session"] - self.api: Optional[MisskeyAPI] = None + self.api: MisskeyAPI | None = None self._running = False self.client_self_id = "" self._bot_username = "" @@ -136,7 +142,7 @@ class MisskeyPlatformAdapter(Platform): self.client_self_id = str(user_info.get("id", "")) self._bot_username = user_info.get("username", "") logger.info( - f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})" + f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})", ) except Exception as e: logger.error(f"[Misskey] 获取用户信息失败: {e}") @@ -153,12 +159,17 @@ class MisskeyPlatformAdapter(Platform): if self.enable_chat: streaming.add_message_handler("newChatMessage", self._handle_chat_message) streaming.add_message_handler( - "messaging:newChatMessage", self._handle_chat_message + "messaging:newChatMessage", + self._handle_chat_message, ) streaming.add_message_handler("_debug", self._debug_handler) async def _send_text_only_message( - self, session_id: str, text: str, session, message_chain + self, + session_id: str, + text: str, + session, + message_chain, ): """发送纯文本消息(无文件上传)""" if not self.api: @@ -168,7 +179,7 @@ class MisskeyPlatformAdapter(Platform): from .misskey_utils import extract_user_id_from_session_id user_id = extract_user_id_from_session_id(session_id) - payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + payload: dict[str, Any] = {"toUserId": user_id, "text": text} await self.api.send_message(payload) elif session_id and is_valid_room_session_id(session_id): from .misskey_utils import extract_room_id_from_session_id @@ -180,14 +191,17 @@ class MisskeyPlatformAdapter(Platform): return await super().send_by_session(session, message_chain) def _process_poll_data( - self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str] + self, + message: AstrBotMessage, + poll: dict[str, Any], + message_parts: list[str], ): """处理投票数据,将其添加到消息中""" try: if not isinstance(message.raw_message, dict): message.raw_message = {} message.raw_message["poll"] = poll - setattr(message, "poll", poll) + message.poll = poll except Exception: pass @@ -196,25 +210,26 @@ class MisskeyPlatformAdapter(Platform): message.message.append(Comp.Plain(poll_text)) message_parts.append(poll_text) - def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]: + def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: """从会话和消息链中提取额外字段""" fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None} for comp in message_chain.chain: if hasattr(comp, "cw") and getattr(comp, "cw", None): - fields["cw"] = getattr(comp, "cw") + fields["cw"] = comp.cw break if hasattr(session, "extra_data") and isinstance( - getattr(session, "extra_data", None), dict + getattr(session, "extra_data", None), + dict, ): - extra_data = getattr(session, "extra_data") + extra_data = session.extra_data fields.update( { "poll": extra_data.get("poll"), "renote_id": extra_data.get("renote_id"), "channel_id": extra_data.get("channel_id"), - } + }, ) return fields @@ -237,7 +252,7 @@ class MisskeyPlatformAdapter(Platform): if await streaming.connect(): logger.info( - f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})" + f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})", ) connection_attempts = 0 await streaming.subscribe_channel("main") @@ -250,34 +265,34 @@ class MisskeyPlatformAdapter(Platform): await streaming.listen() else: logger.error( - f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})" + f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})", ) except Exception as e: logger.error( - f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}" + f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}", ) if self._running: jitter = random.uniform(0, 1.0) sleep_time = backoff_delay + jitter logger.info( - f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})" + f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})", ) await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: Dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]): try: notification_type = data.get("type") logger.debug( - f"[Misskey] 收到通知事件: type={notification_type}, user_id={data.get('userId', 'unknown')}" + f"[Misskey] 收到通知事件: type={notification_type}, user_id={data.get('userId', 'unknown')}", ) if notification_type in ["mention", "reply", "quote"]: note = data.get("note") if note and self._is_bot_mentioned(note): logger.info( - f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..." + f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}...", ) message = await self.convert_message(note) event = MisskeyPlatformEvent( @@ -291,14 +306,14 @@ class MisskeyPlatformAdapter(Platform): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: Dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]): try: sender_id = str( - data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "") + data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""), ) room_id = data.get("toRoomId") logger.debug( - f"[Misskey] 收到聊天事件: sender_id={sender_id}, room_id={room_id}, is_self={sender_id == self.client_self_id}" + f"[Misskey] 收到聊天事件: sender_id={sender_id}, room_id={room_id}, is_self={sender_id == self.client_self_id}", ) if sender_id == self.client_self_id: return @@ -306,7 +321,7 @@ class MisskeyPlatformAdapter(Platform): if room_id: raw_text = data.get("text", "") logger.debug( - f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'" + f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'", ) message = await self.convert_room_message(data) @@ -326,13 +341,13 @@ class MisskeyPlatformAdapter(Platform): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: Dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]): event_type = data.get("type", "unknown") logger.debug( - f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}" + f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}", ) - def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool: + def _is_bot_mentioned(self, note: dict[str, Any]) -> bool: text = note.get("text", "") if not text: return False @@ -352,7 +367,9 @@ class MisskeyPlatformAdapter(Platform): return False async def send_by_session( - self, session: MessageSession, message_chain: MessageChain + self, + session: MessageSession, + message_chain: MessageChain, ) -> Awaitable[Any]: if not self.api: logger.error("[Misskey] API 客户端未初始化") @@ -394,30 +411,33 @@ class MisskeyPlatformAdapter(Platform): if not has_file_components: logger.warning("[Misskey] 消息内容为空且无文件组件,跳过发送") return await super().send_by_session(session, message_chain) - else: - text = "" + text = "" if len(text) > self.max_message_length: text = text[: self.max_message_length] + "..." - file_ids: List[str] = [] - fallback_urls: List[str] = [] + file_ids: list[str] = [] + fallback_urls: list[str] = [] if not self.enable_file_upload: return await self._send_text_only_message( - session_id, text, session, message_chain + session_id, + text, + session, + message_chain, ) MAX_UPLOAD_CONCURRENCY = 10 upload_concurrency = int( self.config.get( - "misskey_upload_concurrency", DEFAULT_UPLOAD_CONCURRENCY - ) + "misskey_upload_concurrency", + DEFAULT_UPLOAD_CONCURRENCY, + ), ) upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY) sem = asyncio.Semaphore(upload_concurrency) - async def _upload_comp(comp) -> Optional[object]: + async def _upload_comp(comp) -> object | None: """组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)""" from .misskey_utils import ( resolve_component_url_or_path, @@ -432,14 +452,16 @@ class MisskeyPlatformAdapter(Platform): # 解析组件的 URL 或本地路径 url_candidate, local_path = await resolve_component_url_or_path( - comp + comp, ) if not url_candidate and not local_path: return None preferred_name = getattr(comp, "name", None) or getattr( - comp, "file", None + comp, + "file", + None, ) # URL 上传:下载后本地上传 @@ -479,7 +501,7 @@ class MisskeyPlatformAdapter(Platform): if local_path and isinstance(local_path, str): data_temp = os.path.join(get_astrbot_data_path(), "temp") if local_path.startswith(data_temp) and os.path.exists( - local_path + local_path, ): try: os.remove(local_path) @@ -508,7 +530,7 @@ class MisskeyPlatformAdapter(Platform): if len(file_components) > MAX_FILE_UPLOAD_COUNT: logger.warning( - f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件" + f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件", ) file_components = file_components[:MAX_FILE_UPLOAD_COUNT] @@ -540,7 +562,7 @@ class MisskeyPlatformAdapter(Platform): if fallback_urls: appended = "\n" + "\n".join(fallback_urls) text = (text or "") + appended - payload: Dict[str, Any] = {"toRoomId": room_id, "text": text} + payload: dict[str, Any] = {"toRoomId": room_id, "text": text} if file_ids: payload["fileIds"] = file_ids await self.api.send_room_message(payload) @@ -555,13 +577,13 @@ class MisskeyPlatformAdapter(Platform): if fallback_urls: appended = "\n" + "\n".join(fallback_urls) text = (text or "") + appended - payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + payload: dict[str, Any] = {"toUserId": user_id, "text": text} if file_ids: # 聊天消息只支持单个文件,使用 fileId 而不是 fileIds payload["fileId"] = file_ids[0] if len(file_ids) > 1: logger.warning( - f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件" + f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件", ) await self.api.send_message(payload) else: @@ -581,7 +603,7 @@ class MisskeyPlatformAdapter(Platform): default_visibility=self.default_visibility, ) logger.debug( - f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}" + f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}", ) fields = self._extract_additional_fields(session, message_chain) @@ -610,7 +632,7 @@ class MisskeyPlatformAdapter(Platform): return await super().send_by_session(session, message_chain) - async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 贴文数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=False) message = create_base_message( @@ -621,7 +643,11 @@ class MisskeyPlatformAdapter(Platform): unique_session=self.unique_session, ) cache_user_info( - self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False + self._user_cache, + sender_info, + raw_data, + self.client_self_id, + is_chat=False, ) message_parts = [] @@ -629,7 +655,10 @@ class MisskeyPlatformAdapter(Platform): if raw_text: text_parts, processed_text = process_at_mention( - message, raw_text, self._bot_username, self.client_self_id + message, + raw_text, + self._bot_username, + self.client_self_id, ) message_parts.extend(text_parts) @@ -652,7 +681,7 @@ class MisskeyPlatformAdapter(Platform): ) return message - async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_chat_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 聊天消息数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=True) message = create_base_message( @@ -663,7 +692,11 @@ class MisskeyPlatformAdapter(Platform): unique_session=self.unique_session, ) cache_user_info( - self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True + self._user_cache, + sender_info, + raw_data, + self.client_self_id, + is_chat=True, ) raw_text = raw_data.get("text", "") @@ -676,7 +709,7 @@ class MisskeyPlatformAdapter(Platform): message.message_str = raw_text if raw_text else "" return message - async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 群聊消息数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=True) room_id = raw_data.get("toRoomId", "") @@ -690,7 +723,11 @@ class MisskeyPlatformAdapter(Platform): ) cache_user_info( - self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False + self._user_cache, + sender_info, + raw_data, + self.client_self_id, + is_chat=False, ) cache_room_info(self._user_cache, raw_data, self.client_self_id) @@ -700,7 +737,10 @@ class MisskeyPlatformAdapter(Platform): if raw_text: if self._bot_username and f"@{self._bot_username}" in raw_text: text_parts, processed_text = process_at_mention( - message, raw_text, self._bot_username, self.client_self_id + message, + raw_text, + self._bot_username, + self.client_self_id, ) message_parts.extend(text_parts) else: diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 4b920508f..06dc6304d 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -1,18 +1,20 @@ +import asyncio import json import random -import asyncio -from typing import Any, Optional, Dict, List, Callable, Awaitable import uuid +from collections.abc import Awaitable, Callable +from typing import Any try: import aiohttp import websockets except ImportError as e: raise ImportError( - "aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets" + "aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets", ) from e from astrbot.api import logger + from .misskey_utils import FileIDExtractor # Constants @@ -23,54 +25,47 @@ HTTP_OK = 200 class APIError(Exception): """Misskey API 基础异常""" - pass - class APIConnectionError(APIError): """网络连接异常""" - pass - class APIRateLimitError(APIError): """API 频率限制异常""" - pass - class AuthenticationError(APIError): """认证失败异常""" - pass - class WebSocketError(APIError): """WebSocket 连接异常""" - pass - class StreamingClient: def __init__(self, instance_url: str, access_token: str): self.instance_url = instance_url.rstrip("/") self.access_token = access_token - self.websocket: Optional[Any] = None + self.websocket: Any | None = None self.is_connected = False - self.message_handlers: Dict[str, Callable] = {} - self.channels: Dict[str, str] = {} - self.desired_channels: Dict[str, Optional[Dict]] = {} + self.message_handlers: dict[str, Callable] = {} + self.channels: dict[str, str] = {} + self.desired_channels: dict[str, dict | None] = {} self._running = False self._last_pong = None async def connect(self) -> bool: try: ws_url = self.instance_url.replace("https://", "wss://").replace( - "http://", "ws://" + "http://", + "ws://", ) ws_url += f"/streaming?i={self.access_token}" self.websocket = await websockets.connect( - ws_url, ping_interval=30, ping_timeout=10 + ws_url, + ping_interval=30, + ping_timeout=10, ) self.is_connected = True self._running = True @@ -84,7 +79,7 @@ class StreamingClient: await self.subscribe_channel(channel_type, params) except Exception as e: logger.warning( - f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}" + f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}", ) except Exception: pass @@ -104,7 +99,9 @@ class StreamingClient: logger.info("[Misskey WebSocket] 连接已断开") async def subscribe_channel( - self, channel_type: str, params: Optional[Dict] = None + self, + channel_type: str, + params: dict | None = None, ) -> str: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") @@ -136,7 +133,9 @@ class StreamingClient: self.desired_channels.pop(channel_type, None) def add_message_handler( - self, event_type: str, handler: Callable[[Dict], Awaitable[None]] + self, + event_type: str, + handler: Callable[[dict], Awaitable[None]], ): self.message_handlers[event_type] = handler @@ -166,7 +165,7 @@ class StreamingClient: pass except websockets.exceptions.ConnectionClosed as e: logger.warning( - f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})" + f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})", ) self.is_connected = False try: @@ -188,11 +187,11 @@ class StreamingClient: except Exception: pass - async def _handle_message(self, data: Dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]): message_type = data.get("type") body = data.get("body", {}) - def _build_channel_summary(message_type: Optional[str], body: Any) -> str: + def _build_channel_summary(message_type: str | None, body: Any) -> str: try: if not isinstance(body, dict): return f"[Misskey WebSocket] 收到消息类型: {message_type}" @@ -228,7 +227,7 @@ class StreamingClient: event_body = body.get("body", {}) logger.debug( - f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}" + f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}", ) if channel_id in self.channels: @@ -243,7 +242,7 @@ class StreamingClient: await self.message_handlers[event_type](event_body) else: logger.debug( - f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}" + f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}", ) if "_debug" in self.message_handlers: await self.message_handlers["_debug"]( @@ -251,7 +250,7 @@ class StreamingClient: "type": event_type, "body": event_body, "channel": channel_type, - } + }, ) elif message_type in self.message_handlers: @@ -269,14 +268,14 @@ def retry_async( backoff_base: float = 1.0, max_backoff: float = 30.0, ): - """ - 智能异步重试装饰器 + """智能异步重试装饰器 Args: max_retries: 最大重试次数 retryable_exceptions: 可重试的异常类型 backoff_base: 退避基数 max_backoff: 最大退避时间 + """ def decorator(func): @@ -291,7 +290,7 @@ def retry_async( last_exc = e if attempt == max_retries: logger.error( - f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}" + f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}", ) break @@ -308,7 +307,7 @@ def retry_async( logger.warning( f"[Misskey API] {func_name} 第 {attempt} 次重试失败: {e}," - f"{sleep_time:.1f}s后重试" + f"{sleep_time:.1f}s后重试", ) await asyncio.sleep(sleep_time) continue @@ -334,12 +333,12 @@ class MisskeyAPI: allow_insecure_downloads: bool = False, download_timeout: int = 15, chunk_size: int = 64 * 1024, - max_download_bytes: Optional[int] = None, + max_download_bytes: int | None = None, ): self.instance_url = instance_url.rstrip("/") self.access_token = access_token - self._session: Optional[aiohttp.ClientSession] = None - self.streaming: Optional[StreamingClient] = None + self._session: aiohttp.ClientSession | None = None + self.streaming: StreamingClient | None = None # download options self.allow_insecure_downloads = allow_insecure_downloads self.download_timeout = download_timeout @@ -381,39 +380,40 @@ class MisskeyAPI: if status == 400: logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})") raise APIError(f"Bad request for {endpoint}") - elif status == 401: + if status == 401: logger.error(f"[Misskey API] 未授权访问: {endpoint} (HTTP {status})") raise AuthenticationError(f"Unauthorized access for {endpoint}") - elif status == 403: + if status == 403: logger.error(f"[Misskey API] 访问被禁止: {endpoint} (HTTP {status})") raise AuthenticationError(f"Forbidden access for {endpoint}") - elif status == 404: + if status == 404: logger.error(f"[Misskey API] 资源不存在: {endpoint} (HTTP {status})") raise APIError(f"Resource not found for {endpoint}") - elif status == 413: + if status == 413: logger.error(f"[Misskey API] 请求体过大: {endpoint} (HTTP {status})") raise APIError(f"Request entity too large for {endpoint}") - elif status == 429: + if status == 429: logger.warning(f"[Misskey API] 请求频率限制: {endpoint} (HTTP {status})") raise APIRateLimitError(f"Rate limit exceeded for {endpoint}") - elif status == 500: + if status == 500: logger.error(f"[Misskey API] 服务器内部错误: {endpoint} (HTTP {status})") raise APIConnectionError(f"Internal server error for {endpoint}") - elif status == 502: + if status == 502: logger.error(f"[Misskey API] 网关错误: {endpoint} (HTTP {status})") raise APIConnectionError(f"Bad gateway for {endpoint}") - elif status == 503: + if status == 503: logger.error(f"[Misskey API] 服务不可用: {endpoint} (HTTP {status})") raise APIConnectionError(f"Service unavailable for {endpoint}") - elif status == 504: + if status == 504: logger.error(f"[Misskey API] 网关超时: {endpoint} (HTTP {status})") raise APIConnectionError(f"Gateway timeout for {endpoint}") - else: - logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})") - raise APIConnectionError(f"HTTP {status} for {endpoint}") + logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})") + raise APIConnectionError(f"HTTP {status} for {endpoint}") async def _process_response( - self, response: aiohttp.ClientResponse, endpoint: str + self, + response: aiohttp.ClientResponse, + endpoint: str, ) -> Any: """处理 API 响应""" if response.status == HTTP_OK: @@ -429,7 +429,7 @@ class MisskeyAPI: ) if notifications_data: logger.debug( - f"[Misskey API] 获取到 {len(notifications_data)} 条新通知" + f"[Misskey API] 获取到 {len(notifications_data)} 条新通知", ) else: logger.debug(f"[Misskey API] 请求成功: {endpoint}") @@ -441,11 +441,11 @@ class MisskeyAPI: try: error_text = await response.text() logger.error( - f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}" + f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}", ) except Exception: logger.error( - f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}" + f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}", ) self._handle_response_status(response.status, endpoint) @@ -456,7 +456,9 @@ class MisskeyAPI: retryable_exceptions=(APIConnectionError, APIRateLimitError), ) async def _make_request( - self, endpoint: str, data: Optional[Dict[str, Any]] = None + self, + endpoint: str, + data: dict[str, Any] | None = None, ) -> Any: url = f"{self.instance_url}/api/{endpoint}" payload = {"i": self.access_token} @@ -472,24 +474,24 @@ class MisskeyAPI: async def create_note( self, - text: Optional[str] = None, + text: str | None = None, visibility: str = "public", - reply_id: Optional[str] = None, - visible_user_ids: Optional[List[str]] = None, - file_ids: Optional[List[str]] = None, + reply_id: str | None = None, + visible_user_ids: list[str] | None = None, + file_ids: list[str] | None = None, local_only: bool = False, - cw: Optional[str] = None, - poll: Optional[Dict[str, Any]] = None, - renote_id: Optional[str] = None, - channel_id: Optional[str] = None, - reaction_acceptance: Optional[str] = None, - no_extract_mentions: Optional[bool] = None, - no_extract_hashtags: Optional[bool] = None, - no_extract_emojis: Optional[bool] = None, - media_ids: Optional[List[str]] = None, - ) -> Dict[str, Any]: + cw: str | None = None, + poll: dict[str, Any] | None = None, + renote_id: str | None = None, + channel_id: str | None = None, + reaction_acceptance: str | None = None, + no_extract_mentions: bool | None = None, + no_extract_hashtags: bool | None = None, + no_extract_emojis: bool | None = None, + media_ids: list[str] | None = None, + ) -> dict[str, Any]: """Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API.""" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} if text is not None: data["text"] = text @@ -537,9 +539,9 @@ class MisskeyAPI: async def upload_file( self, file_path: str, - name: Optional[str] = None, - folder_id: Optional[str] = None, - ) -> Dict[str, Any]: + name: str | None = None, + folder_id: str | None = None, + ) -> dict[str, Any]: """Upload a file to Misskey drive/files/create and return a dict containing id and raw result.""" if not file_path: raise APIError("No file path provided for upload") @@ -565,7 +567,7 @@ class MisskeyAPI: result = await self._process_response(resp, "drive/files/create") file_id = FileIDExtractor.extract_file_id(result) logger.debug( - f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}" + f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", ) return {"id": file_id, "raw": result} finally: @@ -574,7 +576,7 @@ class MisskeyAPI: logger.error(f"[Misskey API] 文件上传网络错误: {e}") raise APIConnectionError(f"Upload failed: {e}") from e - async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]: + async def find_files_by_hash(self, md5_hash: str) -> list[dict[str, Any]]: """Find files by MD5 hash""" if not md5_hash: raise APIError("No MD5 hash provided for find-by-hash") @@ -585,7 +587,7 @@ class MisskeyAPI: logger.debug(f"[Misskey API] find-by-hash 请求: md5={md5_hash}") result = await self._make_request("drive/files/find-by-hash", data) logger.debug( - f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件" + f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件", ) return result if isinstance(result, list) else [] except Exception as e: @@ -593,13 +595,15 @@ class MisskeyAPI: raise async def find_files_by_name( - self, name: str, folder_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, + name: str, + folder_id: str | None = None, + ) -> list[dict[str, Any]]: """Find files by name""" if not name: raise APIError("No name provided for find") - data: Dict[str, Any] = {"name": name} + data: dict[str, Any] = {"name": name} if folder_id: data["folderId"] = folder_id @@ -607,7 +611,7 @@ class MisskeyAPI: logger.debug(f"[Misskey API] find 请求: name={name}, folder_id={folder_id}") result = await self._make_request("drive/files/find", data) logger.debug( - f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件" + f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件", ) return result if isinstance(result, list) else [] except Exception as e: @@ -617,11 +621,11 @@ class MisskeyAPI: async def find_files( self, limit: int = 10, - folder_id: Optional[str] = None, - type: Optional[str] = None, - ) -> List[Dict[str, Any]]: + folder_id: str | None = None, + type: str | None = None, + ) -> list[dict[str, Any]]: """List files with optional filters""" - data: Dict[str, Any] = {"limit": limit} + data: dict[str, Any] = {"limit": limit} if folder_id is not None: data["folderId"] = folder_id if type is not None: @@ -629,11 +633,11 @@ class MisskeyAPI: try: logger.debug( - f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}" + f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}", ) result = await self._make_request("drive/files", data) logger.debug( - f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件" + f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件", ) return result if isinstance(result, list) else [] except Exception as e: @@ -641,27 +645,34 @@ class MisskeyAPI: raise async def _download_with_existing_session( - self, url: str, ssl_verify: bool = True - ) -> Optional[bytes]: + self, + url: str, + ssl_verify: bool = True, + ) -> bytes | None: """使用现有会话下载文件""" if not (hasattr(self, "session") and self.session): raise APIConnectionError("No existing session available") async with self.session.get( - url, timeout=aiohttp.ClientTimeout(total=15), ssl=ssl_verify + url, + timeout=aiohttp.ClientTimeout(total=15), + ssl=ssl_verify, ) as response: if response.status == 200: return await response.read() return None async def _download_with_temp_session( - self, url: str, ssl_verify: bool = True - ) -> Optional[bytes]: + self, + url: str, + ssl_verify: bool = True, + ) -> bytes | None: """使用临时会话下载文件""" connector = aiohttp.TCPConnector(ssl=ssl_verify) async with aiohttp.ClientSession(connector=connector) as temp_session: async with temp_session.get( - url, timeout=aiohttp.ClientTimeout(total=15) + url, + timeout=aiohttp.ClientTimeout(total=15), ) as response: if response.status == 200: return await response.read() @@ -670,13 +681,12 @@ class MisskeyAPI: async def upload_and_find_file( self, url: str, - name: Optional[str] = None, - folder_id: Optional[str] = None, + name: str | None = None, + folder_id: str | None = None, max_wait_time: float = 30.0, check_interval: float = 2.0, - ) -> Optional[Dict[str, Any]]: - """ - 简化的文件上传:尝试 URL 上传,失败则下载后本地上传 + ) -> dict[str, Any] | None: + """简化的文件上传:尝试 URL 上传,失败则下载后本地上传 Args: url: 文件URL @@ -687,28 +697,31 @@ class MisskeyAPI: Returns: 包含文件ID和元信息的字典,失败时返回None + """ if not url: raise APIError("URL不能为空") # 通过本地上传获取即时文件 ID(下载文件 → 上传 → 返回 ID) try: - import tempfile import os + import tempfile # SSL 验证下载,失败则重试不验证 SSL tmp_bytes = None try: tmp_bytes = await self._download_with_existing_session( - url, ssl_verify=True + url, + ssl_verify=True, ) or await self._download_with_temp_session(url, ssl_verify=True) except Exception as ssl_error: logger.debug( - f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL" + f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL", ) try: tmp_bytes = await self._download_with_existing_session( - url, ssl_verify=False + url, + ssl_verify=False, ) or await self._download_with_temp_session(url, ssl_verify=False) except Exception: pass @@ -732,13 +745,15 @@ class MisskeyAPI: return None - async def get_current_user(self) -> Dict[str, Any]: + async def get_current_user(self) -> dict[str, Any]: """获取当前用户信息""" return await self._make_request("i", {}) async def send_message( - self, user_id_or_payload: Any, text: Optional[str] = None - ) -> Dict[str, Any]: + self, + user_id_or_payload: Any, + text: str | None = None, + ) -> dict[str, Any]: """发送聊天消息。 Accepts either (user_id: str, text: str) or a single dict payload prepared by caller. @@ -754,8 +769,10 @@ class MisskeyAPI: return result async def send_room_message( - self, room_id_or_payload: Any, text: Optional[str] = None - ) -> Dict[str, Any]: + self, + room_id_or_payload: Any, + text: str | None = None, + ) -> dict[str, Any]: """发送房间消息。 Accepts either (room_id: str, text: str) or a single dict payload. @@ -771,10 +788,13 @@ class MisskeyAPI: return result async def get_messages( - self, user_id: str, limit: int = 10, since_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, + user_id: str, + limit: int = 10, + since_id: str | None = None, + ) -> list[dict[str, Any]]: """获取聊天消息历史""" - data: Dict[str, Any] = {"userId": user_id, "limit": limit} + data: dict[str, Any] = {"userId": user_id, "limit": limit} if since_id: data["sinceId"] = since_id @@ -785,10 +805,12 @@ class MisskeyAPI: return [] async def get_mentions( - self, limit: int = 10, since_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, + limit: int = 10, + since_id: str | None = None, + ) -> list[dict[str, Any]]: """获取提及通知""" - data: Dict[str, Any] = {"limit": limit} + data: dict[str, Any] = {"limit": limit} if since_id: data["sinceId"] = since_id data["includeTypes"] = ["mention", "reply", "quote"] @@ -796,23 +818,21 @@ class MisskeyAPI: result = await self._make_request("i/notifications", data) if isinstance(result, list): return result - elif isinstance(result, dict) and "notifications" in result: + if isinstance(result, dict) and "notifications" in result: return result["notifications"] - else: - logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}") - return [] + logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}") + return [] async def send_message_with_media( self, message_type: str, target_id: str, - text: Optional[str] = None, - media_urls: Optional[List[str]] = None, - local_files: Optional[List[str]] = None, + text: str | None = None, + media_urls: list[str] | None = None, + local_files: list[str] | None = None, **kwargs, - ) -> Dict[str, Any]: - """ - 通用消息发送函数:统一处理文本+媒体发送 + ) -> dict[str, Any]: + """通用消息发送函数:统一处理文本+媒体发送 Args: message_type: 消息类型 ('chat', 'room', 'note') @@ -827,6 +847,7 @@ class MisskeyAPI: Raises: APIError: 参数错误或发送失败 + """ if not text and not media_urls and not local_files: raise APIError("消息内容不能为空:需要文本或媒体文件") @@ -843,10 +864,14 @@ class MisskeyAPI: # 根据消息类型发送 return await self._dispatch_message( - message_type, target_id, text, file_ids, **kwargs + message_type, + target_id, + text, + file_ids, + **kwargs, ) - async def _process_media_urls(self, urls: List[str]) -> List[str]: + async def _process_media_urls(self, urls: list[str]) -> list[str]: """处理远程媒体文件URL列表,返回文件ID列表""" file_ids = [] for url in urls: @@ -863,7 +888,7 @@ class MisskeyAPI: continue return file_ids - async def _process_local_files(self, file_paths: List[str]) -> List[str]: + async def _process_local_files(self, file_paths: list[str]) -> list[str]: """处理本地文件路径列表,返回文件ID列表""" file_ids = [] for file_path in file_paths: @@ -883,10 +908,10 @@ class MisskeyAPI: self, message_type: str, target_id: str, - text: Optional[str], - file_ids: List[str], + text: str | None, + file_ids: list[str], **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """根据消息类型分发到对应的发送方法""" if message_type == "chat": # 聊天消息使用 fileId (单数) @@ -907,7 +932,7 @@ class MisskeyAPI: return {"multiple": True, "results": results} return await self.send_message(payload) - elif message_type == "room": + if message_type == "room": # 房间消息使用 fileId (单数) payload = {"toRoomId": target_id} if text: @@ -926,7 +951,7 @@ class MisskeyAPI: return {"multiple": True, "results": results} return await self.send_room_message(payload) - elif message_type == "note": + if message_type == "note": # 发帖使用 fileIds (复数) note_kwargs = { "text": text, @@ -936,5 +961,4 @@ class MisskeyAPI: note_kwargs.update(kwargs) return await self.create_note(**note_kwargs) - else: - raise APIError(f"不支持的消息类型: {message_type}") + raise APIError(f"不支持的消息类型: {message_type}") diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index cd737f78e..7975f0ec7 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -1,19 +1,20 @@ import asyncio import re -from typing import AsyncGenerator +from collections.abc import AsyncGenerator + from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import PlatformMetadata, AstrBotMessage from astrbot.api.message_components import Plain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata from .misskey_utils import ( - serialize_message_chain, - resolve_visibility_from_raw_message, - is_valid_user_session_id, - is_valid_room_session_id, add_at_mention_if_needed, - extract_user_id_from_session_id, extract_room_id_from_session_id, + extract_user_id_from_session_id, + is_valid_room_session_id, + is_valid_user_session_id, + resolve_visibility_from_raw_message, + serialize_message_chain, ) @@ -43,7 +44,7 @@ class MisskeyPlatformEvent(AstrMessageEvent): """发送消息,使用适配器的完整上传和发送逻辑""" try: logger.debug( - f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件" + f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件", ) # 使用适配器的 send_by_session 方法,它包含文件上传逻辑 @@ -65,7 +66,7 @@ class MisskeyPlatformEvent(AstrMessageEvent): ) logger.debug( - f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}" + f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}", ) # 调用适配器的 send_by_session 方法 @@ -88,25 +89,27 @@ class MisskeyPlatformEvent(AstrMessageEvent): user_info = { "username": user_data.get("username", ""), "nickname": user_data.get( - "name", user_data.get("username", "") + "name", + user_data.get("username", ""), ), } content = add_at_mention_if_needed(content, user_info, has_at) # 根据会话类型选择发送方式 if hasattr(self.client, "send_message") and is_valid_user_session_id( - self.session_id + self.session_id, ): user_id = extract_user_id_from_session_id(self.session_id) await self.client.send_message(user_id, content) elif hasattr( - self.client, "send_room_message" + self.client, + "send_room_message", ) and is_valid_room_session_id(self.session_id): room_id = extract_room_id_from_session_id(self.session_id) await self.client.send_room_message(room_id, content) elif original_message_id and hasattr(self.client, "create_note"): visibility, visible_user_ids = resolve_visibility_from_raw_message( - raw_message + raw_message, ) await self.client.create_note( content, @@ -124,7 +127,9 @@ class MisskeyPlatformEvent(AstrMessageEvent): logger.error(f"[MisskeyEvent] 发送失败: {e}") async def send_streaming( - self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False + self, + generator: AsyncGenerator[MessageChain, None], + use_fallback: bool = False, ): if not use_fallback: buffer = None @@ -134,7 +139,7 @@ class MisskeyPlatformEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index ebc95d8d7..290acd64e 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -1,6 +1,7 @@ """Misskey 平台适配器通用工具函数""" -from typing import Dict, Any, List, Tuple, Optional, Union +from typing import Any + import astrbot.api.message_components as Comp from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType @@ -9,7 +10,7 @@ class FileIDExtractor: """从 API 响应中提取文件 ID 的帮助类(无状态)。""" @staticmethod - def extract_file_id(result: Any) -> Optional[str]: + def extract_file_id(result: Any) -> str | None: if not isinstance(result, dict): return None @@ -34,8 +35,10 @@ class MessagePayloadBuilder: @staticmethod def build_chat_payload( - user_id: str, text: Optional[str], file_id: Optional[str] = None - ) -> Dict[str, Any]: + user_id: str, + text: str | None, + file_id: str | None = None, + ) -> dict[str, Any]: payload = {"toUserId": user_id} if text: payload["text"] = text @@ -45,8 +48,10 @@ class MessagePayloadBuilder: @staticmethod def build_room_payload( - room_id: str, text: Optional[str], file_id: Optional[str] = None - ) -> Dict[str, Any]: + room_id: str, + text: str | None, + file_id: str | None = None, + ) -> dict[str, Any]: payload = {"toRoomId": room_id} if text: payload["text"] = text @@ -56,9 +61,11 @@ class MessagePayloadBuilder: @staticmethod def build_note_payload( - text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs - ) -> Dict[str, Any]: - payload: Dict[str, Any] = {} + text: str | None, + file_ids: list[str] | None = None, + **kwargs, + ) -> dict[str, Any]: + payload: dict[str, Any] = {} if text: payload["text"] = text if file_ids: @@ -67,7 +74,7 @@ class MessagePayloadBuilder: return payload -def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]: +def serialize_message_chain(chain: list[Any]) -> tuple[str, bool]: """将消息链序列化为文本字符串""" text_parts = [] has_at = False @@ -76,27 +83,25 @@ def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]: nonlocal has_at if isinstance(component, Comp.Plain): return component.text - elif isinstance(component, Comp.File): + if isinstance(component, Comp.File): # 为文件组件返回占位符,但适配器仍会处理原组件 return "[文件]" - elif isinstance(component, Comp.Image): + if isinstance(component, Comp.Image): # 为图片组件返回占位符,但适配器仍会处理原组件 return "[图片]" - elif isinstance(component, Comp.At): + if isinstance(component, Comp.At): has_at = True # 优先使用name字段(用户名),如果没有则使用qq字段 # 这样可以避免在Misskey中生成 @ 这样的无效提及 if hasattr(component, "name") and component.name: return f"@{component.name}" - else: - return f"@{component.qq}" - elif hasattr(component, "text"): + return f"@{component.qq}" + if hasattr(component, "text"): text = getattr(component, "text", "") if "@" in text: has_at = True return text - else: - return str(component) + return str(component) for component in chain: if isinstance(component, Comp.Node) and component.content: @@ -113,12 +118,12 @@ def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]: def resolve_message_visibility( - user_id: Optional[str] = None, - user_cache: Optional[Dict[str, Any]] = None, - self_id: Optional[str] = None, - raw_message: Optional[Dict[str, Any]] = None, + user_id: str | None = None, + user_cache: dict[str, Any] | None = None, + self_id: str | None = None, + raw_message: dict[str, Any] | None = None, default_visibility: str = "public", -) -> Tuple[str, Optional[List[str]]]: +) -> tuple[str, list[str] | None]: """解析 Misskey 消息的可见性设置 可以从 user_cache 或 raw_message 中解析,支持两种调用方式: @@ -169,13 +174,14 @@ def resolve_message_visibility( # 保留旧函数名作为向后兼容的别名 def resolve_visibility_from_raw_message( - raw_message: Dict[str, Any], self_id: Optional[str] = None -) -> Tuple[str, Optional[List[str]]]: + raw_message: dict[str, Any], + self_id: str | None = None, +) -> tuple[str, list[str] | None]: """从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)""" return resolve_message_visibility(raw_message=raw_message, self_id=self_id) -def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_user_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -189,7 +195,7 @@ def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: ) -def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_room_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的房间 session_id (仅限room%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -203,7 +209,7 @@ def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: ) -def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_chat_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -236,7 +242,9 @@ def extract_room_id_from_session_id(session_id: str) -> str: def add_at_mention_if_needed( - text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False + text: str, + user_info: dict[str, Any] | None, + has_at: bool = False, ) -> str: """如果需要且没有@用户,则添加@用户 @@ -258,7 +266,7 @@ def add_at_mention_if_needed( return text -def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]: +def create_file_component(file_info: dict[str, Any]) -> tuple[Any, str]: """创建文件组件和描述文本""" file_url = file_info.get("url", "") file_name = file_info.get("name", "未知文件") @@ -266,16 +274,17 @@ def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]: if file_type.startswith("image/"): return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]" - elif file_type.startswith("audio/"): + if file_type.startswith("audio/"): return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]" - elif file_type.startswith("video/"): + if file_type.startswith("video/"): return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]" - else: - return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]" + return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]" def process_files( - message: AstrBotMessage, files: list, include_text_parts: bool = True + message: AstrBotMessage, + files: list, + include_text_parts: bool = True, ) -> list: """处理文件列表,添加到消息组件中并返回文本描述""" file_parts = [] @@ -287,7 +296,7 @@ def process_files( return file_parts -def format_poll(poll: Dict[str, Any]) -> str: +def format_poll(poll: dict[str, Any]) -> str: """将 Misskey 的 poll 对象格式化为可读字符串。""" if not poll or not isinstance(poll, dict): return "" @@ -304,8 +313,9 @@ def format_poll(poll: Dict[str, Any]) -> str: def extract_sender_info( - raw_data: Dict[str, Any], is_chat: bool = False -) -> Dict[str, Any]: + raw_data: dict[str, Any], + is_chat: bool = False, +) -> dict[str, Any]: """提取发送者信息""" if is_chat: sender = raw_data.get("fromUser", {}) @@ -323,11 +333,11 @@ def extract_sender_info( def create_base_message( - raw_data: Dict[str, Any], - sender_info: Dict[str, Any], + raw_data: dict[str, Any], + sender_info: dict[str, Any], client_self_id: str, is_chat: bool = False, - room_id: Optional[str] = None, + room_id: str | None = None, unique_session: bool = False, ) -> AstrBotMessage: """创建基础消息对象""" @@ -366,8 +376,11 @@ def create_base_message( def process_at_mention( - message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str -) -> Tuple[List[str], str]: + message: AstrBotMessage, + raw_text: str, + bot_username: str, + client_self_id: str, +) -> tuple[list[str], str]: """处理@提及逻辑,返回消息部分列表和处理后的文本""" message_parts = [] @@ -382,16 +395,15 @@ def process_at_mention( message.message.append(Comp.Plain(remaining_text)) message_parts.append(remaining_text) return message_parts, remaining_text - else: - message.message.append(Comp.Plain(raw_text)) - message_parts.append(raw_text) - return message_parts, raw_text + message.message.append(Comp.Plain(raw_text)) + message_parts.append(raw_text) + return message_parts, raw_text def cache_user_info( - user_cache: Dict[str, Any], - sender_info: Dict[str, Any], - raw_data: Dict[str, Any], + user_cache: dict[str, Any], + sender_info: dict[str, Any], + raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, ): @@ -417,7 +429,9 @@ def cache_user_info( def cache_room_info( - user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str + user_cache: dict[str, Any], + raw_data: dict[str, Any], + client_self_id: str, ): """缓存房间信息""" room_data = raw_data.get("toRoom") @@ -437,7 +451,7 @@ def cache_room_info( async def resolve_component_url_or_path( comp: Any, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: """尝试从组件解析可上传的远程 URL 或本地路径。 返回 (url_candidate, local_path)。两者可能都为 None。 @@ -468,8 +482,7 @@ async def resolve_component_url_or_path( if value.startswith("http"): url_candidate = value break - else: - local_path = value + local_path = value except Exception: continue @@ -491,9 +504,8 @@ async def resolve_component_url_or_path( if value.startswith("http"): url_candidate = value break - else: - local_path = value - break + local_path = value + break except Exception: continue @@ -503,7 +515,7 @@ async def resolve_component_url_or_path( return url_candidate, local_path -def summarize_component_for_log(comp: Any) -> Dict[str, Any]: +def summarize_component_for_log(comp: Any) -> dict[str, Any]: """生成适合日志的组件属性字典(尽量不抛异常)。""" attrs = {} for a in ("file", "url", "path", "src", "source", "name"): @@ -519,15 +531,15 @@ def summarize_component_for_log(comp: Any) -> Dict[str, Any]: async def upload_local_with_retries( api: Any, local_path: str, - preferred_name: Optional[str], - folder_id: Optional[str], -) -> Optional[str]: + preferred_name: str | None, + folder_id: str | None, +) -> str | None: """尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。""" try: res = await api.upload_file(local_path, preferred_name, folder_id) if isinstance(res, dict): fid = res.get("id") or (res.get("raw") or {}).get("createdFile", {}).get( - "id" + "id", ) if fid: return str(fid) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2096237ce..f3c2ef0e5 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -1,25 +1,26 @@ +import asyncio +import base64 +import os +import random +import uuid + +import aiofiles import botpy import botpy.message import botpy.types import botpy.types.message -import asyncio -import base64 -import aiofiles -from astrbot.core.utils.io import file_to_base64, download_image_by_url -from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image, Record from botpy import Client from botpy.http import Route -from astrbot.api import logger -from botpy.types.message import Media from botpy.types import message -from typing import Optional -import random -import uuid -import os +from botpy.types.message import Media + +from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import Image, Plain, Record +from astrbot.api.platform import AstrBotMessage, PlatformMetadata +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_image_by_url, file_to_base64 +from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk class QQOfficialMessageEvent(AstrMessageEvent): @@ -77,7 +78,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): async def _post_send(self, stream: dict = None): if not self.send_buffer: - return + return None source = self.message_obj.raw_message assert isinstance( @@ -103,7 +104,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): and not image_path and not record_file_path ): - return + return None payload = { "content": plain_text, @@ -119,29 +120,38 @@ class QQOfficialMessageEvent(AstrMessageEvent): case botpy.message.GroupMessage: if image_base64: media = await self.upload_group_and_c2c_image( - image_base64, 1, group_openid=source.group_openid + image_base64, + 1, + group_openid=source.group_openid, ) payload["media"] = media payload["msg_type"] = 7 if record_file_path: # group record msg media = await self.upload_group_and_c2c_record( - record_file_path, 3, group_openid=source.group_openid + record_file_path, + 3, + group_openid=source.group_openid, ) payload["media"] = media payload["msg_type"] = 7 ret = await self.bot.api.post_group_message( - group_openid=source.group_openid, **payload + group_openid=source.group_openid, + **payload, ) case botpy.message.C2CMessage: if image_base64: media = await self.upload_group_and_c2c_image( - image_base64, 1, openid=source.author.user_openid + image_base64, + 1, + openid=source.author.user_openid, ) payload["media"] = media payload["msg_type"] = 7 if record_file_path: # c2c record media = await self.upload_group_and_c2c_record( - record_file_path, 3, openid=source.author.user_openid + record_file_path, + 3, + openid=source.author.user_openid, ) payload["media"] = media payload["msg_type"] = 7 @@ -153,14 +163,16 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) else: ret = await self.post_c2c_message( - openid=source.author.user_openid, **payload + openid=source.author.user_openid, + **payload, ) logger.debug(f"Message sent to C2C: {ret}") case botpy.message.Message: if image_path: payload["file_image"] = image_path ret = await self.bot.api.post_message( - channel_id=source.channel_id, **payload + channel_id=source.channel_id, + **payload, ) case botpy.message.DirectMessage: if image_path: @@ -174,7 +186,10 @@ class QQOfficialMessageEvent(AstrMessageEvent): return ret async def upload_group_and_c2c_image( - self, image_base64: str, file_type: int, **kwargs + self, + image_base64: str, + file_type: int, + **kwargs, ) -> botpy.types.message.Media: payload = { "file_data": image_base64, @@ -185,7 +200,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): payload["openid"] = kwargs["openid"] route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) return await self.bot.api._http.request(route, json=payload) - elif "group_openid" in kwargs: + if "group_openid" in kwargs: payload["group_openid"] = kwargs["group_openid"] route = Route( "POST", @@ -195,11 +210,13 @@ class QQOfficialMessageEvent(AstrMessageEvent): return await self.bot.api._http.request(route, json=payload) async def upload_group_and_c2c_record( - self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs - ) -> Optional[Media]: - """ - 上传媒体文件 - """ + self, + file_source: str, + file_type: int, + srv_send_msg: bool = False, + **kwargs, + ) -> Media | None: + """上传媒体文件""" # 构建基础payload payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} @@ -291,11 +308,13 @@ class QQOfficialMessageEvent(AstrMessageEvent): record_wav_path = await i.convert_to_file_path() # wav 路径 temp_dir = os.path.join(get_astrbot_data_path(), "temp") record_tecent_silk_path = os.path.join( - temp_dir, f"{uuid.uuid4()}.silk" + temp_dir, + f"{uuid.uuid4()}.silk", ) try: duration = await wav_to_tencent_silk( - record_wav_path, record_tecent_silk_path + record_wav_path, + record_tecent_silk_path, ) if duration > 0: record_file_path = record_tecent_silk_path diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index d5285f759..96be734fd 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -1,30 +1,31 @@ from __future__ import annotations -import botpy -import logging -import time import asyncio +import logging +import os +import time + +import botpy import botpy.message import botpy.types import botpy.types.message -import os - from botpy import Client + +from astrbot import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import At, Image, Plain from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot import logger -from astrbot.api.event import MessageChain -from typing import Union, List -from astrbot.api.message_components import Image, Plain, At -from astrbot.core.platform.astr_message_event import MessageSesion -from .qqofficial_message_event import QQOfficialMessageEvent -from ...register import register_platform_adapter from astrbot.core.message.components import BaseMessageComponent +from astrbot.core.platform.astr_message_event import MessageSesion + +from ...register import register_platform_adapter +from .qqofficial_message_event import QQOfficialMessageEvent # remove logger handler for handler in logging.root.handlers[:]: @@ -33,13 +34,14 @@ for handler in logging.root.handlers[:]: # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialPlatformAdapter"): + def set_platform(self, platform: QQOfficialPlatformAdapter): self.platform = platform # 收到群消息 async def on_group_at_message_create(self, message: botpy.message.GroupMessage): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.GROUP_MESSAGE + message, + MessageType.GROUP_MESSAGE, ) abm.session_id = ( abm.sender.user_id if self.platform.unique_session else message.group_openid @@ -49,7 +51,8 @@ class botClient(Client): # 收到频道消息 async def on_at_message_create(self, message: botpy.message.Message): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.GROUP_MESSAGE + message, + MessageType.GROUP_MESSAGE, ) abm.session_id = ( abm.sender.user_id if self.platform.unique_session else message.channel_id @@ -59,7 +62,8 @@ class botClient(Client): # 收到私聊消息 async def on_direct_message_create(self, message: botpy.message.DirectMessage): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.FRIEND_MESSAGE + message, + MessageType.FRIEND_MESSAGE, ) abm.session_id = abm.sender.user_id self._commit(abm) @@ -67,7 +71,8 @@ class botClient(Client): # 收到 C2C 消息 async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.FRIEND_MESSAGE + message, + MessageType.FRIEND_MESSAGE, ) abm.session_id = abm.sender.user_id self._commit(abm) @@ -80,14 +85,17 @@ class botClient(Client): self.platform.meta(), abm.session_id, self.platform.client, - ) + ), ) @register_platform_adapter("qq_official", "QQ 机器人官方 API 适配器") class QQOfficialPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -107,7 +115,8 @@ class QQOfficialPlatformAdapter(Platform): ) else: self.intents = botpy.Intents( - public_guild_messages=True, direct_message=guild_dm + public_guild_messages=True, + direct_message=guild_dm, ) self.client = botClient( intents=self.intents, @@ -120,7 +129,9 @@ class QQOfficialPlatformAdapter(Platform): self.test_mode = os.environ.get("TEST_MODE", "off") == "on" async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") @@ -133,7 +144,7 @@ class QQOfficialPlatformAdapter(Platform): @staticmethod def _parse_from_qqofficial( - message: Union[botpy.message.Message, botpy.message.GroupMessage], + message: botpy.message.Message | botpy.message.GroupMessage, message_type: MessageType, ): abm = AstrBotMessage() @@ -142,10 +153,11 @@ class QQOfficialPlatformAdapter(Platform): abm.raw_message = message abm.message_id = message.id abm.tag = "qq_official" - msg: List[BaseMessageComponent] = [] + msg: list[BaseMessageComponent] = [] if isinstance(message, botpy.message.GroupMessage) or isinstance( - message, botpy.message.C2CMessage + message, + botpy.message.C2CMessage, ): if isinstance(message, botpy.message.GroupMessage): abm.sender = MessageMember(message.author.member_openid, "") @@ -167,7 +179,8 @@ class QQOfficialPlatformAdapter(Platform): abm.message = msg elif isinstance(message, botpy.message.Message) or isinstance( - message, botpy.message.DirectMessage + message, + botpy.message.DirectMessage, ): try: abm.self_id = str(message.mentions[0].id) @@ -175,7 +188,8 @@ class QQOfficialPlatformAdapter(Platform): abm.self_id = "" plain_content = message.content.replace( - "<@!" + str(abm.self_id) + ">", "" + "<@!" + str(abm.self_id) + ">", + "", ).strip() if message.attachments: @@ -189,7 +203,8 @@ class QQOfficialPlatformAdapter(Platform): abm.message = msg abm.message_str = plain_content abm.sender = MessageMember( - str(message.author.id), str(message.author.username) + str(message.author.id), + str(message.author.username), ) msg.append(At(qq="qq_official")) msg.append(Plain(plain_content)) diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index cc12e9765..2b8c0b420 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,19 +1,21 @@ -import botpy -import logging import asyncio +import logging + +import botpy import botpy.message import botpy.types import botpy.types.message - from botpy import Client -from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata -from astrbot.api.event import MessageChain -from astrbot.core.platform.astr_message_event import MessageSesion -from .qo_webhook_event import QQOfficialWebhookMessageEvent -from ...register import register_platform_adapter -from .qo_webhook_server import QQOfficialWebhook -from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter + from astrbot import logger +from astrbot.api.event import MessageChain +from astrbot.api.platform import AstrBotMessage, MessageType, Platform, PlatformMetadata +from astrbot.core.platform.astr_message_event import MessageSesion + +from ...register import register_platform_adapter +from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter +from .qo_webhook_event import QQOfficialWebhookMessageEvent +from .qo_webhook_server import QQOfficialWebhook # remove logger handler for handler in logging.root.handlers[:]: @@ -28,7 +30,8 @@ class botClient(Client): # 收到群消息 async def on_group_at_message_create(self, message: botpy.message.GroupMessage): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.GROUP_MESSAGE + message, + MessageType.GROUP_MESSAGE, ) abm.session_id = ( abm.sender.user_id if self.platform.unique_session else message.group_openid @@ -38,7 +41,8 @@ class botClient(Client): # 收到频道消息 async def on_at_message_create(self, message: botpy.message.Message): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.GROUP_MESSAGE + message, + MessageType.GROUP_MESSAGE, ) abm.session_id = ( abm.sender.user_id if self.platform.unique_session else message.channel_id @@ -48,7 +52,8 @@ class botClient(Client): # 收到私聊消息 async def on_direct_message_create(self, message: botpy.message.DirectMessage): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.FRIEND_MESSAGE + message, + MessageType.FRIEND_MESSAGE, ) abm.session_id = abm.sender.user_id self._commit(abm) @@ -56,7 +61,8 @@ class botClient(Client): # 收到 C2C 消息 async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm = QQOfficialPlatformAdapter._parse_from_qqofficial( - message, MessageType.FRIEND_MESSAGE + message, + MessageType.FRIEND_MESSAGE, ) abm.session_id = abm.sender.user_id self._commit(abm) @@ -64,15 +70,22 @@ class botClient(Client): def _commit(self, abm: AstrBotMessage): self.platform.commit_event( QQOfficialWebhookMessageEvent( - abm.message_str, abm, self.platform.meta(), abm.session_id, self - ) + abm.message_str, + abm, + self.platform.meta(), + abm.session_id, + self, + ), ) @register_platform_adapter("qq_official_webhook", "QQ 机器人官方 API 适配器(Webhook)") class QQOfficialWebhookPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -83,7 +96,9 @@ class QQOfficialWebhookPlatformAdapter(Platform): self.unique_session = platform_settings["unique_session"] intents = botpy.Intents( - public_messages=True, public_guild_messages=True, direct_message=True + public_messages=True, + public_guild_messages=True, + direct_message=True, ) self.client = botClient( intents=intents, # 已经无用 @@ -93,7 +108,9 @@ class QQOfficialWebhookPlatformAdapter(Platform): self.client.set_platform(self) async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") @@ -106,7 +123,9 @@ class QQOfficialWebhookPlatformAdapter(Platform): async def run(self): self.webhook_helper = QQOfficialWebhook( - self.config, self._event_queue, self.client + self.config, + self._event_queue, + self.client, ) await self.webhook_helper.initialize() await self.webhook_helper.start_polling() diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py index 4c0bf8329..306db5e56 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py @@ -1,5 +1,7 @@ -from astrbot.api.platform import AstrBotMessage, PlatformMetadata from botpy import Client + +from astrbot.api.platform import AstrBotMessage, PlatformMetadata + from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 4a2eae747..65b7c701a 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -1,10 +1,12 @@ -import quart -import logging import asyncio -from botpy import BotAPI, BotHttp, Client, Token, BotWebSocket, ConnectionSession -from astrbot.api import logger +import logging + +import quart +from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token from cryptography.hazmat.primitives.asymmetric import ed25519 +from astrbot.api import logger + # remove logger handler for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -27,7 +29,9 @@ class QQOfficialWebhook: self.server = quart.Quart(__name__) self.server.add_url_rule( - "/astrbot-qo-webhook/callback", view_func=self.callback, methods=["POST"] + "/astrbot-qo-webhook/callback", + view_func=self.callback, + methods=["POST"], ) self.client = botpy_client self.event_queue = event_queue @@ -62,7 +66,8 @@ class QQOfficialWebhook: seed = await self.repeat_seed(self.secret) private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed) msg = validation_payload.get("event_ts", "") + validation_payload.get( - "plain_token", "" + "plain_token", + "", ) # sign signature = private_key.sign(msg.encode()).hex() @@ -99,7 +104,7 @@ class QQOfficialWebhook: async def start_polling(self): logger.info( - f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。" + f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。", ) await self.server.run_task( host=self.callback_server_host, diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index a3f4f53ec..b5751ebd2 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -1,13 +1,22 @@ import asyncio import json import time +from xml.etree import ElementTree as ET + import websockets -from websockets.asyncio.client import connect -from typing import Optional from aiohttp import ClientSession, ClientTimeout -from websockets.asyncio.client import ClientConnection +from websockets.asyncio.client import ClientConnection, connect + from astrbot.api import logger from astrbot.api.event import MessageChain +from astrbot.api.message_components import ( + At, + File, + Image, + Plain, + Record, + Reply, +) from astrbot.api.platform import ( AstrBotMessage, MessageMember, @@ -17,15 +26,6 @@ from astrbot.api.platform import ( register_platform_adapter, ) from astrbot.core.platform.astr_message_event import MessageSession -from astrbot.api.message_components import ( - Plain, - Image, - At, - File, - Record, - Reply, -) -from xml.etree import ElementTree as ET @register_platform_adapter( @@ -34,18 +34,23 @@ from xml.etree import ElementTree as ET ) class SatoriPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self.config = platform_config self.settings = platform_settings self.api_base_url = self.config.get( - "satori_api_base_url", "http://localhost:5140/satori/v1" + "satori_api_base_url", + "http://localhost:5140/satori/v1", ) self.token = self.config.get("satori_token", "") self.endpoint = self.config.get( - "satori_endpoint", "ws://localhost:5140/satori/v1/events" + "satori_endpoint", + "ws://localhost:5140/satori/v1/events", ) self.auto_reconnect = self.config.get("satori_auto_reconnect", True) self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10) @@ -57,21 +62,25 @@ class SatoriPlatformAdapter(Platform): id=self.config["id"], ) - self.ws: Optional[ClientConnection] = None - self.session: Optional[ClientSession] = None + self.ws: ClientConnection | None = None + self.session: ClientSession | None = None self.sequence = 0 self.logins = [] self.running = False - self.heartbeat_task: Optional[asyncio.Task] = None + self.heartbeat_task: asyncio.Task | None = None self.ready_received = False async def send_by_session( - self, session: MessageSession, message_chain: MessageChain + self, + session: MessageSession, + message_chain: MessageChain, ): from .satori_event import SatoriPlatformEvent await SatoriPlatformEvent.send_with_adapter( - self, message_chain, session.session_id + self, + message_chain, + session.session_id, ) await super().send_by_session(session, message_chain) @@ -85,10 +94,9 @@ class SatoriPlatformAdapter(Platform): try: if hasattr(ws, "closed"): return ws.closed - elif hasattr(ws, "close_code"): + if hasattr(ws, "close_code"): return ws.close_code is not None - else: - return False + return False except AttributeError: return False @@ -240,7 +248,7 @@ class SatoriPlatformAdapter(Platform): user_id = user.get("id", "") user_name = user.get("name", "") logger.info( - f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}" + f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}", ) if "sn" in body: @@ -282,7 +290,12 @@ class SatoriPlatformAdapter(Platform): return abm = await self.convert_satori_message( - message, user, channel, guild, login, timestamp + message, + user, + channel, + guild, + login, + timestamp, ) if abm: await self.handle_msg(abm) @@ -295,10 +308,10 @@ class SatoriPlatformAdapter(Platform): message: dict, user: dict, channel: dict, - guild: Optional[dict], + guild: dict | None, login: dict, - timestamp: Optional[int] = None, - ) -> Optional[AstrBotMessage]: + timestamp: int | None = None, + ) -> AstrBotMessage | None: try: abm = AstrBotMessage() abm.message_id = message.get("id", "") @@ -438,7 +451,7 @@ class SatoriPlatformAdapter(Platform): return prefixes - async def _extract_quote_element(self, content: str) -> Optional[dict]: + async def _extract_quote_element(self, content: str) -> dict | None: """提取标签信息""" try: # 处理命名空间前缀问题 @@ -451,7 +464,7 @@ class SatoriPlatformAdapter(Platform): [ f'xmlns:{prefix}="http://temp.uri/{prefix}"' for prefix in prefixes - ] + ], ) # 包装内容 @@ -483,14 +496,17 @@ class SatoriPlatformAdapter(Platform): inner_content += quote_element.text for child in quote_element: inner_content += ET.tostring( - child, encoding="unicode", method="xml" + child, + encoding="unicode", + method="xml", ) if child.tail: inner_content += child.tail # 构造移除了标签的内容 content_without_quote = content.replace( - ET.tostring(quote_element, encoding="unicode", method="xml"), "" + ET.tostring(quote_element, encoding="unicode", method="xml"), + "", ) return { @@ -506,7 +522,7 @@ class SatoriPlatformAdapter(Platform): logger.error(f"提取标签时发生错误: {e}") return None - async def _extract_quote_with_regex(self, content: str) -> Optional[dict]: + async def _extract_quote_with_regex(self, content: str) -> dict | None: """使用正则表达式提取quote标签信息""" import re @@ -529,7 +545,7 @@ class SatoriPlatformAdapter(Platform): "content_without_quote": content_without_quote, } - async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]: + async def _convert_quote_message(self, quote: dict) -> AstrBotMessage | None: """转换引用消息""" try: quote_abm = AstrBotMessage() @@ -587,7 +603,7 @@ class SatoriPlatformAdapter(Platform): [ f'xmlns:{prefix}="http://temp.uri/{prefix}"' for prefix in prefixes - ] + ], ) # 包装内容 @@ -747,13 +763,15 @@ class SatoriPlatformAdapter(Platform): try: async with self.session.request( - method, url, json=data, headers=headers + method, + url, + json=data, + headers=headers, ) as response: if response.status == 200: result = await response.json() return result - else: - return {} + return {} except Exception as e: logger.error(f"Satori HTTP 请求异常: {e}") return {} diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 78325c9a8..81a0d222c 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -1,19 +1,20 @@ from typing import TYPE_CHECKING + from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import ( - Plain, - Image, At, File, - Record, - Video, - Reply, Forward, + Image, Node, Nodes, + Plain, + Record, + Reply, + Video, ) +from astrbot.api.platform import AstrBotMessage, PlatformMetadata if TYPE_CHECKING: from .satori_adapter import SatoriPlatformAdapter @@ -53,14 +54,17 @@ class SatoriPlatformEvent(AstrMessageEvent): @classmethod async def send_with_adapter( - cls, adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str + cls, + adapter: "SatoriPlatformAdapter", + message: MessageChain, + session_id: str, ): try: content_parts = [] for component in message.chain: component_content = await cls._convert_component_to_satori_static( - component + component, ) if component_content: content_parts.append(component_content) @@ -92,12 +96,15 @@ class SatoriPlatformEvent(AstrMessageEvent): user_id = user.get("id", "") if user else "" result = await adapter.send_http_request( - "POST", "/message.create", data, platform, user_id + "POST", + "/message.create", + data, + platform, + user_id, ) if result: return result - else: - return None + return None except Exception as e: logger.error(f"Satori 消息发送异常: {e}") @@ -140,7 +147,11 @@ class SatoriPlatformEvent(AstrMessageEvent): data = {"channel_id": channel_id, "content": content} result = await self.adapter.send_http_request( - "POST", "/message.create", data, platform, user_id + "POST", + "/message.create", + data, + platform, + user_id, ) if not result: logger.error("Satori 消息发送失败") @@ -178,9 +189,9 @@ class SatoriPlatformEvent(AstrMessageEvent): img_chain = MessageChain( [ Plain( - text=f'' - ) - ] + text=f'', + ), + ], ) await self.send(img_chain) except Exception as e: @@ -209,10 +220,10 @@ class SatoriPlatformEvent(AstrMessageEvent): ) return text - elif isinstance(component, At): + if isinstance(component, At): if component.qq: return f'' - elif component.name: + if component.name: return f'' elif isinstance(component, Image): @@ -264,7 +275,7 @@ class SatoriPlatformEvent(AstrMessageEvent): if node.content: for content_component in node.content: component_content = await self._convert_component_to_satori( - content_component + content_component, ) if component_content: content_parts.append(component_content) @@ -302,10 +313,10 @@ class SatoriPlatformEvent(AstrMessageEvent): ) return text - elif isinstance(component, At): + if isinstance(component, At): if component.qq: return f'' - elif component.name: + if component.name: return f'' elif isinstance(component, Image): @@ -358,7 +369,7 @@ class SatoriPlatformEvent(AstrMessageEvent): if node.content: for content_component in node.content: component_content = await cls._convert_component_to_satori_static( - content_component + content_component, ) if component_content: content_parts.append(component_content) @@ -395,8 +406,7 @@ class SatoriPlatformEvent(AstrMessageEvent): if node_parts: return f"{''.join(node_parts)}" - else: - return "" + return "" except Exception as e: logger.error(f"转换合并转发消息失败: {e}") @@ -415,8 +425,7 @@ class SatoriPlatformEvent(AstrMessageEvent): if node_parts: return f"{''.join(node_parts)}" - else: - return "" + return "" except Exception as e: logger.error(f"转换合并转发消息失败: {e}") diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index 7877e4f52..0411f73a4 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -1,14 +1,16 @@ -import json -import hmac -import hashlib import asyncio +import hashlib +import hmac +import json import logging -from typing import Callable, Optional -from quart import Quart, request, Response -from slack_sdk.web.async_client import AsyncWebClient +from collections.abc import Callable + +from quart import Quart, Response, request from slack_sdk.socket_mode.aiohttp import SocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse +from slack_sdk.web.async_client import AsyncWebClient + from astrbot.api import logger @@ -22,7 +24,7 @@ class SlackWebhookClient: host: str = "0.0.0.0", port: int = 3000, path: str = "/slack/events", - event_handler: Optional[Callable] = None, + event_handler: Callable | None = None, ): self.web_client = web_client self.signing_secret = signing_secret @@ -93,7 +95,7 @@ class SlackWebhookClient: async def start(self): """启动 Webhook 服务器""" logger.info( - f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}..." + f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}...", ) await self.app.run_task( @@ -119,7 +121,7 @@ class SlackSocketClient: self, web_client: AsyncWebClient, app_token: str, - event_handler: Optional[Callable] = None, + event_handler: Callable | None = None, ): self.web_client = web_client self.app_token = app_token diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 7e75f3c20..9f21656ed 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -1,34 +1,42 @@ -import time import asyncio -import uuid -import aiohttp -import re import base64 -from typing import Awaitable, Any -from slack_sdk.web.async_client import AsyncWebClient +import re +import time +import uuid +from collections.abc import Awaitable +from typing import Any + +import aiohttp from slack_sdk.socket_mode.request import SocketModeRequest +from slack_sdk.web.async_client import AsyncWebClient + +from astrbot.api import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import * from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot.api.event import MessageChain -from .slack_event import SlackMessageEvent -from .client import SlackWebhookClient, SlackSocketClient -from astrbot.api.message_components import * # noqa: F403 -from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion + from ...register import register_platform_adapter +from .client import SlackSocketClient, SlackWebhookClient +from .slack_event import SlackMessageEvent @register_platform_adapter( - "slack", "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。" + "slack", + "适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。", ) class SlackAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -43,7 +51,8 @@ class SlackAdapter(Platform): self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0") self.webhook_port = platform_config.get("slack_webhook_port", 3000) self.webhook_path = platform_config.get( - "slack_webhook_path", "/astrbot-slack-webhook/callback" + "slack_webhook_path", + "/astrbot-slack-webhook/callback", ) if not self.bot_token: @@ -69,10 +78,13 @@ class SlackAdapter(Platform): self.bot_self_id = None async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): blocks, text = SlackMessageEvent._parse_slack_blocks( - message_chain=message_chain, web_client=self.web_client + message_chain=message_chain, + web_client=self.web_client, ) try: @@ -150,7 +162,7 @@ class SlackAdapter(Platform): abm.message = [] # 优先使用 blocks 字段解析消息 - if "blocks" in event and event["blocks"]: + if event.get("blocks"): abm.message = self._parse_blocks(event["blocks"]) # 更新 message_str abm.message_str = "" @@ -166,7 +178,8 @@ class SlackAdapter(Platform): mentioned_user = await self.web_client.users_info(user=mention) user_data = mentioned_user["user"] user_name = user_data.get("real_name") or user_data.get( - "name", mention + "name", + mention, ) abm.message.append(At(qq=mention, name=user_name)) except Exception: @@ -189,7 +202,7 @@ class SlackAdapter(Platform): else: # TODO: 下载鉴权 abm.message.append( - File(name=file_name, file=file_url, url=file_url) + File(name=file_name, file=file_url, url=file_url), ) abm.raw_message = event @@ -224,7 +237,7 @@ class SlackAdapter(Platform): # 将之前的文本内容先添加到组件中 if text_content.strip(): message_components.append( - Plain(text=text_content) + Plain(text=text_content), ) text_content = "" # 添加@提及组件 @@ -307,11 +320,10 @@ class SlackAdapter(Platform): content = await resp.read() base64_content = base64.b64encode(content).decode("utf-8") return base64_content - else: - logger.error( - f"Failed to download slack file: {resp.status} {await resp.text()}" - ) - raise Exception(f"下载文件失败: {resp.status}") + logger.error( + f"Failed to download slack file: {resp.status} {await resp.text()}", + ) + raise Exception(f"下载文件失败: {resp.status}") async def run(self) -> Awaitable[Any]: self.bot_self_id = await self.get_bot_user_id() @@ -323,7 +335,9 @@ class SlackAdapter(Platform): # 创建 Socket 客户端 self.socket_client = SlackSocketClient( - self.web_client, self.app_token, self._handle_socket_event + self.web_client, + self.app_token, + self._handle_socket_event, ) logger.info("Slack 适配器 (Socket Mode) 启动中...") @@ -344,13 +358,13 @@ class SlackAdapter(Platform): ) logger.info( - f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}..." + f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...", ) await self.webhook_client.start() else: raise ValueError( - f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'" + f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'", ) async def _handle_webhook_event(self, event_data: dict): diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 86f9f9764..21c1b0fed 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -1,16 +1,18 @@ import asyncio import re -from typing import AsyncGenerator +from collections.abc import AsyncGenerator + from slack_sdk.web.async_client import AsyncWebClient + +from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( + BaseMessageComponent, + File, Image, Plain, - File, - BaseMessageComponent, ) from astrbot.api.platform import Group, MessageMember -from astrbot.api import logger class SlackMessageEvent(AstrMessageEvent): @@ -27,12 +29,13 @@ class SlackMessageEvent(AstrMessageEvent): @staticmethod async def _from_segment_to_slack_block( - segment: BaseMessageComponent, web_client: AsyncWebClient + segment: BaseMessageComponent, + web_client: AsyncWebClient, ) -> dict: """将消息段转换为 Slack 块格式""" if isinstance(segment, Plain): return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}} - elif isinstance(segment, Image): + if isinstance(segment, Image): # upload file url = segment.url or segment.file if url.startswith("http"): @@ -61,7 +64,7 @@ class SlackMessageEvent(AstrMessageEvent): }, "alt_text": "图片", } - elif isinstance(segment, File): + if isinstance(segment, File): # upload file url = segment.url or segment.file response = await web_client.files_upload_v2( @@ -82,12 +85,12 @@ class SlackMessageEvent(AstrMessageEvent): "text": f"文件: <{file_url}|{segment.name or '文件'}>", }, } - else: - return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}} + return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}} @staticmethod async def _parse_slack_blocks( - message_chain: MessageChain, web_client: AsyncWebClient + message_chain: MessageChain, + web_client: AsyncWebClient, ): """解析成 Slack 块格式""" blocks = [] @@ -103,27 +106,29 @@ class SlackMessageEvent(AstrMessageEvent): { "type": "section", "text": {"type": "mrkdwn", "text": text_content}, - } + }, ) text_content = "" # 添加其他类型的块 block = await SlackMessageEvent._from_segment_to_slack_block( - segment, web_client + segment, + web_client, ) blocks.append(block) # 如果最后还有文本内容 if text_content.strip(): blocks.append( - {"type": "section", "text": {"type": "mrkdwn", "text": text_content}} + {"type": "section", "text": {"type": "mrkdwn", "text": text_content}}, ) return blocks, "" if blocks else text_content async def send(self, message: MessageChain): blocks, text = await SlackMessageEvent._parse_slack_blocks( - message, self.web_client + message, + self.web_client, ) try: @@ -154,17 +159,21 @@ class SlackMessageEvent(AstrMessageEvent): if self.get_group_id(): await self.web_client.chat_postMessage( - channel=self.get_group_id(), text=fallback_text + channel=self.get_group_id(), + text=fallback_text, ) else: await self.web_client.chat_postMessage( - channel=self.get_sender_id(), text=fallback_text + channel=self.get_sender_id(), + text=fallback_text, ) await super().send(message) async def send_streaming( - self, generator: AsyncGenerator, use_fallback: bool = False + self, + generator: AsyncGenerator, + use_fallback: bool = False, ): if not use_fallback: buffer = None @@ -174,7 +183,7 @@ class SlackMessageEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) @@ -211,7 +220,7 @@ class SlackMessageEvent(AstrMessageEvent): # 获取频道成员 members_response = await self.web_client.conversations_members( - channel=channel_id + channel=channel_id, ) members = [] @@ -224,7 +233,7 @@ class SlackMessageEvent(AstrMessageEvent): user_id=member_id, nickname=user_data.get("real_name") or user_data.get("name", member_id), - ) + ), ) except Exception: # 如果获取用户信息失败,使用默认信息 diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 68ee6a980..88a9f7dc6 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -37,7 +37,10 @@ else: @register_platform_adapter("telegram", "telegram 适配器") class TelegramPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self.config = platform_config @@ -45,13 +48,15 @@ class TelegramPlatformAdapter(Platform): self.client_self_id = uuid.uuid4().hex[:8] base_url = self.config.get( - "telegram_api_base_url", "https://api.telegram.org/bot" + "telegram_api_base_url", + "https://api.telegram.org/bot", ) if not base_url: base_url = "https://api.telegram.org/bot" file_base_url = self.config.get( - "telegram_file_base_url", "https://api.telegram.org/file/bot" + "telegram_file_base_url", + "https://api.telegram.org/file/bot", ) if not file_base_url: file_base_url = "https://api.telegram.org/file/bot" @@ -59,10 +64,12 @@ class TelegramPlatformAdapter(Platform): self.base_url = base_url self.enable_command_register = self.config.get( - "telegram_command_register", True + "telegram_command_register", + True, ) self.enable_command_refresh = self.config.get( - "telegram_command_auto_refresh", True + "telegram_command_auto_refresh", + True, ) self.last_command_hash = None @@ -85,11 +92,15 @@ class TelegramPlatformAdapter(Platform): @override async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): from_username = session.session_id await TelegramPlatformEvent.send_with_client( - self.client, message_chain, from_username + self.client, + message_chain, + from_username, ) await super().send_by_session(session, message_chain) @@ -131,7 +142,7 @@ class TelegramPlatformAdapter(Platform): if commands: current_hash = hash( - tuple((cmd.command, cmd.description) for cmd in commands) + tuple((cmd.command, cmd.description) for cmd in commands), ) if current_hash == self.last_command_hash: return @@ -153,7 +164,9 @@ class TelegramPlatformAdapter(Platform): continue for event_filter in handler_metadata.event_filters: cmd_info = self._extract_command_info( - event_filter, handler_metadata, skip_commands + event_filter, + handler_metadata, + skip_commands, ) if cmd_info: cmd_name, description = cmd_info @@ -164,7 +177,9 @@ class TelegramPlatformAdapter(Platform): @staticmethod def _extract_command_info( - event_filter, handler_metadata, skip_commands: set + event_filter, + handler_metadata, + skip_commands: set, ) -> tuple[str, str] | None: """从事件过滤器中提取指令信息""" cmd_name = None @@ -199,11 +214,12 @@ class TelegramPlatformAdapter(Platform): async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if not update.effective_chat: logger.warning( - "Received a start command without an effective chat, skipping /start reply." + "Received a start command without an effective chat, skipping /start reply.", ) return await context.bot.send_message( - chat_id=update.effective_chat.id, text=self.config["start_message"] + chat_id=update.effective_chat.id, + text=self.config["start_message"], ) async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE): @@ -213,7 +229,10 @@ class TelegramPlatformAdapter(Platform): await self.handle_msg(abm) async def convert_message( - self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True + self, + update: Update, + context: ContextTypes.DEFAULT_TYPE, + get_reply=True, ) -> AstrBotMessage | None: """转换 Telegram 的消息对象为 AstrBotMessage 对象。 @@ -244,7 +263,8 @@ class TelegramPlatformAdapter(Platform): logger.warning("[Telegram] Received a message without a from_user.") return None message.sender = MessageMember( - str(_from_user.id), _from_user.username or "Unknown" + str(_from_user.id), + _from_user.username or "Unknown", ) message.self_id = str(context.bot.username) message.raw_message = update @@ -274,7 +294,7 @@ class TelegramPlatformAdapter(Platform): message_str=reply_abm.message_str, text=reply_abm.message_str, qq=reply_abm.sender.user_id, - ) + ), ) if update.message.text: @@ -320,7 +340,7 @@ class TelegramPlatformAdapter(Platform): if message.message_str.strip() == "/start": await self.start(update, context) - return + return None elif update.message.voice: file = await update.message.voice.get_file() @@ -358,7 +378,7 @@ class TelegramPlatformAdapter(Platform): file_path = file.file_path if file_path is None: logger.warning( - f"Telegram document file_path is None, cannot save the file {file_name}." + f"Telegram document file_path is None, cannot save the file {file_name}.", ) else: message.message.append(Comp.File(file=file_path, name=file_name)) @@ -369,7 +389,7 @@ class TelegramPlatformAdapter(Platform): file_path = file.file_path if file_path is None: logger.warning( - f"Telegram video file_path is None, cannot save the file {file_name}." + f"Telegram video file_path is None, cannot save the file {file_name}.", ) else: message.message.append(Comp.Video(file=file_path, path=file.file_path)) diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 2da7aafe5..34fd86ad9 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,22 +1,24 @@ +import asyncio import os import re -import asyncio + import telegramify_markdown +from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji +from telegram.ext import ExtBot + +from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType from astrbot.api.message_components import ( - Plain, - Image, - Reply, At, File, + Image, + Plain, Record, + Reply, ) -from telegram.ext import ExtBot -from astrbot.core.utils.io import download_file -from astrbot import logger +from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from telegram import ReactionTypeEmoji, ReactionTypeCustomEmoji +from astrbot.core.utils.io import download_file class TelegramPlatformEvent(AstrMessageEvent): @@ -68,7 +70,10 @@ class TelegramPlatformEvent(AstrMessageEvent): @classmethod async def send_with_client( - cls, client: ExtBot, message: MessageChain, user_name: str + cls, + client: ExtBot, + message: MessageChain, + user_name: str, ): image_path = None @@ -104,14 +109,18 @@ class TelegramPlatformEvent(AstrMessageEvent): for chunk in chunks: try: md_text = telegramify_markdown.markdownify( - chunk, max_line_length=None, normalize_whitespace=False + chunk, + max_line_length=None, + normalize_whitespace=False, ) await client.send_message( - text=md_text, parse_mode="MarkdownV2", **payload + text=md_text, + parse_mode="MarkdownV2", + **payload, ) except Exception as e: logger.warning( - f"MarkdownV2 send failed: {e}. Using plain text instead." + f"MarkdownV2 send failed: {e}. Using plain text instead.", ) await client.send_message(text=chunk, **payload) elif isinstance(i, Image): @@ -137,8 +146,7 @@ class TelegramPlatformEvent(AstrMessageEvent): await super().send(message) async def react(self, emoji: str | None, big: bool = False): - """ - 给原消息添加 Telegram 反应: + """给原消息添加 Telegram 反应: - 普通 emoji:传入 '👍'、'😂' 等 - 自定义表情:传入其 custom_emoji_id(纯数字字符串) - 取消本机器人的反应:传入 None 或空字符串 @@ -216,7 +224,9 @@ class TelegramPlatformEvent(AstrMessageEvent): i.file = path await self.client.send_document( - document=i.file, filename=i.name, **payload + document=i.file, + filename=i.name, + **payload, ) continue elif isinstance(i, Record): @@ -263,7 +273,9 @@ class TelegramPlatformEvent(AstrMessageEvent): if delta and current_content != delta: try: markdown_text = telegramify_markdown.markdownify( - delta, max_line_length=None, normalize_whitespace=False + delta, + max_line_length=None, + normalize_whitespace=False, ) await self.client.edit_message_text( text=markdown_text, @@ -274,7 +286,9 @@ class TelegramPlatformEvent(AstrMessageEvent): except Exception as e: logger.warning(f"Markdown转换失败,使用普通文本: {e!s}") await self.client.edit_message_text( - text=delta, chat_id=payload["chat_id"], message_id=message_id + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, ) except Exception as e: logger.warning(f"编辑消息失败(streaming): {e!s}") diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index faec122ac..ad6a50775 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -1,24 +1,27 @@ -import time import asyncio -import uuid import os -from typing import Awaitable, Any, Callable +import time +import uuid +from collections.abc import Awaitable, Callable +from typing import Any + +from astrbot import logger +from astrbot.core.message.components import Image, Plain, Record +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.message.components import Plain, Image, Record # noqa: F403 -from astrbot import logger -from .webchat_queue_mgr import webchat_queue_mgr, WebChatQueueMgr -from .webchat_event import WebChatMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion -from ...register import register_platform_adapter from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from ...register import register_platform_adapter +from .webchat_event import WebChatMessageEvent +from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr + class QueueListener: def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None: @@ -35,7 +38,7 @@ class QueueListener: await self.callback(data) except Exception as e: logger.error( - f"Error processing message from conversation {conversation_id}: {e}" + f"Error processing message from conversation {conversation_id}: {e}", ) break @@ -66,7 +69,10 @@ class QueueListener: @register_platform_adapter("webchat", "webchat") class WebChatAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -77,11 +83,15 @@ class WebChatAdapter(Platform): os.makedirs(self.imgs_dir, exist_ok=True) self.metadata = PlatformMetadata( - name="webchat", description="webchat", id="webchat" + name="webchat", + description="webchat", + id="webchat", ) async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): await WebChatMessageEvent._send(message_chain, session.session_id) await super().send_by_session(session, message_chain) @@ -106,13 +116,13 @@ class WebChatAdapter(Platform): if isinstance(payload["image_url"], list): for img in payload["image_url"]: abm.message.append( - Image.fromFileSystem(os.path.join(self.imgs_dir, img)) + Image.fromFileSystem(os.path.join(self.imgs_dir, img)), ) else: abm.message.append( Image.fromFileSystem( - os.path.join(self.imgs_dir, payload["image_url"]) - ) + os.path.join(self.imgs_dir, payload["image_url"]), + ), ) if payload["audio_url"]: if isinstance(payload["audio_url"], list): diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 3bf1c0a2a..4d4d3b59e 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,11 +1,13 @@ +import base64 import os import uuid -import base64 + from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Plain, Image, Record -from astrbot.core.utils.io import download_image_by_url +from astrbot.api.message_components import Image, Plain, Record from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_image_by_url + from .webchat_queue_mgr import webchat_queue_mgr imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") @@ -26,7 +28,7 @@ class WebChatMessageEvent(AstrMessageEvent): "type": "end", "data": "", "streaming": False, - } # end means this request is finished + }, # end means this request is finished ) return "" @@ -41,7 +43,7 @@ class WebChatMessageEvent(AstrMessageEvent): "data": data, "streaming": streaming, "chain_type": message.type, - } + }, ) elif isinstance(comp, Image): # save image to local @@ -70,7 +72,7 @@ class WebChatMessageEvent(AstrMessageEvent): "cid": cid, "data": data, "streaming": streaming, - } + }, ) elif isinstance(comp, Record): # save record to local @@ -94,7 +96,7 @@ class WebChatMessageEvent(AstrMessageEvent): "cid": cid, "data": data, "streaming": streaming, - } + }, ) else: logger.debug(f"webchat 忽略: {comp.type}") @@ -118,12 +120,14 @@ class WebChatMessageEvent(AstrMessageEvent): "data": final_data, "streaming": True, "cid": cid, - } + }, ) final_data = "" continue final_data += await WebChatMessageEvent._send( - chain, session_id=self.session_id, streaming=True + chain, + session_id=self.session_id, + streaming=True, ) await web_chat_back_queue.put( @@ -132,6 +136,6 @@ class WebChatMessageEvent(AstrMessageEvent): "data": final_data, "streaming": True, "cid": cid, - } + }, ) await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 6b835ecb5..165375cd5 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -2,24 +2,24 @@ import asyncio import base64 import json import os -import traceback import time -from typing import Optional +import traceback import aiohttp import anyio import websockets + from astrbot import logger -from astrbot.api.message_components import Plain, Image, At, Record +from astrbot.api.message_components import At, Image, Plain, Record from astrbot.api.platform import Platform, PlatformMetadata from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astrbot_message import ( AstrBotMessage, MessageMember, MessageType, ) from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter from .wechatpadpro_message_event import WeChatPadProMessageEvent @@ -28,14 +28,17 @@ try: from .xml_data_parser import GeweDataParser except ImportError as e: logger.warning( - f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}" + f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {e!s}", ) @register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器") class WeChatPadProAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self._shutdown_event = None @@ -55,16 +58,19 @@ class WeChatPadProAdapter(Platform): self.host = self.config.get("host") self.port = self.config.get("port") self.active_mesasge_poll: bool = self.config.get( - "wpp_active_message_poll", False + "wpp_active_message_poll", + False, ) self.active_message_poll_interval: int = self.config.get( - "wpp_active_message_poll_interval", 5 + "wpp_active_message_poll_interval", + 5, ) self.base_url = f"http://{self.host}:{self.port}" self.auth_key = None # 用于保存生成的授权码 self.wxid = None # 用于保存登录成功后的 wxid self.credentials_file = os.path.join( - get_astrbot_data_path(), "wechatpadpro_credentials.json" + get_astrbot_data_path(), + "wechatpadpro_credentials.json", ) # 持久化文件路径 self.ws_handle_task = None @@ -81,9 +87,7 @@ class WeChatPadProAdapter(Platform): self.max_text_cache = 100 async def run(self) -> None: - """ - 启动平台适配器的运行实例。 - """ + """启动平台适配器的运行实例。""" logger.info("WeChatPadPro 适配器正在启动...") if loaded_credentials := self.load_credentials(): @@ -132,12 +136,10 @@ class WeChatPadProAdapter(Platform): logger.info("WeChatPadPro 适配器已停止。") def load_credentials(self): - """ - 从文件中加载 auth_key 和 wxid。 - """ + """从文件中加载 auth_key 和 wxid。""" if os.path.exists(self.credentials_file): try: - with open(self.credentials_file, "r") as f: + with open(self.credentials_file) as f: credentials = json.load(f) logger.info("成功加载 WeChatPadPro 凭据。") return credentials @@ -146,9 +148,7 @@ class WeChatPadProAdapter(Platform): return None def save_credentials(self): - """ - 将 auth_key 和 wxid 保存到文件。 - """ + """将 auth_key 和 wxid 保存到文件。""" credentials = { "auth_key": self.auth_key, "wxid": self.wxid, @@ -163,9 +163,7 @@ class WeChatPadProAdapter(Platform): logger.error(f"保存 WeChatPadPro 凭据失败: {e}") async def check_online_status(self): - """ - 检查 WeChatPadPro 设备是否在线。 - """ + """检查 WeChatPadPro 设备是否在线。""" if not self.auth_key: return False url = f"{self.base_url}/login/GetLoginStatus" @@ -182,25 +180,23 @@ class WeChatPadProAdapter(Platform): logger.info("WeChatPadPro 设备当前在线。") return True # login_state == 3 为离线状态 - elif login_state == 3: + if login_state == 3: logger.info("WeChatPadPro 设备不在线。") return False - else: - logger.error(f"未知的在线状态: {response_data}") - return False + logger.error(f"未知的在线状态: {response_data}") + return False # Code == 300 为微信退出状态。 - elif response.status == 200 and response_data.get("Code") == 300: + if response.status == 200 and response_data.get("Code") == 300: logger.info("WeChatPadPro 设备已退出。") return False - elif response.status == 200 and response_data.get("Code") == -2: + if response.status == 200 and response_data.get("Code") == -2: # 该链接不存在 self.auth_key = None return False - else: - logger.error( - f"检查在线状态失败: {response.status}, {response_data}" - ) - return False + logger.error( + f"检查在线状态失败: {response.status}, {response_data}", + ) + return False except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") @@ -221,9 +217,7 @@ class WeChatPadProAdapter(Platform): return None async def generate_auth_key(self): - """ - 生成授权码。 - """ + """生成授权码。""" url = f"{self.base_url}/admin/GenAuthKey1" params = {"key": self.admin_key} payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码 @@ -235,7 +229,7 @@ class WeChatPadProAdapter(Platform): async with session.post(url, params=params, json=payload) as response: if response.status != 200: logger.error( - f"生成授权码失败: {response.status}, {await response.text()}" + f"生成授权码失败: {response.status}, {await response.text()}", ) return @@ -248,7 +242,7 @@ class WeChatPadProAdapter(Platform): logger.info("成功获取授权码") else: logger.error( - f"生成授权码成功但未找到授权码: {response_data}" + f"生成授权码成功但未找到授权码: {response_data}", ) else: logger.error(f"生成授权码失败: {response_data}") @@ -258,9 +252,7 @@ class WeChatPadProAdapter(Platform): logger.error(f"生成授权码时发生错误: {e}") async def get_login_qr_code(self): - """ - 获取登录二维码地址。 - """ + """获取登录二维码地址。""" url = f"{self.base_url}/login/GetLoginQrCodeNew" params = {"key": self.auth_key} payload = {} # 根据文档,这个接口的 body 可以为空 @@ -272,26 +264,24 @@ class WeChatPadProAdapter(Platform): if response.status == 200 and response_data.get("Code") == 200: # 二维码地址在 Data.QrCodeUrl 字段中 if response_data.get("Data") and response_data["Data"].get( - "QrCodeUrl" + "QrCodeUrl", ): return response_data["Data"]["QrCodeUrl"] - else: - logger.error( - f"获取登录二维码成功但未找到二维码地址: {response_data}" - ) - return None - elif "该 key 无效" in response_data.get("Text"): logger.error( - "授权码无效,已经清除。请重新启动 AstrBot 或者本消息适配器。原因也可能是 WeChatPadPro 的 MySQL 服务没有启动成功,请检查 WeChatPadPro 服务的日志。" + f"获取登录二维码成功但未找到二维码地址: {response_data}", + ) + return None + if "该 key 无效" in response_data.get("Text"): + logger.error( + "授权码无效,已经清除。请重新启动 AstrBot 或者本消息适配器。原因也可能是 WeChatPadPro 的 MySQL 服务没有启动成功,请检查 WeChatPadPro 服务的日志。", ) self.auth_key = None self.save_credentials() return None - else: - logger.error( - f"获取登录二维码失败: {response.status}, {response_data}" - ) - return None + logger.error( + f"获取登录二维码失败: {response.status}, {response_data}", + ) + return None except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") return None @@ -300,8 +290,7 @@ class WeChatPadProAdapter(Platform): return None async def check_login_status(self): - """ - 循环检测扫码状态。 + """循环检测扫码状态。 尝试 6 次后跳出循环,添加倒计时。 返回 True 如果登录成功,否则返回 False。 """ @@ -325,31 +314,31 @@ class WeChatPadProAdapter(Platform): ): status = response_data["Data"]["state"] logger.info( - f"第 {attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}秒" + f"第 {attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}秒", ) if status == 2: # 状态 2 表示登录成功 self.wxid = response_data["Data"].get("wxid") self.wxnewpass = response_data["Data"].get( - "wxnewpass" + "wxnewpass", ) logger.info( - f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}" + f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}", ) self.save_credentials() # 登录成功后保存凭据 return True - elif status == -2: # 二维码过期 + if status == -2: # 二维码过期 logger.error("二维码已过期,请重新获取。") return False else: logger.error( - f"检测登录状态成功但未找到登录状态: {response_data}" + f"检测登录状态成功但未找到登录状态: {response_data}", ) elif response_data.get("Code") == 300: # "不存在状态" pass else: logger.info( - f"检测登录状态失败: {response.status}, {response_data}" + f"检测登录状态失败: {response.status}, {response_data}", ) except aiohttp.ClientConnectorError as e: @@ -368,13 +357,11 @@ class WeChatPadProAdapter(Platform): return False async def connect_websocket(self): - """ - 建立 WebSocket 连接并处理接收到的消息。 - """ + """建立 WebSocket 连接并处理接收到的消息。""" os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}" ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}" logger.info( - f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***" + f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***", ) while True: try: @@ -389,7 +376,8 @@ class WeChatPadProAdapter(Platform): while True: try: message = await asyncio.wait_for( - websocket.recv(), timeout=wait_time + websocket.recv(), + timeout=wait_time, ) # logger.debug(message) # 不显示原始消息内容 asyncio.create_task(self.handle_websocket_message(message)) @@ -404,14 +392,12 @@ class WeChatPadProAdapter(Platform): break except Exception as e: logger.error( - f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。" + f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。", ) await asyncio.sleep(5) async def handle_websocket_message(self, message: str): - """ - 处理从 WebSocket 接收到的消息。 - """ + """处理从 WebSocket 接收到的消息。""" logger.debug(f"收到 WebSocket 消息: {message}") try: message_data = json.loads(message) @@ -441,9 +427,7 @@ class WeChatPadProAdapter(Platform): logger.error(f"处理 WebSocket 消息时发生错误: {e}") async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: - """ - 将 WeChatPadPro 原始消息转换为 AstrBotMessage。 - """ + """将 WeChatPadPro 原始消息转换为 AstrBotMessage。""" abm = AstrBotMessage() abm.raw_message = raw_message abm.message_id = str(raw_message.get("msg_id")) @@ -452,7 +436,7 @@ class WeChatPadProAdapter(Platform): if int(time.time()) - abm.timestamp > 180: logger.warning( - f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}。" + f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}。", ) return None @@ -476,7 +460,12 @@ class WeChatPadProAdapter(Platform): # 先判断群聊/私聊并设置基本属性 if await self._process_chat_type( - abm, raw_message, from_user_name, to_user_name, content, push_content + abm, + raw_message, + from_user_name, + to_user_name, + content, + push_content, ): # 再根据消息类型处理消息内容 await self._process_message_content(abm, raw_message, msg_type, content) @@ -493,9 +482,7 @@ class WeChatPadProAdapter(Platform): content: str, push_content: str, ): - """ - 判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。 - """ + """判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。""" if from_user_name == "weixin": return False at_me = False @@ -510,7 +497,8 @@ class WeChatPadProAdapter(Platform): # 获取群聊发送者的nickname if sender_wxid: accurate_nickname = await self._get_group_member_nickname( - abm.group_id, sender_wxid + abm.group_id, + sender_wxid, ) if accurate_nickname: abm.sender.nickname = accurate_nickname @@ -539,11 +527,11 @@ class WeChatPadProAdapter(Platform): return True async def _get_group_member_nickname( - self, group_id: str, member_wxid: str - ) -> Optional[str]: - """ - 通过接口获取群成员的昵称。 - """ + self, + group_id: str, + member_wxid: str, + ) -> str | None: + """通过接口获取群成员的昵称。""" url = f"{self.base_url}/group/GetChatroomMemberDetail" params = {"key": self.auth_key} payload = { @@ -565,11 +553,11 @@ class WeChatPadProAdapter(Platform): if member.get("user_name") == member_wxid: return member.get("nick_name") logger.warning( - f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称" + f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称", ) else: logger.error( - f"获取群成员详情失败: {response.status}, {response_data}" + f"获取群成员详情失败: {response.status}, {response_data}", ) return None except aiohttp.ClientConnectorError as e: @@ -580,7 +568,10 @@ class WeChatPadProAdapter(Platform): return None async def _download_raw_image( - self, from_user_name: str, to_user_name: str, msg_id: int + self, + from_user_name: str, + to_user_name: str, + msg_id: int, ): """下载原始图片。""" url = f"{self.base_url}/message/GetMsgBigImg" @@ -598,9 +589,8 @@ class WeChatPadProAdapter(Platform): async with session.post(url, params=params, json=payload) as response: if response.status == 200: return await response.json() - else: - logger.error(f"下载图片失败: {response.status}") - return None + logger.error(f"下载图片失败: {response.status}") + return None except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") return None @@ -609,7 +599,11 @@ class WeChatPadProAdapter(Platform): return None async def download_voice( - self, to_user_name: str, new_msg_id: str, bufid: str, length: int + self, + to_user_name: str, + new_msg_id: str, + bufid: str, + length: int, ): """下载原始音频。""" url = f"{self.base_url}/message/GetMsgVoice" @@ -635,11 +629,13 @@ class WeChatPadProAdapter(Platform): return None async def _process_message_content( - self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str + self, + abm: AstrBotMessage, + raw_message: dict, + msg_type: int, + content: str, ): - """ - 根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。 - """ + """根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。""" if msg_type == 1: # 文本消息 abm.message_str = content if abm.type == MessageType.GROUP_MESSAGE: @@ -671,10 +667,12 @@ class WeChatPadProAdapter(Platform): if at_me: # 被@了,在消息开头插入At组件(参考gewechat的做法) bot_nickname = await self._get_group_member_nickname( - abm.group_id, abm.self_id + abm.group_id, + abm.self_id, ) abm.message.insert( - 0, At(qq=abm.self_id, name=bot_nickname or abm.self_id) + 0, + At(qq=abm.self_id, name=bot_nickname or abm.self_id), ) # 只有当消息内容不仅仅是@时才添加Plain组件 @@ -727,7 +725,9 @@ class WeChatPadProAdapter(Platform): to_user_name = raw_message.get("to_user_name", {}).get("str", "") msg_id = raw_message.get("msg_id") image_resp = await self._download_raw_image( - from_user_name, to_user_name, msg_id + from_user_name, + to_user_name, + msg_id, ) image_bs64_data = ( image_resp.get("Data", {}).get("Data", {}).get("Buffer", None) @@ -789,7 +789,8 @@ class WeChatPadProAdapter(Platform): voice_bs64_data = base64.b64decode(voice_bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") file_path = os.path.join( - temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk" + temp_dir, + f"wechatpadpro_voice_{abm.message_id}.silk", ) async with await anyio.open_file(file_path, "wb") as f: @@ -819,9 +820,7 @@ class WeChatPadProAdapter(Platform): logger.warning(f"收到未处理的消息类型: {msg_type}。") async def terminate(self): - """ - 终止一个平台的运行实例。 - """ + """终止一个平台的运行实例。""" logger.info("终止 WeChatPadPro 适配器。") try: if self.ws_handle_task: @@ -831,13 +830,13 @@ class WeChatPadProAdapter(Platform): pass def meta(self) -> PlatformMetadata: - """ - 得到一个平台的元数据。 - """ + """得到一个平台的元数据。""" return self.metadata async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): dummy_message_obj = AstrBotMessage() dummy_message_obj.session_id = session.session_id @@ -864,9 +863,7 @@ class WeChatPadProAdapter(Platform): await sending_event.send(message_chain) async def get_contact_list(self): - """ - 获取联系人列表。 - """ + """获取联系人列表。""" url = f"{self.base_url}/friend/GetContactList" params = {"key": self.auth_key} payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0} @@ -884,9 +881,8 @@ class WeChatPadProAdapter(Platform): .get("contactUsernameList", []) ) return contact_list - else: - logger.error(f"获取联系人列表失败: {result}") - return None + logger.error(f"获取联系人列表失败: {result}") + return None except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") return None @@ -895,11 +891,11 @@ class WeChatPadProAdapter(Platform): return None async def get_contact_details_list( - self, room_wx_id_list: list[str] = None, user_names: list[str] = None - ) -> Optional[dict]: - """ - 获取联系人详情列表。 - """ + self, + room_wx_id_list: list[str] = None, + user_names: list[str] = None, + ) -> dict | None: + """获取联系人详情列表。""" if room_wx_id_list is None: room_wx_id_list = [] if user_names is None: @@ -917,9 +913,8 @@ class WeChatPadProAdapter(Platform): if result.get("Code") == 200 and result.get("Data"): contact_list = result.get("Data", {}).get("contactList", {}) return contact_list - else: - logger.error(f"获取联系人详情列表失败: {result}") - return None + logger.error(f"获取联系人详情列表失败: {result}") + return None except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") return None diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 2bd3a1b89..e6d83d8ea 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -10,8 +10,8 @@ from astrbot import logger from astrbot.core.message.components import ( Image, Plain, - WechatEmoji, Record, + WechatEmoji, ) # Import Image from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -56,8 +56,8 @@ class WeChatPadProMessageEvent(AstrMessageEvent): b64c = self._compress_image(raw) payload = { "MsgItem": [ - {"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id} - ] + {"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id}, + ], } url = f"{self.adapter.base_url}/message/SendImageNewMessage" await self._post(session, url, payload) @@ -66,7 +66,8 @@ class WeChatPadProMessageEvent(AstrMessageEvent): if ( self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息 and self.adapter.settings.get( - "reply_with_mention", False + "reply_with_mention", + False, ) # 检查适配器设置是否启用 reply_with_mention and self.message_obj.sender # 确保有发送者信息 and ( @@ -91,8 +92,8 @@ class WeChatPadProMessageEvent(AstrMessageEvent): "MsgType": 1, "TextContent": message_text, "ToUserName": session_id, - } - ] + }, + ], } url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) @@ -104,8 +105,8 @@ class WeChatPadProMessageEvent(AstrMessageEvent): "EmojiMd5": comp.md5, "EmojiSize": comp.md5_len, "ToUserName": self.session_id, - } - ] + }, + ], } url = f"{self.adapter.base_url}/message/SendEmojiMessage" await self._post(session, url, payload) diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py index 054ca1b48..d372211c9 100644 --- a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -1,10 +1,13 @@ from defusedxml import ElementTree as eT + from astrbot.api import logger from astrbot.api.message_components import ( - WechatEmoji as Emoji, - Plain, - Image, BaseMessageComponent, + Image, + Plain, +) +from astrbot.api.message_components import ( + WechatEmoji as Emoji, ) @@ -47,9 +50,7 @@ class GeweDataParser: raise async def parse_mutil_49(self) -> list[BaseMessageComponent] | None: - """ - 处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57) - """ + """处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57)""" try: appmsg_type = self._format_to_xml().findtext(".//appmsg/type") if appmsg_type == "57": @@ -59,9 +60,7 @@ class GeweDataParser: return None async def parse_reply(self) -> list[BaseMessageComponent]: - """ - 处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49) - """ + """处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49)""" components = [] try: @@ -96,7 +95,9 @@ class GeweDataParser: ) if cdn_url and self.downloader: image_resp = await self.downloader( - self.from_user_name, self.to_user_name, self.msg_id + self.from_user_name, + self.to_user_name, + self.msg_id, ) quoted_image_b64 = ( image_resp.get("Data", {}) @@ -111,11 +112,11 @@ class GeweDataParser: [ Image.fromBase64(quoted_image_b64), Plain(f"[引用] {nickname}: [引用的图片]"), - ] + ], ) else: components.append( - Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]") + Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]"), ) case 49: # 嵌套引用 @@ -143,9 +144,7 @@ class GeweDataParser: return components def parse_emoji(self) -> Emoji | None: - """ - 处理 msg_type == 47 的表情消息(emoji) - """ + """处理 msg_type == 47 的表情消息(emoji)""" try: emoji_element = self._format_to_xml().find(".//emoji") if emoji_element is not None: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 50341a8ae..ffd5ec8ee 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -41,10 +41,14 @@ class WecomServer: self.port = int(config.get("port")) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.server.add_url_rule( - "/callback/command", view_func=self.verify, methods=["GET"] + "/callback/command", + view_func=self.verify, + methods=["GET"], ) self.server.add_url_rule( - "/callback/command", view_func=self.callback_command, methods=["POST"] + "/callback/command", + view_func=self.callback_command, + methods=["POST"], ) self.event_queue = event_queue @@ -94,7 +98,7 @@ class WecomServer: async def start_polling(self): logger.info( - f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。" + f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。", ) await self.server.run_task( host=self.callback_server_host, @@ -109,21 +113,24 @@ class WecomServer: @register_platform_adapter("wecom", "wecom 适配器") class WecomPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self.config = platform_config self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] self.api_base_url = platform_config.get( - "api_base_url", "https://qyapi.weixin.qq.com/cgi-bin/" + "api_base_url", + "https://qyapi.weixin.qq.com/cgi-bin/", ) if not self.api_base_url: self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/" - if self.api_base_url.endswith("/"): - self.api_base_url = self.api_base_url[:-1] + self.api_base_url = self.api_base_url.removesuffix("/") if not self.api_base_url.endswith("/cgi-bin"): self.api_base_url += "/cgi-bin" @@ -165,7 +172,8 @@ class WecomPlatformAdapter(Platform): return None msg_new = await asyncio.get_event_loop().run_in_executor( - None, get_latest_msg_item + None, + get_latest_msg_item, ) if msg_new: await self.convert_wechat_kf_message(msg_new) @@ -176,7 +184,9 @@ class WecomPlatformAdapter(Platform): @override async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): await super().send_by_session(session, message_chain) @@ -195,10 +205,11 @@ class WecomPlatformAdapter(Platform): try: acc_list = ( await loop.run_in_executor( - None, self.wechat_kf_api.get_account_list + None, + self.wechat_kf_api.get_account_list, ) ).get("account_list", []) - logger.debug(f"获取到微信客服列表: {str(acc_list)}") + logger.debug(f"获取到微信客服列表: {acc_list!s}") for acc in acc_list: name = acc.get("name", None) if name != self.kf_name: @@ -206,7 +217,7 @@ class WecomPlatformAdapter(Platform): open_kfid = acc.get("open_kfid", None) if not open_kfid: logger.error("获取微信客服失败,open_kfid 为空。") - logger.debug(f"Found open_kfid: {str(open_kfid)}") + logger.debug(f"Found open_kfid: {open_kfid!s}") kf_url = ( await loop.run_in_executor( None, @@ -216,7 +227,7 @@ class WecomPlatformAdapter(Platform): ) ).get("url", "") logger.info( - f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}" + f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}", ) except Exception as e: logger.error(e) @@ -256,7 +267,9 @@ class WecomPlatformAdapter(Platform): assert isinstance(msg, VoiceMessage) resp: Response = await asyncio.get_event_loop().run_in_executor( - None, self.client.media.download, msg.media_id + None, + self.client.media.download, + msg.media_id, ) temp_dir = os.path.join(get_astrbot_data_path(), "temp") path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr") @@ -294,8 +307,8 @@ class WecomPlatformAdapter(Platform): await self.handle_msg(abm) async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: - msgtype = msg.get("msgtype", None) - external_userid = msg.get("external_userid", None) + msgtype = msg.get("msgtype") + external_userid = msg.get("external_userid") abm = AstrBotMessage() abm.raw_message = msg abm.raw_message["_wechat_kf_flag"] = None # 方便处理 @@ -312,7 +325,9 @@ class WecomPlatformAdapter(Platform): elif msgtype == "image": media_id = msg.get("image", {}).get("media_id", "") resp: Response = await asyncio.get_event_loop().run_in_executor( - None, self.client.media.download, media_id + None, + self.client.media.download, + media_id, ) path = f"data/temp/wechat_kf_{media_id}.jpg" with open(path, "wb") as f: @@ -321,7 +336,9 @@ class WecomPlatformAdapter(Platform): elif msgtype == "voice": media_id = msg.get("voice", {}).get("media_id", "") resp: Response = await asyncio.get_event_loop().run_in_executor( - None, self.client.media.download, media_id + None, + self.client.media.download, + media_id, ) temp_dir = os.path.join(get_astrbot_data_path(), "temp") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index e8078a9ac..ba9ad9a49 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -1,22 +1,23 @@ +import asyncio import os import uuid -import asyncio -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image, Record + from wechatpy.enterprise import WeChatClient -from .wecom_kf_message import WeChatKFMessage from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import Image, Plain, Record +from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from .wecom_kf_message import WeChatKFMessage + try: import pydub except Exception: logger.warning( - "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。" + "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。", ) - pass class WecomPlatformEvent(AstrMessageEvent): @@ -33,7 +34,9 @@ class WecomPlatformEvent(AstrMessageEvent): @staticmethod async def send_with_client( - client: WeChatClient, message: MessageChain, user_name: str + client: WeChatClient, + message: MessageChain, + user_name: str, ): pass @@ -44,44 +47,44 @@ class WecomPlatformEvent(AstrMessageEvent): plain (str): 要分割的长文本 Returns: list[str]: 分割后的文本列表 + """ if len(plain) <= 2048: return [plain] - else: - result = [] - start = 0 - while start < len(plain): - # 剩下的字符串长度<2048时结束 - if start + 2048 >= len(plain): - result.append(plain[start:]) + result = [] + start = 0 + while start < len(plain): + # 剩下的字符串长度<2048时结束 + if start + 2048 >= len(plain): + result.append(plain[start:]) + break + + # 向前搜索分割标点符号 + end = min(start + 2048, len(plain)) + cut_position = end + for i in range(end, start, -1): + if i < len(plain) and plain[i - 1] in [ + "。", + "!", + "?", + ".", + "!", + "?", + "\n", + ";", + ";", + ]: + cut_position = i break - # 向前搜索分割标点符号 - end = min(start + 2048, len(plain)) + # 没找到合适的位置分割, 直接切分 + if cut_position == end and end < len(plain): cut_position = end - for i in range(end, start, -1): - if i < len(plain) and plain[i - 1] in [ - "。", - "!", - "?", - ".", - "!", - "?", - "\n", - ";", - ";", - ]: - cut_position = i - break - # 没找到合适的位置分割, 直接切分 - if cut_position == end and end < len(plain): - cut_position = end + result.append(plain[start:cut_position]) + start = cut_position - result.append(plain[start:cut_position]) - start = cut_position - - return result + return result async def send(self, message: MessageChain): message_obj = self.message_obj @@ -111,7 +114,7 @@ class WecomPlatformEvent(AstrMessageEvent): except Exception as e: logger.error(f"微信客服上传图片失败: {e}") await self.send( - MessageChain().message(f"微信客服上传图片失败: {e}") + MessageChain().message(f"微信客服上传图片失败: {e}"), ) return logger.debug(f"微信客服上传图片返回: {response}") @@ -126,7 +129,8 @@ class WecomPlatformEvent(AstrMessageEvent): temp_dir = os.path.join(get_astrbot_data_path(), "temp") record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr") pydub.AudioSegment.from_wav(record_path).export( - record_path_amr, format="amr" + record_path_amr, + format="amr", ) with open(record_path_amr, "rb") as f: @@ -135,7 +139,7 @@ class WecomPlatformEvent(AstrMessageEvent): except Exception as e: logger.error(f"微信客服上传语音失败: {e}") await self.send( - MessageChain().message(f"微信客服上传语音失败: {e}") + MessageChain().message(f"微信客服上传语音失败: {e}"), ) return logger.info(f"微信客服上传语音返回: {response}") @@ -154,7 +158,9 @@ class WecomPlatformEvent(AstrMessageEvent): plain_chunks = await self.split_plain(comp.text) for chunk in plain_chunks: self.client.message.send_text( - message_obj.self_id, message_obj.session_id, chunk + message_obj.self_id, + message_obj.session_id, + chunk, ) await asyncio.sleep(0.5) # Avoid sending too fast elif isinstance(comp, Image): @@ -166,7 +172,7 @@ class WecomPlatformEvent(AstrMessageEvent): except Exception as e: logger.error(f"企业微信上传图片失败: {e}") await self.send( - MessageChain().message(f"企业微信上传图片失败: {e}") + MessageChain().message(f"企业微信上传图片失败: {e}"), ) return logger.debug(f"企业微信上传图片返回: {response}") @@ -181,7 +187,8 @@ class WecomPlatformEvent(AstrMessageEvent): temp_dir = os.path.join(get_astrbot_data_path(), "temp") record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr") pydub.AudioSegment.from_wav(record_path).export( - record_path_amr, format="amr" + record_path_amr, + format="amr", ) with open(record_path_amr, "rb") as f: @@ -190,7 +197,7 @@ class WecomPlatformEvent(AstrMessageEvent): except Exception as e: logger.error(f"企业微信上传语音失败: {e}") await self.send( - MessageChain().message(f"企业微信上传语音失败: {e}") + MessageChain().message(f"企业微信上传语音失败: {e}"), ) return logger.info(f"企业微信上传语音返回: {response}") @@ -212,7 +219,7 @@ class WecomPlatformEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/wecom/wecom_kf.py b/astrbot/core/platform/sources/wecom/wecom_kf.py index 118667975..51f4ee14f 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf.py @@ -1,7 +1,4 @@ -# -*- coding: utf-8 -*- - -""" -The MIT License (MIT) +"""The MIT License (MIT) Copyright (c) 2014-2020 messense @@ -28,15 +25,13 @@ from wechatpy.client.api.base import BaseWeChatAPI class WeChatKF(BaseWeChatAPI): - """ - 微信客服接口 + """微信客服接口 https://work.weixin.qq.com/api/doc/90000/90135/94670 """ def sync_msg(self, token, open_kfid, cursor="", limit=1000): - """ - 微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收) + """微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收) 、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。 支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。 @@ -57,8 +52,7 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/sync_msg", data=data) def get_service_state(self, open_kfid, external_userid): - """ - 获取会话状态 + """获取会话状态 ID 状态 说明 0 未处理 新会话接入。可选择:1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待 @@ -78,10 +72,13 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/service_state/get", data=data) def trans_service_state( - self, open_kfid, external_userid, service_state, servicer_userid="" + self, + open_kfid, + external_userid, + service_state, + servicer_userid="", ): - """ - 变更会话状态 + """变更会话状态 :param open_kfid: 客服帐号ID :param external_userid: 微信客户的external_userid @@ -98,8 +95,7 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/service_state/trans", data=data) def get_servicer_list(self, open_kfid): - """ - 获取接待人员列表 + """获取接待人员列表 :param open_kfid: 客服帐号ID :return: 接口调用结果 @@ -110,8 +106,7 @@ class WeChatKF(BaseWeChatAPI): return self._get("kf/servicer/list", params=data) def add_servicer(self, open_kfid, userid_list): - """ - 添加接待人员 + """添加接待人员 添加指定客服帐号的接待人员。 :param open_kfid: 客服帐号ID @@ -128,8 +123,7 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/servicer/add", data=data) def del_servicer(self, open_kfid, userid_list): - """ - 删除接待人员 + """删除接待人员 从客服帐号删除接待人员 :param open_kfid: 客服帐号ID @@ -146,8 +140,7 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/servicer/del", data=data) def batchget_customer(self, external_userid_list): - """ - 客户基本信息获取 + """客户基本信息获取 :param external_userid_list: external_userid列表 :return: 接口调用结果 @@ -161,16 +154,14 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/customer/batchget", data=data) def get_account_list(self): - """ - 获取客服帐号列表 + """获取客服帐号列表 :return: 接口调用结果 """ return self._get("kf/account/list") def add_contact_way(self, open_kfid, scene): - """ - 获取客服帐号链接 + """获取客服帐号链接 :param open_kfid: 客服帐号ID :param scene: 场景值,字符串类型,由开发者自定义。不多于32字节;字符串取值范围(正则表达式):[0-9a-zA-Z_-]* @@ -180,18 +171,21 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/add_contact_way", data=data) def get_upgrade_service_config(self): - """ - 获取配置的专员与客户群 + """获取配置的专员与客户群 :return: 接口调用结果 """ return self._get("kf/customer/get_upgrade_service_config") def upgrade_service( - self, open_kfid, external_userid, service_type, member=None, groupchat=None + self, + open_kfid, + external_userid, + service_type, + member=None, + groupchat=None, ): - """ - 为客户升级为专员或客户群服务 + """为客户升级为专员或客户群服务 :param open_kfid: 客服帐号ID :param external_userid: 微信客户的external_userid @@ -200,7 +194,6 @@ class WeChatKF(BaseWeChatAPI): :param groupchat: 推荐的客户群,type等于2时有效 :return: 接口调用结果 """ - data = { "open_kfid": open_kfid, "external_userid": external_userid, @@ -213,20 +206,17 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/customer/upgrade_service", data=data) def cancel_upgrade_service(self, open_kfid, external_userid): - """ - 为客户取消推荐 + """为客户取消推荐 :param open_kfid: 客服帐号ID :param external_userid: 微信客户的external_userid :return: 接口调用结果 """ - data = {"open_kfid": open_kfid, "external_userid": external_userid} return self._post("kf/customer/cancel_upgrade_service", data=data) def send_msg_on_event(self, code, msgtype, msg_content, msgid=None): - """ - 当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。 + """当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。 支持发送消息类型:文本、菜单消息。 :param code: 事件响应消息对应的code。通过事件回调下发,仅可使用一次。 @@ -236,7 +226,6 @@ class WeChatKF(BaseWeChatAPI): 字符串取值范围(正则表达式):[0-9a-zA-Z_-]* :return: 接口调用结果 """ - data = {"code": code, "msgtype": msgtype} if msgid: data["msgid"] = msgid @@ -244,8 +233,7 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/send_msg_on_event", data=data) def get_corp_statistic(self, start_time, end_time, open_kfid=None): - """ - 获取「客户数据统计」企业汇总数据 + """获取「客户数据统计」企业汇总数据 :param start_time: 开始时间 :param end_time: 结束时间 @@ -256,10 +244,13 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/get_corp_statistic", data=data) def get_servicer_statistic( - self, start_time, end_time, open_kfid=None, servicer_userid=None + self, + start_time, + end_time, + open_kfid=None, + servicer_userid=None, ): - """ - 获取「客户数据统计」接待人员明细数据 + """获取「客户数据统计」接待人员明细数据 :param start_time: 开始时间 :param end_time: 结束时间 @@ -276,8 +267,7 @@ class WeChatKF(BaseWeChatAPI): return self._post("kf/get_servicer_statistic", data=data) def account_update(self, open_kfid, name, media_id): - """ - 修改客服账号 + """修改客服账号 :param open_kfid: 客服帐号ID :param name: 客服名称 diff --git a/astrbot/core/platform/sources/wecom/wecom_kf_message.py b/astrbot/core/platform/sources/wecom/wecom_kf_message.py index 42fc20d65..d839134ab 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf_message.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf_message.py @@ -1,5 +1,4 @@ -""" -The MIT License (MIT) +"""The MIT License (MIT) Copyright (c) 2014-2020 messense @@ -23,13 +22,11 @@ SOFTWARE. """ from optionaldict import optionaldict - from wechatpy.client.api.base import BaseWeChatAPI class WeChatKFMessage(BaseWeChatAPI): - """ - 发送微信客服消息 + """发送微信客服消息 https://work.weixin.qq.com/api/doc/90000/90135/94677 @@ -46,8 +43,7 @@ class WeChatKFMessage(BaseWeChatAPI): """ def send(self, user_id, open_kfid, msgid="", msg=None): - """ - 当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。 + """当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。 注意仅当微信客户在主动发送消息给客服后的48小时内,企业可发送消息给客户,最多可发送5条消息;若用户继续发送消息,企业可再次下发消息。 支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。 @@ -127,7 +123,13 @@ class WeChatKFMessage(BaseWeChatAPI): ) def send_msgmenu( - self, user_id, open_kfid, head_content, menu_list, tail_content, msgid="" + self, + user_id, + open_kfid, + head_content, + menu_list, + tail_content, + msgid="", ): return self.send( user_id, @@ -144,7 +146,14 @@ class WeChatKFMessage(BaseWeChatAPI): ) def send_location( - self, user_id, open_kfid, name, address, latitude, longitude, msgid="" + self, + user_id, + open_kfid, + name, + address, + latitude, + longitude, + msgid="", ): return self.send( user_id, @@ -162,7 +171,14 @@ class WeChatKFMessage(BaseWeChatAPI): ) def send_miniprogram( - self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid="" + self, + user_id, + open_kfid, + appid, + title, + thumb_media_id, + pagepath, + msgid="", ): return self.send( user_id, diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 5332942b9..d11b8606b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- encoding:utf-8 -*- """对企业微信发送给企业后台的消息加解密示例代码. @copyright: Copyright (c) 1998-2020 Tencent Inc. @@ -7,15 +6,16 @@ """ # ------------------------------------------------------------------------ -import logging import base64 -import random import hashlib -import time -import struct -from Crypto.Cipher import AES -import socket import json +import logging +import random +import socket +import struct +import time + +from Crypto.Cipher import AES from . import ierror @@ -31,7 +31,7 @@ class FormatException(Exception): def throw_exception(message, exception_class=FormatException): - """my define raise exception function""" + """My define raise exception function""" raise exception_class(message) @@ -136,7 +136,7 @@ class PKCS7Encoder: return decrypted[:-pad] -class Prpcrypt(object): +class Prpcrypt: """提供接收和推送给企业微信消息的加解密接口""" def __init__(self, key): @@ -210,7 +210,7 @@ class Prpcrypt(object): return str(random.randint(1000000000000000, 9999999999999999)).encode() -class WXBizJsonMsgCrypt(object): +class WXBizJsonMsgCrypt: # 构造函数 def __init__(self, sToken, sEncodingAESKey, sReceiveId): try: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py index 7da900030..2f87b88b9 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py @@ -1,6 +1,4 @@ -""" -企业微信智能机器人平台适配器包 -""" +"""企业微信智能机器人平台适配器包""" from .wecomai_adapter import WecomAIBotAdapter from .wecomai_api import WecomAIBotAPIClient @@ -9,9 +7,9 @@ from .wecomai_server import WecomAIBotServer from .wecomai_utils import WecomAIBotConstants __all__ = [ - "WecomAIBotAdapter", "WecomAIBotAPIClient", + "WecomAIBotAdapter", + "WecomAIBotConstants", "WecomAIBotMessageEvent", "WecomAIBotServer", - "WecomAIBotConstants", ] diff --git a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py index cc1bf221e..0df14a505 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- ######################################################################### # Author: jonyqin # Created Time: Thu 11 Sep 2014 01:53:58 PM CST diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 830d8de58..29ac02653 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -1,38 +1,37 @@ -""" -企业微信智能机器人平台适配器 +"""企业微信智能机器人平台适配器 基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调 参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应 """ -import time import asyncio -import uuid -import hashlib import base64 -from typing import Awaitable, Any, Dict, Optional, Callable - +import hashlib +import time +import uuid +from collections.abc import Awaitable, Callable +from typing import Any +from astrbot.api import logger +from astrbot.api.event import MessageChain +from astrbot.api.message_components import At, Image, Plain from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, MessageType, + Platform, PlatformMetadata, ) -from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, At, Image -from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion -from ...register import register_platform_adapter +from ...register import register_platform_adapter from .wecomai_api import ( WecomAIBotAPIClient, WecomAIBotMessageParser, WecomAIBotStreamMessageBuilder, ) from .wecomai_event import WecomAIBotMessageEvent +from .wecomai_queue_mgr import WecomAIQueueMgr, wecomai_queue_mgr from .wecomai_server import WecomAIBotServer -from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr from .wecomai_utils import ( WecomAIBotConstants, format_session_id, @@ -45,7 +44,9 @@ class WecomAIQueueListener: """企业微信智能机器人队列监听器,参考webchat的QueueListener设计""" def __init__( - self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]] + self, + queue_mgr: WecomAIQueueMgr, + callback: Callable[[dict], Awaitable[None]], ) -> None: self.queue_mgr = queue_mgr self.callback = callback @@ -90,13 +91,17 @@ class WecomAIQueueListener: @register_platform_adapter( - "wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息" + "wecom_ai_bot", + "企业微信智能机器人适配器,支持 HTTP 回调接收消息", ) class WecomAIBotAdapter(Platform): """企业微信智能机器人适配器""" def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) @@ -110,10 +115,12 @@ class WecomAIBotAdapter(Platform): self.host = self.config.get("callback_server_host", "0.0.0.0") self.bot_name = self.config.get("wecom_ai_bot_name", "") self.initial_respond_text = self.config.get( - "wecomaibot_init_respond_text", "💭 思考中..." + "wecomaibot_init_respond_text", + "💭 思考中...", ) self.friend_message_welcome_text = self.config.get( - "wecomaibot_friend_message_welcome_text", "" + "wecomaibot_friend_message_welcome_text", + "", ) # 平台元数据 @@ -139,7 +146,8 @@ class WecomAIBotAdapter(Platform): # 队列监听器 self.queue_listener = WecomAIQueueListener( - wecomai_queue_mgr, self._handle_queued_message + wecomai_queue_mgr, + self._handle_queued_message, ) async def _handle_queued_message(self, data: dict): @@ -151,8 +159,10 @@ class WecomAIBotAdapter(Platform): logger.error(f"处理队列消息时发生异常: {e}") async def _process_message( - self, message_data: Dict[str, Any], callback_params: Dict[str, str] - ) -> Optional[str]: + self, + message_data: dict[str, Any], + callback_params: dict[str, str], + ) -> str | None: """处理接收到的消息 Args: @@ -161,6 +171,7 @@ class WecomAIBotAdapter(Platform): Returns: 加密后的响应消息,无需响应时返回 None + """ msgtype = message_data.get("msgtype") if not msgtype: @@ -173,15 +184,22 @@ class WecomAIBotAdapter(Platform): # create a brand-new unique stream_id for this message session stream_id = f"{session_id}_{generate_random_string(10)}" await self._enqueue_message( - message_data, callback_params, stream_id, session_id + message_data, + callback_params, + stream_id, + session_id, ) wecomai_queue_mgr.set_pending_response(stream_id, callback_params) resp = WecomAIBotStreamMessageBuilder.make_text_stream( - stream_id, self.initial_respond_text, False + stream_id, + self.initial_respond_text, + False, ) return await self.api_client.encrypt_message( - resp, callback_params["nonce"], callback_params["timestamp"] + resp, + callback_params["nonce"], + callback_params["timestamp"], ) except Exception as e: logger.error("处理消息时发生异常: %s", e) @@ -194,7 +212,9 @@ class WecomAIBotAdapter(Platform): # 返回结束标志,告诉微信服务器流已结束 end_message = WecomAIBotStreamMessageBuilder.make_text_stream( - stream_id, "", True + stream_id, + "", + True, ) resp = await self.api_client.encrypt_message( end_message, @@ -205,7 +225,7 @@ class WecomAIBotAdapter(Platform): queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) if queue.empty(): logger.debug( - f"No new messages in back queue for stream_id: {stream_id}" + f"No new messages in back queue for stream_id: {stream_id}", ) return None @@ -227,7 +247,7 @@ class WecomAIBotAdapter(Platform): else: pass logger.debug( - f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}" + f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}", ) if latest_plain_content or image_base64: msg_items = [] @@ -240,12 +260,15 @@ class WecomAIBotAdapter(Platform): { "msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE, "image": {"base64": img_b64, "md5": img_md5}, - } + }, ) image_base64 = [] plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream( - stream_id, latest_plain_content, msg_items, finish + stream_id, + latest_plain_content, + msg_items, + finish, ) encrypted_message = await self.api_client.encrypt_message( plain_message, @@ -254,7 +277,7 @@ class WecomAIBotAdapter(Platform): ) if encrypted_message: logger.debug( - f"Stream message sent successfully, stream_id: {stream_id}" + f"Stream message sent successfully, stream_id: {stream_id}", ) else: logger.error("消息加密失败") @@ -266,7 +289,7 @@ class WecomAIBotAdapter(Platform): # 用户进入会话,发送欢迎消息 try: resp = WecomAIBotStreamMessageBuilder.make_text( - self.friend_message_welcome_text + self.friend_message_welcome_text, ) return await self.api_client.encrypt_message( resp, @@ -276,17 +299,16 @@ class WecomAIBotAdapter(Platform): except Exception as e: logger.error("处理欢迎消息时发生异常: %s", e) return None - pass - def _extract_session_id(self, message_data: Dict[str, Any]) -> str: + def _extract_session_id(self, message_data: dict[str, Any]) -> str: """从消息数据中提取会话ID""" user_id = message_data.get("from", {}).get("userid", "default_user") return format_session_id("wecomai", user_id) async def _enqueue_message( self, - message_data: Dict[str, Any], - callback_params: Dict[str, str], + message_data: dict[str, Any], + callback_params: dict[str, str], stream_id: str, session_id: str, ): @@ -320,7 +342,7 @@ class WecomAIBotAdapter(Platform): content = WecomAIBotMessageParser.parse_text_message(message_data) elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE: _img_url_to_process.append( - WecomAIBotMessageParser.parse_image_message(message_data) + WecomAIBotMessageParser.parse_image_message(message_data), ) elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED: # 提取混合消息中的文本内容 @@ -390,7 +412,9 @@ class WecomAIBotAdapter(Platform): return abm async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): """通过会话发送消息""" # 企业微信智能机器人主要通过回调响应,这里记录日志 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 540bf06b6..6c448a97e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -1,19 +1,20 @@ -""" -企业微信智能机器人 API 客户端 +"""企业微信智能机器人 API 客户端 处理消息加密解密、API 调用等 """ -import json import base64 import hashlib -from typing import Dict, Any, Optional, Tuple, Union -from Crypto.Cipher import AES -import aiohttp +import json +from typing import Any + +import aiohttp +from Crypto.Cipher import AES -from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt -from .wecomai_utils import WecomAIBotConstants from astrbot import logger +from .wecomai_utils import WecomAIBotConstants +from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt + class WecomAIBotAPIClient: """企业微信智能机器人 API 客户端""" @@ -24,14 +25,19 @@ class WecomAIBotAPIClient: Args: token: 企业微信机器人 Token encoding_aes_key: 消息加密密钥 + """ self.token = token self.encoding_aes_key = encoding_aes_key self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串 async def decrypt_message( - self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str - ) -> Tuple[int, Optional[Dict[str, Any]]]: + self, + encrypted_data: bytes, + msg_signature: str, + timestamp: str, + nonce: str, + ) -> tuple[int, dict[str, Any] | None]: """解密企业微信消息 Args: @@ -42,10 +48,14 @@ class WecomAIBotAPIClient: Returns: (错误码, 解密后的消息数据字典) + """ try: ret, decrypted_msg = self.wxcpt.DecryptMsg( - encrypted_data, msg_signature, timestamp, nonce + encrypted_data, + msg_signature, + timestamp, + nonce, ) if ret != WecomAIBotConstants.SUCCESS: @@ -70,8 +80,11 @@ class WecomAIBotAPIClient: return WecomAIBotConstants.DECRYPT_ERROR, None async def encrypt_message( - self, plain_message: str, nonce: str, timestamp: str - ) -> Optional[str]: + self, + plain_message: str, + nonce: str, + timestamp: str, + ) -> str | None: """加密消息 Args: @@ -81,6 +94,7 @@ class WecomAIBotAPIClient: Returns: 加密后的消息,失败时返回 None + """ try: ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp) @@ -97,7 +111,11 @@ class WecomAIBotAPIClient: return None def verify_url( - self, msg_signature: str, timestamp: str, nonce: str, echostr: str + self, + msg_signature: str, + timestamp: str, + nonce: str, + echostr: str, ) -> str: """验证回调 URL @@ -109,10 +127,14 @@ class WecomAIBotAPIClient: Returns: 验证结果字符串 + """ try: ret, echo_result = self.wxcpt.VerifyURL( - msg_signature, timestamp, nonce, echostr + msg_signature, + timestamp, + nonce, + echostr, ) if ret != WecomAIBotConstants.SUCCESS: @@ -127,8 +149,10 @@ class WecomAIBotAPIClient: return "verify fail" async def process_encrypted_image( - self, image_url: str, aes_key_base64: Optional[str] = None - ) -> Tuple[bool, Union[bytes, str]]: + self, + image_url: str, + aes_key_base64: str | None = None, + ) -> tuple[bool, bytes | str]: """下载并解密加密图片 Args: @@ -137,6 +161,7 @@ class WecomAIBotAPIClient: Returns: (是否成功, 图片数据或错误信息) + """ try: # 下载图片 @@ -161,7 +186,7 @@ class WecomAIBotAPIClient: # Base64 解码密钥 aes_key = base64.b64decode( - aes_key_base64 + "=" * (-len(aes_key_base64) % 4) + aes_key_base64 + "=" * (-len(aes_key_base64) % 4), ) if len(aes_key) != 32: raise ValueError("无效的 AES 密钥长度: 应为 32 字节") @@ -183,17 +208,17 @@ class WecomAIBotAPIClient: return True, decrypted_data except aiohttp.ClientError as e: - error_msg = f"图片下载失败: {str(e)}" + error_msg = f"图片下载失败: {e!s}" logger.error(error_msg) return False, error_msg except ValueError as e: - error_msg = f"参数错误: {str(e)}" + error_msg = f"参数错误: {e!s}" logger.error(error_msg) return False, error_msg except Exception as e: - error_msg = f"图片处理异常: {str(e)}" + error_msg = f"图片处理异常: {e!s}" logger.error(error_msg) return False, error_msg @@ -212,6 +237,7 @@ class WecomAIBotStreamMessageBuilder: Returns: JSON 格式的流消息字符串 + """ plain = { "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, @@ -221,7 +247,9 @@ class WecomAIBotStreamMessageBuilder: @staticmethod def make_image_stream( - stream_id: str, image_data: bytes, finish: bool = False + stream_id: str, + image_data: bytes, + finish: bool = False, ) -> str: """构建图片流消息 @@ -232,6 +260,7 @@ class WecomAIBotStreamMessageBuilder: Returns: JSON 格式的流消息字符串 + """ image_md5 = hashlib.md5(image_data).hexdigest() image_base64 = base64.b64encode(image_data).decode("utf-8") @@ -245,7 +274,7 @@ class WecomAIBotStreamMessageBuilder: { "msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE, "image": {"base64": image_base64, "md5": image_md5}, - } + }, ], }, } @@ -253,7 +282,10 @@ class WecomAIBotStreamMessageBuilder: @staticmethod def make_mixed_stream( - stream_id: str, content: str, msg_items: list, finish: bool = False + stream_id: str, + content: str, + msg_items: list, + finish: bool = False, ) -> str: """构建混合类型流消息 @@ -265,6 +297,7 @@ class WecomAIBotStreamMessageBuilder: Returns: JSON 格式的流消息字符串 + """ plain = { "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, @@ -283,6 +316,7 @@ class WecomAIBotStreamMessageBuilder: Returns: JSON 格式的文本消息字符串 + """ plain = {"msgtype": "text", "text": {"content": content}} return json.dumps(plain, ensure_ascii=False) @@ -292,7 +326,7 @@ class WecomAIBotMessageParser: """企业微信智能机器人消息解析器""" @staticmethod - def parse_text_message(data: Dict[str, Any]) -> Optional[str]: + def parse_text_message(data: dict[str, Any]) -> str | None: """解析文本消息 Args: @@ -300,6 +334,7 @@ class WecomAIBotMessageParser: Returns: 文本内容,解析失败返回 None + """ try: return data.get("text", {}).get("content") @@ -308,7 +343,7 @@ class WecomAIBotMessageParser: return None @staticmethod - def parse_image_message(data: Dict[str, Any]) -> Optional[str]: + def parse_image_message(data: dict[str, Any]) -> str | None: """解析图片消息 Args: @@ -316,6 +351,7 @@ class WecomAIBotMessageParser: Returns: 图片 URL,解析失败返回 None + """ try: return data.get("image", {}).get("url") @@ -324,7 +360,7 @@ class WecomAIBotMessageParser: return None @staticmethod - def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def parse_stream_message(data: dict[str, Any]) -> dict[str, Any] | None: """解析流消息 Args: @@ -332,6 +368,7 @@ class WecomAIBotMessageParser: Returns: 流消息数据,解析失败返回 None + """ try: stream_data = data.get("stream", {}) @@ -346,7 +383,7 @@ class WecomAIBotMessageParser: return None @staticmethod - def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: + def parse_mixed_message(data: dict[str, Any]) -> list | None: """解析混合消息 Args: @@ -354,6 +391,7 @@ class WecomAIBotMessageParser: Returns: 消息项列表,解析失败返回 None + """ try: return data.get("mixed", {}).get("msg_item", []) @@ -362,7 +400,7 @@ class WecomAIBotMessageParser: return None @staticmethod - def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def parse_event_message(data: dict[str, Any]) -> dict[str, Any] | None: """解析事件消息 Args: @@ -370,6 +408,7 @@ class WecomAIBotMessageParser: Returns: 事件数据,解析失败返回 None + """ try: return data.get("event", {}) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 2d7ec91ca..130182b48 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -1,13 +1,11 @@ -""" -企业微信智能机器人事件处理模块,处理消息事件的发送和接收 -""" +"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收""" +from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( Image, Plain, ) -from astrbot.api import logger from .wecomai_api import WecomAIBotAPIClient from .wecomai_queue_mgr import wecomai_queue_mgr @@ -32,6 +30,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): platform_meta: 平台元数据 session_id: 会话 ID api_client: API 客户端 + """ super().__init__(message_str, message_obj, platform_meta, session_id) self.api_client = api_client @@ -50,7 +49,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "type": "end", "data": "", "streaming": False, - } + }, ) return "" @@ -64,7 +63,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "data": data, "streaming": streaming, "session_id": stream_id, - } + }, ) elif isinstance(comp, Image): # 处理图片消息 @@ -77,7 +76,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "image_data": image_base64, "streaming": streaming, "session_id": stream_id, - } + }, ) else: logger.warning("图片数据为空,跳过") @@ -127,7 +126,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "data": final_data, "streaming": True, "session_id": self.session_id, - } + }, ) final_data = "" continue @@ -144,6 +143,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent): "data": final_data, "streaming": True, "session_id": self.session_id, - } + }, ) await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 1367301c9..eb3455292 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -1,11 +1,11 @@ -""" -企业微信智能机器人队列管理器 +"""企业微信智能机器人队列管理器 参考 webchat_queue_mgr.py,为企业微信智能机器人实现队列机制 支持异步消息处理和流式响应 """ import asyncio -from typing import Dict, Any, Optional +from typing import Any + from astrbot.api import logger @@ -13,13 +13,13 @@ class WecomAIQueueMgr: """企业微信智能机器人队列管理器""" def __init__(self) -> None: - self.queues: Dict[str, asyncio.Queue] = {} + self.queues: dict[str, asyncio.Queue] = {} """StreamID 到输入队列的映射 - 用于接收用户消息""" - self.back_queues: Dict[str, asyncio.Queue] = {} + self.back_queues: dict[str, asyncio.Queue] = {} """StreamID 到输出队列的映射 - 用于发送机器人响应""" - self.pending_responses: Dict[str, Dict[str, Any]] = {} + self.pending_responses: dict[str, dict[str, Any]] = {} """待处理的响应缓存,用于流式响应""" def get_or_create_queue(self, session_id: str) -> asyncio.Queue: @@ -30,6 +30,7 @@ class WecomAIQueueMgr: Returns: 输入队列实例 + """ if session_id not in self.queues: self.queues[session_id] = asyncio.Queue() @@ -44,6 +45,7 @@ class WecomAIQueueMgr: Returns: 输出队列实例 + """ if session_id not in self.back_queues: self.back_queues[session_id] = asyncio.Queue() @@ -55,6 +57,7 @@ class WecomAIQueueMgr: Args: session_id: 会话ID + """ if session_id in self.queues: del self.queues[session_id] @@ -76,6 +79,7 @@ class WecomAIQueueMgr: Returns: 是否存在队列 + """ return session_id in self.queues @@ -87,15 +91,17 @@ class WecomAIQueueMgr: Returns: 是否存在输出队列 + """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: Dict[str, str]): + def set_pending_response(self, session_id: str, callback_params: dict[str, str]): """设置待处理的响应参数 Args: session_id: 会话ID callback_params: 回调参数(nonce, timestamp等) + """ self.pending_responses[session_id] = { "callback_params": callback_params, @@ -103,7 +109,7 @@ class WecomAIQueueMgr: } logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") - def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]: + def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """获取待处理的响应参数 Args: @@ -111,6 +117,7 @@ class WecomAIQueueMgr: Returns: 响应参数,如果不存在则返回None + """ return self.pending_responses.get(session_id) @@ -119,6 +126,7 @@ class WecomAIQueueMgr: Args: max_age_seconds: 最大存活时间(秒) + """ current_time = asyncio.get_event_loop().time() expired_sessions = [] @@ -131,11 +139,12 @@ class WecomAIQueueMgr: del self.pending_responses[session_id] logger.debug(f"[WecomAI] 清理过期响应: {session_id}") - def get_stats(self) -> Dict[str, int]: + def get_stats(self) -> dict[str, int]: """获取队列统计信息 Returns: 统计信息字典 + """ return { "input_queues": len(self.queues), diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index bbb69d041..35acd9066 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -1,12 +1,13 @@ -""" -企业微信智能机器人 HTTP 服务器 +"""企业微信智能机器人 HTTP 服务器 处理企业微信智能机器人的 HTTP 回调请求 """ import asyncio -from typing import Dict, Any, Optional, Callable +from collections.abc import Callable +from typing import Any import quart + from astrbot.api import logger from .wecomai_api import WecomAIBotAPIClient @@ -21,9 +22,7 @@ class WecomAIBotServer: host: str, port: int, api_client: WecomAIBotAPIClient, - message_handler: Optional[ - Callable[[Dict[str, Any], Dict[str, str]], Any] - ] = None, + message_handler: Callable[[dict[str, Any], dict[str, str]], Any] | None = None, ): """初始化服务器 @@ -32,6 +31,7 @@ class WecomAIBotServer: port: 监听端口 api_client: API客户端实例 message_handler: 消息处理回调函数 + """ self.host = host self.port = port @@ -45,7 +45,6 @@ class WecomAIBotServer: def _setup_routes(self): """设置 Quart 路由""" - # 使用 Quart 的 add_url_rule 方法添加路由 self.app.add_url_rule( "/webhook/wecom-ai-bot", @@ -98,7 +97,7 @@ class WecomAIBotServer: assert nonce is not None logger.debug( - f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}" + f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}", ) try: @@ -111,7 +110,10 @@ class WecomAIBotServer: # 解密消息 ret_code, message_data = await self.api_client.decrypt_message( - post_data, msg_signature, timestamp, nonce + post_data, + msg_signature, + timestamp, + nonce, ) if ret_code != WecomAIBotConstants.SUCCESS or not message_data: @@ -123,7 +125,8 @@ class WecomAIBotServer: if self.message_handler: try: response = await self.message_handler( - message_data, {"nonce": nonce, "timestamp": timestamp} + message_data, + {"nonce": nonce, "timestamp": timestamp}, ) except Exception as e: logger.error("消息处理器执行异常: %s", e) @@ -131,8 +134,7 @@ class WecomAIBotServer: if response: return response, 200, {"Content-Type": "text/plain"} - else: - return "success", 200, {"Content-Type": "text/plain"} + return "success", 200, {"Content-Type": "text/plain"} except Exception as e: logger.error("处理消息时发生异常: %s", e) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index dccb2e260..f0285f998 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -1,16 +1,17 @@ -""" -企业微信智能机器人工具模块 +"""企业微信智能机器人工具模块 提供常量定义、工具函数和辅助方法 """ -import string -import random -import hashlib -import base64 -import aiohttp import asyncio +import base64 +import hashlib +import random +import string +from typing import Any + +import aiohttp from Crypto.Cipher import AES -from typing import Any, Tuple + from astrbot.api import logger @@ -49,6 +50,7 @@ def generate_random_string(length: int = 10) -> str: Returns: 随机字符串 + """ letters = string.ascii_letters + string.digits return "".join(random.choice(letters) for _ in range(length)) @@ -62,6 +64,7 @@ def calculate_image_md5(image_data: bytes) -> str: Returns: MD5 哈希值(十六进制字符串) + """ return hashlib.md5(image_data).hexdigest() @@ -74,6 +77,7 @@ def encode_image_base64(image_data: bytes) -> str: Returns: Base64 编码的字符串 + """ return base64.b64encode(image_data).decode("utf-8") @@ -87,11 +91,12 @@ def format_session_id(session_type: str, session_id: str) -> str: Returns: 格式化后的会话 ID + """ return f"wecom_ai_bot_{session_type}_{session_id}" -def parse_session_id(formatted_session_id: str) -> Tuple[str, str]: +def parse_session_id(formatted_session_id: str) -> tuple[str, str]: """解析格式化的会话 ID Args: @@ -99,6 +104,7 @@ def parse_session_id(formatted_session_id: str) -> Tuple[str, str]: Returns: (会话类型, 原始会话ID) + """ parts = formatted_session_id.split("_", 3) if ( @@ -120,6 +126,7 @@ def safe_json_loads(json_str: str, default: Any = None) -> Any: Returns: 解析结果或默认值 + """ import json @@ -139,13 +146,15 @@ def format_error_response(error_code: int, error_msg: str) -> str: Returns: 格式化的错误响应字符串 + """ return f"Error {error_code}: {error_msg}" async def process_encrypted_image( - image_url: str, aes_key_base64: str -) -> Tuple[bool, str]: + image_url: str, + aes_key_base64: str, +) -> tuple[bool, str]: """下载并解密加密图片 Args: @@ -155,6 +164,7 @@ async def process_encrypted_image( Returns: Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码, status 为 False 时 data 是错误信息 + """ # 1. 下载加密图片 logger.info("开始下载加密图片: %s", image_url) @@ -165,7 +175,7 @@ async def process_encrypted_image( encrypted_data = await response.read() logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) except (aiohttp.ClientError, asyncio.TimeoutError) as e: - error_msg = f"下载图片失败: {str(e)}" + error_msg = f"下载图片失败: {e!s}" logger.error(error_msg) return False, error_msg diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index c67c2037b..f44b06e90 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -1,28 +1,28 @@ +import asyncio import sys import uuid -import asyncio -import quart +import quart +from requests import Response +from wechatpy import WeChatClient, parse_message +from wechatpy.crypto import WeChatCrypto +from wechatpy.exceptions import InvalidSignatureException +from wechatpy.messages import BaseMessage, ImageMessage, TextMessage, VoiceMessage +from wechatpy.utils import check_signature + +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Image, Plain, Record from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, - PlatformMetadata, MessageType, + Platform, + PlatformMetadata, + register_platform_adapter, ) -from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image, Record -from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.api.platform import register_platform_adapter from astrbot.core import logger -from requests import Response +from astrbot.core.platform.astr_message_event import MessageSesion -from wechatpy.utils import check_signature -from wechatpy.crypto import WeChatCrypto -from wechatpy import WeChatClient -from wechatpy.messages import TextMessage, ImageMessage, VoiceMessage, BaseMessage -from wechatpy.exceptions import InvalidSignatureException -from wechatpy import parse_message from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent if sys.version_info >= (3, 12): @@ -40,10 +40,14 @@ class WecomServer: self.encoding_aes_key = config.get("encoding_aes_key") self.appid = config.get("appid") self.server.add_url_rule( - "/callback/command", view_func=self.verify, methods=["GET"] + "/callback/command", + view_func=self.verify, + methods=["GET"], ) self.server.add_url_rule( - "/callback/command", view_func=self.callback_command, methods=["POST"] + "/callback/command", + view_func=self.callback_command, + methods=["POST"], ) self.crypto = WeChatCrypto(self.token, self.encoding_aes_key, self.appid) @@ -97,7 +101,7 @@ class WecomServer: async def start_polling(self): logger.info( - f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。" + f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。", ) await self.server.run_task( host=self.callback_server_host, @@ -112,22 +116,25 @@ class WecomServer: @register_platform_adapter("weixin_official_account", "微信公众平台 适配器") class WeixinOfficialAccountPlatformAdapter(Platform): def __init__( - self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, ) -> None: super().__init__(event_queue) self.config = platform_config self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] self.api_base_url = platform_config.get( - "api_base_url", "https://api.weixin.qq.com/cgi-bin/" + "api_base_url", + "https://api.weixin.qq.com/cgi-bin/", ) self.active_send_mode = self.config.get("active_send_mode", False) if not self.api_base_url: self.api_base_url = "https://api.weixin.qq.com/cgi-bin/" - if self.api_base_url.endswith("/"): - self.api_base_url = self.api_base_url[:-1] + self.api_base_url = self.api_base_url.removesuffix("/") if not self.api_base_url.endswith("/cgi-bin"): self.api_base_url += "/cgi-bin" @@ -161,7 +168,8 @@ class WeixinOfficialAccountPlatformAdapter(Platform): await self.convert_message(msg, future) # I love shield so much! result = await asyncio.wait_for( - asyncio.shield(future), 60 + asyncio.shield(future), + 60, ) # wait for 60s logger.debug(f"Got future result: {result}") self.wexin_event_workers.pop(msg.id, None) @@ -175,7 +183,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform): @override async def send_by_session( - self, session: MessageSesion, message_chain: MessageChain + self, + session: MessageSesion, + message_chain: MessageChain, ): await super().send_by_session(session, message_chain) @@ -192,7 +202,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform): await self.server.start_polling() async def convert_message( - self, msg, future: asyncio.Future = None + self, + msg, + future: asyncio.Future = None, ) -> AstrBotMessage | None: abm = AstrBotMessage() if isinstance(msg, TextMessage): @@ -224,7 +236,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform): assert isinstance(msg, VoiceMessage) resp: Response = await asyncio.get_event_loop().run_in_executor( - None, self.client.media.download, msg.media_id + None, + self.client.media.download, + msg.media_id, ) path = f"data/temp/wecom_{msg.media_id}.amr" with open(path, "wb") as f: @@ -238,7 +252,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): audio.export(path_wav, format="wav") except Exception as e: logger.error( - f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。" + f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。", ) path_wav = path return diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index 4077cc1ab..d138fc80c 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -1,21 +1,20 @@ -import uuid import asyncio -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image, Record -from wechatpy import WeChatClient -from wechatpy.replies import TextReply, ImageReply, VoiceReply +import uuid +from wechatpy import WeChatClient +from wechatpy.replies import ImageReply, TextReply, VoiceReply from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import Image, Plain, Record +from astrbot.api.platform import AstrBotMessage, PlatformMetadata try: import pydub except Exception: logger.warning( - "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。" + "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。", ) - pass class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): @@ -32,7 +31,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): @staticmethod async def send_with_client( - client: WeChatClient, message: MessageChain, user_name: str + client: WeChatClient, + message: MessageChain, + user_name: str, ): pass @@ -43,44 +44,44 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): plain (str): 要分割的长文本 Returns: list[str]: 分割后的文本列表 + """ if len(plain) <= 2048: return [plain] - else: - result = [] - start = 0 - while start < len(plain): - # 剩下的字符串长度<2048时结束 - if start + 2048 >= len(plain): - result.append(plain[start:]) + result = [] + start = 0 + while start < len(plain): + # 剩下的字符串长度<2048时结束 + if start + 2048 >= len(plain): + result.append(plain[start:]) + break + + # 向前搜索分割标点符号 + end = min(start + 2048, len(plain)) + cut_position = end + for i in range(end, start, -1): + if i < len(plain) and plain[i - 1] in [ + "。", + "!", + "?", + ".", + "!", + "?", + "\n", + ";", + ";", + ]: + cut_position = i break - # 向前搜索分割标点符号 - end = min(start + 2048, len(plain)) + # 没找到合适的位置分割, 直接切分 + if cut_position == end and end < len(plain): cut_position = end - for i in range(end, start, -1): - if i < len(plain) and plain[i - 1] in [ - "。", - "!", - "?", - ".", - "!", - "?", - "\n", - ";", - ";", - ]: - cut_position = i - break - # 没找到合适的位置分割, 直接切分 - if cut_position == end and end < len(plain): - cut_position = end + result.append(plain[start:cut_position]) + start = cut_position - result.append(plain[start:cut_position]) - start = cut_position - - return result + return result async def send(self, message: MessageChain): message_obj = self.message_obj @@ -111,7 +112,7 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): except Exception as e: logger.error(f"微信公众平台上传图片失败: {e}") await self.send( - MessageChain().message(f"微信公众平台上传图片失败: {e}") + MessageChain().message(f"微信公众平台上传图片失败: {e}"), ) return logger.debug(f"微信公众平台上传图片返回: {response}") @@ -136,7 +137,8 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): # 转成amr record_path_amr = f"data/temp/{uuid.uuid4()}.amr" pydub.AudioSegment.from_wav(record_path).export( - record_path_amr, format="amr" + record_path_amr, + format="amr", ) with open(record_path_amr, "rb") as f: @@ -145,7 +147,7 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): except Exception as e: logger.error(f"微信公众平台上传语音失败: {e}") await self.send( - MessageChain().message(f"微信公众平台上传语音失败: {e}") + MessageChain().message(f"微信公众平台上传语音失败: {e}"), ) return logger.info(f"微信公众平台上传语音返回: {response}") @@ -178,7 +180,7 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): else: buffer.chain.extend(chain.chain) if not buffer: - return + return None buffer.squash_plain() await self.send(buffer) return await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index 16e59a5cc..fa9a9733c 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -43,5 +43,7 @@ class PlatformMessageHistoryManager: async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): """Delete platform message history records older than the specified offset.""" await self.db.delete_platform_message_offset( - platform_id=platform_id, user_id=user_id, offset_sec=offset_sec + platform_id=platform_id, + user_id=user_id, + offset_sec=offset_sec, ) diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index ed7135fe6..abbe08234 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1,5 +1,4 @@ -from .provider import Provider, Personality, STTProvider - from .entities import ProviderMetaData +from .provider import Personality, Provider, STTProvider -__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"] +__all__ = ["Personality", "Provider", "ProviderMetaData", "STTProvider"] diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index dbbbca923..af97c4ab6 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -1,19 +1,19 @@ from astrbot.core.provider.entities import ( + AssistantMessageSegment, + LLMResponse, + ProviderMetaData, ProviderRequest, ProviderType, - ProviderMetaData, - ToolCallsResult, - AssistantMessageSegment, ToolCallMessageSegment, - LLMResponse, + ToolCallsResult, ) __all__ = [ + "AssistantMessageSegment", + "LLMResponse", + "ProviderMetaData", "ProviderRequest", "ProviderType", - "ProviderMetaData", - "ToolCallsResult", - "AssistantMessageSegment", "ToolCallMessageSegment", - "LLMResponse", + "ToolCallsResult", ] diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 85687c417..28dc63f72 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,20 +1,22 @@ -import enum import base64 +import enum import json -from astrbot.core.utils.io import download_image_by_url -from astrbot import logger from dataclasses import dataclass, field -from typing import List, Dict, Type, Any -from astrbot.core.agent.tool import ToolSet -from openai.types.chat.chat_completion import ChatCompletion -from google.genai.types import GenerateContentResponse +from typing import Any + from anthropic.types import Message +from google.genai.types import GenerateContentResponse +from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, ) + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain -import astrbot.core.message.components as Comp +from astrbot.core.utils.io import download_image_by_url class ProviderType(enum.Enum): @@ -32,7 +34,7 @@ class ProviderMetaData: desc: str = "" """提供商适配器描述.""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: Type | None = None + cls_type: type | None = None default_config_tmpl: dict | None = None """平台的默认配置模板""" @@ -61,7 +63,7 @@ class AssistantMessageSegment: """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" content: str | None = None - tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list) + tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list) role: str = "assistant" def to_dict(self): @@ -84,10 +86,10 @@ class ToolCallsResult: tool_calls_info: AssistantMessageSegment """函数调用的信息""" - tool_calls_result: List[ToolCallMessageSegment] + tool_calls_result: list[ToolCallMessageSegment] """函数调用的结果""" - def to_openai_messages(self) -> List[Dict]: + def to_openai_messages(self) -> list[dict]: ret = [ self.tool_calls_info.to_dict(), *[item.to_dict() for item in self.tool_calls_result], @@ -175,13 +177,13 @@ class ProviderRequest: return result_parts - async def assemble_context(self) -> Dict: + async def assemble_context(self) -> dict: """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" if self.image_urls: user_content = { "role": "user", "content": [ - {"type": "text", "text": self.prompt if self.prompt else "[图片]"} + {"type": "text", "text": self.prompt if self.prompt else "[图片]"}, ], } for image_url in self.image_urls: @@ -197,11 +199,10 @@ class ProviderRequest: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue user_content["content"].append( - {"type": "image_url", "image_url": {"url": image_data}} + {"type": "image_url", "image_url": {"url": image_data}}, ) return user_content - else: - return {"role": "user", "content": self.prompt} + return {"role": "user", "content": self.prompt} async def _encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" @@ -219,15 +220,15 @@ class LLMResponse: """角色, assistant, tool, err""" result_chain: MessageChain | None = None """返回的消息链""" - tools_call_args: List[Dict[str, Any]] = field(default_factory=list) + tools_call_args: list[dict[str, Any]] = field(default_factory=list) """工具调用参数""" - tools_call_name: List[str] = field(default_factory=list) + tools_call_name: list[str] = field(default_factory=list) """工具调用名称""" - tools_call_ids: List[str] = field(default_factory=list) + tools_call_ids: list[str] = field(default_factory=list) """工具调用 ID""" raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None - _new_record: Dict[str, Any] | None = None + _new_record: dict[str, Any] | None = None _completion_text: str = "" @@ -239,11 +240,11 @@ class LLMResponse: role: str, completion_text: str = "", result_chain: MessageChain | None = None, - tools_call_args: List[Dict[str, Any]] | None = None, - tools_call_name: List[str] | None = None, - tools_call_ids: List[str] | None = None, + tools_call_args: list[dict[str, Any]] | None = None, + tools_call_name: list[str] | None = None, + tools_call_ids: list[str] | None = None, raw_completion: ChatCompletion | None = None, - _new_record: Dict[str, Any] | None = None, + _new_record: dict[str, Any] | None = None, is_chunk: bool = False, ): """初始化 LLMResponse @@ -255,6 +256,7 @@ class LLMResponse: tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None. tools_call_name (List[str], optional): 工具调用名称. Defaults to None. raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. + """ if tools_call_args is None: tools_call_args = [] @@ -291,7 +293,7 @@ class LLMResponse: else: self._completion_text = value - def to_openai_tool_calls(self) -> List[Dict]: + def to_openai_tool_calls(self) -> list[dict]: """将工具调用信息转换为 OpenAI 格式""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): @@ -303,7 +305,7 @@ class LLMResponse: "arguments": json.dumps(tool_call_arg), }, "type": "function", - } + }, ) return ret diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 51cde0eb9..b3ef1ed5c 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,17 +1,18 @@ from __future__ import annotations + +import asyncio import json import os -import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + import aiohttp -from typing import Dict, List, Awaitable, Callable, Any from astrbot import logger from astrbot.core import sp - -from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.agent.mcp_client import MCPClient -from astrbot.core.agent.tool import ToolSet, FunctionTool - +from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.utils.astrbot_path import get_astrbot_data_path DEFAULT_MCP_CONFIG = {"mcpServers": {}} @@ -30,7 +31,7 @@ FuncTool = FunctionTool def _prepare_config(config: dict) -> dict: """准备配置,处理嵌套格式""" - if "mcpServers" in config and config["mcpServers"]: + if config.get("mcpServers"): first_key = next(iter(config["mcpServers"])) config = config["mcpServers"][first_key] config.pop("active", None) @@ -72,8 +73,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: ) as response: if response.status == 200: return True, "" - else: - return False, f"HTTP {response.status}: {response.reason}" + return False, f"HTTP {response.status}: {response.reason}" else: async with session.get( url, @@ -85,8 +85,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: ) as response: if response.status == 200: return True, "" - else: - return False, f"HTTP {response.status}: {response.reason}" + return False, f"HTTP {response.status}: {response.reason}" except asyncio.TimeoutError: return False, f"连接超时: {timeout}秒" @@ -96,10 +95,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class FunctionToolManager: def __init__(self) -> None: - self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str, MCPClient] = {} + self.func_list: list[FuncTool] = [] + self.mcp_client_dict: dict[str, MCPClient] = {} """MCP 服务列表""" - self.mcp_client_event: Dict[str, asyncio.Event] = {} + self.mcp_client_event: dict[str, asyncio.Event] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -150,14 +149,12 @@ class FunctionToolManager: func_args=func_args, desc=desc, handler=handler, - ) + ), ) logger.info(f"添加函数调用工具: {name}") def remove_func(self, name: str) -> None: - """ - 删除一个函数调用工具。 - """ + """删除一个函数调用工具。""" for i, f in enumerate(self.func_list): if f.name == name: self.func_list.pop(i) @@ -202,16 +199,16 @@ class FunctionToolManager: logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return - mcp_server_json_obj: Dict[str, Dict] = json.load( - open(mcp_json_file, "r", encoding="utf-8") + mcp_server_json_obj: dict[str, dict] = json.load( + open(mcp_json_file, encoding="utf-8"), )["mcpServers"] - for name in mcp_server_json_obj.keys(): + for name in mcp_server_json_obj: cfg = mcp_server_json_obj[name] if cfg.get("active", True): event = asyncio.Event() asyncio.create_task( - self._init_mcp_client_task_wrapper(name, cfg, event) + self._init_mcp_client_task_wrapper(name, cfg, event), ) self.mcp_client_event[name] = event @@ -325,9 +322,11 @@ class FunctionToolManager: event (asyncio.Event): Event to signal when the MCP client is ready. ready_future (asyncio.Future): Future to signal when the MCP client is ready. timeout (int): Timeout for the initialization. + Raises: TimeoutError: If the initialization does not complete within the specified timeout. Exception: If there is an error during initialization. + """ if not event: event = asyncio.Event() @@ -336,7 +335,7 @@ class FunctionToolManager: if name in self.mcp_client_dict: return asyncio.create_task( - self._init_mcp_client_task_wrapper(name, config, event, ready_future) + self._init_mcp_client_task_wrapper(name, config, event, ready_future), ) try: await asyncio.wait_for(ready_future, timeout=timeout) @@ -349,13 +348,16 @@ class FunctionToolManager: raise exc async def disable_mcp_server( - self, name: str | None = None, timeout: float = 10 + self, + name: str | None = None, + timeout: float = 10, ) -> None: """Disable an MCP server by its name. Args: name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. timeout (int): Timeout. + """ if name: if name not in self.mcp_client_event: @@ -389,27 +391,21 @@ class FunctionToolManager: self.func_list = [f for f in self.func_list if f.origin != "mcp"] def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: - """ - 获得 OpenAI API 风格的**已经激活**的工具描述 - """ + """获得 OpenAI API 风格的**已经激活**的工具描述""" tools = [f for f in self.func_list if f.active] toolset = ToolSet(tools) return toolset.openai_schema( - omit_empty_parameter_field=omit_empty_parameter_field + omit_empty_parameter_field=omit_empty_parameter_field, ) def get_func_desc_anthropic_style(self) -> list: - """ - 获得 Anthropic API 风格的**已经激活**的工具描述 - """ + """获得 Anthropic API 风格的**已经激活**的工具描述""" tools = [f for f in self.func_list if f.active] toolset = ToolSet(tools) return toolset.anthropic_schema() def get_func_desc_google_genai_style(self) -> dict: - """ - 获得 Google GenAI API 风格的**已经激活**的工具描述 - """ + """获得 Google GenAI API 风格的**已经激活**的工具描述""" tools = [f for f in self.func_list if f.active] toolset = ToolSet(tools) return toolset.google_schema() @@ -418,13 +414,18 @@ class FunctionToolManager: """停用一个已经注册的函数调用工具。 Returns: - 如果没找到,会返回 False""" + 如果没找到,会返回 False + + """ func_tool = self.get_func(name) if func_tool is not None: func_tool.active = False inactivated_llm_tools: list = sp.get( - "inactivated_llm_tools", [], scope="global", scope_id="global" + "inactivated_llm_tools", + [], + scope="global", + scope_id="global", ) if name not in inactivated_llm_tools: inactivated_llm_tools.append(name) @@ -445,13 +446,16 @@ class FunctionToolManager: if func_tool.handler_module_path in star_map: if not star_map[func_tool.handler_module_path].activated: raise ValueError( - f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。" + f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。", ) func_tool.active = True inactivated_llm_tools: list = sp.get( - "inactivated_llm_tools", [], scope="global", scope_id="global" + "inactivated_llm_tools", + [], + scope="global", + scope_id="global", ) if name in inactivated_llm_tools: inactivated_llm_tools.remove(name) @@ -479,7 +483,7 @@ class FunctionToolManager: return DEFAULT_MCP_CONFIG try: - with open(self.mcp_config_path, "r", encoding="utf-8") as f: + with open(self.mcp_config_path, encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"加载 MCP 配置失败: {e}") @@ -509,7 +513,8 @@ class FunctionToolManager: if response.status == 200: data = await response.json() mcp_server_list = data.get("data", {}).get( - "mcp_server_list", [] + "mcp_server_list", + [], ) local_mcp_config = self.load_mcp_config() @@ -541,23 +546,23 @@ class FunctionToolManager: self.enable_mcp_server( name=name, config=local_mcp_config["mcpServers"][name], - ) + ), ) await asyncio.gather(*tasks) logger.info( - f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器" + f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器", ) else: logger.warning("没有找到可用的 ModelScope MCP 服务器") else: raise Exception( - f"ModelScope API 请求失败: HTTP {response.status}" + f"ModelScope API 请求失败: HTTP {response.status}", ) except aiohttp.ClientError as e: - raise Exception(f"网络连接错误: {str(e)}") + raise Exception(f"网络连接错误: {e!s}") except Exception as e: - raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {str(e)}") + raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") def __str__(self): return str(self.func_list) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 5a1f51cef..5fc5a4b5e 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -5,16 +5,16 @@ from astrbot.core import logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase +from ..persona_mgr import PersonaManager from .entities import ProviderType from .provider import ( + EmbeddingProvider, Provider, + RerankProvider, STTProvider, TTSProvider, - EmbeddingProvider, - RerankProvider, ) from .register import llm_tools, provider_cls_map -from ..persona_mgr import PersonaManager class ProviderManager: @@ -76,7 +76,10 @@ class ProviderManager: return self.persona_mgr.selected_default_persona_v3 async def set_provider( - self, provider_id: str, provider_type: ProviderType, umo: str | None = None + self, + provider_id: str, + provider_type: ProviderType, + umo: str | None = None, ): """设置提供商。 @@ -86,6 +89,7 @@ class ProviderManager: umo (str, optional): 用户会话 ID,用于提供商会话隔离。 Version 4.0.0: 这个版本下已经默认隔离提供商 + """ if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") @@ -100,17 +104,20 @@ class ProviderManager: prov = self.inst_map[provider_id] if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance( - prov, TTSProvider + prov, + TTSProvider, ): self.curr_tts_provider_inst = prov sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( - prov, STTProvider + prov, + STTProvider, ): self.curr_stt_provider_inst = prov sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global") elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( - prov, Provider + prov, + Provider, ): self.curr_provider_inst = prov sp.put("curr_provider", provider_id, scope="global", scope_id="global") @@ -120,7 +127,9 @@ class ProviderManager: return self.inst_map.get(provider_id) def get_using_provider( - self, provider_type: ProviderType, umo=None + self, + provider_type: ProviderType, + umo=None, ) -> Provider | STTProvider | TTSProvider | None: """获取正在使用的提供商实例。 @@ -130,6 +139,7 @@ class ProviderManager: Returns: Provider: 正在使用的提供商实例。 + """ provider = None if umo: @@ -219,7 +229,7 @@ class ProviderManager: return logger.info( - f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..." + f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", ) # 动态导入 @@ -321,18 +331,18 @@ class ProviderManager: ) except (ImportError, ModuleNotFoundError) as e: logger.critical( - f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" + f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。", ) return except Exception as e: logger.critical( - f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因" + f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因", ) return if provider_config["type"] not in provider_cls_map: logger.error( - f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。" + f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。", ) return @@ -358,7 +368,7 @@ class ProviderManager: ): self.curr_stt_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。" + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", ) if not self.curr_stt_provider_inst: self.curr_stt_provider_inst = inst @@ -374,7 +384,7 @@ class ProviderManager: if self.provider_settings.get("provider_id") == provider_config["id"]: self.curr_tts_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。" + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", ) if not self.curr_tts_provider_inst: self.curr_tts_provider_inst = inst @@ -397,7 +407,7 @@ class ProviderManager: ): self.curr_provider_inst = inst logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。" + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", ) if not self.curr_provider_inst: self.curr_provider_inst = inst @@ -416,10 +426,10 @@ class ProviderManager: self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error( - f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}" + f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) raise Exception( - f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}" + f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) async def reload(self, provider_config: dict): @@ -439,7 +449,7 @@ class ProviderManager: elif self.curr_provider_inst is None and len(self.provider_insts) > 0: self.curr_provider_inst = self.provider_insts[0] logger.info( - f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。" + f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。", ) if len(self.stt_provider_insts) == 0: @@ -447,7 +457,7 @@ class ProviderManager: elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: self.curr_stt_provider_inst = self.stt_provider_insts[0] logger.info( - f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。" + f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。", ) if len(self.tts_provider_insts) == 0: @@ -455,7 +465,7 @@ class ProviderManager: elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: self.curr_tts_provider_inst = self.tts_provider_insts[0] logger.info( - f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。" + f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", ) def get_insts(self): @@ -464,7 +474,7 @@ class ProviderManager: async def terminate_provider(self, provider_id: str): if provider_id in self.inst_map: logger.info( - f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..." + f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...", ) if self.inst_map[provider_id] in self.provider_insts: @@ -491,7 +501,7 @@ class ProviderManager: await self.inst_map[provider_id].terminate() # type: ignore logger.info( - f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" + f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})", ) del self.inst_map[provider_id] diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 9953e9f17..23abfcfc0 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,17 +1,17 @@ import abc import asyncio -from typing import List -from typing import AsyncGenerator +from collections.abc import AsyncGenerator +from dataclasses import dataclass + from astrbot.core.agent.tool import ToolSet +from astrbot.core.db.po import Personality from astrbot.core.provider.entities import ( LLMResponse, - ToolCallsResult, ProviderType, RerankResult, + ToolCallsResult, ) from astrbot.core.provider.register import provider_cls_map -from astrbot.core.db.po import Personality -from dataclasses import dataclass @dataclass @@ -65,21 +65,21 @@ class Provider(AbstractProvider): @abc.abstractmethod def get_current_key(self) -> str: - raise NotImplementedError() + raise NotImplementedError - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: """获得提供商 Key""" keys = self.provider_config.get("key", [""]) return keys or [""] @abc.abstractmethod def set_key(self, key: str): - raise NotImplementedError() + raise NotImplementedError @abc.abstractmethod - async def get_models(self) -> List[str]: + async def get_models(self) -> list[str]: """获得支持的模型列表""" - raise NotImplementedError() + raise NotImplementedError @abc.abstractmethod async def text_chat( @@ -108,6 +108,7 @@ class Provider(AbstractProvider): Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 + """ ... @@ -137,23 +138,20 @@ class Provider(AbstractProvider): Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 - """ - ... - async def pop_record(self, context: List): - """ - 弹出 context 第一条非系统提示词对话记录 """ + + async def pop_record(self, context: list): + """弹出 context 第一条非系统提示词对话记录""" poped = 0 indexs_to_pop = [] for idx, record in enumerate(context): if record["role"] == "system": continue - else: - indexs_to_pop.append(idx) - poped += 1 - if poped == 2: - break + indexs_to_pop.append(idx) + poped += 1 + if poped == 2: + break for idx in reversed(indexs_to_pop): context.pop(idx) @@ -168,7 +166,7 @@ class STTProvider(AbstractProvider): @abc.abstractmethod async def get_text(self, audio_url: str) -> str: """获取音频的文本""" - raise NotImplementedError() + raise NotImplementedError class TTSProvider(AbstractProvider): @@ -180,7 +178,7 @@ class TTSProvider(AbstractProvider): @abc.abstractmethod async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" - raise NotImplementedError() + raise NotImplementedError class EmbeddingProvider(AbstractProvider): @@ -223,6 +221,7 @@ class EmbeddingProvider(AbstractProvider): Returns: 向量列表 + """ semaphore = asyncio.Semaphore(tasks_limit) all_embeddings: list[list[float]] = [] @@ -246,7 +245,7 @@ class EmbeddingProvider(AbstractProvider): # 最后一次重试失败,记录失败的批次 failed_batches.append((batch_idx, batch_texts)) raise Exception( - f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}" + f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}", ) # 等待一段时间后重试,使用指数退避 await asyncio.sleep(2**attempt) @@ -279,7 +278,10 @@ class RerankProvider(AbstractProvider): @abc.abstractmethod async def rerank( - self, query: str, documents: list[str], top_n: int | None = None + self, + query: str, + documents: list[str], + top_n: int | None = None, ) -> list[RerankResult]: """获取查询和文档的重排序分数""" ... diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 02d7934d1..eb8c72aea 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,11 +1,11 @@ -from typing import List, Dict -from .entities import ProviderMetaData, ProviderType from astrbot.core import logger + +from .entities import ProviderMetaData, ProviderType from .func_tool_manager import FuncCall -provider_registry: List[ProviderMetaData] = [] +provider_registry: list[ProviderMetaData] = [] """维护了通过装饰器注册的 Provider""" -provider_cls_map: Dict[str, ProviderMetaData] = {} +provider_cls_map: dict[str, ProviderMetaData] = {} """维护了 Provider 类型名称和 ProviderMetadata 的映射""" llm_tools = FuncCall() @@ -23,7 +23,7 @@ def register_provider_adapter( def decorator(cls): if provider_type_name in provider_cls_map: raise ValueError( - f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。" + f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。", ) # 添加必备选项 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index cd4206ce7..6f292f076 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,23 +1,24 @@ -import json -import anthropic import base64 -from typing import List +import json +from collections.abc import AsyncGenerator from mimetypes import guess_type +import anthropic from anthropic import AsyncAnthropic from anthropic.types import Message -from astrbot.core.utils.io import download_image_by_url -from astrbot.api.provider import Provider from astrbot import logger -from astrbot.core.provider.func_tool_manager import ToolSet -from ..register import register_provider_adapter +from astrbot.api.provider import Provider from astrbot.core.provider.entities import LLMResponse -from typing import AsyncGenerator +from astrbot.core.provider.func_tool_manager import ToolSet +from astrbot.core.utils.io import download_image_by_url + +from ..register import register_provider_adapter @register_provider_adapter( - "anthropic_chat_completion", "Anthropic Claude API 提供商适配器" + "anthropic_chat_completion", + "Anthropic Claude API 提供商适配器", ) class ProviderAnthropic(Provider): def __init__( @@ -33,7 +34,7 @@ class ProviderAnthropic(Provider): ) self.chosen_api_key: str = "" - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else "" self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.timeout = provider_config.get("timeout", 120) @@ -41,7 +42,9 @@ class ProviderAnthropic(Provider): self.timeout = int(self.timeout) self.client = AsyncAnthropic( - api_key=self.chosen_api_key, timeout=self.timeout, base_url=self.base_url + api_key=self.chosen_api_key, + timeout=self.timeout, + base_url=self.base_url, ) self.set_model(provider_config["model_config"]["model"]) @@ -54,6 +57,7 @@ class ProviderAnthropic(Provider): Returns: system_prompt: 系统提示内容 new_messages: 处理后的消息列表,去除系统提示 + """ system_prompt = "" new_messages = [] @@ -73,18 +77,19 @@ class ProviderAnthropic(Provider): "input": ( json.loads(tool_call["function"]["arguments"]) if isinstance( - tool_call["function"]["arguments"], str + tool_call["function"]["arguments"], + str, ) else tool_call["function"]["arguments"] ), "id": tool_call["id"], - } + }, ) new_messages.append( { "role": "assistant", "content": blocks, - } + }, ) elif message["role"] == "tool": new_messages.append( @@ -95,9 +100,9 @@ class ProviderAnthropic(Provider): "type": "tool_result", "tool_use_id": message["tool_call_id"], "content": message["content"], - } + }, ], - } + }, ) else: new_messages.append(message) @@ -135,7 +140,9 @@ class ProviderAnthropic(Provider): return llm_response async def _query_stream( - self, payloads: dict, tools: ToolSet | None + self, + payloads: dict, + tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: if tools: if tool_list := tools.get_func_desc_anthropic_style(): @@ -154,7 +161,9 @@ class ProviderAnthropic(Provider): if event.content_block.type == "text": # 文本块开始 yield LLMResponse( - role="assistant", completion_text="", is_chunk=True + role="assistant", + completion_text="", + is_chunk=True, ) elif event.content_block.type == "tool_use": # 工具使用块开始,初始化缓冲区 @@ -198,7 +207,7 @@ class ProviderAnthropic(Provider): "id": tool_info["id"], "name": tool_info["name"], "input": tool_info["input"], - } + }, ) yield LLMResponse( @@ -218,7 +227,9 @@ class ProviderAnthropic(Provider): # 返回最终的完整结果 final_response = LLMResponse( - role="assistant", completion_text=final_text, is_chunk=False + role="assistant", + completion_text=final_text, + is_chunk=False, ) if final_tool_calls: @@ -326,7 +337,7 @@ class ProviderAnthropic(Provider): async for llm_response in self._query_stream(payloads, func_tool): yield llm_response - async def assemble_context(self, text: str, image_urls: List[str] | None = None): + async def assemble_context(self, text: str, image_urls: list[str] | None = None): """组装上下文,支持文本和图片""" if not image_urls: return {"role": "user", "content": text} @@ -365,15 +376,13 @@ class ProviderAnthropic(Provider): else image_data ), }, - } + }, ) return {"role": "user", "content": content} async def encode_image_bs64(self, image_url: str) -> str: - """ - 将图片转换为 base64 - """ + """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") with open(image_url, "rb") as f: @@ -384,7 +393,7 @@ class ProviderAnthropic(Provider): def get_current_key(self) -> str: return self.chosen_api_key - async def get_models(self) -> List[str]: + async def get_models(self) -> list[str]: models_str = [] models = await self.client.models.list() models = sorted(models.data, key=lambda x: x.id) diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 6ddf452d4..b67d7e2ab 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -1,15 +1,15 @@ -import uuid -import time -import json -import re -import hashlib -import random import asyncio +import hashlib +import json +import random +import re +import time +import uuid from pathlib import Path -from typing import Dict from xml.sax.saxutils import escape from httpx import AsyncClient, Timeout + from astrbot.core.config.default import VERSION from ..entities import ProviderType @@ -21,7 +21,7 @@ TEMP_DIR.mkdir(parents=True, exist_ok=True) class OTTSProvider: - def __init__(self, config: Dict): + def __init__(self, config: dict): self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -58,7 +58,7 @@ class OTTSProvider: path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/" return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}" - async def get_audio(self, text: str, voice_params: Dict) -> str: + async def get_audio(self, text: str, voice_params: dict) -> str: file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav" signature = await self._generate_signature() for attempt in range(self.retry_count): @@ -86,7 +86,7 @@ class OTTSProvider: return str(file_path.resolve()) except Exception as e: if attempt == self.retry_count - 1: - raise RuntimeError(f"OTTS请求失败: {str(e)}") from e + raise RuntimeError(f"OTTS请求失败: {e!s}") from e await asyncio.sleep(0.5 * (attempt + 1)) @@ -94,7 +94,8 @@ class AzureNativeProvider(TTSProvider): def __init__(self, provider_config: dict, provider_settings: dict): super().__init__(provider_config, provider_settings) self.subscription_key = provider_config.get( - "azure_tts_subscription_key", "" + "azure_tts_subscription_key", + "", ).strip() if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key): raise ValueError("无效的Azure订阅密钥") @@ -119,7 +120,7 @@ class AzureNativeProvider(TTSProvider): "User-Agent": f"AstrBot/{VERSION}", "Content-Type": "application/ssml+xml", "X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm", - } + }, ) return self @@ -132,7 +133,8 @@ class AzureNativeProvider(TTSProvider): f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" ) response = await self.client.post( - token_url, headers={"Ocp-Apim-Subscription-Key": self.subscription_key} + token_url, + headers={"Ocp-Apim-Subscription-Key": self.subscription_key}, ) response.raise_for_status() self.token = response.text diff --git a/astrbot/core/provider/sources/coze_api_client.py b/astrbot/core/provider/sources/coze_api_client.py index a768979c6..e8f3a1e24 100644 --- a/astrbot/core/provider/sources/coze_api_client.py +++ b/astrbot/core/provider/sources/coze_api_client.py @@ -1,8 +1,11 @@ -import json import asyncio -import aiohttp import io -from typing import Dict, List, Any, AsyncGenerator +import json +from collections.abc import AsyncGenerator +from typing import Any + +import aiohttp + from astrbot.core import logger @@ -32,7 +35,9 @@ class CozeAPIClient: "Accept": "text/event-stream", } self.session = aiohttp.ClientSession( - headers=headers, timeout=timeout, connector=connector + headers=headers, + timeout=timeout, + connector=connector, ) return self.session @@ -46,6 +51,7 @@ class CozeAPIClient: file_data (bytes): 文件的二进制数据 Returns: str: 上传成功后返回的 file_id + """ session = await self._ensure_session() url = f"{self.api_base}/v1/files/upload" @@ -64,12 +70,12 @@ class CozeAPIClient: response_text = await response.text() logger.debug( - f"文件上传响应状态: {response.status}, 内容: {response_text}" + f"文件上传响应状态: {response.status}, 内容: {response_text}", ) if response.status != 200: raise Exception( - f"文件上传失败,状态码: {response.status}, 响应: {response_text}" + f"文件上传失败,状态码: {response.status}, 响应: {response_text}", ) try: @@ -88,8 +94,8 @@ class CozeAPIClient: logger.error("文件上传超时") raise Exception("文件上传超时") except Exception as e: - logger.error(f"文件上传失败: {str(e)}") - raise Exception(f"文件上传失败: {str(e)}") + logger.error(f"文件上传失败: {e!s}") + raise Exception(f"文件上传失败: {e!s}") async def download_image(self, image_url: str) -> bytes: """下载图片并返回字节数据 @@ -98,6 +104,7 @@ class CozeAPIClient: image_url (str): 图片的URL Returns: bytes: 图片的二进制数据 + """ session = await self._ensure_session() @@ -110,19 +117,19 @@ class CozeAPIClient: return image_data except Exception as e: - logger.error(f"下载图片失败 {image_url}: {str(e)}") - raise Exception(f"下载图片失败: {str(e)}") + logger.error(f"下载图片失败 {image_url}: {e!s}") + raise Exception(f"下载图片失败: {e!s}") async def chat_messages( self, bot_id: str, user_id: str, - additional_messages: List[Dict] | None = None, + additional_messages: list[dict] | None = None, conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, timeout: float = 120, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[dict[str, Any], None]: """发送聊天消息并返回流式响应 Args: @@ -133,6 +140,7 @@ class CozeAPIClient: auto_save_history: 是否自动保存历史 stream: 是否流式响应 timeout: 超时时间 + """ session = await self._ensure_session() url = f"{self.api_base}/v3/chat" @@ -198,7 +206,7 @@ class CozeAPIClient: except asyncio.TimeoutError: raise Exception(f"Coze API 流式请求超时 ({timeout}秒)") except Exception as e: - raise Exception(f"Coze API 流式请求失败: {str(e)}") + raise Exception(f"Coze API 流式请求失败: {e!s}") async def clear_context(self, conversation_id: str): """清空会话上下文 @@ -207,6 +215,7 @@ class CozeAPIClient: conversation_id: 会话ID Returns: dict: API响应结果 + """ session = await self._ensure_session() url = f"{self.api_base}/v3/conversation/message/clear_context" @@ -230,7 +239,7 @@ class CozeAPIClient: except asyncio.TimeoutError: raise Exception("Coze API 请求超时") except aiohttp.ClientError as e: - raise Exception(f"Coze API 请求失败: {str(e)}") + raise Exception(f"Coze API 请求失败: {e!s}") async def get_message_list( self, @@ -248,6 +257,7 @@ class CozeAPIClient: offset: 偏移量 Returns: dict: API响应结果 + """ session = await self._ensure_session() url = f"{self.api_base}/v3/conversation/message/list" @@ -264,8 +274,8 @@ class CozeAPIClient: return await response.json() except Exception as e: - logger.error(f"获取Coze消息列表失败: {str(e)}") - raise Exception(f"获取Coze消息列表失败: {str(e)}") + logger.error(f"获取Coze消息列表失败: {e!s}") + raise Exception(f"获取Coze消息列表失败: {e!s}") async def close(self): """关闭会话""" @@ -275,8 +285,8 @@ class CozeAPIClient: if __name__ == "__main__": - import os import asyncio + import os async def test_coze_api_client(): api_key = os.getenv("COZE_API_KEY", "") diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py index 639af0814..caee65020 100644 --- a/astrbot/core/provider/sources/coze_source.py +++ b/astrbot/core/provider/sources/coze_source.py @@ -1,13 +1,15 @@ -import json -import os import base64 import hashlib -from typing import AsyncGenerator, Dict -from astrbot.core.message.message_event_result import MessageChain +import json +import os +from collections.abc import AsyncGenerator + import astrbot.core.message.components as Comp -from astrbot.api.provider import Provider from astrbot import logger +from astrbot.api.provider import Provider +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse + from ..register import register_provider_adapter from .coze_api_client import CozeAPIClient @@ -34,18 +36,18 @@ class ProviderCoze(Provider): self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") if not isinstance(self.api_base, str) or not self.api_base.startswith( - ("http://", "https://") + ("http://", "https://"), ): raise Exception( - "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。" + "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。", ) self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.auto_save_history = provider_config.get("auto_save_history", True) - self.conversation_ids: Dict[str, str] = {} - self.file_id_cache: Dict[str, Dict[str, str]] = {} + self.conversation_ids: dict[str, str] = {} + self.file_id_cache: dict[str, dict[str, str]] = {} # 创建 API 客户端 self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) @@ -59,8 +61,8 @@ class ProviderCoze(Provider): Returns: str: 缓存键 - """ + """ try: if is_base64 and data.startswith("data:image/"): try: @@ -71,26 +73,24 @@ class ProviderCoze(Provider): except Exception: cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest() return cache_key + elif data.startswith(("http://", "https://")): + # URL图片,使用URL作为缓存键 + cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() + return cache_key else: - if data.startswith(("http://", "https://")): - # URL图片,使用URL作为缓存键 - cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() - return cache_key - else: - clean_path = ( - data.split("_")[0] - if "_" in data and len(data.split("_")) >= 3 - else data - ) + clean_path = ( + data.split("_")[0] + if "_" in data and len(data.split("_")) >= 3 + else data + ) - if os.path.exists(clean_path): - with open(clean_path, "rb") as f: - file_content = f.read() - cache_key = hashlib.md5(file_content).hexdigest() - return cache_key - else: - cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest() - return cache_key + if os.path.exists(clean_path): + with open(clean_path, "rb") as f: + file_content = f.read() + cache_key = hashlib.md5(file_content).hexdigest() + return cache_key + cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest() + return cache_key except Exception as e: cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() @@ -117,7 +117,9 @@ class ProviderCoze(Provider): return file_id async def _download_and_upload_image( - self, image_url: str, session_id: str | None = None + self, + image_url: str, + session_id: str | None = None, ) -> str: """下载图片并上传到 Coze,返回 file_id""" # 计算哈希实现缓存 @@ -142,14 +144,15 @@ class ProviderCoze(Provider): return file_id except Exception as e: - logger.error(f"处理图片失败 {image_url}: {str(e)}") - raise Exception(f"处理图片失败: {str(e)}") + logger.error(f"处理图片失败 {image_url}: {e!s}") + raise Exception(f"处理图片失败: {e!s}") async def _process_context_images( - self, content: str | list, session_id: str + self, + content: str | list, + session_id: str, ) -> str: """处理上下文中的图片内容,将 base64 图片上传并替换为 file_id""" - try: if isinstance(content, str): return content @@ -184,14 +187,15 @@ class ProviderCoze(Provider): continue # 计算哈希用于缓存 cache_key = self._generate_cache_key( - image_data, is_base64=image_data.startswith("data:image/") + image_data, + is_base64=image_data.startswith("data:image/"), ) # 检查缓存 if cache_key in self.file_id_cache[session_id]: file_id = self.file_id_cache[session_id][cache_key] processed_content.append( - {"type": "image", "file_id": file_id} + {"type": "image", "file_id": file_id}, ) else: # 上传图片并缓存 @@ -207,7 +211,8 @@ class ProviderCoze(Provider): elif image_data.startswith(("http://", "https://")): # URL 图片 file_id = await self._download_and_upload_image( - image_data, session_id + image_data, + session_id, ) # 为URL图片也添加缓存 self.file_id_cache[session_id][cache_key] = file_id @@ -222,22 +227,21 @@ class ProviderCoze(Provider): ) else: logger.warning( - f"无法处理的图片格式: {image_data[:50]}..." + f"无法处理的图片格式: {image_data[:50]}...", ) continue processed_content.append( - {"type": "image", "file_id": file_id} + {"type": "image", "file_id": file_id}, ) result = json.dumps(processed_content, ensure_ascii=False) return result except Exception as e: - logger.error(f"处理上下文图片失败: {str(e)}") + logger.error(f"处理上下文图片失败: {e!s}") if isinstance(content, str): return content - else: - return json.dumps(content, ensure_ascii=False) + return json.dumps(content, ensure_ascii=False) async def text_chat( self, @@ -262,8 +266,10 @@ class ProviderCoze(Provider): system_prompt (str): 系统提示语 tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持) model (str): 模型名称(不支持) + Returns: LLMResponse: LLM响应对象 + """ accumulated_content = "" final_response = None @@ -291,8 +297,7 @@ class ProviderCoze(Provider): if accumulated_content: chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) return LLMResponse(role="assistant", result_chain=chain) - else: - return LLMResponse(role="assistant", completion_text="") + return LLMResponse(role="assistant", completion_text="") async def text_chat_stream( self, @@ -319,7 +324,11 @@ class ProviderCoze(Provider): if system_prompt: if not self.auto_save_history or not conversation_id: additional_messages.append( - {"role": "system", "content": system_prompt, "content_type": "text"} + { + "role": "system", + "content": system_prompt, + "content_type": "text", + }, ) if not self.auto_save_history and contexts: @@ -343,14 +352,15 @@ class ProviderCoze(Provider): ) ): processed_content = await self._process_context_images( - content, user_id + content, + user_id, ) additional_messages.append( { "role": ctx["role"], "content": processed_content, "content_type": "object_string", - } + }, ) else: # 纯文本 @@ -363,7 +373,7 @@ class ProviderCoze(Provider): else json.dumps(content, ensure_ascii=False) ), "content_type": "text", - } + }, ) else: logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}") @@ -380,7 +390,8 @@ class ProviderCoze(Provider): if url.startswith(("http://", "https://")): # 网络图片 file_id = await self._download_and_upload_image( - url, user_id + url, + user_id, ) else: # 本地文件或 base64 @@ -389,37 +400,41 @@ class ProviderCoze(Provider): _, encoded = url.split(",", 1) image_data = base64.b64decode(encoded) cache_key = self._generate_cache_key( - url, is_base64=True + url, + is_base64=True, ) file_id = await self._upload_file( - image_data, user_id, cache_key + image_data, + user_id, + cache_key, + ) + # 本地文件 + elif os.path.exists(url): + with open(url, "rb") as f: + image_data = f.read() + # 用文件路径和修改时间来缓存 + file_stat = os.stat(url) + cache_key = self._generate_cache_key( + f"{url}_{file_stat.st_mtime}_{file_stat.st_size}", + is_base64=False, + ) + file_id = await self._upload_file( + image_data, + user_id, + cache_key, ) else: - # 本地文件 - if os.path.exists(url): - with open(url, "rb") as f: - image_data = f.read() - # 用文件路径和修改时间来缓存 - file_stat = os.stat(url) - cache_key = self._generate_cache_key( - f"{url}_{file_stat.st_mtime}_{file_stat.st_size}", - is_base64=False, - ) - file_id = await self._upload_file( - image_data, user_id, cache_key - ) - else: - logger.warning(f"图片文件不存在: {url}") - continue + logger.warning(f"图片文件不存在: {url}") + continue object_string_content.append( { "type": "image", "file_id": file_id, - } + }, ) except Exception as e: - logger.error(f"处理图片失败 {url}: {str(e)}") + logger.error(f"处理图片失败 {url}: {e!s}") continue if object_string_content: @@ -429,18 +444,17 @@ class ProviderCoze(Provider): "role": "user", "content": content, "content_type": "object_string", - } - ) - else: - # 纯文本 - if prompt: - additional_messages.append( - { - "role": "user", - "content": prompt, - "content_type": "text", - } + }, ) + # 纯文本 + elif prompt: + additional_messages.append( + { + "role": "user", + "content": prompt, + "content_type": "text", + }, + ) try: accumulated_content = "" @@ -534,10 +548,10 @@ class ProviderCoze(Provider): ) except Exception as e: - logger.error(f"Coze 流式请求失败: {str(e)}") + logger.error(f"Coze 流式请求失败: {e!s}") yield LLMResponse( role="err", - completion_text=f"Coze 流式请求失败: {str(e)}", + completion_text=f"Coze 流式请求失败: {e!s}", is_chunk=False, ) @@ -558,12 +572,11 @@ class ProviderCoze(Provider): if "code" in response and response["code"] == 0: self.conversation_ids.pop(user_id, None) return True - else: - logger.warning(f"清空 Coze 会话上下文失败: {response}") - return False + logger.warning(f"清空 Coze 会话上下文失败: {response}") + return False except Exception as e: - logger.error(f"清空 Coze 会话失败: {str(e)}") + logger.error(f"清空 Coze 会话失败: {e!s}") return False async def get_current_key(self): @@ -590,7 +603,10 @@ class ProviderCoze(Provider): self.bot_id = model async def get_human_readable_context( - self, session_id: str, page: int = 1, page_size: int = 10 + self, + session_id: str, + page: int = 1, + page_size: int = 10, ): """获取人类可读的上下文历史""" user_id = session_id @@ -627,7 +643,7 @@ class ProviderCoze(Provider): return readable_history except Exception as e: - logger.error(f"获取 Coze 消息历史失败: {str(e)}") + logger.error(f"获取 Coze 消息历史失败: {e!s}") return [] async def terminate(self): diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 0183f7244..92613dc1a 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -1,15 +1,18 @@ -import re import asyncio import functools -from .. import Provider, Personality -from ..entities import LLMResponse -from ..register import register_provider_adapter -from astrbot.core.message.message_event_result import MessageChain -from .openai_source import ProviderOpenAIOfficial -from astrbot.core import logger, sp +import re + from dashscope import Application from dashscope.app.application_response import ApplicationResponse +from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain + +from .. import Personality, Provider +from ..entities import LLMResponse +from ..register import register_provider_adapter +from .openai_source import ProviderOpenAIOfficial + @register_provider_adapter("dashscope", "Dashscope APP 适配器。") class ProviderDashscope(ProviderOpenAIOfficial): @@ -50,6 +53,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): Returns: bool: 是否有 RAG 选项 + """ if self.rag_options and ( len(self.rag_options.get("pipeline_ids", [])) > 0 @@ -127,12 +131,12 @@ class ProviderDashscope(ProviderOpenAIOfficial): if response.status_code != 200: logger.error( - f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code" + f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", ) return LLMResponse( role="err", result_chain=MessageChain().message( - f"阿里云百炼请求失败: message={response.message} code={response.status_code}" + f"阿里云百炼请求失败: message={response.message} code={response.status_code}", ), ) diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index efda31ca9..44e9965cc 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,7 +3,7 @@ import base64 import logging import os import uuid -from typing import Optional, Tuple + import aiohttp import dashscope from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer @@ -15,14 +15,17 @@ except ( ): # pragma: no cover - older dashscope versions without Qwen TTS support MultiModalConversation = None +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + from ..entities import ProviderType from ..provider import TTSProvider from ..register import register_provider_adapter -from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter( - "dashscope_tts", "Dashscope TTS API", provider_type=ProviderType.TEXT_TO_SPEECH + "dashscope_tts", + "Dashscope TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderDashscopeTTSAPI(TTSProvider): def __init__( @@ -33,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): super().__init__(provider_config, provider_settings) self.chosen_api_key: str = provider_config.get("api_key", "") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model")) self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 dashscope.api_key = self.chosen_api_key @@ -52,7 +55,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): if not audio_bytes: raise RuntimeError( - "Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable." + "Audio synthesis failed, returned empty content. The model may not be supported or the service is unavailable.", ) path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}") @@ -63,7 +66,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): def _call_qwen_tts(self, model: str, text: str): if MultiModalConversation is None: raise RuntimeError( - "dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models." + "dashscope SDK missing MultiModalConversation. Please upgrade the dashscope package to use Qwen TTS models.", ) kwargs = { @@ -74,24 +77,26 @@ class ProviderDashscopeTTSAPI(TTSProvider): } if not self.voice: logging.warning( - "No voice specified for Qwen TTS model, using default 'Cherry'." + "No voice specified for Qwen TTS model, using default 'Cherry'.", ) return MultiModalConversation.call(**kwargs) async def _synthesize_with_qwen_tts( - self, model: str, text: str - ) -> Tuple[Optional[bytes], str]: + self, + model: str, + text: str, + ) -> tuple[bytes | None, str]: loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, self._call_qwen_tts, model, text) audio_bytes = await self._extract_audio_from_response(response) if not audio_bytes: raise RuntimeError( - f"Audio synthesis failed for model '{model}'. {response}" + f"Audio synthesis failed for model '{model}'. {response}", ) ext = ".wav" return audio_bytes, ext - async def _extract_audio_from_response(self, response) -> Optional[bytes]: + async def _extract_audio_from_response(self, response) -> bytes | None: output = getattr(response, "output", None) audio_obj = getattr(output, "audio", None) if output is not None else None if not audio_obj: @@ -102,7 +107,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): try: return base64.b64decode(data_b64) except (ValueError, TypeError): - logging.error("Failed to decode base64 audio data.") + logging.exception("Failed to decode base64 audio data.") return None url = getattr(audio_obj, "url", None) @@ -110,23 +115,28 @@ class ProviderDashscopeTTSAPI(TTSProvider): return await self._download_audio_from_url(url) return None - async def _download_audio_from_url(self, url: str) -> Optional[bytes]: + async def _download_audio_from_url(self, url: str) -> bytes | None: if not url: return None timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20 try: - async with aiohttp.ClientSession() as session: - async with session.get( - url, timeout=aiohttp.ClientTimeout(total=timeout) - ) as response: - return await response.read() + async with ( + aiohttp.ClientSession() as session, + session.get( + url, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response, + ): + return await response.read() except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e: - logging.error(f"Failed to download audio from URL {url}: {e}") + logging.exception(f"Failed to download audio from URL {url}: {e}") return None async def _synthesize_with_cosyvoice( - self, model: str, text: str - ) -> Tuple[Optional[bytes], str]: + self, + model: str, + text: str, + ) -> tuple[bytes | None, str]: synthesizer = SpeechSynthesizer( model=model, voice=self.voice, @@ -134,13 +144,16 @@ class ProviderDashscopeTTSAPI(TTSProvider): ) loop = asyncio.get_event_loop() audio_bytes = await loop.run_in_executor( - None, synthesizer.call, text, self.timeout_ms + None, + synthesizer.call, + text, + self.timeout_ms, ) if not audio_bytes: resp = synthesizer.get_response() if resp and isinstance(resp, dict): raise RuntimeError( - f"Audio synthesis failed for model '{model}'. {resp}".strip() + f"Audio synthesis failed for model '{model}'. {resp}".strip(), ) return audio_bytes, ".wav" diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index f7c4e63ca..9f9f146aa 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,13 +1,15 @@ -import astrbot.core.message.components as Comp import os -from .. import Provider -from ..entities import LLMResponse -from ..register import register_provider_adapter -from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_image_by_url, download_file + +import astrbot.core.message.components as Comp from astrbot.core import logger, sp from astrbot.core.message.message_event_result import MessageChain from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.dify_api_client import DifyAPIClient +from astrbot.core.utils.io import download_file, download_image_by_url + +from .. import Provider +from ..entities import LLMResponse +from ..register import register_provider_adapter @register_provider_adapter("dify", "Dify APP 适配器。") @@ -32,10 +34,12 @@ class ProviderDify(Provider): raise Exception("Dify API 类型不能为空。") self.model_name = "dify" self.workflow_output_key = provider_config.get( - "dify_workflow_output_key", "astrbot_wf_output" + "dify_workflow_output_key", + "astrbot_wf_output", ) self.dify_query_input_key = provider_config.get( - "dify_query_input_key", "astrbot_text_query" + "dify_query_input_key", + "astrbot_text_query", ) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" @@ -76,12 +80,13 @@ class ProviderDify(Provider): else image_url ) file_response = await self.api_client.file_upload( - image_path, user=session_id + image_path, + user=session_id, ) logger.debug(f"Dify 上传图片响应:{file_response}") if "id" not in file_response: logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。", ) continue files_payload.append( @@ -89,7 +94,7 @@ class ProviderDify(Provider): "type": "image", "transfer_method": "local_file", "upload_file_id": file_response["id"], - } + }, ) # 获得会话变量 @@ -132,7 +137,7 @@ class ProviderDify(Provider): elif chunk["event"] == "error": logger.error(f"Dify 出现错误:{chunk}") raise Exception( - f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}" + f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}", ) case "workflow": @@ -149,37 +154,37 @@ class ProviderDify(Provider): match chunk["event"]: case "workflow_started": logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。" + f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。", ) case "node_finished": logger.debug( - f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。" + f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。", ) case "workflow_finished": logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束" + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束", ) logger.debug(f"Dify 工作流结果:{chunk}") if chunk["data"]["error"]: logger.error( - f"Dify 工作流出现错误:{chunk['data']['error']}" + f"Dify 工作流出现错误:{chunk['data']['error']}", ) raise Exception( - f"Dify 工作流出现错误:{chunk['data']['error']}" + f"Dify 工作流出现错误:{chunk['data']['error']}", ) if ( self.workflow_output_key not in chunk["data"]["outputs"] ): raise Exception( - f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}" + f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}", ) result = chunk case _: raise Exception(f"未知的 Dify API 类型:{self.api_type}") except Exception as e: - logger.error(f"Dify 请求失败:{str(e)}") - return LLMResponse(role="err", completion_text=f"Dify 请求失败:{str(e)}") + logger.error(f"Dify 请求失败:{e!s}") + return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}") if not result: logger.warning("Dify 请求结果为空,请查看 Debug 日志。") diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 44c2d1756..8bbf62325 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -1,14 +1,17 @@ -import uuid -import os -import edge_tts -import subprocess import asyncio -from ..provider import TTSProvider -from ..entities import ProviderType -from ..register import register_provider_adapter +import os +import subprocess +import uuid + +import edge_tts + from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + """ edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 ``` @@ -19,7 +22,9 @@ Windows 如果提示找不到指定文件,以管理员身份运行命令行窗 @register_provider_adapter( - "edge_tts", "Microsoft Edge TTS", provider_type=ProviderType.TEXT_TO_SPEECH + "edge_tts", + "Microsoft Edge TTS", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderEdgeTTS(TTSProvider): def __init__( @@ -31,9 +36,9 @@ class ProviderEdgeTTS(TTSProvider): # 设置默认语音,如果没有指定则使用中文小萱 self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") - self.rate = provider_config.get("rate", None) - self.volume = provider_config.get("volume", None) - self.pitch = provider_config.get("pitch", None) + self.rate = provider_config.get("rate") + self.volume = provider_config.get("volume") + self.pitch = provider_config.get("pitch") self.timeout = provider_config.get("timeout", 30) self.proxy = os.getenv("https_proxy", None) @@ -97,26 +102,25 @@ class ProviderEdgeTTS(TTSProvider): os.remove(mp3_path) if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: return wav_path - else: - logger.error("生成的WAV文件不存在或为空") - raise RuntimeError("生成的WAV文件不存在或为空") + logger.error("生成的WAV文件不存在或为空") + raise RuntimeError("生成的WAV文件不存在或为空") except subprocess.CalledProcessError as e: logger.error( - f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}" + f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", ) try: if os.path.exists(mp3_path): os.remove(mp3_path) except Exception: pass - raise RuntimeError(f"FFmpeg 转换失败: {str(e)}") + raise RuntimeError(f"FFmpeg 转换失败: {e!s}") except Exception as e: - logger.error(f"音频生成失败: {str(e)}") + logger.error(f"音频生成失败: {e!s}") try: if os.path.exists(mp3_path): os.remove(mp3_path) except Exception: pass - raise RuntimeError(f"音频生成失败: {str(e)}") + raise RuntimeError(f"音频生成失败: {e!s}") diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 49c78239e..ca571c3ee 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -1,15 +1,18 @@ import os -import uuid import re -import ormsgpack -from pydantic import BaseModel, conint -from httpx import AsyncClient +import uuid from typing import Annotated, Literal -from ..provider import TTSProvider -from ..entities import ProviderType -from ..register import register_provider_adapter + +import ormsgpack +from httpx import AsyncClient +from pydantic import BaseModel, conint + from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + class ServeReferenceAudio(BaseModel): audio: bytes @@ -35,7 +38,9 @@ class ServeTTSRequest(BaseModel): @register_provider_adapter( - "fishaudio_tts_api", "FishAudio TTS API", provider_type=ProviderType.TEXT_TO_SPEECH + "fishaudio_tts_api", + "FishAudio TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderFishAudioTTSAPI(TTSProvider): def __init__( @@ -48,16 +53,16 @@ class ProviderFishAudioTTSAPI(TTSProvider): self.reference_id: str = provider_config.get("fishaudio-tts-reference-id", "") self.character: str = provider_config.get("fishaudio-tts-character", "可莉") self.api_base: str = provider_config.get( - "api_base", "https://api.fish-audio.cn/v1" + "api_base", + "https://api.fish-audio.cn/v1", ) self.headers = { "Authorization": f"Bearer {self.chosen_api_key}", } - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model")) async def _get_reference_id_by_character(self, character: str) -> str: - """ - 获取角色的reference_id + """获取角色的reference_id Args: character: 角色名称 @@ -67,13 +72,16 @@ class ProviderFishAudioTTSAPI(TTSProvider): exception: APIException: 获取语音角色列表为空 + """ sort_options = ["score", "task_count", "created_at"] async with AsyncClient(base_url=self.api_base.replace("/v1", "")) as client: for sort_by in sort_options: params = {"title": character, "sort_by": sort_by} response = await client.get( - "/model", params=params, headers=self.headers + "/model", + params=params, + headers=self.headers, ) resp_data = response.json() if resp_data["total"] == 0: @@ -84,14 +92,14 @@ class ProviderFishAudioTTSAPI(TTSProvider): return None def _validate_reference_id(self, reference_id: str) -> bool: - """ - 验证reference_id格式是否有效 + """验证reference_id格式是否有效 Args: reference_id: 参考模型ID Returns: bool: ID是否有效 + """ if not reference_id or not reference_id.strip(): return False @@ -109,7 +117,7 @@ class ProviderFishAudioTTSAPI(TTSProvider): raise ValueError( f"无效的FishAudio参考模型ID: '{self.reference_id}'. " f"请确保ID是32位十六进制字符串(例如: 626bb6d3f3364c9cbc3aa6a67300a664)。" - f"您可以从 https://fish.audio/zh-CN/discovery 获取有效的模型ID。" + f"您可以从 https://fish.audio/zh-CN/discovery 获取有效的模型ID。", ) reference_id = self.reference_id.strip() else: diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index 562d11353..8d11cce5f 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -1,9 +1,10 @@ from google import genai from google.genai import types from google.genai.errors import APIError + +from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter -from ..entities import ProviderType @register_provider_adapter( @@ -18,40 +19,38 @@ class GeminiEmbeddingProvider(EmbeddingProvider): self.provider_settings = provider_settings api_key: str = provider_config.get("embedding_api_key") - api_base: str = provider_config.get("embedding_api_base", None) + api_base: str = provider_config.get("embedding_api_base") timeout: int = int(provider_config.get("timeout", 20)) http_options = types.HttpOptions(timeout=timeout * 1000) if api_base: - if api_base.endswith("/"): - api_base = api_base[:-1] + api_base = api_base.removesuffix("/") http_options.base_url = api_base self.client = genai.Client(api_key=api_key, http_options=http_options).aio self.model = provider_config.get( - "embedding_model", "gemini-embedding-exp-03-07" + "embedding_model", + "gemini-embedding-exp-03-07", ) async def get_embedding(self, text: str) -> list[float]: - """ - 获取文本的嵌入 - """ + """获取文本的嵌入""" try: result = await self.client.models.embed_content( - model=self.model, contents=text + model=self.model, + contents=text, ) return result.embeddings[0].values except APIError as e: raise Exception(f"Gemini Embedding API请求失败: {e.message}") async def get_embeddings(self, texts: list[str]) -> list[list[float]]: - """ - 批量获取文本的嵌入 - """ + """批量获取文本的嵌入""" try: result = await self.client.models.embed_content( - model=self.model, contents=texts + model=self.model, + contents=texts, ) return [embedding.values for embedding in result.embeddings] except APIError as e: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b14a9bdcb..f9eef2e92 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,7 +3,6 @@ import base64 import json import logging import random -from typing import Optional, List from collections.abc import AsyncGenerator from google import genai @@ -32,7 +31,8 @@ logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning()) @register_provider_adapter( - "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" + "googlegenai_chat_completion", + "Google Gemini Chat Completion 提供商适配器", ) class ProviderGoogleGenAI(Provider): CATEGORY_MAPPING = { @@ -60,11 +60,11 @@ class ProviderGoogleGenAI(Provider): provider_settings, default_persona, ) - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else "" self.timeout: int = int(provider_config.get("timeout", 180)) - self.api_base: Optional[str] = provider_config.get("api_base", None) + self.api_base: str | None = provider_config.get("api_base", None) if self.api_base and self.api_base.endswith("/"): self.api_base = self.api_base[:-1] @@ -87,7 +87,8 @@ class ProviderGoogleGenAI(Provider): user_safety_config = self.provider_config.get("gm_safety_settings", {}) self.safety_settings = [ types.SafetySetting( - category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str] + category=harm_category, + threshold=self.THRESHOLD_MAPPING[threshold_str], ) for config_key, harm_category in self.CATEGORY_MAPPING.items() if (threshold_str := user_safety_config.get(config_key)) @@ -104,27 +105,25 @@ class ProviderGoogleGenAI(Provider): if len(keys) > 0: self.set_key(random.choice(keys)) logger.info( - f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..." + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}...", ) await asyncio.sleep(1) return True - else: - logger.error( - f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..." - ) - raise Exception("达到了 Gemini 速率限制, 请稍后再试...") - else: logger.error( - f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}" + f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...", ) - raise e + raise Exception("达到了 Gemini 速率限制, 请稍后再试...") + logger.error( + f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}", + ) + raise e async def _prepare_query_config( self, payloads: dict, - tools: Optional[ToolSet] = None, - system_instruction: Optional[str] = None, - modalities: Optional[list[str]] = None, + tools: ToolSet | None = None, + system_instruction: str | None = None, + modalities: list[str] | None = None, temperature: float = 0.7, ) -> types.GenerateContentConfig: """准备查询配置""" @@ -152,7 +151,7 @@ class ProviderGoogleGenAI(Provider): logger.warning("代码执行工具与搜索工具互斥,已忽略搜索工具") if url_context: logger.warning( - "代码执行工具与URL上下文工具互斥,已忽略URL上下文工具" + "代码执行工具与URL上下文工具互斥,已忽略URL上下文工具", ) else: if native_search: @@ -163,13 +162,13 @@ class ProviderGoogleGenAI(Provider): tool_list.append(types.Tool(url_context=types.UrlContext())) else: logger.warning( - "当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包" + "当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包", ) elif "gemini-2.0-lite" in model_name: if native_coderunner or native_search or url_context: logger.warning( - "gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文,将忽略这些设置" + "gemini-2.0-lite 不支持代码执行、搜索工具和URL上下文,将忽略这些设置", ) tool_list = None @@ -186,7 +185,7 @@ class ProviderGoogleGenAI(Provider): tool_list.append(types.Tool(url_context=types.UrlContext())) else: logger.warning( - "当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包" + "当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包", ) if not tool_list: @@ -196,7 +195,7 @@ class ProviderGoogleGenAI(Provider): logger.warning("已启用原生工具,函数工具将被忽略") elif tools and (func_desc := tools.get_func_desc_google_genai_style()): tool_list = [ - types.Tool(function_declarations=func_desc["function_declarations"]) + types.Tool(function_declarations=func_desc["function_declarations"]), ] return types.GenerateContentConfig( @@ -223,8 +222,9 @@ class ProviderGoogleGenAI(Provider): thinking_budget=min( int( self.provider_config.get("gm_thinking_config", {}).get( - "budget", 0 - ) + "budget", + 0, + ), ), 24576, ), @@ -234,7 +234,7 @@ class ProviderGoogleGenAI(Provider): else None ), automatic_function_calling=types.AutomaticFunctionCallingConfig( - disable=True + disable=True, ), ) @@ -268,7 +268,7 @@ class ProviderGoogleGenAI(Provider): [ self.provider_config.get("gm_native_coderunner", False), self.provider_config.get("gm_native_search", False), - ] + ], ) for message in payloads["messages"]: role, content = message["role"], message.get("content") @@ -304,7 +304,7 @@ class ProviderGoogleGenAI(Provider): logger.warning("assistant 角色的消息内容为空,已添加空格占位") if native_tool_enabled and "tool_calls" in message: logger.warning( - "检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文" + "检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文", ) parts = [types.Part.from_text(text=" ")] append_or_extend(gemini_contents, parts, types.ModelContent) @@ -317,7 +317,7 @@ class ProviderGoogleGenAI(Provider): "name": message["tool_call_id"], "content": message["content"], }, - ) + ), ] append_or_extend(gemini_contents, parts, types.UserContent) @@ -328,7 +328,8 @@ class ProviderGoogleGenAI(Provider): @staticmethod def _process_content_parts( - candidate: types.Candidate, llm_response: LLMResponse + candidate: types.Candidate, + llm_response: LLMResponse, ) -> MessageChain: """处理内容部分并构建消息链""" if not candidate.content: @@ -381,7 +382,7 @@ class ProviderGoogleGenAI(Provider): llm_response.tools_call_args.append(part.function_call.args) # gemini 返回的 function_call.id 可能为 None llm_response.tools_call_ids.append( - part.function_call.id or part.function_call.name + part.function_call.id or part.function_call.name, ) elif ( part.inline_data @@ -406,11 +407,15 @@ class ProviderGoogleGenAI(Provider): conversation = self._prepare_conversation(payloads) temperature = payloads.get("temperature", 0.7) - result: Optional[types.GenerateContentResponse] = None + result: types.GenerateContentResponse | None = None while True: try: config = await self._prepare_query_config( - payloads, tools, system_instruction, modalities, temperature + payloads, + tools, + system_instruction, + modalities, + temperature, ) result = await self.client.models.generate_content( model=self.get_model(), @@ -427,7 +432,7 @@ class ProviderGoogleGenAI(Provider): raise Exception("温度参数已超过最大值2,仍然发生recitation") temperature += 0.2 logger.warning( - f"发生了recitation,正在提高温度至{temperature:.1f}重试..." + f"发生了recitation,正在提高温度至{temperature:.1f}重试...", ) continue @@ -438,7 +443,7 @@ class ProviderGoogleGenAI(Provider): e.message = "" if "Developer instruction is not enabled" in e.message: logger.warning( - f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)", ) system_instruction = None elif "Function calling is not enabled" in e.message: @@ -451,7 +456,7 @@ class ProviderGoogleGenAI(Provider): or "only supports text output" in e.message ): logger.warning( - f"{self.get_model()} 不支持多模态输出,降级为文本模态" + f"{self.get_model()} 不支持多模态输出,降级为文本模态", ) modalities = ["Text"] else: @@ -461,12 +466,15 @@ class ProviderGoogleGenAI(Provider): llm_response = LLMResponse("assistant") llm_response.raw_completion = result llm_response.result_chain = self._process_content_parts( - result.candidates[0], llm_response + result.candidates[0], + llm_response, ) return llm_response async def _query_stream( - self, payloads: dict, tools: ToolSet | None + self, + payloads: dict, + tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式请求 Gemini API""" system_instruction = next( @@ -480,7 +488,9 @@ class ProviderGoogleGenAI(Provider): while True: try: config = await self._prepare_query_config( - payloads, tools, system_instruction + payloads, + tools, + system_instruction, ) result = await self.client.models.generate_content_stream( model=self.get_model(), @@ -493,7 +503,7 @@ class ProviderGoogleGenAI(Provider): e.message = "" if "Developer instruction is not enabled" in e.message: logger.warning( - f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)" + f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)", ) system_instruction = None elif "Function calling is not enabled" in e.message: @@ -523,7 +533,8 @@ class ProviderGoogleGenAI(Provider): llm_response = LLMResponse("assistant", is_chunk=False) llm_response.raw_completion = chunk llm_response.result_chain = self._process_content_parts( - chunk.candidates[0], llm_response + chunk.candidates[0], + llm_response, ) yield llm_response return @@ -539,7 +550,8 @@ class ProviderGoogleGenAI(Provider): final_response = LLMResponse("assistant", is_chunk=False) final_response.raw_completion = chunk final_response.result_chain = self._process_content_parts( - chunk.candidates[0], final_response + chunk.candidates[0], + final_response, ) break @@ -550,7 +562,7 @@ class ProviderGoogleGenAI(Provider): # Set the complete accumulated text in the final response if accumulated_text: final_response.result_chain = MessageChain( - chain=[Comp.Plain(accumulated_text)] + chain=[Comp.Plain(accumulated_text)], ) elif not final_response.result_chain: # If no text was accumulated and no final response was set, provide empty space @@ -680,9 +692,7 @@ class ProviderGoogleGenAI(Provider): self._init_client() async def assemble_context(self, text: str, image_urls: list[str] | None = None): - """ - 组装上下文。 - """ + """组装上下文。""" if image_urls: user_content = { "role": "user", @@ -704,16 +714,13 @@ class ProviderGoogleGenAI(Provider): { "type": "image_url", "image_url": {"url": image_data}, - } + }, ) return user_content - else: - return {"role": "user", "content": text} + return {"role": "user", "content": text} async def encode_image_bs64(self, image_url: str) -> str: - """ - 将图片转换为 base64 - """ + """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") with open(image_url, "rb") as f: diff --git a/astrbot/core/provider/sources/gemini_tts_source.py b/astrbot/core/provider/sources/gemini_tts_source.py index 48cb48335..0bf92b325 100644 --- a/astrbot/core/provider/sources/gemini_tts_source.py +++ b/astrbot/core/provider/sources/gemini_tts_source.py @@ -13,7 +13,9 @@ from ..register import register_provider_adapter @register_provider_adapter( - "gemini_tts", "Gemini TTS API", provider_type=ProviderType.TEXT_TO_SPEECH + "gemini_tts", + "Gemini TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderGeminiTTSAPI(TTSProvider): def __init__( @@ -28,13 +30,13 @@ class ProviderGeminiTTSAPI(TTSProvider): http_options = types.HttpOptions(timeout=timeout * 1000) if api_base: - if api_base.endswith("/"): - api_base = api_base[:-1] + api_base = api_base.removesuffix("/") http_options.base_url = api_base self.client = genai.Client(api_key=api_key, http_options=http_options).aio self.model: str = provider_config.get( - "gemini_tts_model", "gemini-2.5-flash-preview-tts" + "gemini_tts_model", + "gemini-2.5-flash-preview-tts", ) self.prefix: str | None = provider_config.get( "gemini_tts_prefix", @@ -54,8 +56,8 @@ class ProviderGeminiTTSAPI(TTSProvider): voice_config=types.VoiceConfig( prebuilt_voice_config=types.PrebuiltVoiceConfig( voice_name=self.voice_name, - ) - ) + ), + ), ), ), ) diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index 6c4d872a9..7f8d39eac 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -3,12 +3,14 @@ import os import uuid import aiohttp -from ..provider import TTSProvider -from ..entities import ProviderType -from ..register import register_provider_adapter + from astrbot import logger from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + @register_provider_adapter( provider_type_name="gsv_tts_selfhost", @@ -24,7 +26,7 @@ class ProviderGSVTTS(TTSProvider): super().__init__(provider_config, provider_settings) self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip( - "/" + "/", ) self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "") self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "") @@ -40,7 +42,7 @@ class ProviderGSVTTS(TTSProvider): async def initialize(self): """异步初始化:在 ProviderManager 中被调用""" self._session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=self.timeout) + timeout=aiohttp.ClientTimeout(total=self.timeout), ) try: await self._set_model_weights() @@ -52,12 +54,15 @@ class ProviderGSVTTS(TTSProvider): def get_session(self) -> aiohttp.ClientSession: if not self._session or self._session.closed: raise RuntimeError( - "[GSV TTS] Provider HTTP session is not ready or closed." + "[GSV TTS] Provider HTTP session is not ready or closed.", ) return self._session async def _make_request( - self, endpoint: str, params=None, retries: int = 3 + self, + endpoint: str, + params=None, + retries: int = 3, ) -> bytes | None: """发起请求""" for attempt in range(retries): @@ -67,13 +72,13 @@ class ProviderGSVTTS(TTSProvider): if response.status != 200: error_text = await response.text() raise Exception( - f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}" + f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}", ) return await response.read() except Exception as e: if attempt < retries - 1: logger.warning( - f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..." + f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中...", ) await asyncio.sleep(1) else: @@ -98,7 +103,7 @@ class ProviderGSVTTS(TTSProvider): {"weights_path": self.sovits_weights_path}, ) logger.info( - f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}" + f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}", ) else: logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型") @@ -127,12 +132,10 @@ class ProviderGSVTTS(TTSProvider): with open(path, "wb") as f: f.write(result) return path - else: - raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") + raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") def build_synthesis_params(self, text: str) -> dict: - """ - 构建语音合成所需的参数字典。 + """构建语音合成所需的参数字典。 当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。 """ diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index c2444819b..d8b171718 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,15 +1,20 @@ import os -import uuid -import aiohttp import urllib.parse -from ..provider import TTSProvider -from ..entities import ProviderType -from ..register import register_provider_adapter +import uuid + +import aiohttp + from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + @register_provider_adapter( - "gsvi_tts_api", "GSVI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH + "gsvi_tts_api", + "GSVI TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderGSVITTS(TTSProvider): def __init__( @@ -19,8 +24,7 @@ class ProviderGSVITTS(TTSProvider): ) -> None: super().__init__(provider_config, provider_settings) self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") - if self.api_base.endswith("/"): - self.api_base = self.api_base[:-1] + self.api_base = self.api_base.removesuffix("/") self.character = provider_config.get("character") self.emotion = provider_config.get("emotion") @@ -49,7 +53,7 @@ class ProviderGSVITTS(TTSProvider): else: error_text = await response.text() raise Exception( - f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}" + f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", ) return path diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 5b210835b..5ffc7cc63 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -1,17 +1,22 @@ import json import os import uuid +from collections.abc import AsyncIterator + import aiohttp -from typing import Dict, List, Union, AsyncIterator -from astrbot.core.utils.astrbot_path import get_astrbot_data_path + from astrbot.api import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + from ..entities import ProviderType from ..provider import TTSProvider from ..register import register_provider_adapter @register_provider_adapter( - "minimax_tts_api", "MiniMax TTS API", provider_type=ProviderType.TEXT_TO_SPEECH + "minimax_tts_api", + "MiniMax TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderMiniMaxTTSAPI(TTSProvider): def __init__( @@ -22,19 +27,21 @@ class ProviderMiniMaxTTSAPI(TTSProvider): super().__init__(provider_config, provider_settings) self.chosen_api_key: str = provider_config.get("api_key", "") self.api_base: str = provider_config.get( - "api_base", "https://api.minimax.chat/v1/t2a_v2" + "api_base", + "https://api.minimax.chat/v1/t2a_v2", ) self.group_id: str = provider_config.get("minimax-group-id", "") self.set_model(provider_config.get("model", "")) self.lang_boost: str = provider_config.get("minimax-langboost", "auto") self.is_timber_weight: bool = provider_config.get( - "minimax-is-timber-weight", False + "minimax-is-timber-weight", + False, ) - self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads( + self.timber_weight: list[dict[str, str | int]] = json.loads( provider_config.get( "minimax-timber-weight", '[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]', - ) + ), ) self.voice_setting: dict = { @@ -47,7 +54,8 @@ class ProviderMiniMaxTTSAPI(TTSProvider): "emotion": provider_config.get("minimax-voice-emotion", "neutral"), "latex_read": provider_config.get("minimax-voice-latex", False), "english_normalization": provider_config.get( - "minimax-voice-english-normalization", False + "minimax-voice-english-normalization", + False, ), } @@ -66,7 +74,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider): def _build_tts_stream_body(self, text: str): """构建流式请求体""" - dict_body: Dict[str, object] = { + dict_body: dict[str, object] = { "model": self.model_name, "text": text, "stream": True, @@ -82,44 +90,46 @@ class ProviderMiniMaxTTSAPI(TTSProvider): async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: """进行流式请求""" try: - async with aiohttp.ClientSession() as session: - async with session.post( + async with ( + aiohttp.ClientSession() as session, + session.post( self.concat_base_url, headers=self.headers, data=self._build_tts_stream_body(text), timeout=aiohttp.ClientTimeout(total=60), - ) as response: - response.raise_for_status() + ) as response, + ): + response.raise_for_status() - buffer = b"" - while True: - chunk = await response.content.read(8192) - if not chunk: - break + buffer = b"" + while True: + chunk = await response.content.read(8192) + if not chunk: + break - buffer += chunk + buffer += chunk - while b"\n\n" in buffer: - try: - message, buffer = buffer.split(b"\n\n", 1) - if message.startswith(b"data: "): - try: - data = json.loads(message[6:]) - if "extra_info" in data: - continue - audio = data.get("data", {}).get("audio") - if audio is not None: - yield audio - except json.JSONDecodeError: - logger.warning( - "Failed to parse JSON data from SSE message" - ) + while b"\n\n" in buffer: + try: + message, buffer = buffer.split(b"\n\n", 1) + if message.startswith(b"data: "): + try: + data = json.loads(message[6:]) + if "extra_info" in data: continue - except ValueError: - buffer = buffer[-1024:] + audio = data.get("data", {}).get("audio") + if audio is not None: + yield audio + except json.JSONDecodeError: + logger.warning( + "Failed to parse JSON data from SSE message", + ) + continue + except ValueError: + buffer = buffer[-1024:] except aiohttp.ClientError as e: - raise Exception(f"MiniMax TTS API请求失败: {str(e)}") + raise Exception(f"MiniMax TTS API请求失败: {e!s}") async def _audio_play(self, audio_stream: AsyncIterator[str]) -> bytes: """解码数据流到 audio 比特流""" diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index e6f692a35..368e610ec 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,7 +1,8 @@ from openai import AsyncOpenAI + +from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter -from ..entities import ProviderType @register_provider_adapter( @@ -17,23 +18,20 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), base_url=provider_config.get( - "embedding_api_base", "https://api.openai.com/v1" + "embedding_api_base", + "https://api.openai.com/v1", ), timeout=int(provider_config.get("timeout", 20)), ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") async def get_embedding(self, text: str) -> list[float]: - """ - 获取文本的嵌入 - """ + """获取文本的嵌入""" embedding = await self.client.embeddings.create(input=text, model=self.model) return embedding.data[0].embedding async def get_embeddings(self, texts: list[str]) -> list[list[float]]: - """ - 批量获取文本的嵌入 - """ + """批量获取文本的嵌入""" embeddings = await self.client.embeddings.create(input=texts, model=self.model) return [item.embedding for item in embeddings.data] diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 361d0a4de..1020075af 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,29 +1,30 @@ +import asyncio import base64 +import inspect import json import os -import inspect import random -import asyncio -import astrbot.core.message.components as Comp - -from openai import AsyncOpenAI, AsyncAzureOpenAI -from openai.types.chat.chat_completion import ChatCompletion +from collections.abc import AsyncGenerator +from openai import AsyncAzureOpenAI, AsyncOpenAI from openai._exceptions import NotFoundError, UnprocessableEntityError from openai.lib.streaming.chat._completions import ChatCompletionStreamState -from astrbot.core.utils.io import download_image_by_url -from astrbot.core.message.message_event_result import MessageChain +from openai.types.chat.chat_completion import ChatCompletion -from astrbot.api.provider import Provider +import astrbot.core.message.components as Comp from astrbot import logger -from astrbot.core.provider.func_tool_manager import ToolSet -from typing import List, AsyncGenerator -from ..register import register_provider_adapter +from astrbot.api.provider import Provider +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, ToolCallsResult +from astrbot.core.provider.func_tool_manager import ToolSet +from astrbot.core.utils.io import download_image_by_url + +from ..register import register_provider_adapter @register_provider_adapter( - "openai_chat_completion", "OpenAI API Chat Completion 提供商适配器" + "openai_chat_completion", + "OpenAI API Chat Completion 提供商适配器", ) class ProviderOpenAIOfficial(Provider): def __init__( @@ -38,7 +39,7 @@ class ProviderOpenAIOfficial(Provider): default_persona, ) self.chosen_api_key = None - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): @@ -61,7 +62,7 @@ class ProviderOpenAIOfficial(Provider): ) self.default_params = inspect.signature( - self.client.chat.completions.create + self.client.chat.completions.create, ).parameters.keys() model_config = provider_config.get("model_config", {}) @@ -106,7 +107,7 @@ class ProviderOpenAIOfficial(Provider): model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model tool_list = tools.get_func_desc_openai_style( - omit_empty_parameter_field=omit_empty_param_field + omit_empty_parameter_field=omit_empty_param_field, ) if tool_list: payloads["tools"] = tool_list @@ -114,7 +115,7 @@ class ProviderOpenAIOfficial(Provider): # 不在默认参数中的参数放在 extra_body 中 extra_body = {} to_del = [] - for key in payloads.keys(): + for key in payloads: if key not in self.default_params: extra_body[key] = payloads[key] to_del.append(key) @@ -133,12 +134,14 @@ class ProviderOpenAIOfficial(Provider): del payloads["tools"] completion = await self.client.chat.completions.create( - **payloads, stream=False, extra_body=extra_body + **payloads, + stream=False, + extra_body=extra_body, ) if not isinstance(completion, ChatCompletion): raise Exception( - f"API 返回的 completion 类型错误:{type(completion)}: {completion}。" + f"API 返回的 completion 类型错误:{type(completion)}: {completion}。", ) logger.debug(f"completion: {completion}") @@ -148,14 +151,16 @@ class ProviderOpenAIOfficial(Provider): return llm_response async def _query_stream( - self, payloads: dict, tools: ToolSet + self, + payloads: dict, + tools: ToolSet, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model tool_list = tools.get_func_desc_openai_style( - omit_empty_parameter_field=omit_empty_param_field + omit_empty_parameter_field=omit_empty_param_field, ) if tool_list: payloads["tools"] = tool_list @@ -169,7 +174,7 @@ class ProviderOpenAIOfficial(Provider): extra_body.update(custom_extra_body) to_del = [] - for key in payloads.keys(): + for key in payloads: if key not in self.default_params: extra_body[key] = payloads[key] to_del.append(key) @@ -177,7 +182,9 @@ class ProviderOpenAIOfficial(Provider): del payloads[key] stream = await self.client.chat.completions.create( - **payloads, stream=True, extra_body=extra_body + **payloads, + stream=True, + extra_body=extra_body, ) llm_response = LLMResponse("assistant", is_chunk=True) @@ -196,7 +203,7 @@ class ProviderOpenAIOfficial(Provider): if delta.content: completion_text = delta.content llm_response.result_chain = MessageChain( - chain=[Comp.Plain(completion_text)] + chain=[Comp.Plain(completion_text)], ) yield llm_response @@ -247,7 +254,7 @@ class ProviderOpenAIOfficial(Provider): if choice.finish_reason == "content_filter": raise Exception( - "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。" + "API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。", ) if llm_response.completion_text is None and not llm_response.tools_call_args: @@ -305,14 +312,14 @@ class ProviderOpenAIOfficial(Provider): context_query: list, func_tool: ToolSet, chosen_key: str, - available_api_keys: List[str], + available_api_keys: list[str], retry_cnt: int, max_retries: int, ) -> tuple: """处理API错误并尝试恢复""" if "429" in str(e): logger.warning( - f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" + f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}", ) # 最后一次不等待 if retry_cnt < max_retries - 1: @@ -328,11 +335,10 @@ class ProviderOpenAIOfficial(Provider): context_query, func_tool, ) - else: - raise e - elif "maximum context length" in str(e): + raise e + if "maximum context length" in str(e): logger.warning( - f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}", ) await self.pop_record(context_query) payloads["messages"] = context_query @@ -344,7 +350,7 @@ class ProviderOpenAIOfficial(Provider): context_query, func_tool, ) - elif "The model is not a VLM" in str(e): # siliconcloud + if "The model is not a VLM" in str(e): # siliconcloud # 尝试删除所有 image new_contexts = await self._remove_image_from_context(context_query) payloads["messages"] = new_contexts @@ -357,32 +363,30 @@ class ProviderOpenAIOfficial(Provider): context_query, func_tool, ) - elif ( + if ( "Function calling is not enabled" in str(e) or ("tool" in str(e).lower() and "support" in str(e).lower()) or ("function" in str(e).lower() and "support" in str(e).lower()) ): # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 logger.info( - f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" + f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。", ) - if "tools" in payloads: - del payloads["tools"] + payloads.pop("tools", None) return False, chosen_key, available_api_keys, payloads, context_query, None - else: - logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") + logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") - if "tool" in str(e).lower() and "support" in str(e).lower(): - logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all") + if "tool" in str(e).lower() and "support" in str(e).lower(): + logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all") - if "Connection error." in str(e): - proxy = os.environ.get("http_proxy", None) - if proxy: - logger.error( - f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" - ) + if "Connection error." in str(e): + proxy = os.environ.get("http_proxy", None) + if proxy: + logger.error( + f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}", + ) - raise e + raise e async def text_chat( self, @@ -522,10 +526,8 @@ class ProviderOpenAIOfficial(Provider): raise Exception("未知错误") raise last_exception - async def _remove_image_from_context(self, contexts: List): - """ - 从上下文中删除所有带有 image 的记录 - """ + async def _remove_image_from_context(self, contexts: list): + """从上下文中删除所有带有 image 的记录""" new_contexts = [] for context in contexts: @@ -546,14 +548,16 @@ class ProviderOpenAIOfficial(Provider): def get_current_key(self) -> str: return self.client.api_key - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key): self.client.api_key = key async def assemble_context( - self, text: str, image_urls: List[str] | None = None + self, + text: str, + image_urls: list[str] | None = None, ) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: @@ -577,16 +581,13 @@ class ProviderOpenAIOfficial(Provider): { "type": "image_url", "image_url": {"url": image_data}, - } + }, ) return user_content - else: - return {"role": "user", "content": text} + return {"role": "user", "content": text} async def encode_image_bs64(self, image_url: str) -> str: - """ - 将图片转换为 base64 - """ + """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") with open(image_url, "rb") as f: diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index c5fb467b7..d71e98112 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,14 +1,19 @@ import os import uuid -from openai import AsyncOpenAI, NOT_GIVEN -from ..provider import TTSProvider -from ..entities import ProviderType -from ..register import register_provider_adapter + +from openai import NOT_GIVEN, AsyncOpenAI + from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + @register_provider_adapter( - "openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH + "openai_tts_api", + "OpenAI TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderOpenAITTSAPI(TTSProvider): def __init__( @@ -26,7 +31,7 @@ class ProviderOpenAITTSAPI(TTSProvider): self.client = AsyncOpenAI( api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), + base_url=provider_config.get("api_base"), timeout=timeout, ) @@ -36,7 +41,10 @@ class ProviderOpenAITTSAPI(TTSProvider): temp_dir = os.path.join(get_astrbot_data_path(), "temp") path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav") async with self.client.audio.speech.with_streaming_response.create( - model=self.model_name, voice=self.voice, response_format="wav", input=text + model=self.model_name, + voice=self.voice, + response_format="wav", + input=text, ) as response: with open(path, "wb") as f: async for chunk in response.iter_bytes(chunk_size=1024): diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index b6e3331f8..67947c685 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -1,22 +1,24 @@ -""" -Author: diudiu62 +"""Author: diudiu62 Date: 2025-02-24 18:04:18 LastEditTime: 2025-02-25 14:06:30 """ import asyncio -from datetime import datetime import os import re +from datetime import datetime + from funasr_onnx import SenseVoiceSmall from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess -from ..provider import STTProvider -from ..entities import ProviderType -from astrbot.core.utils.io import download_file -from ..register import register_provider_adapter + from astrbot.core import logger +from astrbot.core.utils.io import download_file from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav +from ..entities import ProviderType +from ..provider import STTProvider +from ..register import register_provider_adapter + @register_provider_adapter( "sensevoice_stt_selfhost", @@ -30,7 +32,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.set_model(provider_config.get("stt_model", None)) + self.set_model(provider_config.get("stt_model")) self.model = None self.is_emotion = provider_config.get("is_emotion", False) @@ -39,7 +41,8 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): # 将模型加载放到线程池中执行 self.model = await asyncio.get_event_loop().run_in_executor( - None, lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16) + None, + lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16), ) logger.info("SenseVoice 模型加载完成。") @@ -55,8 +58,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): if silk_header in file_header: return True - else: - return False + return False async def get_text(self, audio_url: str) -> str: try: diff --git a/astrbot/core/provider/sources/vllm_rerank_source.py b/astrbot/core/provider/sources/vllm_rerank_source.py index 2620e3456..3e6f3d33c 100644 --- a/astrbot/core/provider/sources/vllm_rerank_source.py +++ b/astrbot/core/provider/sources/vllm_rerank_source.py @@ -1,8 +1,10 @@ import aiohttp + from astrbot import logger + +from ..entities import ProviderType, RerankResult from ..provider import RerankProvider from ..register import register_provider_adapter -from ..entities import ProviderType, RerankResult @register_provider_adapter( @@ -30,7 +32,10 @@ class VLLMRerankProvider(RerankProvider): ) async def rerank( - self, query: str, documents: list[str], top_n: int | None = None + self, + query: str, + documents: list[str], + top_n: int | None = None, ) -> list[RerankResult]: payload = { "query": query, @@ -40,14 +45,15 @@ class VLLMRerankProvider(RerankProvider): if top_n is not None: payload["top_n"] = top_n async with self.client.post( - f"{self.base_url}/v1/rerank", json=payload + f"{self.base_url}/v1/rerank", + json=payload, ) as response: response_data = await response.json() results = response_data.get("results", []) if not results: logger.warning( - f"Rerank API 返回了空的列表数据。原始响应: {response_data}" + f"Rerank API 返回了空的列表数据。原始响应: {response_data}", ) return [ diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py index 12e7ed9cd..f5d758f5c 100644 --- a/astrbot/core/provider/sources/volcengine_tts.py +++ b/astrbot/core/provider/sources/volcengine_tts.py @@ -1,18 +1,23 @@ -import uuid +import asyncio import base64 import json import os import traceback -import asyncio +import uuid + import aiohttp -from ..provider import TTSProvider -from ..entities import ProviderType -from ..register import register_provider_adapter + from astrbot import logger +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + @register_provider_adapter( - "volcengine_tts", "火山引擎 TTS", provider_type=ProviderType.TEXT_TO_SPEECH + "volcengine_tts", + "火山引擎 TTS", + provider_type=ProviderType.TEXT_TO_SPEECH, ) class ProviderVolcengineTTS(TTSProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: @@ -23,7 +28,8 @@ class ProviderVolcengineTTS(TTSProvider): self.voice_type = provider_config.get("volcengine_voice_type", "") self.speed_ratio = provider_config.get("volcengine_speed_ratio", 1.0) self.api_base = provider_config.get( - "api_base", "https://openspeech.bytedance.com/api/v1/tts" + "api_base", + "https://openspeech.bytedance.com/api/v1/tts", ) self.timeout = provider_config.get("timeout", 20) @@ -66,43 +72,44 @@ class ProviderVolcengineTTS(TTSProvider): logger.debug(f"请求体: {json.dumps(payload, ensure_ascii=False)[:100]}...") try: - async with aiohttp.ClientSession() as session: - async with session.post( + async with ( + aiohttp.ClientSession() as session, + session.post( self.api_base, data=json.dumps(payload), headers=headers, timeout=self.timeout, - ) as response: - logger.debug(f"响应状态码: {response.status}") + ) as response, + ): + logger.debug(f"响应状态码: {response.status}") - response_text = await response.text() - logger.debug(f"响应内容: {response_text[:200]}...") + response_text = await response.text() + logger.debug(f"响应内容: {response_text[:200]}...") - if response.status == 200: - resp_data = json.loads(response_text) + if response.status == 200: + resp_data = json.loads(response_text) - if "data" in resp_data: - audio_data = base64.b64decode(resp_data["data"]) + if "data" in resp_data: + audio_data = base64.b64decode(resp_data["data"]) - os.makedirs("data/temp", exist_ok=True) + os.makedirs("data/temp", exist_ok=True) - file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3" + file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3" - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, lambda: open(file_path, "wb").write(audio_data) - ) - - return file_path - else: - error_msg = resp_data.get("message", "未知错误") - raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}") - else: - raise Exception( - f"火山引擎 TTS API 请求失败: {response.status}, {response_text}" + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + lambda: open(file_path, "wb").write(audio_data), ) + return file_path + error_msg = resp_data.get("message", "未知错误") + raise Exception(f"火山引擎 TTS API 返回错误: {error_msg}") + raise Exception( + f"火山引擎 TTS API 请求失败: {response.status}, {response_text}", + ) + except Exception as e: error_details = traceback.format_exc() logger.debug(f"火山引擎 TTS 异常详情: {error_details}") - raise Exception(f"火山引擎 TTS 异常: {str(e)}") + raise Exception(f"火山引擎 TTS 异常: {e!s}") diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index dfe286978..8f6d9e292 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,13 +1,16 @@ -import uuid import os -from openai import AsyncOpenAI, NOT_GIVEN -from ..provider import STTProvider -from ..entities import ProviderType -from astrbot.core.utils.io import download_file -from ..register import register_provider_adapter +import uuid + +from openai import NOT_GIVEN, AsyncOpenAI + from astrbot.core import logger -from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_file +from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav + +from ..entities import ProviderType +from ..provider import STTProvider +from ..register import register_provider_adapter @register_provider_adapter( @@ -26,11 +29,11 @@ class ProviderOpenAIWhisperAPI(STTProvider): self.client = AsyncOpenAI( api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), + base_url=provider_config.get("api_base"), timeout=provider_config.get("timeout", NOT_GIVEN), ) - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model")) async def _is_silk_file(self, file_path): silk_header = b"SILK" @@ -39,11 +42,10 @@ class ProviderOpenAIWhisperAPI(STTProvider): if silk_header in file_header: return True - else: - return False + return False async def get_text(self, audio_url: str) -> str: - """only supports mp3, mp4, mpeg, m4a, wav, webm""" + """Only supports mp3, mp4, mpeg, m4a, wav, webm""" is_tencent = False if audio_url.startswith("http"): diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 7cb76cc4c..fbdc7d626 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,14 +1,17 @@ -import uuid -import os import asyncio +import os +import uuid + import whisper -from ..provider import STTProvider -from ..entities import ProviderType -from astrbot.core.utils.io import download_file -from ..register import register_provider_adapter + from astrbot.core import logger -from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_file +from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav + +from ..entities import ProviderType +from ..provider import STTProvider +from ..register import register_provider_adapter @register_provider_adapter( @@ -23,14 +26,16 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model")) self.model = None async def initialize(self): loop = asyncio.get_event_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") self.model = await loop.run_in_executor( - None, whisper.load_model, self.model_name + None, + whisper.load_model, + self.model_name, ) logger.info("Whisper 模型加载完成。") @@ -41,8 +46,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): if silk_header in file_header: return True - else: - return False + return False async def get_text(self, audio_url: str) -> str: loop = asyncio.get_event_loop() diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 3c27d7c3a..29f3ab095 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -1,10 +1,12 @@ from xinference_client.client.restful.async_restful_client import ( AsyncClient as Client, ) + from astrbot import logger + +from ..entities import ProviderType, RerankResult from ..provider import RerankProvider from ..register import register_provider_adapter -from ..entities import ProviderType, RerankResult @register_provider_adapter( @@ -23,7 +25,8 @@ class XinferenceRerankProvider(RerankProvider): self.model_name = provider_config.get("rerank_model", "BAAI/bge-reranker-base") self.api_key = provider_config.get("rerank_api_key") self.launch_model_if_not_running = provider_config.get( - "launch_model_if_not_running", False + "launch_model_if_not_running", + False, ) self.client = None self.model = None @@ -42,7 +45,7 @@ class XinferenceRerankProvider(RerankProvider): for uid, model_spec in running_models.items(): if model_spec.get("model_name") == self.model_name: logger.info( - f"Model '{self.model_name}' is already running with UID: {uid}" + f"Model '{self.model_name}' is already running with UID: {uid}", ) self.model_uid = uid break @@ -51,12 +54,13 @@ class XinferenceRerankProvider(RerankProvider): if self.launch_model_if_not_running: logger.info(f"Launching {self.model_name} model...") self.model_uid = await self.client.launch_model( - model_name=self.model_name, model_type="rerank" + model_name=self.model_name, + model_type="rerank", ) logger.info("Model launched.") else: logger.warning( - f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available." + f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available.", ) return @@ -66,12 +70,16 @@ class XinferenceRerankProvider(RerankProvider): except Exception as e: logger.error(f"Failed to initialize Xinference model: {e}") logger.debug( - f"Xinference initialization failed with exception: {e}", exc_info=True + f"Xinference initialization failed with exception: {e}", + exc_info=True, ) self.model = None async def rerank( - self, query: str, documents: list[str], top_n: int | None = None + self, + query: str, + documents: list[str], + top_n: int | None = None, ) -> list[RerankResult]: if not self.model: logger.error("Xinference rerank model is not initialized.") @@ -83,7 +91,7 @@ class XinferenceRerankProvider(RerankProvider): if not results: logger.warning( - f"Rerank API returned an empty list. Original response: {response}" + f"Rerank API returned an empty list. Original response: {response}", ) return [ diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index d8e908159..9c69a0039 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -1,15 +1,18 @@ -import uuid import os +import uuid + import aiohttp from xinference_client.client.restful.async_restful_client import ( AsyncClient as Client, ) -from ..provider import STTProvider -from ..entities import ProviderType -from ..register import register_provider_adapter + from astrbot.core import logger -from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav + +from ..entities import ProviderType +from ..provider import STTProvider +from ..register import register_provider_adapter @register_provider_adapter( @@ -28,7 +31,8 @@ class ProviderXinferenceSTT(STTProvider): self.model_name = provider_config.get("model", "whisper-large-v3") self.api_key = provider_config.get("api_key") self.launch_model_if_not_running = provider_config.get( - "launch_model_if_not_running", False + "launch_model_if_not_running", + False, ) self.client = None self.model_uid = None @@ -46,7 +50,7 @@ class ProviderXinferenceSTT(STTProvider): for uid, model_spec in running_models.items(): if model_spec.get("model_name") == self.model_name: logger.info( - f"Model '{self.model_name}' is already running with UID: {uid}" + f"Model '{self.model_name}' is already running with UID: {uid}", ) self.model_uid = uid break @@ -55,19 +59,21 @@ class ProviderXinferenceSTT(STTProvider): if self.launch_model_if_not_running: logger.info(f"Launching {self.model_name} model...") self.model_uid = await self.client.launch_model( - model_name=self.model_name, model_type="audio" + model_name=self.model_name, + model_type="audio", ) logger.info("Model launched.") else: logger.warning( - f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available." + f"Model '{self.model_name}' is not running and auto-launch is disabled. Provider will not be available.", ) return except Exception as e: logger.error(f"Failed to initialize Xinference model: {e}") logger.debug( - f"Xinference initialization failed with exception: {e}", exc_info=True + f"Xinference initialization failed with exception: {e}", + exc_info=True, ) async def get_text(self, audio_url: str) -> str: @@ -90,16 +96,15 @@ class ProviderXinferenceSTT(STTProvider): audio_bytes = await resp.read() else: logger.error( - f"Failed to download audio from {audio_url}, status: {resp.status}" + f"Failed to download audio from {audio_url}, status: {resp.status}", ) return "" + elif os.path.exists(audio_url): + with open(audio_url, "rb") as f: + audio_bytes = f.read() else: - if os.path.exists(audio_url): - with open(audio_url, "rb") as f: - audio_bytes = f.read() - else: - logger.error(f"File not found: {audio_url}") - return "" + logger.error(f"File not found: {audio_url}") + return "" if not audio_bytes: logger.error("Audio bytes are empty.") @@ -145,23 +150,28 @@ class ProviderXinferenceSTT(STTProvider): data = aiohttp.FormData() data.add_field("model", self.model_uid) data.add_field( - "file", audio_bytes, filename="audio.wav", content_type="audio/wav" + "file", + audio_bytes, + filename="audio.wav", + content_type="audio/wav", ) async with self.client.session.post( - url, data=data, headers=headers, timeout=self.timeout + url, + data=data, + headers=headers, + timeout=self.timeout, ) as resp: if resp.status == 200: result = await resp.json() text = result.get("text", "") logger.debug(f"Xinference STT result: {text}") return text - else: - error_text = await resp.text() - logger.error( - f"Xinference STT transcription failed with status {resp.status}: {error_text}" - ) - return "" + error_text = await resp.text() + logger.error( + f"Xinference STT transcription failed with status {resp.status}: {error_text}", + ) + return "" except Exception as e: logger.error(f"Xinference STT failed: {e}") diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 70e06d0d5..e27db7405 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -1,10 +1,11 @@ +from astrbot.core import html_renderer +from astrbot.core.provider import Provider +from astrbot.core.star.star_tools import StarTools +from astrbot.core.utils.command_parser import CommandParserMixin + +from .context import Context from .star import StarMetadata, star_map, star_registry from .star_manager import PluginManager -from .context import Context -from astrbot.core.provider import Provider -from astrbot.core.utils.command_parser import CommandParserMixin -from astrbot.core import html_renderer -from astrbot.core.star.star_tools import StarTools class Star(CommandParserMixin): @@ -36,24 +37,28 @@ class Star(CommandParserMixin): ) async def html_render( - self, tmpl: str, data: dict, return_url=True, options: dict | None = None + self, + tmpl: str, + data: dict, + return_url=True, + options: dict | None = None, ) -> str: """渲染 HTML""" return await html_renderer.render_custom_template( - tmpl, data, return_url=return_url, options=options + tmpl, + data, + return_url=return_url, + options=options, ) async def initialize(self): """当插件被激活时会调用这个方法""" - pass async def terminate(self): """当插件被禁用、重载插件时会调用这个方法""" - pass def __del__(self): """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" - pass -__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"] +__all__ = ["Context", "PluginManager", "Provider", "Star", "StarMetadata", "StarTools"] diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index 23a522dc1..a9af974c5 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -1,23 +1,20 @@ -""" -此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta -""" +"""此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta""" -from typing import Union -import os import json +import os + from astrbot.core.utils.astrbot_path import get_astrbot_data_path -def load_config(namespace: str) -> Union[dict, bool]: - """ - 从配置文件中加载配置。 +def load_config(namespace: str) -> dict | bool: + """从配置文件中加载配置。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。 """ path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): return False - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: ret = {} data = json.load(f) for k in data: @@ -26,8 +23,7 @@ def load_config(namespace: str) -> Union[dict, bool]: def put_config(namespace: str, name: str, key: str, value, description: str): - """ - 将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 + """将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 name: str, 配置项的显示名字。 key: str, 配置项的键。 @@ -51,7 +47,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): if not os.path.exists(path): with open(path, "w", encoding="utf-8-sig") as f: f.write("{}") - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: d = json.load(f) assert isinstance(d, dict) if key not in d: @@ -69,8 +65,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): def update_config(namespace: str, key: str, value): - """ - 更新配置文件中的配置项。 + """更新配置文件中的配置项。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 key: str, 配置项的键。 value: str, int, float, bool, list, 配置项的值。 @@ -78,7 +73,7 @@ def update_config(namespace: str, key: str, value): path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。") - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: d = json.load(f) assert isinstance(d, dict) if key not in d: diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index d878a64c5..620e7e907 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,43 +1,43 @@ from asyncio import Queue +from collections.abc import Awaitable, Callable +from typing import Any -from astrbot.core.provider.provider import ( - Provider, - TTSProvider, - STTProvider, - EmbeddingProvider, - RerankProvider, -) -from astrbot.core.provider.entities import ProviderType -from astrbot.core.db import BaseDatabase +from deprecated import deprecated + +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.provider.func_tool_manager import FunctionToolManager, FunctionTool -from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.db import BaseDatabase +from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.manager import ProviderManager +from astrbot.core.persona_mgr import PersonaManager from astrbot.core.platform import Platform +from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager -from astrbot.core.astrbot_config_mgr import AstrBotConfigManager -from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager -from astrbot.core.persona_mgr import PersonaManager -from .star import star_registry, StarMetadata, star_map -from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType +from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager +from astrbot.core.provider.manager import ProviderManager +from astrbot.core.provider.provider import ( + EmbeddingProvider, + Provider, + RerankProvider, + STTProvider, + TTSProvider, +) +from astrbot.core.star.filter.platform_adapter_type import ( + ADAPTER_NAME_2_TYPE, + PlatformAdapterType, +) + from .filter.command import CommandFilter from .filter.regex import RegexFilter -from typing import Any -from collections.abc import Awaitable, Callable -from astrbot.core.conversation_mgr import ConversationManager -from astrbot.core.star.filter.platform_adapter_type import ( - PlatformAdapterType, - ADAPTER_NAME_2_TYPE, -) -from deprecated import deprecated +from .star import StarMetadata, star_map, star_registry +from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry class Context: - """ - 暴露给插件的接口上下文。 - """ + """暴露给插件的接口上下文。""" registered_web_apis: list = [] @@ -91,6 +91,7 @@ class Context: Returns: 如果没找到,会返回 False + """ return self.provider_manager.llm_tools.activate_llm_tool(name, star_map) @@ -98,17 +99,18 @@ class Context: """停用一个已经注册的函数调用工具。 Returns: - 如果没找到,会返回 False""" + 如果没找到,会返回 False + + """ return self.provider_manager.llm_tools.deactivate_llm_tool(name) def register_provider(self, provider: Provider): - """ - 注册一个 LLM Provider(Chat_Completion 类型)。 - """ + """注册一个 LLM Provider(Chat_Completion 类型)。""" self.provider_manager.provider_insts.append(provider) def get_provider_by_id( - self, provider_id: str + self, + provider_id: str, ) -> ( Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None ): @@ -133,11 +135,11 @@ class Context: return self.provider_manager.embedding_provider_insts def get_using_provider(self, umo: str | None = None) -> Provider | None: - """ - 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 + """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 Args: umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 + """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.CHAT_COMPLETION, @@ -148,11 +150,11 @@ class Context: return prov def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None: - """ - 获取当前使用的用于 TTS 任务的 Provider。 + """获取当前使用的用于 TTS 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.TEXT_TO_SPEECH, @@ -163,11 +165,11 @@ class Context: return prov def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: - """ - 获取当前使用的用于 STT 任务的 Provider。 + """获取当前使用的用于 STT 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 + """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.SPEECH_TO_TEXT, @@ -182,23 +184,19 @@ class Context: if not umo: # using default config return self._config - else: - return self.astrbot_config_mgr.get_conf(umo) + return self.astrbot_config_mgr.get_conf(umo) def get_db(self) -> BaseDatabase: """获取 AstrBot 数据库。""" return self._db def get_event_queue(self) -> Queue: - """ - 获取事件队列。 - """ + """获取事件队列。""" return self._event_queue @deprecated(version="4.0.0", reason="Use get_platform_inst instead") def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: - """ - 获取指定类型的平台适配器。 + """获取指定类型的平台适配器。 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) """ @@ -207,32 +205,32 @@ class Context: if isinstance(platform_type, str): if name == platform_type: return platform - else: - if ( - name in ADAPTER_NAME_2_TYPE - and ADAPTER_NAME_2_TYPE[name] & platform_type - ): - return platform + elif ( + name in ADAPTER_NAME_2_TYPE + and ADAPTER_NAME_2_TYPE[name] & platform_type + ): + return platform def get_platform_inst(self, platform_id: str) -> Platform | None: - """ - 获取指定 ID 的平台适配器实例。 + """获取指定 ID 的平台适配器实例。 Args: platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 Returns: Platform: 平台适配器实例,如果未找到则返回 None。 + """ for platform in self.platform_manager.platform_insts: if platform.meta().id == platform_id: return platform async def send_message( - self, session: str | MessageSesion, message_chain: MessageChain + self, + session: str | MessageSesion, + message_chain: MessageChain, ) -> bool: - """ - 根据 session(unified_msg_origin) 主动发送消息。 + """根据 session(unified_msg_origin) 主动发送消息。 @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 @param message_chain: 消息链。 @@ -243,7 +241,6 @@ class Context: NOTE: qq_official(QQ 官方 API 平台) 不支持此方法 """ - if isinstance(session, str): try: session = MessageSesion.from_str(session) @@ -272,8 +269,7 @@ class Context: desc: str, func_obj: Callable[..., Awaitable[Any]], ) -> None: - """ - 为函数调用(function-calling / tools-use)添加工具。 + """为函数调用(function-calling / tools-use)添加工具。 @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @@ -308,8 +304,7 @@ class Context: use_regex=False, ignore_prefix=False, ): - """ - 注册一个命令。 + """注册一个命令。 [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 @@ -333,18 +328,20 @@ class Context: md.event_filters.append(RegexFilter(regex=command_name)) else: md.event_filters.append( - CommandFilter(command_name=command_name, handler_md=md) + CommandFilter(command_name=command_name, handler_md=md), ) star_handlers_registry.append(md) def register_task(self, task: Awaitable, desc: str): - """ - 注册一个异步任务。 - """ + """注册一个异步任务。""" self._register_tasks.append(task) def register_web_api( - self, route: str, view_handler: Awaitable, methods: list, desc: str + self, + route: str, + view_handler: Awaitable, + methods: list, + desc: str, ): for idx, api in enumerate(self.registered_web_apis): if api[0] == route and methods == api[2]: diff --git a/astrbot/core/star/filter/__init__.py b/astrbot/core/star/filter/__init__.py index c2f78e275..e550017ae 100644 --- a/astrbot/core/star/filter/__init__.py +++ b/astrbot/core/star/filter/__init__.py @@ -1,7 +1,8 @@ import abc -from astrbot.core.platform.message_type import MessageType -from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.message_type import MessageType class HandlerFilter(abc.ABC): @@ -11,4 +12,4 @@ class HandlerFilter(abc.ABC): raise NotImplementedError -__all__ = ["HandlerFilter", "MessageType", "AstrMessageEvent", "AstrBotConfig"] +__all__ = ["AstrBotConfig", "AstrMessageEvent", "HandlerFilter", "MessageType"] diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 3d67cb750..6e0283a0e 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -1,20 +1,20 @@ -import re import inspect +import re import types import typing -from typing import List, Any, Type, Dict -from . import HandlerFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent +from typing import Any + from astrbot.core.config import AstrBotConfig -from .custom_filter import CustomFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent + from ..star_handler import StarHandlerMetadata +from . import HandlerFilter +from .custom_filter import CustomFilter class GreedyStr(str): """标记指令完成其他参数接收后的所有剩余文本。""" - pass - def unwrap_optional(annotation) -> tuple: """去掉 Optional[T] / Union[T, None] / T|None,返回 T""" @@ -22,10 +22,9 @@ def unwrap_optional(annotation) -> tuple: non_none_args = [a for a in args if a is not type(None)] if len(non_none_args) == 1: return (non_none_args[0],) - elif len(non_none_args) > 1: + if len(non_none_args) > 1: return tuple(non_none_args) - else: - return () + return () # 标准指令受到 wake_prefix 的制约。 @@ -37,14 +36,14 @@ class CommandFilter(HandlerFilter): command_name: str, alias: set | None = None, handler_md: StarHandlerMetadata | None = None, - parent_command_names: List[str] = [""], + parent_command_names: list[str] = [""], ): self.command_name = command_name self.alias = alias if alias else set() self.parent_command_names = parent_command_names if handler_md: self.init_handler_md(handler_md) - self.custom_filter_list: List[CustomFilter] = [] + self.custom_filter_list: list[CustomFilter] = [] # Cache for complete command names list self._cmpl_cmd_names: list | None = None @@ -89,8 +88,10 @@ class CommandFilter(HandlerFilter): return True def validate_and_convert_params( - self, params: List[Any], param_type: Dict[str, Type] - ) -> Dict[str, Any]: + self, + params: list[Any], + param_type: dict[str, type], + ) -> dict[str, Any]: """将参数列表 params 根据 param_type 转换为参数字典。""" result = {} param_items = list(param_type.items()) @@ -101,7 +102,7 @@ class CommandFilter(HandlerFilter): # GreedyStr 必须是最后一个参数 if i != len(param_items) - 1: raise ValueError( - f"参数 '{param_name}' (GreedyStr) 必须是最后一个参数。" + f"参数 '{param_name}' (GreedyStr) 必须是最后一个参数。", ) # 将剩余的所有部分合并成一个字符串 @@ -111,17 +112,16 @@ class CommandFilter(HandlerFilter): # 没有 GreedyStr 的情况 if i >= len(params): if ( - isinstance(param_type_or_default_val, (Type, types.UnionType)) + isinstance(param_type_or_default_val, (type, types.UnionType)) or typing.get_origin(param_type_or_default_val) is typing.Union or param_type_or_default_val is inspect.Parameter.empty ): # 是类型 raise ValueError( - f"必要参数缺失。该指令完整参数: {self.print_types()}" + f"必要参数缺失。该指令完整参数: {self.print_types()}", ) - else: - # 是默认值 - result[param_name] = param_type_or_default_val + # 是默认值 + result[param_name] = param_type_or_default_val else: # 尝试强制转换 try: @@ -142,7 +142,7 @@ class CommandFilter(HandlerFilter): result[param_name] = False else: raise ValueError( - f"参数 {param_name} 必须是布尔值(true/false, yes/no, 1/0)。" + f"参数 {param_name} 必须是布尔值(true/false, yes/no, 1/0)。", ) elif isinstance(param_type_or_default_val, int): result[param_name] = int(params[i]) @@ -165,7 +165,7 @@ class CommandFilter(HandlerFilter): result[param_name] = param_type_or_default_val(params[i]) except ValueError: raise ValueError( - f"参数 {param_name} 类型错误。完整参数: {self.print_types()}" + f"参数 {param_name} 类型错误。完整参数: {self.print_types()}", ) return result diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index e01fa2c58..0f5c19ec5 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import List, Union +from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent + from . import HandlerFilter from .command import CommandFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.config import AstrBotConfig from .custom_filter import CustomFilter @@ -18,25 +18,27 @@ class CommandGroupFilter(HandlerFilter): ): self.group_name = group_name self.alias = alias if alias else set() - self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = [] - self.custom_filter_list: List[CustomFilter] = [] + self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = [] + self.custom_filter_list: list[CustomFilter] = [] self.parent_group = parent_group # Cache for complete command names list self._cmpl_cmd_names: list | None = None def add_sub_command_filter( - self, sub_command_filter: Union[CommandFilter, CommandGroupFilter] + self, + sub_command_filter: CommandFilter | CommandGroupFilter, ): self.sub_command_filters.append(sub_command_filter) def add_custom_filter(self, custom_filter: CustomFilter): self.custom_filter_list.append(custom_filter) - def get_complete_command_names(self) -> List[str]: + def get_complete_command_names(self) -> list[str]: """遍历父节点获取完整的指令名。 - 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。""" + 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。 + """ if self._cmpl_cmd_names is not None: return self._cmpl_cmd_names @@ -59,7 +61,7 @@ class CommandGroupFilter(HandlerFilter): # 以树的形式打印出来 def print_cmd_tree( self, - sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], + sub_command_filters: list[CommandFilter | CommandGroupFilter], prefix: str = "", event: AstrMessageEvent | None = None, cfg: AstrBotConfig | None = None, @@ -125,7 +127,7 @@ class CommandGroupFilter(HandlerFilter): + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) ) raise ValueError( - f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree + f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree, ) return self.startswith(event.message_str) diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index 9a76b74f2..d57b5cac0 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -1,8 +1,9 @@ -from abc import abstractmethod, ABCMeta +from abc import ABCMeta, abstractmethod + +from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent from . import HandlerFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.config import AstrBotConfig class CustomFilterMeta(ABCMeta): @@ -38,7 +39,7 @@ class CustomFilterOr(CustomFilter): super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( - "CustomFilter lass can only operate with other CustomFilter." + "CustomFilter lass can only operate with other CustomFilter.", ) self.filter1 = filter1 self.filter2 = filter2 @@ -52,7 +53,7 @@ class CustomFilterAnd(CustomFilter): super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( - "CustomFilter lass can only operate with other CustomFilter." + "CustomFilter lass can only operate with other CustomFilter.", ) self.filter1 = filter1 self.filter2 = filter2 diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py index ce36ec9ed..7f350bd38 100644 --- a/astrbot/core/star/filter/event_message_type.py +++ b/astrbot/core/star/filter/event_message_type.py @@ -1,9 +1,11 @@ import enum -from . import HandlerFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType +from . import HandlerFilter + class EventMessageType(enum.Flag): GROUP_MESSAGE = enum.auto() diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 307b492a4..3374544c2 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -1,7 +1,9 @@ import enum -from . import HandlerFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from . import HandlerFilter class PermissionType(enum.Flag): diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 4c5510783..1182ff9b0 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -1,7 +1,9 @@ import enum -from . import HandlerFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from . import HandlerFilter class PlatformAdapterType(enum.Flag): diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py index af9cb3a5a..cd5bebdb4 100644 --- a/astrbot/core/star/filter/regex.py +++ b/astrbot/core/star/filter/regex.py @@ -1,7 +1,9 @@ import re -from . import HandlerFilter -from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +from . import HandlerFilter # 正则表达式过滤器不会受到 wake_prefix 的制约。 diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index 0519e8ca1..15fe1e9c5 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -1,37 +1,37 @@ from .star import register_star from .star_handler import ( + register_after_message_sent, + register_agent, register_command, register_command_group, - register_event_message_type, - register_platform_adapter_type, - register_regex, - register_permission_type, register_custom_filter, + register_event_message_type, + register_llm_tool, register_on_astrbot_loaded, - register_on_platform_loaded, + register_on_decorating_result, register_on_llm_request, register_on_llm_response, - register_llm_tool, - register_agent, - register_on_decorating_result, - register_after_message_sent, + register_on_platform_loaded, + register_permission_type, + register_platform_adapter_type, + register_regex, ) __all__ = [ - "register_star", + "register_after_message_sent", + "register_agent", "register_command", "register_command_group", - "register_event_message_type", - "register_platform_adapter_type", - "register_regex", - "register_permission_type", "register_custom_filter", + "register_event_message_type", + "register_llm_tool", "register_on_astrbot_loaded", - "register_on_platform_loaded", + "register_on_decorating_result", "register_on_llm_request", "register_on_llm_response", - "register_llm_tool", - "register_agent", - "register_on_decorating_result", - "register_after_message_sent", + "register_on_platform_loaded", + "register_permission_type", + "register_platform_adapter_type", + "register_regex", + "register_star", ] diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py index a5190dd5c..617cd5ff7 100644 --- a/astrbot/core/star/register/star.py +++ b/astrbot/core/star/register/star.py @@ -6,7 +6,11 @@ _warned_register_star = False def register_star( - name: str, author: str, desc: str, version: str, repo: str | None = None + name: str, + author: str, + desc: str, + version: str, + repo: str | None = None, ): """注册一个插件(Star)。 @@ -29,8 +33,8 @@ def register_star( ... 帮助信息会被自动提取。使用 `/plugin <插件名> 可以查看帮助信息。` - """ + """ global _warned_register_star if not _warned_register_star: _warned_register_star = True diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index d1c5a6dce..7ce5febd5 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -1,26 +1,30 @@ from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + import docstring_parser -from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType -from ..filter.command import CommandFilter -from ..filter.command_group import CommandGroupFilter -from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType -from ..filter.platform_adapter_type import ( - PlatformAdapterTypeFilter, - PlatformAdapterType, -) -from ..filter.permission import PermissionTypeFilter, PermissionType -from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr -from ..filter.regex import RegexFilter -from typing import Awaitable, Any, Callable -from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES -from astrbot.core.provider.register import llm_tools +from astrbot.core import logger from astrbot.core.agent.agent import Agent -from astrbot.core.agent.tool import FunctionTool from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.hooks import BaseAgentRunHooks +from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core import logger +from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES +from astrbot.core.provider.register import llm_tools + +from ..filter.command import CommandFilter +from ..filter.command_group import CommandGroupFilter +from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr +from ..filter.event_message_type import EventMessageType, EventMessageTypeFilter +from ..filter.permission import PermissionType, PermissionTypeFilter +from ..filter.platform_adapter_type import ( + PlatformAdapterType, + PlatformAdapterTypeFilter, +) +from ..filter.regex import RegexFilter +from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: @@ -39,27 +43,26 @@ def get_handler_or_create( md = star_handlers_registry.get_handler_by_full_name(handler_full_name) if md: return md - else: - md = StarHandlerMetadata( - event_type=event_type, - handler_full_name=handler_full_name, - handler_name=handler.__name__, - handler_module_path=handler.__module__, - handler=handler, - event_filters=[], - ) + md = StarHandlerMetadata( + event_type=event_type, + handler_full_name=handler_full_name, + handler_name=handler.__name__, + handler_module_path=handler.__module__, + handler=handler, + event_filters=[], + ) - # 插件handler的附加额外信息 - if handler.__doc__: - md.desc = handler.__doc__.strip() - if "desc" in kwargs: - md.desc = kwargs["desc"] - del kwargs["desc"] - md.extras_configs = kwargs + # 插件handler的附加额外信息 + if handler.__doc__: + md.desc = handler.__doc__.strip() + if "desc" in kwargs: + md.desc = kwargs["desc"] + del kwargs["desc"] + md.extras_configs = kwargs - if not dont_add: - star_handlers_registry.append(md) - return md + if not dont_add: + star_handlers_registry.append(md) + return md def register_command( @@ -78,20 +81,22 @@ def register_command( command_name.parent_group.get_complete_command_names() ) new_command = CommandFilter( - sub_command, alias, None, parent_command_names=parent_command_names + sub_command, + alias, + None, + parent_command_names=parent_command_names, ) command_name.parent_group.add_sub_command_filter(new_command) else: logger.warning( - f"注册指令{command_name} 的子指令时未提供 sub_command 参数。" + f"注册指令{command_name} 的子指令时未提供 sub_command 参数。", ) + # 裸指令 + elif command_name is None: + logger.warning("注册裸指令时未提供 command_name 参数。") else: - # 裸指令 - if command_name is None: - logger.warning("注册裸指令时未提供 command_name 参数。") - else: - new_command = CommandFilter(command_name, alias, None) - add_to_event_filters = True + new_command = CommandFilter(command_name, alias, None) + add_to_event_filters = True def decorator(awaitable): if not add_to_event_filters: @@ -99,7 +104,9 @@ def register_command( True # 打一个标记,表示这是一个子指令,再 wakingstage 阶段这个 handler 将会直接被跳过(其父指令会接管) ) handler_md = get_handler_or_create( - awaitable, EventType.AdapterMessageEvent, **kwargs + awaitable, + EventType.AdapterMessageEvent, + **kwargs, ) if new_command: new_command.init_handler_md(handler_md) @@ -116,6 +123,7 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): custom_type_filter: 在裸指令时为CustomFilter对象 在指令组时为父指令的RegisteringCommandable对象,即self或者command_group的返回 raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True + """ add_to_event_filters = False raise_error = True @@ -140,19 +148,20 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): def decorator(awaitable): # 裸指令,子指令与指令组的区分,指令组会因为标记跳过wake。 if ( - not add_to_event_filters - and isinstance(awaitable, RegisteringCommandable) - or (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)) - ): + not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) + ) or (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)): # 指令组 与 根指令组,添加到本层的grouphandle中一起判断 awaitable.parent_group.add_custom_filter(custom_filter) else: handler_md = get_handler_or_create( - awaitable, EventType.AdapterMessageEvent, **kwargs + awaitable, + EventType.AdapterMessageEvent, + **kwargs, ) if not add_to_event_filters and not isinstance( - awaitable, RegisteringCommandable + awaitable, + RegisteringCommandable, ): # 底层子指令 handle_full_name = get_handler_full_name(awaitable) @@ -171,7 +180,9 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): else: # 裸指令 handler_md = get_handler_or_create( - awaitable, EventType.AdapterMessageEvent, **kwargs + awaitable, + EventType.AdapterMessageEvent, + **kwargs, ) handler_md.event_filters.append(custom_filter) @@ -194,20 +205,23 @@ def register_command_group( logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定") else: new_group = CommandGroupFilter( - sub_command, alias, parent_group=command_group_name.parent_group + sub_command, + alias, + parent_group=command_group_name.parent_group, ) command_group_name.parent_group.add_sub_command_filter(new_group) + # 根指令组 + elif command_group_name is None: + logger.warning("根指令组的名称未指定") else: - # 根指令组 - if command_group_name is None: - logger.warning("根指令组的名称未指定") - else: - new_group = CommandGroupFilter(command_group_name, alias) + new_group = CommandGroupFilter(command_group_name, alias) def decorator(obj): if new_group: handler_md = get_handler_or_create( - obj, EventType.AdapterMessageEvent, **kwargs + obj, + EventType.AdapterMessageEvent, + **kwargs, ) handler_md.event_filters.append(new_group) @@ -220,9 +234,7 @@ def register_command_group( class RegisteringCommandable: """用于指令组级联注册""" - group: Callable[..., Callable[..., "RegisteringCommandable"]] = ( - register_command_group - ) + group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., None]] = register_custom_filter @@ -235,7 +247,9 @@ def register_event_message_type(event_message_type: EventMessageType, **kwargs): def decorator(awaitable): handler_md = get_handler_or_create( - awaitable, EventType.AdapterMessageEvent, **kwargs + awaitable, + EventType.AdapterMessageEvent, + **kwargs, ) handler_md.event_filters.append(EventMessageTypeFilter(event_message_type)) return awaitable @@ -244,14 +258,15 @@ def register_event_message_type(event_message_type: EventMessageType, **kwargs): def register_platform_adapter_type( - platform_adapter_type: PlatformAdapterType, **kwargs + platform_adapter_type: PlatformAdapterType, + **kwargs, ): """注册一个 PlatformAdapterType""" def decorator(awaitable): handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) handler_md.event_filters.append( - PlatformAdapterTypeFilter(platform_adapter_type) + PlatformAdapterTypeFilter(platform_adapter_type), ) return awaitable @@ -263,7 +278,9 @@ def register_regex(regex: str, **kwargs): def decorator(awaitable): handler_md = get_handler_or_create( - awaitable, EventType.AdapterMessageEvent, **kwargs + awaitable, + EventType.AdapterMessageEvent, + **kwargs, ) handler_md.event_filters.append(RegexFilter(regex)) return awaitable @@ -277,12 +294,13 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool Args: permission_type: PermissionType raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True + """ def decorator(awaitable): handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) handler_md.event_filters.append( - PermissionTypeFilter(permission_type, raise_error) + PermissionTypeFilter(permission_type, raise_error), ) return awaitable @@ -300,9 +318,7 @@ def register_on_astrbot_loaded(**kwargs): def register_on_platform_loaded(**kwargs): - """ - 当平台加载完成时 - """ + """当平台加载完成时""" def decorator(awaitable): _ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs) @@ -324,6 +340,7 @@ def register_on_llm_request(**kwargs): ``` 请务必接收两个参数:event, request + """ def decorator(awaitable): @@ -346,6 +363,7 @@ def register_on_llm_response(**kwargs): ``` 请务必接收两个参数:event, request + """ def decorator(awaitable): @@ -365,7 +383,7 @@ def register_llm_tool(name: str | None = None, **kwargs): async def get_weather(event: AstrMessageEvent, location: str): \'\'\'获取天气信息。 - Args: + Args: location(string): 地点 \'\'\' # 处理逻辑 @@ -386,8 +404,8 @@ def register_llm_tool(name: str | None = None, **kwargs): event.stop_event() yield ``` - """ + """ name_ = name registering_agent = None if kwargs.get("registering_agent"): @@ -401,14 +419,14 @@ def register_llm_tool(name: str | None = None, **kwargs): for arg in docstring.params: if arg.type_name not in SUPPORTED_TYPES: raise ValueError( - f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}" + f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}", ) args.append( { "type": arg.type_name, "name": arg.arg_name, "description": arg.description, - } + }, ) # print(llm_tool_name, registering_agent) if not registering_agent: @@ -454,6 +472,7 @@ def register_agent( instruction: Agent 的指令 tools: Agent 使用的工具列表 run_hooks: Agent 运行时的钩子函数 + """ tools_ = tools or [] @@ -478,7 +497,9 @@ def register_on_decorating_result(**kwargs): def decorator(awaitable): _ = get_handler_or_create( - awaitable, EventType.OnDecoratingResultEvent, **kwargs + awaitable, + EventType.OnDecoratingResultEvent, + **kwargs, ) return awaitable @@ -490,7 +511,9 @@ def register_after_message_sent(**kwargs): def decorator(awaitable): _ = get_handler_or_create( - awaitable, EventType.OnAfterMessageSentEvent, **kwargs + awaitable, + EventType.OnAfterMessageSentEvent, + **kwargs, ) return awaitable diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py index 8fb88c6b8..8c40f25c1 100644 --- a/astrbot/core/star/session_llm_manager.py +++ b/astrbot/core/star/session_llm_manager.py @@ -1,6 +1,4 @@ -""" -会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态 -""" +"""会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态""" from astrbot.core import logger, sp from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -22,10 +20,14 @@ class SessionServiceManager: Returns: bool: True表示启用,False表示禁用 + """ # 获取会话服务配置 session_services = sp.get( - "session_service_config", {}, scope="umo", scope_id=session_id + "session_service_config", + {}, + scope="umo", + scope_id=session_id, ) # 如果配置了该会话的LLM状态,返回该状态 @@ -43,13 +45,17 @@ class SessionServiceManager: Args: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 + """ session_config = ( sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} ) session_config["llm_enabled"] = enabled sp.put( - "session_service_config", session_config, scope="umo", scope_id=session_id + "session_service_config", + session_config, + scope="umo", + scope_id=session_id, ) @staticmethod @@ -61,6 +67,7 @@ class SessionServiceManager: Returns: bool: True表示应该处理,False表示跳过 + """ session_id = event.unified_msg_origin return SessionServiceManager.is_llm_enabled_for_session(session_id) @@ -78,10 +85,14 @@ class SessionServiceManager: Returns: bool: True表示启用,False表示禁用 + """ # 获取会话服务配置 session_services = sp.get( - "session_service_config", {}, scope="umo", scope_id=session_id + "session_service_config", + {}, + scope="umo", + scope_id=session_id, ) # 如果配置了该会话的TTS状态,返回该状态 @@ -99,17 +110,21 @@ class SessionServiceManager: Args: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 + """ session_config = ( sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} ) session_config["tts_enabled"] = enabled sp.put( - "session_service_config", session_config, scope="umo", scope_id=session_id + "session_service_config", + session_config, + scope="umo", + scope_id=session_id, ) logger.info( - f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}" + f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}", ) @staticmethod @@ -121,6 +136,7 @@ class SessionServiceManager: Returns: bool: True表示应该处理,False表示跳过 + """ session_id = event.unified_msg_origin return SessionServiceManager.is_tts_enabled_for_session(session_id) @@ -138,10 +154,14 @@ class SessionServiceManager: Returns: bool: True表示启用,False表示禁用 + """ # 获取会话服务配置 session_services = sp.get( - "session_service_config", {}, scope="umo", scope_id=session_id + "session_service_config", + {}, + scope="umo", + scope_id=session_id, ) # 如果配置了该会话的整体状态,返回该状态 @@ -159,17 +179,21 @@ class SessionServiceManager: Args: session_id: 会话ID (unified_msg_origin) enabled: True表示启用,False表示禁用 + """ session_config = ( sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} ) session_config["session_enabled"] = enabled sp.put( - "session_service_config", session_config, scope="umo", scope_id=session_id + "session_service_config", + session_config, + scope="umo", + scope_id=session_id, ) logger.info( - f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}" + f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}", ) @staticmethod @@ -181,6 +205,7 @@ class SessionServiceManager: Returns: bool: True表示应该处理,False表示跳过 + """ session_id = event.unified_msg_origin return SessionServiceManager.is_session_enabled(session_id) @@ -198,9 +223,13 @@ class SessionServiceManager: Returns: str: 自定义名称,如果没有设置则返回None + """ session_services = sp.get( - "session_service_config", {}, scope="umo", scope_id=session_id + "session_service_config", + {}, + scope="umo", + scope_id=session_id, ) return session_services.get("custom_name") @@ -211,6 +240,7 @@ class SessionServiceManager: Args: session_id: 会话ID (unified_msg_origin) custom_name: 自定义名称,可以为空字符串来清除名称 + """ session_config = ( sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {} @@ -221,11 +251,14 @@ class SessionServiceManager: # 如果传入空名称,则删除自定义名称 session_config.pop("custom_name", None) sp.put( - "session_service_config", session_config, scope="umo", scope_id=session_id + "session_service_config", + session_config, + scope="umo", + scope_id=session_id, ) logger.info( - f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}" + f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}", ) @staticmethod @@ -237,6 +270,7 @@ class SessionServiceManager: Returns: str: 显示名称 + """ custom_name = SessionServiceManager.get_session_custom_name(session_id) if custom_name: diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index 94a0c8a4d..c74546fe7 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -1,9 +1,6 @@ -""" -会话插件管理器 - 负责管理每个会话的插件启停状态 -""" +"""会话插件管理器 - 负责管理每个会话的插件启停状态""" -from astrbot.core import sp, logger -from typing import Dict, List +from astrbot.core import logger, sp from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -20,10 +17,14 @@ class SessionPluginManager: Returns: bool: True表示启用,False表示禁用 + """ # 获取会话插件配置 session_plugin_config = sp.get( - "session_plugin_config", {}, scope="umo", scope_id=session_id + "session_plugin_config", + {}, + scope="umo", + scope_id=session_id, ) session_config = session_plugin_config.get(session_id, {}) @@ -43,7 +44,9 @@ class SessionPluginManager: @staticmethod def set_plugin_status_for_session( - session_id: str, plugin_name: str, enabled: bool + session_id: str, + plugin_name: str, + enabled: bool, ) -> None: """设置插件在指定会话中的启停状态 @@ -51,10 +54,14 @@ class SessionPluginManager: session_id: 会话ID (unified_msg_origin) plugin_name: 插件名称 enabled: True表示启用,False表示禁用 + """ # 获取当前配置 session_plugin_config = sp.get( - "session_plugin_config", {}, scope="umo", scope_id=session_id + "session_plugin_config", + {}, + scope="umo", + scope_id=session_id, ) if session_id not in session_plugin_config: session_plugin_config[session_id] = { @@ -91,11 +98,11 @@ class SessionPluginManager: ) logger.info( - f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}" + f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}", ) @staticmethod - def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]: + def get_session_plugin_config(session_id: str) -> dict[str, list[str]]: """获取指定会话的插件配置 Args: @@ -103,16 +110,21 @@ class SessionPluginManager: Returns: Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典 + """ session_plugin_config = sp.get( - "session_plugin_config", {}, scope="umo", scope_id=session_id + "session_plugin_config", + {}, + scope="umo", + scope_id=session_id, ) return session_plugin_config.get( - session_id, {"enabled_plugins": [], "disabled_plugins": []} + session_id, + {"enabled_plugins": [], "disabled_plugins": []}, ) @staticmethod - def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List: + def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list: """根据会话配置过滤处理器列表 Args: @@ -121,6 +133,7 @@ class SessionPluginManager: Returns: List: 过滤后的处理器列表 + """ from astrbot.core.star.star import star_map @@ -145,12 +158,13 @@ class SessionPluginManager: # 检查插件是否在当前会话中启用 if SessionPluginManager.is_plugin_enabled_for_session( - session_id, plugin.name + session_id, + plugin.name, ): filtered_handlers.append(handler) else: logger.debug( - f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}" + f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}", ) return filtered_handlers diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index bd16cb216..c5b7b1243 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -16,8 +16,7 @@ if TYPE_CHECKING: @dataclass class StarMetadata: - """ - 插件的元数据。 + """插件的元数据。 当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。 """ diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 80b5adb60..141f9180a 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,7 +1,10 @@ from __future__ import annotations + import enum +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic +from typing import Any, Generic, TypeVar + from .filter import HandlerFilter from .star import star_map @@ -10,8 +13,8 @@ T = TypeVar("T", bound="StarHandlerMetadata") class StarHandlerRegistry(Generic[T]): def __init__(self): - self.star_handlers_map: Dict[str, StarHandlerMetadata] = {} - self._handlers: List[StarHandlerMetadata] = [] + self.star_handlers_map: dict[str, StarHandlerMetadata] = {} + self._handlers: list[StarHandlerMetadata] = [] def append(self, handler: StarHandlerMetadata): """添加一个 Handler,并保持按优先级有序""" @@ -31,7 +34,7 @@ class StarHandlerRegistry(Generic[T]): event_type: EventType, only_activated=True, plugins_name: list[str] | None = None, - ) -> List[StarHandlerMetadata]: + ) -> list[StarHandlerMetadata]: handlers = [] for handler in self._handlers: # 过滤事件类型 @@ -64,8 +67,9 @@ class StarHandlerRegistry(Generic[T]): return self.star_handlers_map.get(full_name, None) def get_handlers_by_module_name( - self, module_name: str - ) -> List[StarHandlerMetadata]: + self, + module_name: str, + ) -> list[StarHandlerMetadata]: return [ handler for handler in self._handlers @@ -126,7 +130,7 @@ class StarHandlerMetadata: handler: Callable[..., Awaitable[Any]] """Handler 的函数对象,应当是一个异步函数""" - event_filters: List[HandlerFilter] + event_filters: list[HandlerFilter] """一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件""" desc: str = "" @@ -138,5 +142,6 @@ class StarHandlerMetadata: def __lt__(self, other: StarHandlerMetadata): """定义小于运算符以支持优先队列""" return self.extras_configs.get("priority", 0) < other.extras_configs.get( - "priority", 0 + "priority", + 0, ) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index c1057e4b6..ef50917fe 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -1,6 +1,4 @@ -""" -插件的重载、启停、安装、卸载等操作。 -""" +"""插件的重载、启停、安装、卸载等操作。""" import asyncio import functools @@ -15,6 +13,7 @@ from types import ModuleType import yaml from astrbot.core import logger, pip_installer, sp +from astrbot.core.agent.handoff import FunctionTool, HandoffTool from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.provider.register import llm_tools from astrbot.core.utils.astrbot_path import ( @@ -22,7 +21,6 @@ from astrbot.core.utils.astrbot_path import ( get_astrbot_plugin_path, ) from astrbot.core.utils.io import remove_dir -from astrbot.core.agent.handoff import HandoffTool, FunctionTool from . import StarMetadata from .context import Context @@ -52,8 +50,9 @@ class PluginManager: """存储插件配置的路径。data/config""" self.reserved_plugin_path = os.path.abspath( os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../../packages" - ) + os.path.dirname(os.path.abspath(__file__)), + "../../../packages", + ), ) """保留插件的路径。在 packages 目录下""" self.conf_schema_fname = "_conf_schema.json" @@ -80,7 +79,7 @@ class PluginManager: except asyncio.CancelledError: pass except Exception as e: - logger.error(f"插件热重载监视任务异常: {str(e)}") + logger.error(f"插件热重载监视任务异常: {e!s}") logger.error(traceback.format_exc()) async def _handle_file_changes(self, changes): @@ -95,11 +94,13 @@ class PluginManager: continue if star.reserved: plugin_dir_path = os.path.join( - self.reserved_plugin_path, star.root_dir_name + self.reserved_plugin_path, + star.root_dir_name, ) else: plugin_dir_path = os.path.join( - self.plugin_store_path, star.root_dir_name + self.plugin_store_path, + star.root_dir_name, ) plugins_to_check.append((plugin_dir_path, star.name)) reloaded_plugins = set() @@ -143,14 +144,14 @@ class PluginManager: logger.info(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。") continue if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists( - os.path.join(path, d, d + ".py") + os.path.join(path, d, d + ".py"), ): modules.append( { "pname": d, "module": module_str, "module_path": os.path.join(path, d, module_str), - } + }, ) return modules @@ -186,7 +187,7 @@ class PluginManager: try: await pip_installer.install(requirements_path=pth) except Exception as e: - logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") + logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}") @staticmethod def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None: @@ -201,7 +202,8 @@ class PluginManager: if os.path.exists(os.path.join(plugin_path, "metadata.yaml")): with open( - os.path.join(plugin_path, "metadata.yaml"), encoding="utf-8" + os.path.join(plugin_path, "metadata.yaml"), + encoding="utf-8", ) as f: metadata = yaml.safe_load(f) elif plugin_obj and hasattr(plugin_obj, "info"): @@ -219,7 +221,7 @@ class PluginManager: or "author" not in metadata ): raise Exception( - "插件元数据信息不完整。name, desc, version, author 是必须的字段。" + "插件元数据信息不完整。name, desc, version, author 是必须的字段。", ) metadata = StarMetadata( name=metadata["name"], @@ -234,7 +236,8 @@ class PluginManager: @staticmethod def _get_plugin_related_modules( - plugin_root_dir: str, is_reserved: bool = False + plugin_root_dir: str, + is_reserved: bool = False, ) -> list[str]: """获取与指定插件相关的所有已加载模块名 @@ -246,6 +249,7 @@ class PluginManager: Returns: list[str]: 与该插件相关的模块名列表 + """ prefix = "packages." if is_reserved else "data.plugins." return [ @@ -268,6 +272,7 @@ class PluginManager: module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"]) root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块 is_reserved: 插件是否为保留插件(影响模块路径前缀) + """ if module_patterns: for pattern in module_patterns: @@ -278,7 +283,8 @@ class PluginManager: if root_dir_name: for module_name in self._get_plugin_related_modules( - root_dir_name, is_reserved + root_dir_name, + is_reserved, ): try: del sys.modules[module_name] @@ -297,6 +303,7 @@ class PluginManager: tuple: 返回 load() 方法的结果,包含 (success, error_message) - success (bool): 重载是否成功 - error_message (str|None): 错误信息,成功时为 None + """ async with self._pm_lock: specified_module_path = None @@ -315,7 +322,7 @@ class PluginManager: except Exception as e: logger.warning(traceback.format_exc()) logger.warning( - f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" + f"插件 {smd.name} 未被正常终止: {e!s}, 可能会导致该插件运行不正常。", ) if smd.name and smd.module_path: await self._unbind_plugin(smd.name, smd.module_path) @@ -332,7 +339,7 @@ class PluginManager: except Exception as e: logger.warning(traceback.format_exc()) logger.warning( - f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" + f"插件 {smd.name} 未被正常终止: {e!s}, 可能会导致该插件运行不正常。", ) if smd.name: await self._unbind_plugin(smd.name, specified_module_path) @@ -353,6 +360,7 @@ class PluginManager: tuple: (success, error_message) - success (bool): 是否全部加载成功 - error_message (str|None): 错误信息,成功时为 None + """ inactivated_plugins = await sp.global_get("inactivated_plugins", []) inactivated_llm_tools = await sp.global_get("inactivated_llm_tools", []) @@ -371,7 +379,8 @@ class PluginManager: # module_path = plugin_module['module_path'] root_dir_name = plugin_module["pname"] # 插件的目录名 reserved = plugin_module.get( - "reserved", False + "reserved", + False, ) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。 path = "data.plugins." if not reserved else "packages." @@ -394,7 +403,7 @@ class PluginManager: module = __import__(path, fromlist=[module_str]) except Exception as e: logger.error(traceback.format_exc()) - logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}") + logger.error(f"插件 {root_dir_name} 导入失败。原因:{e!s}") continue # 检查 _conf_schema.json @@ -405,14 +414,16 @@ class PluginManager: else os.path.join(self.reserved_plugin_path, root_dir_name) ) plugin_schema_path = os.path.join( - plugin_dir_path, self.conf_schema_fname + plugin_dir_path, + self.conf_schema_fname, ) if os.path.exists(plugin_schema_path): # 加载插件配置 with open(plugin_schema_path, encoding="utf-8") as f: plugin_config = AstrBotConfig( config_path=os.path.join( - self.plugin_config_path, f"{root_dir_name}_config.json" + self.plugin_config_path, + f"{root_dir_name}_config.json", ), schema=json.loads(f.read()), ) @@ -425,7 +436,7 @@ class PluginManager: try: # yaml 文件的元数据优先 metadata_yaml = self._load_plugin_metadata( - plugin_path=plugin_dir_path + plugin_path=plugin_dir_path, ) if metadata_yaml: metadata.name = metadata_yaml.name @@ -436,7 +447,7 @@ class PluginManager: metadata.display_name = metadata_yaml.display_name except Exception as e: logger.warning( - f"插件 {root_dir_name} 元数据载入失败: {str(e)}。使用默认元数据。" + f"插件 {root_dir_name} 元数据载入失败: {e!s}。使用默认元数据。", ) logger.info(metadata) metadata.config = plugin_config @@ -445,15 +456,16 @@ class PluginManager: if plugin_config and metadata.star_cls_type: try: metadata.star_cls = metadata.star_cls_type( - context=self.context, config=plugin_config + context=self.context, + config=plugin_config, ) except TypeError as _: metadata.star_cls = metadata.star_cls_type( - context=self.context + context=self.context, ) elif metadata.star_cls_type: metadata.star_cls = metadata.star_cls_type( - context=self.context + context=self.context, ) else: logger.info(f"插件 {metadata.name} 已被禁用。") @@ -469,7 +481,7 @@ class PluginManager: # 绑定 handler related_handlers = ( star_handlers_registry.get_handlers_by_module_name( - metadata.module_path + metadata.module_path, ) ) for handler in related_handlers: @@ -505,7 +517,7 @@ class PluginManager: else: # v3.4.0 以前的方式注册插件 logger.debug( - f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。" + f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。", ) classes = self._get_classes(module) @@ -514,19 +526,21 @@ class PluginManager: if plugin_config: try: obj = getattr(module, classes[0])( - context=self.context, config=plugin_config + context=self.context, + config=plugin_config, ) # 实例化插件类 except TypeError as _: obj = getattr(module, classes[0])( - context=self.context + context=self.context, ) # 实例化插件类 else: obj = getattr(module, classes[0])( - context=self.context + context=self.context, ) # 实例化插件类 metadata = self._load_plugin_metadata( - plugin_path=plugin_dir_path, plugin_obj=obj + plugin_path=plugin_dir_path, + plugin_obj=obj, ) if not metadata: raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。") @@ -552,7 +566,7 @@ class PluginManager: full_names = [] for handler in star_handlers_registry.get_handlers_by_module_name( - metadata.module_path + metadata.module_path, ): full_names.append(handler.handler_full_name) @@ -562,7 +576,8 @@ class PluginManager: and handler.handler_name in alter_cmd[metadata.name] ): cmd_type = alter_cmd[metadata.name][handler.handler_name].get( - "permission", "member" + "permission", + "member", ) found_permission_filter = False for filter_ in handler.event_filters: @@ -578,12 +593,12 @@ class PluginManager: PermissionTypeFilter( PermissionType.ADMIN if cmd_type == "admin" - else PermissionType.MEMBER - ) + else PermissionType.MEMBER, + ), ) logger.debug( - f"插入权限过滤器 {cmd_type} 到 {metadata.name} 的 {handler.handler_name} 方法。" + f"插入权限过滤器 {cmd_type} 到 {metadata.name} 的 {handler.handler_name} 方法。", ) metadata.star_handler_full_names = full_names @@ -598,7 +613,7 @@ class PluginManager: for line in errors.split("\n"): logger.error(f"| {line}") logger.error("----------------------------------") - fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}。\n" + fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {e!s}。\n" # 清除 pip.main 导致的多余的 logging handlers for handler in logging.root.handlers[:]: @@ -606,9 +621,8 @@ class PluginManager: if not fail_rec: return True, None - else: - self.failed_plugin_info = fail_rec - return False, fail_rec + self.failed_plugin_info = fail_rec + return False, fail_rec async def install_plugin(self, repo_url: str, proxy=""): """从仓库 URL 安装插件 @@ -624,6 +638,7 @@ class PluginManager: - repo: 插件的仓库 URL - readme: README.md 文件的内容(如果存在) 如果找不到插件元数据则返回 None。 + """ async with self._pm_lock: plugin_path = await self.updator.install(repo_url, proxy) @@ -652,7 +667,7 @@ class PluginManager: readme_content = f.read() except Exception as e: logger.warning( - f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}" + f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}", ) plugin_info = None @@ -673,6 +688,7 @@ class PluginManager: Raises: Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常 + """ async with self._pm_lock: plugin = self.context.get_registered_star(plugin_name) @@ -689,7 +705,7 @@ class PluginManager: except Exception as e: logger.warning(traceback.format_exc()) logger.warning( - f"插件 {plugin_name} 未被正常终止 {str(e)}, 可能会导致资源泄露等问题。" + f"插件 {plugin_name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。", ) # 从 star_registry 和 star_map 中删除 @@ -702,7 +718,7 @@ class PluginManager: remove_dir(os.path.join(ppath, root_dir_name)) except Exception as e: raise Exception( - f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。" + f"移除插件成功,但是删除插件文件夹失败: {e!s}。您可以手动删除该文件夹,位于 addons/plugins/ 下。", ) async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): @@ -711,6 +727,7 @@ class PluginManager: Args: plugin_name: 要解绑的插件名称 plugin_module_path: 插件的完整模块路径 + """ plugin = None del star_map[plugin_module_path] @@ -720,10 +737,10 @@ class PluginManager: del star_registry[i] break for handler in star_handlers_registry.get_handlers_by_module_name( - plugin_module_path + plugin_module_path, ): logger.info( - f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})" + f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})", ) star_handlers_registry.remove(handler) @@ -738,7 +755,8 @@ class PluginManager: return self._purge_modules( - root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved + root_dir_name=plugin.root_dir_name, + is_reserved=plugin.reserved, ) async def update_plugin(self, plugin_name: str, proxy=""): @@ -753,8 +771,7 @@ class PluginManager: await self.reload(plugin_name) async def turn_off_plugin(self, plugin_name: str): - """ - 禁用一个插件。 + """禁用一个插件。 调用插件的 terminate() 方法, 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 并且同时将插件启用的 llm_tool 禁用。 @@ -773,7 +790,7 @@ class PluginManager: inactivated_plugins.append(plugin.module_path) inactivated_llm_tools: list = list( - set(await sp.global_get("inactivated_llm_tools", [])) + set(await sp.global_get("inactivated_llm_tools", [])), ) # 后向兼容 # 禁用插件启用的 llm_tool @@ -803,7 +820,8 @@ class PluginManager: if "__del__" in star_metadata.star_cls_type.__dict__: asyncio.get_event_loop().run_in_executor( - None, star_metadata.star_cls.__del__ + None, + star_metadata.star_cls.__del__, ) elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() @@ -842,7 +860,7 @@ class PluginManager: try: os.remove(zip_file_path) except BaseException as e: - logger.warning(f"删除插件压缩包失败: {str(e)}") + logger.warning(f"删除插件压缩包失败: {e!s}") # await self.reload() await self.load(specified_dir_name=dir_name) @@ -866,7 +884,7 @@ class PluginManager: with open(readme_path, encoding="utf-8") as f: readme_content = f.read() except Exception as e: - logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}") + logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}") plugin_info = None if plugin: diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 6f9dfe2fa..7a66449b4 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -1,5 +1,4 @@ -""" -插件开发工具集 +"""插件开发工具集 封装了许多常用的操作,方便插件开发者使用 说明: @@ -21,47 +20,49 @@ import inspect import os import uuid +from collections.abc import Awaitable, Callable from pathlib import Path -from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar +from typing import Any, ClassVar + +from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain -from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.core.star.context import Context -from astrbot.core.star.star import star_map -from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( AiocqhttpMessageEvent, ) from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( AiocqhttpAdapter, ) +from astrbot.core.star.context import Context +from astrbot.core.star.star import star_map +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class StarTools: - """ - 提供给插件使用的便捷工具函数集合 + """提供给插件使用的便捷工具函数集合 这些方法封装了一些常用操作,使插件开发更加简单便捷! """ - _context: ClassVar[Optional[Context]] = None + _context: ClassVar[Context | None] = None @classmethod def initialize(cls, context: Context) -> None: - """ - 初始化StarTools,设置context引用 + """初始化StarTools,设置context引用 Args: context: 暴露给插件的上下文 + """ cls._context = context @classmethod async def send_message( - cls, session: Union[str, MessageSesion], message_chain: MessageChain + cls, + session: str | MessageSesion, + message_chain: MessageChain, ) -> bool: - """ - 根据session(unified_msg_origin)主动发送消息 + """根据session(unified_msg_origin)主动发送消息 Args: session: 消息会话。通过event.session或者event.unified_msg_origin获取 @@ -75,6 +76,7 @@ class StarTools: Note: qq_official(QQ官方API平台)不支持此方法 + """ if cls._context is None: raise ValueError("StarTools not initialized") @@ -88,21 +90,22 @@ class StarTools: message_chain: MessageChain, platform: str = "aiocqhttp", ): - """ - 根据 id(例如qq号, 群号等) 直接, 主动地发送消息 + """根据 id(例如qq号, 群号等) 直接, 主动地发送消息 Args: type (str): 消息类型, 可选: PrivateMessage, GroupMessage id (str): 目标ID, 例如QQ号, 群号等 message_chain (MessageChain): 消息链 platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp + """ if cls._context is None: raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": adapter = next( - (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None + (p for p in platforms if isinstance(p, AiocqhttpAdapter)), + None, ) if adapter is None: raise ValueError("未找到适配器: AiocqhttpAdapter") @@ -122,14 +125,13 @@ class StarTools: self_id: str, session_id: str, sender: MessageMember, - message: List[BaseMessageComponent], + message: list[BaseMessageComponent], message_str: str, message_id: str = "", raw_message: object = None, group_id: str = "", ) -> AstrBotMessage: - """ - 创建一个AstrBot消息对象 + """创建一个AstrBot消息对象 Args: type (str): 消息类型, 例如 "GroupMessage" "FriendMessage" "OtherMessage" @@ -145,6 +147,7 @@ class StarTools: Returns: AstrBotMessage: 创建的消息对象 + """ abm = AstrBotMessage() abm.type = MessageType(type) @@ -162,23 +165,27 @@ class StarTools: @classmethod async def create_event( - cls, abm: AstrBotMessage, platform: str = "aiocqhttp", is_wake: bool = True + cls, + abm: AstrBotMessage, + platform: str = "aiocqhttp", + is_wake: bool = True, ) -> None: - """ - 创建并提交事件到指定平台 + """创建并提交事件到指定平台 当有需要创建一个事件, 触发某些处理流程时, 使用该方法 Args: abm (AstrBotMessage): 要提交的消息对象, 请先使用 create_message 创建 platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp is_wake (bool): 是否标记为唤醒事件, 默认为 True, 只有唤醒事件才会被 llm 响应 + """ if cls._context is None: raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": adapter = next( - (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None + (p for p in platforms if isinstance(p, AiocqhttpAdapter)), + None, ) if adapter is None: raise ValueError("未找到适配器: AiocqhttpAdapter") @@ -196,12 +203,12 @@ class StarTools: @classmethod def activate_llm_tool(cls, name: str) -> bool: - """ - 激活一个已经注册的函数调用工具 + """激活一个已经注册的函数调用工具 注册的工具默认是激活状态 Args: name (str): 工具名称 + """ if cls._context is None: raise ValueError("StarTools not initialized") @@ -209,11 +216,11 @@ class StarTools: @classmethod def deactivate_llm_tool(cls, name: str) -> bool: - """ - 停用一个已经注册的函数调用工具 + """停用一个已经注册的函数调用工具 Args: name (str): 工具名称 + """ if cls._context is None: raise ValueError("StarTools not initialized") @@ -227,14 +234,14 @@ class StarTools: desc: str, func_obj: Callable[..., Awaitable[Any]], ) -> None: - """ - 为函数调用(function-calling/tools-use)添加工具 + """为函数调用(function-calling/tools-use)添加工具 Args: name (str): 工具名称 func_args (list): 函数参数列表 desc (str): 工具描述 func_obj (Awaitable): 函数对象,必须是异步函数 + """ if cls._context is None: raise ValueError("StarTools not initialized") @@ -242,21 +249,20 @@ class StarTools: @classmethod def unregister_llm_tool(cls, name: str) -> None: - """ - 删除一个函数调用工具 + """删除一个函数调用工具 如果再要启用,需要重新注册 Args: name (str): 工具名称 + """ if cls._context is None: raise ValueError("StarTools not initialized") cls._context.unregister_llm_tool(name) @classmethod - def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: - """ - 返回插件数据目录的绝对路径。 + def get_data_dir(cls, plugin_name: str | None = None) -> Path: + """返回插件数据目录的绝对路径。 此方法会在 data/plugin_data 目录下为插件创建一个专属的数据目录。如果未提供插件名称, 会自动从调用栈中获取插件信息。 @@ -272,6 +278,7 @@ class StarTools: - 无法获取调用者模块信息 - 无法获取模块的元数据信息 - 创建目录失败(权限不足或其他IO错误) + """ if not plugin_name: frame = inspect.currentframe() @@ -294,7 +301,7 @@ class StarTools: raise ValueError("无法获取插件名称") data_dir = Path( - os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name) + os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name), ) try: diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index a22455377..8793ad505 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -1,12 +1,13 @@ import os -import zipfile import shutil +import zipfile -from ..updator import RepoZipUpdator -from astrbot.core.utils.io import remove_dir, on_error -from ..star.star import StarMetadata from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path +from astrbot.core.utils.io import on_error, remove_dir + +from ..star.star import StarMetadata +from ..updator import RepoZipUpdator class PluginUpdator(RepoZipUpdator): @@ -44,7 +45,7 @@ class PluginUpdator(RepoZipUpdator): remove_dir(plugin_path) except BaseException as e: logger.error( - f"删除旧版本插件 {plugin_path} 文件夹失败: {str(e)},使用覆盖安装。" + f"删除旧版本插件 {plugin_path} 文件夹失败: {e!s},使用覆盖安装。", ) self.unzip_file(plugin_path + ".zip", plugin_path) @@ -64,18 +65,17 @@ class PluginUpdator(RepoZipUpdator): if os.path.isdir(os.path.join(target_dir, update_dir, f)): if os.path.exists(os.path.join(target_dir, f)): shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) - else: - if os.path.exists(os.path.join(target_dir, f)): - os.remove(os.path.join(target_dir, f)) + elif os.path.exists(os.path.join(target_dir, f)): + os.remove(os.path.join(target_dir, f)) shutil.move(os.path.join(target_dir, update_dir, f), target_dir) try: logger.info( - f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}" + f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}", ) shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) os.remove(zip_path) except BaseException: logger.warning( - f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}" + f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}", ) diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index dd2063e56..07858da5f 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -15,7 +15,10 @@ class UmopConfigRouter: """加载路由表""" # 从 SharedPreferences 中加载 umop_to_conf_id 映射 sp_data = self.sp.get( - "umop_config_routing", {}, scope="global", scope_id="global" + "umop_config_routing", + {}, + scope="global", + scope_id="global", ) self.umop_to_conf_id = sp_data @@ -37,6 +40,7 @@ class UmopConfigRouter: Returns: str | None: 配置文件 ID,如果没有找到则返回 None + """ for pattern, conf_id in self.umop_to_conf_id.items(): if self._is_umo_match(pattern, umo): @@ -52,11 +56,12 @@ class UmopConfigRouter: Raises: ValueError: 如果 new_routing 中的 key 格式不正确 + """ - for part in new_routing.keys(): + for part in new_routing: if not isinstance(part, str) or len(part.split(":")) != 3: raise ValueError( - "umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all" + "umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) self.umop_to_conf_id = new_routing @@ -71,10 +76,11 @@ class UmopConfigRouter: Raises: ValueError: 如果 umo 格式不正确 + """ if not isinstance(umo, str) or len(umo.split(":")) != 3: raise ValueError( - "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all" + "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", ) self.umop_to_conf_id[umo] = conf_id diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 68e4a6c58..d13bab687 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -1,12 +1,15 @@ import os -import psutil import sys import time -from .zip_updator import ReleaseInfo, RepoZipUpdator + +import psutil + from astrbot.core import logger from astrbot.core.config.default import VERSION -from astrbot.core.utils.io import download_file from astrbot.core.utils.astrbot_path import get_astrbot_path +from astrbot.core.utils.io import download_file + +from .zip_updator import ReleaseInfo, RepoZipUpdator class AstrBotUpdator(RepoZipUpdator): @@ -67,11 +70,16 @@ class AstrBotUpdator(RepoZipUpdator): raise e async def check_update( - self, url: str, current_version: str, consider_prerelease: bool = True + self, + url: str, + current_version: str, + consider_prerelease: bool = True, ) -> ReleaseInfo: """检查更新""" return await super().check_update( - self.ASTRBOT_RELEASE_API, VERSION, consider_prerelease + self.ASTRBOT_RELEASE_API, + VERSION, + consider_prerelease, ) async def get_releases(self) -> list: diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py index 64ed9229f..e13379b92 100644 --- a/astrbot/core/utils/astrbot_path.py +++ b/astrbot/core/utils/astrbot_path.py @@ -1,5 +1,4 @@ -""" -Astrbot统一路径获取 +"""Astrbot统一路径获取 项目路径:固定为源码所在路径 根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定 @@ -14,7 +13,7 @@ import os def get_astrbot_path() -> str: """获取Astrbot项目路径""" return os.path.realpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"), ) @@ -22,8 +21,7 @@ def get_astrbot_root() -> str: """获取Astrbot根目录路径""" if path := os.environ.get("ASTRBOT_ROOT"): return os.path.realpath(path) - else: - return os.path.realpath(os.getcwd()) + return os.path.realpath(os.getcwd()) def get_astrbot_data_path() -> str: diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index 15a6b71fb..2500e69a5 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -1,8 +1,11 @@ import codecs import json +from collections.abc import AsyncGenerator +from typing import Any + +from aiohttp import ClientResponse, ClientSession + from astrbot.core import logger -from aiohttp import ClientSession, ClientResponse -from typing import Dict, List, Any, AsyncGenerator async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: @@ -25,7 +28,6 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: yield json.loads(buffer[5:]) except json.JSONDecodeError: logger.warning(f"Drop invalid dify json data: {buffer[5:]}") - pass class DifyAPIClient: @@ -39,36 +41,39 @@ class DifyAPIClient: async def chat_messages( self, - inputs: Dict, + inputs: dict, query: str, user: str, response_mode: str = "streaming", conversation_id: str = "", - files: List[Dict[str, Any]] = [], + files: list[dict[str, Any]] = [], timeout: float = 60, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[dict[str, Any], None]: url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") payload.pop("timeout") logger.info(f"chat_messages payload: {payload}") async with self.session.post( - url, json=payload, headers=self.headers, timeout=timeout + url, + json=payload, + headers=self.headers, + timeout=timeout, ) as resp: if resp.status != 200: text = await resp.text() raise Exception( - f"Dify /chat-messages 接口请求失败:{resp.status}. {text}" + f"Dify /chat-messages 接口请求失败:{resp.status}. {text}", ) async for event in _stream_sse(resp): yield event async def workflow_run( self, - inputs: Dict, + inputs: dict, user: str, response_mode: str = "streaming", - files: List[Dict[str, Any]] = [], + files: list[dict[str, Any]] = [], timeout: float = 60, ): url = f"{self.api_base}/workflows/run" @@ -77,12 +82,15 @@ class DifyAPIClient: payload.pop("timeout") logger.info(f"workflow_run payload: {payload}") async with self.session.post( - url, json=payload, headers=self.headers, timeout=timeout + url, + json=payload, + headers=self.headers, + timeout=timeout, ) as resp: if resp.status != 200: text = await resp.text() raise Exception( - f"Dify /workflows/run 接口请求失败:{resp.status}. {text}" + f"Dify /workflows/run 接口请求失败:{resp.status}. {text}", ) async for event in _stream_sse(resp): yield event @@ -91,7 +99,7 @@ class DifyAPIClient: self, file_path: str, user: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = f"{self.api_base}/files/upload" with open(file_path, "rb") as f: payload = { @@ -99,7 +107,9 @@ class DifyAPIClient: "file": f, } async with self.session.post( - url, data=payload, headers=self.headers + url, + data=payload, + headers=self.headers, ) as resp: return await resp.json() # {"id": "xxx", ...} @@ -126,7 +136,11 @@ class DifyAPIClient: return await resp.json() async def rename( - self, conversation_id: str, name: str, user: str, auto_generate: bool = False + self, + conversation_id: str, + name: str, + user: str, + auto_generate: bool = False, ): # /conversations/:conversation_id/name url = f"{self.api_base}/conversations/{conversation_id}/name" diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 1d0f77b76..bd0bea920 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,29 +1,26 @@ +import base64 +import logging import os -from pathlib import Path -import ssl import shutil import socket +import ssl import time -import aiohttp -import base64 -import zipfile import uuid -import psutil -import logging +import zipfile +from pathlib import Path +import aiohttp import certifi - - +import psutil from PIL import Image + from .astrbot_path import get_astrbot_data_path logger = logging.getLogger("astrbot") def on_error(func, path, exc_info): - """ - a callback of the rmtree function. - """ + """A callback of the rmtree function.""" import stat if not os.access(path, os.W_OK): @@ -78,35 +75,35 @@ def save_temp_img(img: Image.Image | str) -> str: async def download_image_by_url( - url: str, post: bool = False, post_data: dict = None, path=None + url: str, + post: bool = False, + post_data: dict = None, + path=None, ) -> str: - """ - 下载图片, 返回 path - """ + """下载图片, 返回 path""" try: ssl_context = ssl.create_default_context( - cafile=certifi.where() + cafile=certifi.where(), ) # 使用 certifi 提供的 CA 证书 connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书 async with aiohttp.ClientSession( - trust_env=True, connector=connector + trust_env=True, + connector=connector, ) as session: if post: async with session.post(url, json=post_data) as resp: if not path: return save_temp_img(await resp.read()) - else: - with open(path, "wb") as f: - f.write(await resp.read()) - return path + with open(path, "wb") as f: + f.write(await resp.read()) + return path else: async with session.get(url) as resp: if not path: return save_temp_img(await resp.read()) - else: - with open(path, "wb") as f: - f.write(await resp.read()) - return path + with open(path, "wb") as f: + f.write(await resp.read()) + return path except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证 ssl_context = ssl.create_default_context() @@ -123,16 +120,15 @@ async def download_image_by_url( async def download_file(url: str, path: str, show_progress: bool = False): - """ - 从指定 url 下载文件到指定路径 path - """ + """从指定 url 下载文件到指定路径 path""" try: ssl_context = ssl.create_default_context( - cafile=certifi.where() + cafile=certifi.where(), ) # 使用 certifi 提供的 CA 证书 connector = aiohttp.TCPConnector(ssl=ssl_context) async with aiohttp.ClientSession( - trust_env=True, connector=connector + trust_env=True, + connector=connector, ) as session: async with session.get(url, timeout=1800) as resp: if resp.status != 200: @@ -227,7 +223,6 @@ async def download_dashboard( proxy: str | None = None, ) -> None: """下载管理面板文件""" - if path is None: zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip" else: @@ -237,11 +232,13 @@ async def download_dashboard( ver_name = "latest" if latest else version dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip" logger.info( - f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}" + f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}", ) try: await download_file( - dashboard_release_url, str(zip_path), show_progress=True + dashboard_release_url, + str(zip_path), + show_progress=True, ) except BaseException as _: if latest: @@ -251,7 +248,9 @@ async def download_dashboard( if proxy: dashboard_release_url = f"{proxy}/{dashboard_release_url}" await download_file( - dashboard_release_url, str(zip_path), show_progress=True + dashboard_release_url, + str(zip_path), + show_progress=True, ) else: url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip" diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py index bf5402f17..2e931dd81 100644 --- a/astrbot/core/utils/log_pipe.py +++ b/astrbot/core/utils/log_pipe.py @@ -1,5 +1,5 @@ -import threading import os +import threading from logging import Logger diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 7fe9bde05..f12019e3c 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -1,10 +1,12 @@ -import aiohttp -import sys import os import socket +import sys import uuid -from astrbot.core.config import VERSION + +import aiohttp + from astrbot.core import db_helper, logger +from astrbot.core.config import VERSION class Metric: @@ -21,7 +23,7 @@ class Metric: if os.path.exists(id_file): try: - with open(id_file, "r") as f: + with open(id_file) as f: Metric._iid_cache = f.read().strip() return Metric._iid_cache except Exception: @@ -39,8 +41,7 @@ class Metric: @staticmethod async def upload(**kwargs): - """ - 上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 + """上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 Powered by TickStats. """ @@ -64,7 +65,6 @@ class Metric: ) except Exception as e: logger.error(f"保存指标到数据库失败: {e}") - pass try: async with aiohttp.ClientSession(trust_env=True) as session: diff --git a/astrbot/core/utils/path_util.py b/astrbot/core/utils/path_util.py index 0d8511f0c..9520d481d 100644 --- a/astrbot/core/utils/path_util.py +++ b/astrbot/core/utils/path_util.py @@ -19,24 +19,23 @@ def path_Mapping(mappings, srcPath: str) -> str: # 切割后大于4个项目,或者只有1个项目,那肯定是错误的,只能是2,3,4个项目 logger.warning(f"路径映射规则错误: {mapping}") continue - else: - # rule.len == 3 or 4 - if os.path.exists(rule[0] + ":" + rule[1]): - # 前面两个项目合并路径存在,说明是本地Window路径。后面一个或两个项目组成的路径本地大概率无法解析,直接拼接 - from_ = rule[0] + ":" + rule[1] - if len(rule) == 3: - to_ = rule[2] - else: - to_ = rule[2] + ":" + rule[3] + # rule.len == 3 or 4 + elif os.path.exists(rule[0] + ":" + rule[1]): + # 前面两个项目合并路径存在,说明是本地Window路径。后面一个或两个项目组成的路径本地大概率无法解析,直接拼接 + from_ = rule[0] + ":" + rule[1] + if len(rule) == 3: + to_ = rule[2] else: - # 前面两个项目合并路径不存在,说明第一个项目是本地Linux路径,后面一个或两个项目直接拼接。 - from_ = rule[0] - if len(rule) == 3: - to_ = rule[1] + ":" + rule[2] - else: - # 这种情况下存在四个项目,说明规则也是错误的 - logger.warning(f"路径映射规则错误: {mapping}") - continue + to_ = rule[2] + ":" + rule[3] + else: + # 前面两个项目合并路径不存在,说明第一个项目是本地Linux路径,后面一个或两个项目直接拼接。 + from_ = rule[0] + if len(rule) == 3: + to_ = rule[1] + ":" + rule[2] + else: + # 这种情况下存在四个项目,说明规则也是错误的 + logger.warning(f"路径映射规则错误: {mapping}") + continue from_ = from_.removesuffix("/") from_ = from_.removesuffix("\\") diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 88cc21306..abe247146 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -1,5 +1,5 @@ -import logging import asyncio +import logging import sys logger = logging.getLogger("astrbot") diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index c27a54113..33b7cb17a 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -1,24 +1,22 @@ -""" -会话控制 -""" +"""会话控制""" import abc import asyncio -import time -import functools import copy +import functools +import time +from collections.abc import Awaitable, Callable +from typing import Any + import astrbot.core.message.components as Comp -from typing import Dict, Any, Callable, Awaitable, List from astrbot.core.platform import AstrMessageEvent -USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 -FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例 +USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 +FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例 class SessionController: - """ - 控制一个 Session 是否已经结束 - """ + """控制一个 Session 是否已经结束""" def __init__(self): self.future = asyncio.Future() @@ -29,7 +27,7 @@ class SessionController: self.timeout: float | int = None """上次保持(keep)开始时的超时时间""" - self.history_chains: List[List[Comp.BaseMessageComponent]] = [] + self.history_chains: list[list[Comp.BaseMessageComponent]] = [] def stop(self, error: Exception = None): """立即结束这个会话""" @@ -39,13 +37,14 @@ class SessionController: else: self.future.set_result(None) - def keep(self, timeout: float | int = 0, reset_timeout=False): + def keep(self, timeout: float = 0, reset_timeout=False): """保持这个会话 Args: timeout (float): 必填。会话超时时间。 当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。 当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0) + """ new_ts = time.time() @@ -81,7 +80,7 @@ class SessionController: pass # 避免报错 # finally: - def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]: + def get_history_chains(self) -> list[list[Comp.BaseMessageComponent]]: """获取历史消息链""" return self.history_chains @@ -92,7 +91,6 @@ class SessionFilter: @abc.abstractmethod def filter(self, event: AstrMessageEvent) -> str: """根据事件返回一个会话标识符""" - pass class DefaultSessionFilter(SessionFilter): @@ -120,7 +118,9 @@ class SessionWaiter: """需要保证一个 session 同时只有一个 trigger""" async def register_wait( - self, handler: Callable[[str], Awaitable[Any]], timeout: int = 30 + self, + handler: Callable[[str], Awaitable[Any]], + timeout: int = 30, ) -> Any: """等待外部输入并处理""" self.handler = handler @@ -149,7 +149,7 @@ class SessionWaiter: @classmethod async def trigger(cls, session_id: str, event: AstrMessageEvent): """外部输入触发会话处理""" - session = USER_SESSIONS.get(session_id, None) + session = USER_SESSIONS.get(session_id) if not session or session.session_controller.future.done(): return @@ -157,7 +157,7 @@ class SessionWaiter: if not session.session_controller.future.done(): if session.record_history_chains: session.session_controller.history_chains.append( - [copy.deepcopy(comp) for comp in event.get_messages()] + [copy.deepcopy(comp) for comp in event.get_messages()], ) try: # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行 @@ -167,8 +167,7 @@ class SessionWaiter: def session_waiter(timeout: int = 30, record_history_chains: bool = False): - """ - 装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。 + """装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。 :param timeout: 超时时间(秒) :param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。 diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index c1368f186..c6b4c5ede 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,11 +1,12 @@ -from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import Preference -import threading import asyncio import os -from typing import TypeVar, Any, overload -from .astrbot_path import get_astrbot_data_path +import threading +from typing import Any, TypeVar, overload +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Preference + +from .astrbot_path import get_astrbot_data_path _VT = TypeVar("_VT") @@ -14,7 +15,8 @@ class SharedPreferences: def __init__(self, db_helper: BaseDatabase, json_storage_path=None): if json_storage_path is None: json_storage_path = os.path.join( - get_astrbot_data_path(), "shared_preferences.json" + get_astrbot_data_path(), + "shared_preferences.json", ) self.path = json_storage_path self.db_helper = db_helper @@ -38,13 +40,15 @@ class SharedPreferences: else: ret = default return ret - else: - raise ValueError( - "scope_id and key cannot be None when getting a specific preference." - ) + raise ValueError( + "scope_id and key cannot be None when getting a specific preference.", + ) async def range_get_async( - self, scope: str, scope_id: str | None = None, key: str | None = None + self, + scope: str, + scope_id: str | None = None, + key: str | None = None, ) -> list[Preference]: """获取指定范围的偏好设置 Note: 返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。scope_id 和 key 可以为 None,这时返回该范围下所有的偏好设置。 @@ -54,21 +58,33 @@ class SharedPreferences: @overload async def session_get( - self, umo: None, key: str, default: Any = None + self, + umo: None, + key: str, + default: Any = None, ) -> list[Preference]: ... @overload async def session_get( - self, umo: str, key: None, default: Any = None + self, + umo: str, + key: None, + default: Any = None, ) -> list[Preference]: ... @overload async def session_get( - self, umo: None, key: None, default: Any = None + self, + umo: None, + key: None, + default: Any = None, ) -> list[Preference]: ... async def session_get( - self, umo: str | None, key: str | None = None, default: _VT = None + self, + umo: str | None, + key: str | None = None, + default: _VT = None, ) -> _VT | list[Preference]: """获取会话范围的偏好设置 @@ -85,7 +101,9 @@ class SharedPreferences: async def global_get(self, key: str, default: _VT = None) -> _VT: ... async def global_get( - self, key: str | None, default: _VT = None + self, + key: str | None, + default: _VT = None, ) -> _VT | list[Preference]: """获取全局范围的偏好设置 @@ -98,7 +116,10 @@ class SharedPreferences: async def put_async(self, scope: str, scope_id: str, key: str, value: Any): """设置指定范围和键的偏好设置""" await self.db_helper.insert_preference_or_update( - scope, scope_id, key, {"val": value} + scope, + scope_id, + key, + {"val": value}, ) async def session_put(self, umo: str, key: str, value: Any): @@ -139,7 +160,7 @@ class SharedPreferences: if scope_id is None or key is None: # result = asyncio.run(self.range_get_async(scope, scope_id, key)) raise ValueError( - "scope_id and key cannot be None when getting a specific preference." + "scope_id and key cannot be None when getting a specific preference.", ) result = asyncio.run_coroutine_threadsafe( self.get_async(scope or "unknown", scope_id or "unknown", key, default), @@ -149,11 +170,15 @@ class SharedPreferences: return result if result is not None else default def range_get( - self, scope: str, scope_id: str | None = None, key: str | None = None + self, + scope: str, + scope_id: str | None = None, + key: str | None = None, ) -> list[Preference]: """获取指定范围的偏好设置(已弃用)""" result = asyncio.run_coroutine_threadsafe( - self.range_get_async(scope, scope_id, key), self._sync_loop + self.range_get_async(scope, scope_id, key), + self._sync_loop, ).result() return result diff --git a/astrbot/core/utils/t2i/__init__.py b/astrbot/core/utils/t2i/__init__.py index 8ce209ad3..5038a46f7 100644 --- a/astrbot/core/utils/t2i/__init__.py +++ b/astrbot/core/utils/t2i/__init__.py @@ -8,6 +8,9 @@ class RenderStrategy(ABC): @abstractmethod def render_custom_template( - self, tmpl_str: str, tmpl_data: dict, return_url: bool + self, + tmpl_str: str, + tmpl_data: dict, + return_url: bool, ) -> str: pass diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index c43f9ed2e..7ebba5669 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -1,14 +1,17 @@ -import aiohttp import asyncio -import ssl -import certifi import logging import random -from . import RenderStrategy +import ssl + +import aiohttp +import certifi + from astrbot.core.config import VERSION from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.t2i.template_manager import TemplateManager +from . import RenderStrategy + ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" logger = logging.getLogger("astrbot") @@ -38,7 +41,7 @@ class NetworkRenderStrategy(RenderStrategy): try: async with aiohttp.ClientSession() as session: async with session.get( - "https://api.soulter.top/astrbot/t2i-endpoints" + "https://api.soulter.top/astrbot/t2i-endpoints", ) as resp: if resp.status == 200: data = await resp.json() @@ -49,14 +52,13 @@ class NetworkRenderStrategy(RenderStrategy): if ep.get("active") and ep.get("url") ] logger.info( - f"Successfully got {len(self.endpoints)} official T2I endpoints." + f"Successfully got {len(self.endpoints)} official T2I endpoints.", ) except Exception as e: logger.error(f"Failed to get official endpoints: {e}") def _clean_url(self, url: str): - if url.endswith("/"): - url = url[:-1] + url = url.removesuffix("/") if not url.endswith("text2img"): url += "/text2img" return url @@ -69,7 +71,6 @@ class NetworkRenderStrategy(RenderStrategy): options: dict | None = None, ) -> str: """使用自定义文转图模板""" - default_options = {"full_page": True, "type": "jpeg", "quality": 40} if options: default_options |= options @@ -89,21 +90,26 @@ class NetworkRenderStrategy(RenderStrategy): if return_url: ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession( - trust_env=True, connector=connector - ) as session: - async with session.post( - f"{endpoint}/generate", json=post_data - ) as resp: - if resp.status == 200: - ret = await resp.json() - return f"{endpoint}/{ret['data']['id']}" - else: - raise Exception(f"HTTP {resp.status}") + async with ( + aiohttp.ClientSession( + trust_env=True, + connector=connector, + ) as session, + session.post( + f"{endpoint}/generate", + json=post_data, + ) as resp, + ): + if resp.status == 200: + ret = await resp.json() + return f"{endpoint}/{ret['data']['id']}" + raise Exception(f"HTTP {resp.status}") else: # download_image_by_url 失败时抛异常 return await download_image_by_url( - f"{endpoint}/generate", post=True, post_data=post_data + f"{endpoint}/generate", + post=True, + post_data=post_data, ) except Exception as e: last_exception = e @@ -114,15 +120,18 @@ class NetworkRenderStrategy(RenderStrategy): raise RuntimeError(f"All endpoints failed: {last_exception}") async def render( - self, text: str, return_url: bool = False, template_name: str | None = "base" + self, + text: str, + return_url: bool = False, + template_name: str | None = "base", ) -> str: - """ - 返回图像的文件路径 - """ + """返回图像的文件路径""" if not template_name: template_name = "base" tmpl_str = await self.get_template(name=template_name) text = text.replace("`", "\\`") return await self.render_custom_template( - tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url + tmpl_str, + {"text": text, "version": f"v{VERSION}"}, + return_url, ) diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 122189f93..2ce7a5ebf 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -1,7 +1,8 @@ -from .network_strategy import NetworkRenderStrategy -from .local_strategy import LocalRenderStrategy from astrbot.core.log import LogManager +from .local_strategy import LocalRenderStrategy +from .network_strategy import NetworkRenderStrategy + logger = LogManager.GetLogger(log_name="astrbot") @@ -30,7 +31,10 @@ class HtmlRenderer: @example: 参见 https://astrbot.app 插件开发部分。 """ return await self.network_strategy.render_custom_template( - tmpl_str, tmpl_data, return_url, options + tmpl_str, + tmpl_data, + return_url, + options, ) async def render_t2i( @@ -44,11 +48,13 @@ class HtmlRenderer: if use_network: try: return await self.network_strategy.render( - text, return_url=return_url, template_name=template_name + text, + return_url=return_url, + template_name=template_name, ) except BaseException as e: logger.error( - f"Failed to render image via AstrBot API: {e}. Falling back to local rendering." + f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.", ) return await self.local_strategy.render(text) else: diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index b441a908e..6d44f735b 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -2,12 +2,12 @@ import os import shutil + from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path class TemplateManager: - """ - 负责管理 t2i HTML 模板的 CRUD 和重置操作。 + """负责管理 t2i HTML 模板的 CRUD 和重置操作。 采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。 所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。 """ @@ -16,7 +16,12 @@ class TemplateManager: def __init__(self): self.builtin_template_dir = os.path.join( - get_astrbot_path(), "astrbot", "core", "utils", "t2i", "template" + get_astrbot_path(), + "astrbot", + "core", + "utils", + "t2i", + "template", ) self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates") @@ -43,12 +48,11 @@ class TemplateManager: def _read_file(self, path: str) -> str: """读取文件内容。""" - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return f.read() def list_templates(self) -> list[dict]: - """ - 列出所有可用模板。 + """列出所有可用模板。 该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。 """ dirs_to_scan = [self.builtin_template_dir, self.user_template_dir] @@ -63,8 +67,7 @@ class TemplateManager: ] def get_template(self, name: str) -> str: - """ - 获取指定模板的内容。 + """获取指定模板的内容。 优先从用户目录加载,如果不存在则回退到内置目录。 """ user_path = self._get_user_template_path(name) @@ -86,8 +89,7 @@ class TemplateManager: f.write(content) def update_template(self, name: str, content: str): - """ - 更新一个模板。此操作始终写入用户目录。 + """更新一个模板。此操作始终写入用户目录。 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本, 从而实现对内置模板的“覆盖”。 """ @@ -96,8 +98,7 @@ class TemplateManager: f.write(content) def delete_template(self, name: str): - """ - 仅删除用户目录中的模板文件。 + """仅删除用户目录中的模板文件。 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。 """ path = self._get_user_template_path(name) @@ -106,7 +107,5 @@ class TemplateManager: os.remove(path) def reset_default_template(self): - """ - 将核心模板从内置目录强制重置到用户目录。 - """ + """将核心模板从内置目录强制重置到用户目录。""" self._copy_core_templates(overwrite=True) diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index 2c97a01ed..9cc36571e 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -1,10 +1,11 @@ +import asyncio import base64 -import wave import os import subprocess -from io import BytesIO -import asyncio import tempfile +import wave +from io import BytesIO + from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -35,7 +36,7 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int: import pilk except (ImportError, ModuleNotFoundError) as _: raise Exception( - "pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库" + "pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库", ) # with wave.open(wav_path, 'rb') as wav: # wav_data = wav.readframes(wav.getnframes()) @@ -60,8 +61,7 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int: async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: - """ - 将 MP3 或其他音频格式转换为 PCM 16bit WAV,采样率24000Hz,单声道。 + """将 MP3 或其他音频格式转换为 PCM 16bit WAV,采样率24000Hz,单声道。 若转换失败则抛出异常。 """ try: @@ -99,13 +99,11 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: if os.path.exists(output_path) and os.path.getsize(output_path) > 0: return output_path - else: - raise RuntimeError("生成的WAV文件不存在或为空") + raise RuntimeError("生成的WAV文件不存在或为空") async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]: - """ - 将 MP3/WAV 文件转为 Tencent Silk 并返回 base64 编码与时长(秒)。 + """将 MP3/WAV 文件转为 Tencent Silk 并返回 base64 编码与时长(秒)。 参数: - audio_path: 输入音频文件路径(.mp3 或 .wav) @@ -125,7 +123,9 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]: # 是否需要转换为 WAV ext = os.path.splitext(audio_path)[1].lower() temp_wav = tempfile.NamedTemporaryFile( - suffix=".wav", delete=False, dir=temp_dir + suffix=".wav", + delete=False, + dir=temp_dir, ).name if ext != ".wav": @@ -140,12 +140,18 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]: rate = wav_file.getframerate() silk_path = tempfile.NamedTemporaryFile( - suffix=".silk", delete=False, dir=temp_dir + suffix=".silk", + delete=False, + dir=temp_dir, ).name try: duration = await asyncio.to_thread( - pilk.encode, wav_path, silk_path, pcm_rate=rate, tencent=True + pilk.encode, + wav_path, + silk_path, + pcm_rate=rate, + tencent=True, ) with open(silk_path, "rb") as f: diff --git a/astrbot/core/utils/version_comparator.py b/astrbot/core/utils/version_comparator.py index f7ad65fcd..e3bf74951 100644 --- a/astrbot/core/utils/version_comparator.py +++ b/astrbot/core/utils/version_comparator.py @@ -38,15 +38,15 @@ class VersionComparator: for i in range(length): if v1_parts[i] > v2_parts[i]: return 1 - elif v1_parts[i] < v2_parts[i]: + if v1_parts[i] < v2_parts[i]: return -1 # 比较预发布标签 if v1_prerelease is None and v2_prerelease is not None: return 1 # 没有预发布标签的版本高于有预发布标签的版本 - elif v1_prerelease is not None and v2_prerelease is None: + if v1_prerelease is not None and v2_prerelease is None: return -1 # 有预发布标签的版本低于没有预发布标签的版本 - elif v1_prerelease is not None and v2_prerelease is not None: + if v1_prerelease is not None and v2_prerelease is not None: len_pre = max(len(v1_prerelease), len(v2_prerelease)) for i in range(len_pre): p1 = v1_prerelease[i] if i < len(v1_prerelease) else None @@ -54,21 +54,18 @@ class VersionComparator: if p1 is None and p2 is not None: return -1 - elif p1 is not None and p2 is None: + if p1 is not None and p2 is None: return 1 - elif isinstance(p1, int) and isinstance(p2, str): + if isinstance(p1, int) and isinstance(p2, str): return -1 - elif isinstance(p1, str) and isinstance(p2, int): + if isinstance(p1, str) and isinstance(p2, int): return 1 - elif isinstance(p1, int) and isinstance(p2, int): + if (isinstance(p1, int) and isinstance(p2, int)) or ( + isinstance(p1, str) and isinstance(p2, str) + ): if p1 > p2: return 1 - elif p1 < p2: - return -1 - elif isinstance(p1, str) and isinstance(p2, str): - if p1 > p2: - return 1 - elif p1 < p2: + if p1 < p2: return -1 return 0 # 预发布标签完全相同 diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 7e5f3bfbb..728dfdabb 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -1,14 +1,14 @@ -import aiohttp import os import re -import zipfile import shutil - import ssl +import zipfile + +import aiohttp import certifi -from astrbot.core.utils.io import on_error, download_file from astrbot.core import logger +from astrbot.core.utils.io import download_file, on_error from astrbot.core.utils.version_comparator import VersionComparator @@ -18,7 +18,10 @@ class ReleaseInfo: body: str def __init__( - self, version: str = "", published_at: str = "", body: str = "" + self, + version: str = "", + published_at: str = "", + body: str = "", ) -> None: self.version = version self.published_at = published_at @@ -34,29 +37,31 @@ class RepoZipUpdator: self.rm_on_error = on_error async def fetch_release_info(self, url: str, latest: bool = True) -> list: - """ - 请求版本信息。 + """请求版本信息。 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 """ try: ssl_context = ssl.create_default_context( - cafile=certifi.where() + cafile=certifi.where(), ) # 新增:创建基于 certifi 的 SSL 上下文 connector = aiohttp.TCPConnector( - ssl=ssl_context + ssl=ssl_context, ) # 新增:使用 TCPConnector 指定 SSL 上下文 - async with aiohttp.ClientSession( - trust_env=True, connector=connector - ) as session: - async with session.get(url) as response: - # 检查 HTTP 状态码 - if response.status != 200: - text = await response.text() - logger.error( - f"请求 {url} 失败,状态码: {response.status}, 内容: {text}" - ) - raise Exception(f"请求失败,状态码: {response.status}") - result = await response.json() + async with ( + aiohttp.ClientSession( + trust_env=True, + connector=connector, + ) as session, + session.get(url) as response, + ): + # 检查 HTTP 状态码 + if response.status != 200: + text = await response.text() + logger.error( + f"请求 {url} 失败,状态码: {response.status}, 内容: {text}", + ) + raise Exception(f"请求失败,状态码: {response.status}") + result = await response.json() if not result: return [] # if latest: @@ -72,7 +77,7 @@ class RepoZipUpdator: "body": release["body"], "tag_name": release["tag_name"], "zipball_url": release["zipball_url"], - } + }, ) except Exception as e: logger.error(f"解析版本信息时发生异常: {e}") @@ -80,8 +85,7 @@ class RepoZipUpdator: return ret def github_api_release_parser(self, releases: list) -> list: - """ - 解析 GitHub API 返回的 releases 信息。 + """解析 GitHub API 返回的 releases 信息。 返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。 """ ret = [] @@ -93,22 +97,25 @@ class RepoZipUpdator: "body": release["body"], "tag_name": release["tag_name"], "zipball_url": release["zipball_url"], - } + }, ) return ret def unzip(self): - raise NotImplementedError() + raise NotImplementedError async def update(self): - raise NotImplementedError() + raise NotImplementedError def compare_version(self, v1: str, v2: str) -> int: """Semver 版本比较""" return VersionComparator.compare_version(v1, v2) async def check_update( - self, url: str, current_version: str, consider_prerelease: bool = True + self, + url: str, + current_version: str, + consider_prerelease: bool = True, ) -> ReleaseInfo | None: update_data = await self.fetch_release_info(url) @@ -157,7 +164,7 @@ class RepoZipUpdator: releases = await self.fetch_release_info(url=release_url) except Exception as e: logger.warning( - f"获取 {author}/{repo} 的 GitHub Releases 失败: {e},将尝试下载默认分支" + f"获取 {author}/{repo} 的 GitHub Releases 失败: {e},将尝试下载默认分支", ) releases = [] if not releases: @@ -173,7 +180,7 @@ class RepoZipUpdator: proxy = proxy.rstrip("/") release_url = f"{proxy}/{release_url}" logger.info( - f"检查到设置了镜像站,将使用镜像站下载 {author}/{repo} 仓库源码: {release_url}" + f"检查到设置了镜像站,将使用镜像站下载 {author}/{repo} 仓库源码: {release_url}", ) await download_file(release_url, target_path + ".zip") @@ -194,13 +201,10 @@ class RepoZipUpdator: repo = match.group(2) branch = match.group(4) return author, repo, branch - else: - raise ValueError("无效的 GitHub URL") + raise ValueError("无效的 GitHub URL") def unzip_file(self, zip_path: str, target_dir: str): - """ - 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir - """ + """解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir""" os.makedirs(target_dir, exist_ok=True) update_dir = "" with zipfile.ZipFile(zip_path, "r") as z: @@ -213,20 +217,19 @@ class RepoZipUpdator: if os.path.isdir(os.path.join(target_dir, update_dir, f)): if os.path.exists(os.path.join(target_dir, f)): shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) - else: - if os.path.exists(os.path.join(target_dir, f)): - os.remove(os.path.join(target_dir, f)) + elif os.path.exists(os.path.join(target_dir, f)): + os.remove(os.path.join(target_dir, f)) shutil.move(os.path.join(target_dir, update_dir, f), target_dir) try: logger.debug( - f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}" + f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}", ) shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) os.remove(zip_path) except BaseException: logger.warning( - f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}" + f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}", ) def format_name(self, name: str) -> str: diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index e1d58f622..b7997cf8e 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -1,31 +1,31 @@ from .auth import AuthRoute -from .plugin import PluginRoute -from .config import ConfigRoute -from .update import UpdateRoute -from .stat import StatRoute -from .log import LogRoute -from .static_file import StaticFileRoute from .chat import ChatRoute -from .tools import ToolsRoute +from .config import ConfigRoute from .conversation import ConversationRoute from .file import FileRoute -from .session_management import SessionManagementRoute -from .persona import PersonaRoute from .knowledge_base import KnowledgeBaseRoute +from .log import LogRoute +from .persona import PersonaRoute +from .plugin import PluginRoute +from .session_management import SessionManagementRoute +from .stat import StatRoute +from .static_file import StaticFileRoute +from .tools import ToolsRoute +from .update import UpdateRoute __all__ = [ "AuthRoute", - "PluginRoute", - "ConfigRoute", - "UpdateRoute", - "StatRoute", - "LogRoute", - "StaticFileRoute", "ChatRoute", - "ToolsRoute", + "ConfigRoute", "ConversationRoute", "FileRoute", - "SessionManagementRoute", - "PersonaRoute", "KnowledgeBaseRoute", + "LogRoute", + "PersonaRoute", + "PluginRoute", + "SessionManagementRoute", + "StatRoute", + "StaticFileRoute", + "ToolsRoute", + "UpdateRoute", ] diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 87af4b61e..4ee0d57d4 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -1,10 +1,13 @@ -import jwt -import datetime import asyncio -from .route import Route, Response, RouteContext +import datetime + +import jwt from quart import request -from astrbot.core import DEMO_MODE + from astrbot import logger +from astrbot.core import DEMO_MODE + +from .route import Response, Route, RouteContext class AuthRoute(Route): @@ -37,13 +40,12 @@ class AuthRoute(Route): "token": self.generate_jwt(username), "username": username, "change_pwd_hint": change_pwd_hint, - } + }, ) .__dict__ ) - else: - await asyncio.sleep(3) - return Response().error("用户名或密码错误").__dict__ + await asyncio.sleep(3) + return Response().error("用户名或密码错误").__dict__ async def edit_account(self): if DEMO_MODE: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 71fd3472b..6954f2d6a 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,16 +1,20 @@ -import uuid +import asyncio import json import os -import asyncio +import uuid from contextlib import asynccontextmanager -from .route import Route, Response, RouteContext -from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr -from quart import request, Response as QuartResponse, g, make_response -from astrbot.core.db import BaseDatabase + +from quart import Response as QuartResponse +from quart import g, make_response, request + from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.db import BaseDatabase from astrbot.core.platform.astr_message_event import MessageSession +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +from .route import Response, Route, RouteContext @asynccontextmanager @@ -70,10 +74,9 @@ class ChatRoute(Route): if filename_ext == ".wav": return QuartResponse(f.read(), mimetype="audio/wav") - elif filename_ext[1:] in self.supported_imgs: + if filename_ext[1:] in self.supported_imgs: return QuartResponse(f.read(), mimetype="image/jpeg") - else: - return QuartResponse(f.read()) + return QuartResponse(f.read()) except (FileNotFoundError, OSError): return Response().error("File access error").__dict__ @@ -96,7 +99,7 @@ class ChatRoute(Route): return Response().error("Missing key: file").__dict__ file = post_data["file"] - filename = f"{str(uuid.uuid4())}" + filename = f"{uuid.uuid4()!s}" # 通过文件格式判断文件类型 if file.content_type.startswith("audio"): filename += ".wav" @@ -179,7 +182,7 @@ class ChatRoute(Route): except Exception as e: if not client_disconnected: logger.debug( - f"[WebChat] 用户 {username} 断开聊天长连接。 {e}" + f"[WebChat] 用户 {username} 断开聊天长连接。 {e}", ) client_disconnected = True @@ -222,7 +225,7 @@ class ChatRoute(Route): "selected_provider": selected_provider, "selected_model": selected_model, }, - ) + ), ) response = await make_response( @@ -243,7 +246,8 @@ class ChatRoute(Route): NOTE: 关于这里为什么要单独做一个 WebChat 的 Conversation ID 出来,这个是为了向前兼容。 """ conversation = await self.conv_mgr.get_conversation( - unified_msg_origin="webchat", conversation_id=conversation_id + unified_msg_origin="webchat", + conversation_id=conversation_id, ) if not conversation: raise ValueError(f"Conversation with ID {conversation_id} not found.") @@ -267,7 +271,9 @@ class ChatRoute(Route): conversation_id=conversation_id, ) await self.platform_history_mgr.delete( - platform_id="webchat", user_id=webchat_conv_id, offset_sec=99999999 + platform_id="webchat", + user_id=webchat_conv_id, + offset_sec=99999999, ) return Response().ok().__dict__ @@ -314,7 +320,10 @@ class ChatRoute(Route): # Get platform message history history_ls = await self.platform_history_mgr.get( - platform_id="webchat", user_id=webchat_conv_id, page=1, page_size=1000 + platform_id="webchat", + user_id=webchat_conv_id, + page=1, + page_size=1000, ) history_res = [history.model_dump() for history in history_ls] @@ -325,7 +334,7 @@ class ChatRoute(Route): data={ "history": history_res, "is_running": self.running_convs.get(webchat_conv_id, False), - } + }, ) .__dict__ ) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 998240c99..59b07e47d 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,26 +1,29 @@ -import traceback -import os +import asyncio import inspect -from .route import Route, Response, RouteContext -from astrbot.core.provider.entities import ProviderType +import os +import traceback + from quart import request + +from astrbot.core import file_token_service, logger +from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.default import ( - DEFAULT_CONFIG, CONFIG_METADATA_2, - DEFAULT_VALUE_MAP, CONFIG_METADATA_3, CONFIG_METADATA_3_SYSTEM, + DEFAULT_CONFIG, + DEFAULT_VALUE_MAP, ) -from astrbot.core.utils.astrbot_path import get_astrbot_path -from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.platform.register import platform_registry, platform_cls_map +from astrbot.core.platform.register import platform_cls_map, platform_registry +from astrbot.core.provider import Provider +from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.provider import RerankProvider from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry -from astrbot.core import logger, file_token_service -from astrbot.core.provider import Provider -from astrbot.core.provider.provider import RerankProvider -import asyncio +from astrbot.core.utils.astrbot_path import get_astrbot_path + +from .route import Response, Route, RouteContext def try_cast(value: str, type_: str): @@ -33,9 +36,7 @@ def try_cast(value: str, type_: str): type_ == "float" and isinstance(value, str) and value.replace(".", "", 1).isdigit() - ): - return float(value) - elif type_ == "float" and isinstance(value, int): + ) or (type_ == "float" and isinstance(value, int)): return float(value) elif type_ == "float": try: @@ -61,7 +62,7 @@ def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict] continue if meta["type"] == "list" and not isinstance(value, list): errors.append( - f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", ) elif ( meta["type"] == "list" @@ -80,31 +81,31 @@ def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict] casted = try_cast(value, "int") if casted is None: errors.append( - f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}", ) data[key] = casted elif meta["type"] == "float" and not isinstance(value, float): casted = try_cast(value, "float") if casted is None: errors.append( - f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}", ) data[key] = casted elif meta["type"] == "bool" and not isinstance(value, bool): errors.append( - f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}", ) elif meta["type"] in ["string", "text"] and not isinstance(value, str): errors.append( - f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}", ) elif meta["type"] == "list" and not isinstance(value, list): errors.append( - f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", ) elif meta["type"] == "object" and not isinstance(value, dict): errors.append( - f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}" + f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}", ) if is_core: @@ -127,7 +128,9 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) try: if is_core: errors, post_config = validate_config( - post_config, CONFIG_METADATA_2, is_core + post_config, + CONFIG_METADATA_2, + is_core, ) else: errors, post_config = validate_config(post_config, config.schema, is_core) @@ -143,7 +146,9 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) class ConfigRoute(Route): def __init__( - self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle @@ -199,7 +204,7 @@ class ConfigRoute(Route): return Response().ok(message="更新成功").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"更新路由表失败: {str(e)}").__dict__ + return Response().error(f"更新路由表失败: {e!s}").__dict__ async def update_ucr(self): """更新 UMOP 配置路由表""" @@ -218,7 +223,7 @@ class ConfigRoute(Route): return Response().ok(message="更新成功").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"更新路由表失败: {str(e)}").__dict__ + return Response().error(f"更新路由表失败: {e!s}").__dict__ async def delete_ucr(self): """删除 UMOP 配置路由表中的一项""" @@ -238,7 +243,7 @@ class ConfigRoute(Route): return Response().ok(message="删除成功").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"删除路由表项失败: {str(e)}").__dict__ + return Response().error(f"删除路由表项失败: {e!s}").__dict__ async def get_default_config(self): """获取默认配置文件""" @@ -305,13 +310,12 @@ class ConfigRoute(Route): success = self.acm.delete_conf(conf_id) if success: return Response().ok(message="删除成功").__dict__ - else: - return Response().error("删除失败").__dict__ + return Response().error("删除失败").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"删除配置文件失败: {str(e)}").__dict__ + return Response().error(f"删除配置文件失败: {e!s}").__dict__ async def update_abconf(self): """更新指定 AstrBot 配置文件信息""" @@ -329,13 +333,12 @@ class ConfigRoute(Route): success = self.acm.update_conf_info(conf_id, name=name) if success: return Response().ok(message="更新成功").__dict__ - else: - return Response().error("更新失败").__dict__ + return Response().error("更新失败").__dict__ except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"更新配置文件失败: {str(e)}").__dict__ + return Response().error(f"更新配置文件失败: {e!s}").__dict__ async def _test_single_provider(self, provider): """辅助函数:测试单个 provider 的可用性""" @@ -352,17 +355,18 @@ class ConfigRoute(Route): "error": None, } logger.debug( - f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})" + f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", ) if provider_capability_type == ProviderType.CHAT_COMPLETION: try: logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") response = await asyncio.wait_for( - provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 + provider.text_chat(prompt="REPLY `PONG` ONLY"), + timeout=45.0, ) logger.debug( - f"Received response from {status_info['name']}: {response}" + f"Received response from {status_info['name']}: {response}", ) if response is not None: status_info["status"] = "available" @@ -386,14 +390,14 @@ class ConfigRoute(Route): except Exception as _: pass logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" + f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'", ) else: status_info["error"] = ( "Test call returned None, but expected an LLMResponse object." ) logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None." + f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.", ) except asyncio.TimeoutError: @@ -401,16 +405,16 @@ class ConfigRoute(Route): "Connection timed out after 45 seconds during test call." ) logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) timed out." + f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.", ) except Exception as e: error_message = str(e) status_info["error"] = error_message logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}" + f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", ) logger.debug( - f"Traceback for {status_info['name']}:\n{traceback.format_exc()}" + f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", ) elif provider_capability_type == ProviderType.EMBEDDING: @@ -432,7 +436,7 @@ class ConfigRoute(Route): exc_info=True, ) status_info["status"] = "unavailable" - status_info["error"] = f"Embedding test failed: {str(e)}" + status_info["error"] = f"Embedding test failed: {e!s}" elif provider_capability_type == ProviderType.TEXT_TO_SPEECH: try: @@ -447,17 +451,20 @@ class ConfigRoute(Route): ) except Exception as e: logger.error( - f"Error testing TTS provider {provider_name}: {e}", exc_info=True + f"Error testing TTS provider {provider_name}: {e}", + exc_info=True, ) status_info["status"] = "unavailable" - status_info["error"] = f"TTS test failed: {str(e)}" + status_info["error"] = f"TTS test failed: {e!s}" elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: try: logger.debug( - f"Sending health check audio to provider: {status_info['name']}" + f"Sending health check audio to provider: {status_info['name']}", ) sample_audio_path = os.path.join( - get_astrbot_path(), "samples", "stt_health_check.wav" + get_astrbot_path(), + "samples", + "stt_health_check.wav", ) if not os.path.exists(sample_audio_path): status_info["status"] = "unavailable" @@ -465,7 +472,7 @@ class ConfigRoute(Route): "STT test failed: sample audio file not found." ) logger.warning( - f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}" + f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}", ) else: text_result = await provider.get_text(sample_audio_path) @@ -477,7 +484,7 @@ class ConfigRoute(Route): else text_result ) logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'" + f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'", ) else: status_info["status"] = "unavailable" @@ -485,14 +492,15 @@ class ConfigRoute(Route): f"STT test failed: unexpected result type {type(text_result)}" ) logger.warning( - f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}" + f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}", ) except Exception as e: logger.error( - f"Error testing STT provider {provider_name}: {e}", exc_info=True + f"Error testing STT provider {provider_name}: {e}", + exc_info=True, ) status_info["status"] = "unavailable" - status_info["error"] = f"STT test failed: {str(e)}" + status_info["error"] = f"STT test failed: {e!s}" elif provider_capability_type == ProviderType.RERANK: try: assert isinstance(provider, RerankProvider) @@ -504,11 +512,11 @@ class ConfigRoute(Route): exc_info=True, ) status_info["status"] = "unavailable" - status_info["error"] = f"Rerank test failed: {str(e)}" + status_info["error"] = f"Rerank test failed: {e!s}" else: logger.debug( - f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}" + f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}", ) status_info["status"] = "available" status_info["error"] = ( @@ -518,7 +526,10 @@ class ConfigRoute(Route): return status_info def _error_response( - self, message: str, status_code: int = 500, log_fn=logger.error + self, + message: str, + status_code: int = 500, + log_fn=logger.error, ): log_fn(message) # 记录更详细的traceback信息,但只在是严重错误时 @@ -531,7 +542,9 @@ class ConfigRoute(Route): provider_id = request.args.get("id") if not provider_id: return self._error_response( - "Missing provider_id parameter", 400, logger.warning + "Missing provider_id parameter", + 400, + logger.warning, ) logger.info(f"API call: /config/provider/check_one id={provider_id}") @@ -541,7 +554,7 @@ class ConfigRoute(Route): if not target: logger.warning( - f"Provider with id '{provider_id}' not found in provider_manager." + f"Provider with id '{provider_id}' not found in provider_manager.", ) return ( Response() @@ -554,7 +567,8 @@ class ConfigRoute(Route): except Exception as e: return self._error_response( - f"Critical error checking provider {provider_id}: {e}", 500 + f"Critical error checking provider {provider_id}: {e}", + 500, ) async def get_configs(self): @@ -646,13 +660,13 @@ class ConfigRoute(Route): dim = len(vec) logger.info( - f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}" + f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}", ) return Response().ok({"embedding_dimensions": dim}).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"获取嵌入维度失败: {str(e)}").__dict__ + return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ async def get_platform_list(self): """获取所有平台的列表""" @@ -693,7 +707,7 @@ class ConfigRoute(Route): try: save_config(self.config, self.config, is_core=True) await self.core_lifecycle.platform_manager.load_platform( - new_platform_config + new_platform_config, ) except Exception as e: return Response().error(str(e)).__dict__ @@ -705,7 +719,7 @@ class ConfigRoute(Route): try: save_config(self.config, self.config, is_core=True) await self.core_lifecycle.provider_manager.load_provider( - new_provider_config + new_provider_config, ) except Exception as e: return Response().error(str(e)).__dict__ @@ -802,9 +816,9 @@ class ConfigRoute(Route): if cache_key in self._logo_token_cache: cached_token = self._logo_token_cache[cache_key] # 确保platform_default_tmpl[platform.name]存在且为字典 - if platform.name not in platform_default_tmpl: - platform_default_tmpl[platform.name] = {} - elif not isinstance(platform_default_tmpl[platform.name], dict): + if platform.name not in platform_default_tmpl or not isinstance( + platform_default_tmpl[platform.name], dict + ): platform_default_tmpl[platform.name] = {} platform_default_tmpl[platform.name]["logo_token"] = cached_token logger.debug(f"Using cached logo token for platform {platform.name}") @@ -826,13 +840,14 @@ class ConfigRoute(Route): # 检查文件是否存在并注册令牌 if os.path.exists(logo_file_path): logo_token = await file_token_service.register_file( - logo_file_path, timeout=3600 + logo_file_path, + timeout=3600, ) # 确保platform_default_tmpl[platform.name]存在且为字典 - if platform.name not in platform_default_tmpl: - platform_default_tmpl[platform.name] = {} - elif not isinstance(platform_default_tmpl[platform.name], dict): + if platform.name not in platform_default_tmpl or not isinstance( + platform_default_tmpl[platform.name], dict + ): platform_default_tmpl[platform.name] = {} platform_default_tmpl[platform.name]["logo_token"] = logo_token @@ -843,18 +858,18 @@ class ConfigRoute(Route): logger.debug(f"Logo token registered for platform {platform.name}") else: logger.warning( - f"Platform {platform.name} logo file not found: {logo_file_path}" + f"Platform {platform.name} logo file not found: {logo_file_path}", ) except (ImportError, AttributeError) as e: logger.warning( - f"Failed to import required modules for platform {platform.name}: {e}" + f"Failed to import required modules for platform {platform.name}: {e}", ) except OSError as e: logger.warning(f"File system error for platform {platform.name} logo: {e}") except Exception as e: logger.warning( - f"Unexpected error registering logo for platform {platform.name}: {e}" + f"Unexpected error registering logo for platform {platform.name}: {e}", ) async def _get_astrbot_config(self): @@ -873,7 +888,7 @@ class ConfigRoute(Route): # 收集logo注册任务 if platform.logo_path: logo_registration_tasks.append( - self._register_platform_logo(platform, platform_default_tmpl) + self._register_platform_logo(platform, platform_default_tmpl), ) # 并行执行logo注册 @@ -905,7 +920,7 @@ class ConfigRoute(Route): "description": f"{plugin_name} 配置", "type": "object", "items": plugin_md.config.schema, # 初始化时通过 __setattr__ 存入了 schema - } + }, } break diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py index 56f892f24..d19fdf793 100644 --- a/astrbot/dashboard/routes/conversation.py +++ b/astrbot/dashboard/routes/conversation.py @@ -1,10 +1,13 @@ -import traceback import json -from .route import Route, Response, RouteContext -from astrbot.core import logger +import traceback + from quart import request -from astrbot.core.db import BaseDatabase + +from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase + +from .route import Response, Route, RouteContext class ConversationRoute(Route): @@ -55,12 +58,10 @@ class ConversationRoute(Route): exclude_platforms.split(",") if exclude_platforms else [] ) - if page < 1: - page = 1 + page = max(page, 1) if page_size < 1: page_size = 20 - if page_size > 100: - page_size = 100 + page_size = min(page_size, 100) try: ( @@ -76,8 +77,8 @@ class ConversationRoute(Route): exclude_platforms=exclude_platform_list, ) except Exception as e: - logger.error(f"数据库查询出错: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"数据库查询出错: {str(e)}").__dict__ + logger.error(f"数据库查询出错: {e!s}\n{traceback.format_exc()}") + return Response().error(f"数据库查询出错: {e!s}").__dict__ # 计算总页数 total_pages = ( @@ -96,9 +97,9 @@ class ConversationRoute(Route): return Response().ok(result).__dict__ except Exception as e: - error_msg = f"获取对话列表失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"获取对话列表失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"获取对话列表失败: {str(e)}").__dict__ + return Response().error(f"获取对话列表失败: {e!s}").__dict__ async def get_conv_detail(self): """获取指定对话详情(通过POST请求)""" @@ -111,7 +112,8 @@ class ConversationRoute(Route): return Response().error("缺少必要参数: user_id 和 cid").__dict__ conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, conversation_id=cid + unified_msg_origin=user_id, + conversation_id=cid, ) if not conversation: return Response().error("对话不存在").__dict__ @@ -127,14 +129,14 @@ class ConversationRoute(Route): "history": conversation.history, "created_at": conversation.created_at, "updated_at": conversation.updated_at, - } + }, ) .__dict__ ) except Exception as e: - logger.error(f"获取对话详情失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"获取对话详情失败: {str(e)}").__dict__ + logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取对话详情失败: {e!s}").__dict__ async def upd_conv(self): """更新对话信息(标题和角色ID)""" @@ -148,7 +150,8 @@ class ConversationRoute(Route): if not user_id or not cid: return Response().error("缺少必要参数: user_id 和 cid").__dict__ conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, conversation_id=cid + unified_msg_origin=user_id, + conversation_id=cid, ) if not conversation: return Response().error("对话不存在").__dict__ @@ -162,8 +165,8 @@ class ConversationRoute(Route): return Response().ok({"message": "对话信息更新成功"}).__dict__ except Exception as e: - logger.error(f"更新对话信息失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"更新对话信息失败: {str(e)}").__dict__ + logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"更新对话信息失败: {e!s}").__dict__ async def del_conv(self): """删除对话""" @@ -188,17 +191,18 @@ class ConversationRoute(Route): if not user_id or not cid: failed_items.append( - f"user_id:{user_id}, cid:{cid} - 缺少必要参数" + f"user_id:{user_id}, cid:{cid} - 缺少必要参数", ) continue try: await self.core_lifecycle.conversation_manager.delete_conversation( - unified_msg_origin=user_id, conversation_id=cid + unified_msg_origin=user_id, + conversation_id=cid, ) deleted_count += 1 except Exception as e: - failed_items.append(f"user_id:{user_id}, cid:{cid} - {str(e)}") + failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}") message = f"成功删除 {deleted_count} 个对话" if failed_items: @@ -212,26 +216,26 @@ class ConversationRoute(Route): "deleted_count": deleted_count, "failed_count": len(failed_items), "failed_items": failed_items, - } + }, ) .__dict__ ) - else: - # 单个删除 - user_id = data.get("user_id") - cid = data.get("cid") + # 单个删除 + user_id = data.get("user_id") + cid = data.get("cid") - if not user_id or not cid: - return Response().error("缺少必要参数: user_id 和 cid").__dict__ + if not user_id or not cid: + return Response().error("缺少必要参数: user_id 和 cid").__dict__ - await self.core_lifecycle.conversation_manager.delete_conversation( - unified_msg_origin=user_id, conversation_id=cid - ) - return Response().ok({"message": "对话删除成功"}).__dict__ + await self.core_lifecycle.conversation_manager.delete_conversation( + unified_msg_origin=user_id, + conversation_id=cid, + ) + return Response().ok({"message": "对话删除成功"}).__dict__ except Exception as e: - logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"删除对话失败: {str(e)}").__dict__ + logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"删除对话失败: {e!s}").__dict__ async def update_history(self): """更新对话历史内容""" @@ -260,7 +264,8 @@ class ConversationRoute(Route): ) conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, conversation_id=cid + unified_msg_origin=user_id, + conversation_id=cid, ) if not conversation: return Response().error("对话不存在").__dict__ @@ -268,11 +273,13 @@ class ConversationRoute(Route): history = json.loads(history) if isinstance(history, str) else history await self.conv_mgr.update_conversation( - unified_msg_origin=user_id, conversation_id=cid, history=history + unified_msg_origin=user_id, + conversation_id=cid, + history=history, ) return Response().ok({"message": "对话历史更新成功"}).__dict__ except Exception as e: - logger.error(f"更新对话历史失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"更新对话历史失败: {str(e)}").__dict__ + logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"更新对话历史失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/file.py b/astrbot/dashboard/routes/file.py index 8ea73d084..71d867fe1 100644 --- a/astrbot/dashboard/routes/file.py +++ b/astrbot/dashboard/routes/file.py @@ -1,8 +1,10 @@ -from .route import Route, RouteContext -from astrbot import logger from quart import abort, send_file + +from astrbot import logger from astrbot.core import file_token_service +from .route import Route, RouteContext + class FileRoute(Route): def __init__( diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index d8d0434d1..b4e21382a 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -1,17 +1,20 @@ """知识库管理 API 路由""" -import uuid -import aiofiles +import asyncio import os import traceback -import asyncio +import uuid + +import aiofiles from quart import request + from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from .route import Route, Response, RouteContext -from ..utils import generate_tsne_visualization from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider +from ..utils import generate_tsne_visualization +from .route import Response, Route, RouteContext + class KnowledgeBaseRoute(Route): """知识库管理路由 @@ -108,7 +111,7 @@ class KnowledgeBaseRoute(Route): "stage": "parsing", "current": 0, "total": 100, - } + }, ) # 创建进度回调函数 @@ -122,7 +125,7 @@ class KnowledgeBaseRoute(Route): "stage": stage, "current": current, "total": total, - } + }, ) doc = await kb_helper.upload_document( @@ -141,7 +144,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") failed_docs.append( - {"file_name": file_info["file_name"], "error": str(e)} + {"file_name": file_info["file_name"], "error": str(e)}, ) # 更新任务完成状态 @@ -202,7 +205,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取知识库列表失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取知识库列表失败: {str(e)}").__dict__ + return Response().error(f"获取知识库列表失败: {e!s}").__dict__ async def create_kb(self): """创建知识库 @@ -240,7 +243,7 @@ class KnowledgeBaseRoute(Route): if not embedding_provider_id: return Response().error("缺少参数 embedding_provider_id").__dict__ prv = await kb_manager.provider_manager.get_provider_by_id( - embedding_provider_id + embedding_provider_id, ) # type: ignore if not prv or not isinstance(prv, EmbeddingProvider): return ( @@ -250,15 +253,15 @@ class KnowledgeBaseRoute(Route): vec = await prv.get_embedding("astrbot") if len(vec) != prv.get_dim(): raise ValueError( - f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}" + f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}", ) except Exception as e: - return Response().error(f"测试嵌入模型失败: {str(e)}").__dict__ + return Response().error(f"测试嵌入模型失败: {e!s}").__dict__ # pre-check rerank if rerank_provider_id: rerank_prv: RerankProvider = ( await kb_manager.provider_manager.get_provider_by_id( - rerank_provider_id + rerank_provider_id, ) ) # type: ignore if not rerank_prv: @@ -266,14 +269,15 @@ class KnowledgeBaseRoute(Route): # 检查重排序模型可用性 try: res = await rerank_prv.rerank( - query="astrbot", documents=["astrbot knowledge base"] + query="astrbot", + documents=["astrbot knowledge base"], ) if not res: raise ValueError("重排序模型返回结果异常") except Exception as e: return ( Response() - .error(f"测试重排序模型失败: {str(e)},请检查控制台日志输出。") + .error(f"测试重排序模型失败: {e!s},请检查控制台日志输出。") .__dict__ ) @@ -298,7 +302,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"创建知识库失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"创建知识库失败: {str(e)}").__dict__ + return Response().error(f"创建知识库失败: {e!s}").__dict__ async def get_kb(self): """获取知识库详情 @@ -324,7 +328,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取知识库详情失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取知识库详情失败: {str(e)}").__dict__ + return Response().error(f"获取知识库详情失败: {e!s}").__dict__ async def update_kb(self): """更新知识库 @@ -404,7 +408,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"更新知识库失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"更新知识库失败: {str(e)}").__dict__ + return Response().error(f"更新知识库失败: {e!s}").__dict__ async def delete_kb(self): """删除知识库 @@ -431,7 +435,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"删除知识库失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"删除知识库失败: {str(e)}").__dict__ + return Response().error(f"删除知识库失败: {e!s}").__dict__ async def get_kb_stats(self): """获取知识库统计信息 @@ -466,7 +470,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取知识库统计失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取知识库统计失败: {str(e)}").__dict__ + return Response().error(f"获取知识库统计失败: {e!s}").__dict__ # ===== 文档管理 API ===== @@ -508,7 +512,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取文档列表失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取文档列表失败: {str(e)}").__dict__ + return Response().error(f"获取文档列表失败: {e!s}").__dict__ async def upload_document(self): """上传文档 @@ -597,7 +601,7 @@ class KnowledgeBaseRoute(Route): "file_name": file_name, "file_content": file_content, "file_type": file_type, - } + }, ) finally: # 清理临时文件 @@ -630,7 +634,7 @@ class KnowledgeBaseRoute(Route): batch_size=batch_size, tasks_limit=tasks_limit, max_retries=max_retries, - ) + ), ) return ( @@ -640,7 +644,7 @@ class KnowledgeBaseRoute(Route): "task_id": task_id, "file_count": len(files_to_upload), "message": "task created, processing in background", - } + }, ) .__dict__ ) @@ -650,7 +654,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"上传文档失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"上传文档失败: {str(e)}").__dict__ + return Response().error(f"上传文档失败: {e!s}").__dict__ async def get_upload_progress(self): """获取上传进度和结果 @@ -703,7 +707,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取上传进度失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取上传进度失败: {str(e)}").__dict__ + return Response().error(f"获取上传进度失败: {e!s}").__dict__ async def get_document(self): """获取文档详情 @@ -734,7 +738,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取文档详情失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取文档详情失败: {str(e)}").__dict__ + return Response().error(f"获取文档详情失败: {e!s}").__dict__ async def delete_document(self): """删除文档 @@ -766,7 +770,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"删除文档失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"删除文档失败: {str(e)}").__dict__ + return Response().error(f"删除文档失败: {e!s}").__dict__ async def delete_chunk(self): """删除文本块 @@ -801,7 +805,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"删除文本块失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"删除文本块失败: {str(e)}").__dict__ + return Response().error(f"删除文本块失败: {e!s}").__dict__ async def list_chunks(self): """获取块列表 @@ -827,7 +831,9 @@ class KnowledgeBaseRoute(Route): if not kb_helper: return Response().error("知识库不存在").__dict__ chunk_list = await kb_helper.get_chunks_by_doc_id( - doc_id=doc_id, offset=offset, limit=limit + doc_id=doc_id, + offset=offset, + limit=limit, ) return ( Response() @@ -837,7 +843,7 @@ class KnowledgeBaseRoute(Route): "page": page, "page_size": page_size, "total": await kb_helper.get_chunk_count_by_doc_id(doc_id), - } + }, ) .__dict__ ) @@ -846,7 +852,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"获取块列表失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"获取块列表失败: {str(e)}").__dict__ + return Response().error(f"获取块列表失败: {e!s}").__dict__ # ===== 检索 API ===== @@ -893,7 +899,9 @@ class KnowledgeBaseRoute(Route): if debug: try: img_base64 = await generate_tsne_visualization( - query, kb_names, kb_manager + query, + kb_names, + kb_manager, ) if img_base64: response_data["visualization"] = img_base64 @@ -909,7 +917,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"检索失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"检索失败: {str(e)}").__dict__ + return Response().error(f"检索失败: {e!s}").__dict__ # ===== 会话知识库配置 API ===== @@ -945,7 +953,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True) - return Response().error(f"获取会话知识库配置失败: {str(e)}").__dict__ + return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__ async def set_session_kb_config(self): """设置会话的知识库配置 @@ -1024,13 +1032,12 @@ class KnowledgeBaseRoute(Route): ) .__dict__ ) - else: - logger.error("[KB配置] 配置保存失败,验证不匹配") - return Response().error("配置保存失败").__dict__ + logger.error("[KB配置] 配置保存失败,验证不匹配") + return Response().error("配置保存失败").__dict__ except Exception as e: logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True) - return Response().error(f"设置会话知识库配置失败: {str(e)}").__dict__ + return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__ async def delete_session_kb_config(self): """删除会话的知识库配置 @@ -1062,4 +1069,4 @@ class KnowledgeBaseRoute(Route): except Exception as e: logger.error(f"删除会话知识库配置失败: {e}") logger.error(traceback.format_exc()) - return Response().error(f"删除会话知识库配置失败: {str(e)}").__dict__ + return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index e47f9d77c..eb02fdf40 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -1,8 +1,11 @@ import asyncio import json + from quart import make_response -from astrbot.core import logger, LogBroker -from .route import Route, RouteContext, Response + +from astrbot.core import LogBroker, logger + +from .route import Response, Route, RouteContext class LogRoute(Route): @@ -11,7 +14,9 @@ class LogRoute(Route): self.log_broker = log_broker self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"]) self.app.add_url_rule( - "/api/log-history", view_func=self.log_history, methods=["GET"] + "/api/log-history", + view_func=self.log_history, + methods=["GET"], ) async def log(self): @@ -55,7 +60,7 @@ class LogRoute(Route): .ok( data={ "logs": logs, - } + }, ) .__dict__ ) diff --git a/astrbot/dashboard/routes/persona.py b/astrbot/dashboard/routes/persona.py index 032471ee4..7ddb75f17 100644 --- a/astrbot/dashboard/routes/persona.py +++ b/astrbot/dashboard/routes/persona.py @@ -1,9 +1,12 @@ import traceback -from .route import Route, Response, RouteContext -from astrbot.core import logger + from quart import request -from astrbot.core.db import BaseDatabase + +from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase + +from .route import Response, Route, RouteContext class PersonaRoute(Route): @@ -46,13 +49,13 @@ class PersonaRoute(Route): else None, } for persona in personas - ] + ], ) .__dict__ ) except Exception as e: - logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"获取人格列表失败: {str(e)}").__dict__ + logger.error(f"获取人格列表失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取人格列表失败: {e!s}").__dict__ async def get_persona_detail(self): """获取指定人格的详细信息""" @@ -81,13 +84,13 @@ class PersonaRoute(Route): "updated_at": persona.updated_at.isoformat() if persona.updated_at else None, - } + }, ) .__dict__ ) except Exception as e: - logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"获取人格详情失败: {str(e)}").__dict__ + logger.error(f"获取人格详情失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"获取人格详情失败: {e!s}").__dict__ async def create_persona(self): """创建新人格""" @@ -136,15 +139,15 @@ class PersonaRoute(Route): if persona.updated_at else None, }, - } + }, ) .__dict__ ) except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: - logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"创建人格失败: {str(e)}").__dict__ + logger.error(f"创建人格失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"创建人格失败: {e!s}").__dict__ async def update_persona(self): """更新人格信息""" @@ -177,8 +180,8 @@ class PersonaRoute(Route): except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: - logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"更新人格失败: {str(e)}").__dict__ + logger.error(f"更新人格失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"更新人格失败: {e!s}").__dict__ async def delete_persona(self): """删除人格""" @@ -195,5 +198,5 @@ class PersonaRoute(Route): except ValueError as e: return Response().error(str(e)).__dict__ except Exception as e: - logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}") - return Response().error(f"删除人格失败: {str(e)}").__dict__ + logger.error(f"删除人格失败: {e!s}\n{traceback.format_exc()}") + return Response().error(f"删除人格失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 2df06dcbb..71b238dd7 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,24 +1,23 @@ -import traceback -import aiohttp -import os import json +import os +import ssl +import traceback from datetime import datetime -import ssl +import aiohttp import certifi - -from .route import Route, Response, RouteContext -from astrbot.core import logger, file_token_service from quart import request -from astrbot.core.star.star_manager import PluginManager + +from astrbot.core import DEMO_MODE, file_token_service, logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.filter.permission import PermissionTypeFilter from astrbot.core.star.filter.regex import RegexFilter -from astrbot.core.star.star_handler import EventType -from astrbot.core import DEMO_MODE +from astrbot.core.star.star_handler import EventType, star_handlers_registry +from astrbot.core.star.star_manager import PluginManager + +from .route import Response, Route, RouteContext class PluginRoute(Route): @@ -106,29 +105,33 @@ class PluginRoute(Route): for url in urls: try: - async with aiohttp.ClientSession( - trust_env=True, connector=connector - ) as session: - async with session.get(url) as response: - if response.status == 200: - remote_data = await response.json() + async with ( + aiohttp.ClientSession( + trust_env=True, + connector=connector, + ) as session, + session.get(url) as response, + ): + if response.status == 200: + remote_data = await response.json() - # 检查远程数据是否为空 - if not remote_data or ( - isinstance(remote_data, dict) and len(remote_data) == 0 - ): - logger.warning(f"远程插件市场数据为空: {url}") - continue # 继续尝试其他URL或使用缓存 + # 检查远程数据是否为空 + if not remote_data or ( + isinstance(remote_data, dict) and len(remote_data) == 0 + ): + logger.warning(f"远程插件市场数据为空: {url}") + continue # 继续尝试其他URL或使用缓存 - logger.info("成功获取远程插件市场数据") - # 获取最新的MD5并保存到缓存 - current_md5 = await self._get_remote_md5() - self._save_plugin_cache( - cache_file, remote_data, current_md5 - ) - return Response().ok(remote_data).__dict__ - else: - logger.error(f"请求 {url} 失败,状态码:{response.status}") + logger.info("成功获取远程插件市场数据") + # 获取最新的MD5并保存到缓存 + current_md5 = await self._get_remote_md5() + self._save_plugin_cache( + cache_file, + remote_data, + current_md5, + ) + return Response().ok(remote_data).__dict__ + logger.error(f"请求 {url} 失败,状态码:{response.status}") except Exception as e: logger.error(f"请求 {url} 失败,错误:{e}") @@ -165,7 +168,7 @@ class PluginRoute(Route): is_valid = cached_md5 == remote_md5 logger.debug( - f"插件数据MD5: 本地={cached_md5}, 远程={remote_md5}, 有效={is_valid}" + f"插件数据MD5: 本地={cached_md5}, 远程={remote_md5}, 有效={is_valid}", ) return is_valid @@ -179,18 +182,20 @@ class PluginRoute(Route): ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession( - trust_env=True, connector=connector - ) as session: - async with session.get( - "https://api.soulter.top/astrbot/plugins-md5" - ) as response: - if response.status == 200: - data = await response.json() - return data.get("md5", "") - else: - logger.error(f"获取MD5失败,状态码:{response.status}") - return "" + async with ( + aiohttp.ClientSession( + trust_env=True, + connector=connector, + ) as session, + session.get( + "https://api.soulter.top/astrbot/plugins-md5", + ) as response, + ): + if response.status == 200: + data = await response.json() + return data.get("md5", "") + logger.error(f"获取MD5失败,状态码:{response.status}") + return "" except Exception as e: logger.error(f"获取远程MD5失败: {e}") return "" @@ -204,7 +209,7 @@ class PluginRoute(Route): # 检查缓存是否有效 if "data" in cache_data and "timestamp" in cache_data: logger.debug( - f"加载缓存文件: {cache_file}, 缓存时间: {cache_data['timestamp']}" + f"加载缓存文件: {cache_file}, 缓存时间: {cache_data['timestamp']}", ) return cache_data["data"] except Exception as e: @@ -260,7 +265,7 @@ class PluginRoute(Route): "activated": plugin.activated, "online_vesion": "", "handlers": await self.get_plugin_handlers_info( - plugin.star_handler_full_names + plugin.star_handler_full_names, ), "display_name": plugin.display_name, "logo": f"/api/file/{logo_url}" if logo_url else None, @@ -279,13 +284,15 @@ class PluginRoute(Route): for handler_full_name in handler_full_names: info = {} handler = star_handlers_registry.star_handlers_map.get( - handler_full_name, None + handler_full_name, + None, ) if handler is None: continue info["event_type"] = handler.event_type.name info["event_type_h"] = self.translated_event_type.get( - handler.event_type, handler.event_type.name + handler.event_type, + handler.event_type.name, ) info["handler_full_name"] = handler.handler_full_name info["desc"] = handler.desc @@ -308,7 +315,7 @@ class PluginRoute(Route): info["cmd"] = filter.get_complete_command_names()[0] info["cmd"] = info["cmd"].strip() info["sub_command"] = filter.print_cmd_tree( - filter.sub_command_filters + filter.sub_command_filters, ) elif isinstance(filter, RegexFilter): info["type"] = "正则匹配" @@ -474,7 +481,8 @@ class PluginRoute(Route): return Response().error(f"插件 {plugin_name} 不存在").__dict__ plugin_dir = os.path.join( - self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name + self.plugin_manager.plugin_store_path, + plugin_obj.root_dir_name, ) if not os.path.isdir(plugin_dir): @@ -498,4 +506,4 @@ class PluginRoute(Route): ) except Exception as e: logger.error(f"/api/plugin/readme: {traceback.format_exc()}") - return Response().error(f"读取README文件失败: {str(e)}").__dict__ + return Response().error(f"读取README文件失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index ec455ce3d..1105b69a7 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,7 +1,9 @@ -from astrbot.core.config.astrbot_config import AstrBotConfig from dataclasses import dataclass + from quart import Quart +from astrbot.core.config.astrbot_config import AstrBotConfig + @dataclass class RouteContext: diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index 1d632171d..0b16c0949 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -47,7 +47,10 @@ class SessionManagementRoute(Route): # 获取活跃的会话数据(处于对话内的会话) sessions_data, total = await self.db_helper.get_session_conversations( - page, page_size, search_query, platform + page, + page_size, + search_query, + platform, ) provider_manager = self.core_lifecycle.provider_manager @@ -80,13 +83,13 @@ class SessionManagementRoute(Route): "stt_provider_id": None, "tts_provider_id": None, "session_enabled": SessionServiceManager.is_session_enabled( - session_id + session_id, ), "llm_enabled": SessionServiceManager.is_llm_enabled_for_session( - session_id + session_id, ), "tts_enabled": SessionServiceManager.is_tts_enabled_for_session( - session_id + session_id, ), "platform": session_id.split(":")[0] if ":" in session_id @@ -95,7 +98,7 @@ class SessionManagementRoute(Route): if session_id.count(":") >= 1 else "unknown", "session_name": SessionServiceManager.get_session_display_name( - session_id + session_id, ), "session_raw_name": session_id.split(":")[2] if session_id.count(":") >= 2 @@ -105,13 +108,16 @@ class SessionManagementRoute(Route): # 获取 provider 信息 chat_provider = provider_manager.get_using_provider( - provider_type=ProviderType.CHAT_COMPLETION, umo=session_id + provider_type=ProviderType.CHAT_COMPLETION, + umo=session_id, ) tts_provider = provider_manager.get_using_provider( - provider_type=ProviderType.TEXT_TO_SPEECH, umo=session_id + provider_type=ProviderType.TEXT_TO_SPEECH, + umo=session_id, ) stt_provider = provider_manager.get_using_provider( - provider_type=ProviderType.SPEECH_TO_TEXT, umo=session_id + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=session_id, ) if chat_provider: meta = chat_provider.meta() @@ -139,7 +145,7 @@ class SessionManagementRoute(Route): "name": meta.id, "model": meta.model, "type": meta.type, - } + }, ) available_stt_providers = [] @@ -151,7 +157,7 @@ class SessionManagementRoute(Route): "name": meta.id, "model": meta.model, "type": meta.type, - } + }, ) available_tts_providers = [] @@ -163,7 +169,7 @@ class SessionManagementRoute(Route): "name": meta.id, "model": meta.model, "type": meta.type, - } + }, ) result = { @@ -185,15 +191,15 @@ class SessionManagementRoute(Route): return Response().ok(result).__dict__ except Exception as e: - error_msg = f"获取会话列表失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"获取会话列表失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"获取会话列表失败: {str(e)}").__dict__ + return Response().error(f"获取会话列表失败: {e!s}").__dict__ async def _update_single_session_persona(self, session_id: str, persona_name: str): """更新单个会话的 persona 的内部方法""" conversation_manager = self.core_lifecycle.star_context.conversation_manager conversation_id = await conversation_manager.get_curr_conversation_id( - session_id + session_id, ) conv = None @@ -207,11 +213,16 @@ class SessionManagementRoute(Route): # 更新 persona await conversation_manager.update_conversation_persona_id( - session_id, persona_name + session_id, + persona_name, ) async def _handle_batch_operation( - self, session_ids: list, operation_func, operation_name: str, **kwargs + self, + session_ids: list, + operation_func, + operation_name: str, + **kwargs, ): """通用的批量操作处理方法""" success_count = 0 @@ -222,7 +233,7 @@ class SessionManagementRoute(Route): await operation_func(session_id, **kwargs) success_count += 1 except Exception as e: - logger.error(f"批量{operation_name} 会话 {session_id} 失败: {str(e)}") + logger.error(f"批量{operation_name} 会话 {session_id} 失败: {e!s}") error_sessions.append(session_id) if error_sessions: @@ -234,21 +245,20 @@ class SessionManagementRoute(Route): "success_count": success_count, "error_count": len(error_sessions), "error_sessions": error_sessions, - } + }, ) .__dict__ ) - else: - return ( - Response() - .ok( - { - "message": f"成功批量{operation_name} {success_count} 个会话", - "success_count": success_count, - } - ) - .__dict__ + return ( + Response() + .ok( + { + "message": f"成功批量{operation_name} {success_count} 个会话", + "success_count": success_count, + }, ) + .__dict__ + ) async def update_session_persona(self): """更新指定会话的 persona,支持批量操作""" @@ -271,29 +281,31 @@ class SessionManagementRoute(Route): "更新人格", persona_name=persona_name, ) - else: - session_id = data.get("session_id") - if not session_id: - return Response().error("缺少必要参数: session_id").__dict__ + session_id = data.get("session_id") + if not session_id: + return Response().error("缺少必要参数: session_id").__dict__ - await self._update_single_session_persona(session_id, persona_name) - return ( - Response() - .ok( - { - "message": f"成功更新会话 {session_id} 的人格为 {persona_name}" - } - ) - .__dict__ + await self._update_single_session_persona(session_id, persona_name) + return ( + Response() + .ok( + { + "message": f"成功更新会话 {session_id} 的人格为 {persona_name}", + }, ) + .__dict__ + ) except Exception as e: - error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话人格失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话人格失败: {str(e)}").__dict__ + return Response().error(f"更新会话人格失败: {e!s}").__dict__ async def _update_single_session_provider( - self, session_id: str, provider_id: str, provider_type_enum + self, + session_id: str, + provider_id: str, + provider_type_enum, ): """更新单个会话的 provider 的内部方法""" provider_manager = self.core_lifecycle.star_context.provider_manager @@ -344,28 +356,29 @@ class SessionManagementRoute(Route): provider_id=provider_id, provider_type_enum=provider_type_enum, ) - else: - session_id = data.get("session_id") - if not session_id: - return Response().error("缺少必要参数: session_id").__dict__ + session_id = data.get("session_id") + if not session_id: + return Response().error("缺少必要参数: session_id").__dict__ - await self._update_single_session_provider( - session_id, provider_id, provider_type_enum - ) - return ( - Response() - .ok( - { - "message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}" - } - ) - .__dict__ + await self._update_single_session_provider( + session_id, + provider_id, + provider_type_enum, + ) + return ( + Response() + .ok( + { + "message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}", + }, ) + .__dict__ + ) except Exception as e: - error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话提供商失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话提供商失败: {str(e)}").__dict__ + return Response().error(f"更新会话提供商失败: {e!s}").__dict__ async def get_session_plugins(self): """获取指定会话的插件配置信息""" @@ -384,7 +397,8 @@ class SessionManagementRoute(Route): if plugin.activated and not plugin.reserved: plugin_name = plugin.name or "" plugin_enabled = SessionPluginManager.is_plugin_enabled_for_session( - session_id, plugin_name + session_id, + plugin_name, ) all_plugins.append( @@ -393,7 +407,7 @@ class SessionManagementRoute(Route): "author": plugin.author, "desc": plugin.desc, "enabled": plugin_enabled, - } + }, ) return ( @@ -402,15 +416,15 @@ class SessionManagementRoute(Route): { "session_id": session_id, "plugins": all_plugins, - } + }, ) .__dict__ ) except Exception as e: - error_msg = f"获取会话插件配置失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"获取会话插件配置失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"获取会话插件配置失败: {str(e)}").__dict__ + return Response().error(f"获取会话插件配置失败: {e!s}").__dict__ async def update_session_plugin(self): """更新指定会话的插件启停状态""" @@ -448,7 +462,9 @@ class SessionManagementRoute(Route): # 使用 SessionPluginManager 更新插件状态 SessionPluginManager.set_plugin_status_for_session( - session_id, plugin_name, enabled + session_id, + plugin_name, + enabled, ) return ( @@ -459,15 +475,15 @@ class SessionManagementRoute(Route): "session_id": session_id, "plugin_name": plugin_name, "enabled": enabled, - } + }, ) .__dict__ ) except Exception as e: - error_msg = f"更新会话插件状态失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话插件状态失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__ + return Response().error(f"更新会话插件状态失败: {e!s}").__dict__ async def _update_single_session_llm(self, session_id: str, enabled: bool): """更新单个会话的LLM状态的内部方法""" @@ -495,28 +511,27 @@ class SessionManagementRoute(Route): enabled=enabled, ) return result - else: - session_id = data.get("session_id") - if not session_id: - return Response().error("缺少必要参数: session_id").__dict__ + session_id = data.get("session_id") + if not session_id: + return Response().error("缺少必要参数: session_id").__dict__ - await self._update_single_session_llm(session_id, enabled) - return ( - Response() - .ok( - { - "message": f"LLM已{'启用' if enabled else '禁用'}", - "session_id": session_id, - "llm_enabled": enabled, - } - ) - .__dict__ + await self._update_single_session_llm(session_id, enabled) + return ( + Response() + .ok( + { + "message": f"LLM已{'启用' if enabled else '禁用'}", + "session_id": session_id, + "llm_enabled": enabled, + }, ) + .__dict__ + ) except Exception as e: - error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话LLM状态失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__ + return Response().error(f"更新会话LLM状态失败: {e!s}").__dict__ async def _update_single_session_tts(self, session_id: str, enabled: bool): """更新单个会话的TTS状态的内部方法""" @@ -544,28 +559,27 @@ class SessionManagementRoute(Route): enabled=enabled, ) return result - else: - session_id = data.get("session_id") - if not session_id: - return Response().error("缺少必要参数: session_id").__dict__ + session_id = data.get("session_id") + if not session_id: + return Response().error("缺少必要参数: session_id").__dict__ - await self._update_single_session_tts(session_id, enabled) - return ( - Response() - .ok( - { - "message": f"TTS已{'启用' if enabled else '禁用'}", - "session_id": session_id, - "tts_enabled": enabled, - } - ) - .__dict__ + await self._update_single_session_tts(session_id, enabled) + return ( + Response() + .ok( + { + "message": f"TTS已{'启用' if enabled else '禁用'}", + "session_id": session_id, + "tts_enabled": enabled, + }, ) + .__dict__ + ) except Exception as e: - error_msg = f"更新会话TTS状态失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话TTS状态失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话TTS状态失败: {str(e)}").__dict__ + return Response().error(f"更新会话TTS状态失败: {e!s}").__dict__ async def update_session_name(self): """更新指定会话的自定义名称""" @@ -588,17 +602,17 @@ class SessionManagementRoute(Route): "session_id": session_id, "custom_name": custom_name, "display_name": SessionServiceManager.get_session_display_name( - session_id + session_id, ), - } + }, ) .__dict__ ) except Exception as e: - error_msg = f"更新会话名称失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话名称失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话名称失败: {str(e)}").__dict__ + return Response().error(f"更新会话名称失败: {e!s}").__dict__ async def update_session_status(self): """更新指定会话的整体启停状态""" @@ -623,15 +637,15 @@ class SessionManagementRoute(Route): "message": f"会话整体状态已更新为: {'启用' if session_enabled else '禁用'}", "session_id": session_id, "session_enabled": session_enabled, - } + }, ) .__dict__ ) except Exception as e: - error_msg = f"更新会话整体状态失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"更新会话整体状态失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"更新会话整体状态失败: {str(e)}").__dict__ + return Response().error(f"更新会话整体状态失败: {e!s}").__dict__ async def delete_session(self): """删除指定会话及其所有相关数据""" @@ -649,13 +663,13 @@ class SessionManagementRoute(Route): try: await conversation_manager.delete_conversations_by_user_id(session_id) except Exception as e: - logger.warning(f"删除会话 {session_id} 的对话失败: {str(e)}") + logger.warning(f"删除会话 {session_id} 的对话失败: {e!s}") # 2. 清除会话的偏好设置数据(清空该会话的所有配置) try: await sp.clear_async("umo", session_id) except Exception as e: - logger.warning(f"清除会话 {session_id} 的偏好设置失败: {str(e)}") + logger.warning(f"清除会话 {session_id} 的偏好设置失败: {e!s}") return ( Response() @@ -663,12 +677,12 @@ class SessionManagementRoute(Route): { "message": f"会话 {session_id} 及其相关所有对话数据已成功删除", "session_id": session_id, - } + }, ) .__dict__ ) except Exception as e: - error_msg = f"删除会话失败: {str(e)}\n{traceback.format_exc()}" + error_msg = f"删除会话失败: {e!s}\n{traceback.format_exc()}" logger.error(error_msg) - return Response().error(f"删除会话失败: {str(e)}").__dict__ + return Response().error(f"删除会话失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index d13eb802c..8df690cc2 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,17 +1,19 @@ -import traceback -import psutil -import time import threading +import time +import traceback + import aiohttp -from .route import Route, Response, RouteContext -from astrbot.core import logger +import psutil from quart import request + +from astrbot.core import DEMO_MODE, logger +from astrbot.core.config import VERSION from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.config import VERSION -from astrbot.core.utils.io import get_dashboard_version -from astrbot.core import DEMO_MODE from astrbot.core.db.migration.helper import check_migration_needed_v4 +from astrbot.core.utils.io import get_dashboard_version + +from .route import Response, Route, RouteContext class StatRoute(Route): @@ -70,7 +72,7 @@ class StatRoute(Route): "dashboard_version": await get_dashboard_version(), "change_pwd_hint": self.is_default_cred(), "need_migration": need_migration, - } + }, ) .__dict__ ) @@ -116,17 +118,17 @@ class StatRoute(Route): # 计算运行时长组件 running_time = self._get_running_time_components( - int(time.time()) - self.core_lifecycle.start_time + int(time.time()) - self.core_lifecycle.start_time, ) stat_dict.update( { "platform": self.db_helper.get_grouped_base_stats( - offset_sec + offset_sec, ).platform, "message_count": self.db_helper.get_total_message_count() or 0, "platform_count": len( - self.core_lifecycle.platform_manager.get_insts() + self.core_lifecycle.platform_manager.get_insts(), ), "plugin_count": len(plugins), "plugins": plugin_info, @@ -139,7 +141,7 @@ class StatRoute(Route): "cpu_percent": round(cpu_percent, 1), "thread_count": thread_count, "start_time": self.core_lifecycle.start_time, - } + }, ) return Response().ok(stat_dict).__dict__ @@ -148,9 +150,7 @@ class StatRoute(Route): return Response().error(e.__str__()).__dict__ async def test_ghproxy_connection(self): - """ - 测试 GitHub 代理连接是否可用。 - """ + """测试 GitHub 代理连接是否可用。""" try: data = await request.get_json() proxy_url: str = data.get("proxy_url") @@ -163,23 +163,23 @@ class StatRoute(Route): test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version" start_time = time.time() - async with aiohttp.ClientSession() as session: - async with session.get( - test_url, timeout=aiohttp.ClientTimeout(total=10) - ) as response: - if response.status == 200: - end_time = time.time() - _ = await response.text() - ret = { - "latency": round((end_time - start_time) * 1000, 2), - } - return Response().ok(data=ret).__dict__ - else: - return ( - Response() - .error(f"Failed. Status code: {response.status}") - .__dict__ - ) + async with ( + aiohttp.ClientSession() as session, + session.get( + test_url, + timeout=aiohttp.ClientTimeout(total=10), + ) as response, + ): + if response.status == 200: + end_time = time.time() + _ = await response.text() + ret = { + "latency": round((end_time - start_time) * 1000, 2), + } + return Response().ok(data=ret).__dict__ + return ( + Response().error(f"Failed. Status code: {response.status}").__dict__ + ) except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"Error: {str(e)}").__dict__ + return Response().error(f"Error: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index 04f87bc99..db70a8820 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -1,11 +1,13 @@ # astrbot/dashboard/routes/t2i.py from dataclasses import asdict + from quart import jsonify, request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.utils.t2i.template_manager import TemplateManager + from .route import Response, Route, RouteContext @@ -49,7 +51,7 @@ class T2iRoute(Route): try: active_template = self.config.get("t2i_active_template", "base") return jsonify( - asdict(Response().ok(data={"active_template": active_template})) + asdict(Response().ok(data={"active_template": active_template})), ) except Exception as e: logger.error("Error in get_active_template", exc_info=True) @@ -62,7 +64,7 @@ class T2iRoute(Route): try: content = self.manager.get_template(name) return jsonify( - asdict(Response().ok(data={"name": name, "content": content})) + asdict(Response().ok(data={"name": name, "content": content})), ) except FileNotFoundError: response = jsonify(asdict(Response().error("Template not found"))) @@ -81,7 +83,7 @@ class T2iRoute(Route): content = data.get("content") if not name or not content: response = jsonify( - asdict(Response().error("Name and content are required.")) + asdict(Response().error("Name and content are required.")), ) response.status_code = 400 return response @@ -91,15 +93,16 @@ class T2iRoute(Route): response = jsonify( asdict( Response().ok( - data={"name": name}, message="Template created successfully." - ) - ) + data={"name": name}, + message="Template created successfully.", + ), + ), ) response.status_code = 201 return response except FileExistsError: response = jsonify( - asdict(Response().error("Template with this name already exists.")) + asdict(Response().error("Template with this name already exists.")), ) response.status_code = 409 return response @@ -149,7 +152,7 @@ class T2iRoute(Route): name = name.strip() self.manager.delete_template(name) return jsonify( - asdict(Response().ok(message="Template deleted successfully.")) + asdict(Response().ok(message="Template deleted successfully.")), ) except FileNotFoundError: response = jsonify(asdict(Response().error("Template not found."))) @@ -189,7 +192,7 @@ class T2iRoute(Route): except FileNotFoundError: response = jsonify( - asdict(Response().error(f"模板 '{name}' 不存在,无法应用。")) + asdict(Response().error(f"模板 '{name}' 不存在,无法应用。")), ) response.status_code = 404 return response @@ -215,9 +218,9 @@ class T2iRoute(Route): return jsonify( asdict( Response().ok( - message="Default template has been reset and activated." - ) - ) + message="Default template has been reset and activated.", + ), + ), ) except FileNotFoundError as e: response = jsonify(asdict(Response().error(str(e)))) diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 8fd89919a..4fbc494d1 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -13,7 +13,9 @@ DEFAULT_MCP_CONFIG = {"mcpServers": {}} class ToolsRoute(Route): def __init__( - self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle @@ -64,7 +66,7 @@ class ToolsRoute(Route): return Response().ok(servers).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"获取 MCP 服务器列表失败: {str(e)}").__dict__ + return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__ async def add_mcp_server(self): try: @@ -105,23 +107,22 @@ class ToolsRoute(Route): if self.tool_mgr.save_mcp_config(config): try: await self.tool_mgr.enable_mcp_server( - name, server_config, timeout=30 + name, + server_config, + timeout=30, ) except TimeoutError: return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__ except Exception as e: logger.error(traceback.format_exc()) return ( - Response() - .error(f"启用 MCP 服务器 {name} 失败: {str(e)}") - .__dict__ + Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__ ) return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ - else: - return Response().error("保存配置失败").__dict__ + return Response().error("保存配置失败").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"添加 MCP 服务器失败: {str(e)}").__dict__ + return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__ async def update_mcp_server(self): try: @@ -139,7 +140,8 @@ class ToolsRoute(Route): # 获取活动状态 active = server_data.get( - "active", config["mcpServers"][name].get("active", True) + "active", + config["mcpServers"][name].get("active", True), ) # 创建新的配置对象 @@ -177,19 +179,21 @@ class ToolsRoute(Route): except TimeoutError as e: return ( Response() - .error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}") + .error(f"启用前停用 MCP 服务器时 {name} 超时: {e!s}") .__dict__ ) except Exception as e: logger.error(traceback.format_exc()) return ( Response() - .error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}") + .error(f"启用前停用 MCP 服务器时 {name} 失败: {e!s}") .__dict__ ) try: await self.tool_mgr.enable_mcp_server( - name, config["mcpServers"][name], timeout=30 + name, + config["mcpServers"][name], + timeout=30, ) except TimeoutError: return ( @@ -199,34 +203,30 @@ class ToolsRoute(Route): logger.error(traceback.format_exc()) return ( Response() - .error(f"启用 MCP 服务器 {name} 失败: {str(e)}") + .error(f"启用 MCP 服务器 {name} 失败: {e!s}") + .__dict__ + ) + # 如果要停用服务器 + elif name in self.tool_mgr.mcp_client_dict: + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError: + return ( + Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"停用 MCP 服务器 {name} 失败: {e!s}") .__dict__ ) - else: - # 如果要停用服务器 - if name in self.tool_mgr.mcp_client_dict: - try: - await self.tool_mgr.disable_mcp_server(name, timeout=10) - except TimeoutError: - return ( - Response() - .error(f"停用 MCP 服务器 {name} 超时。") - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return ( - Response() - .error(f"停用 MCP 服务器 {name} 失败: {str(e)}") - .__dict__ - ) return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__ - else: - return Response().error("保存配置失败").__dict__ + return Response().error("保存配置失败").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"更新 MCP 服务器失败: {str(e)}").__dict__ + return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__ async def delete_mcp_server(self): try: @@ -255,20 +255,17 @@ class ToolsRoute(Route): logger.error(traceback.format_exc()) return ( Response() - .error(f"停用 MCP 服务器 {name} 失败: {str(e)}") + .error(f"停用 MCP 服务器 {name} 失败: {e!s}") .__dict__ ) return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__ - else: - return Response().error("保存配置失败").__dict__ + return Response().error("保存配置失败").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__ + return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__ async def test_mcp_connection(self): - """ - 测试 MCP 服务器连接 - """ + """测试 MCP 服务器连接""" try: server_data = await request.json config = server_data.get("mcp_server_config", None) @@ -283,9 +280,8 @@ class ToolsRoute(Route): if len(keys) > 1: return Response().error("一次只能配置一个 MCP 服务器配置").__dict__ config = config["mcpServers"][keys[0]] - else: - if not config: - return Response().error("MCP 服务器配置不能为空").__dict__ + elif not config: + return Response().error("MCP 服务器配置不能为空").__dict__ tools_name = await self.tool_mgr.test_mcp_server_connection(config) return ( @@ -294,7 +290,7 @@ class ToolsRoute(Route): except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__ + return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__ async def get_tool_list(self): """获取所有注册的工具列表""" @@ -304,7 +300,7 @@ class ToolsRoute(Route): return Response().ok(data=tools_dict).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"获取工具列表失败: {str(e)}").__dict__ + return Response().error(f"获取工具列表失败: {e!s}").__dict__ async def toggle_tool(self): """启用或停用指定的工具""" @@ -320,18 +316,17 @@ class ToolsRoute(Route): try: ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) except ValueError as e: - return Response().error(f"启用工具失败: {str(e)}").__dict__ + return Response().error(f"启用工具失败: {e!s}").__dict__ else: ok = self.tool_mgr.deactivate_llm_tool(tool_name) if ok: return Response().ok(None, "操作成功。").__dict__ - else: - return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__ + return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"操作工具失败: {str(e)}").__dict__ + return Response().error(f"操作工具失败: {e!s}").__dict__ async def sync_provider(self): """同步 MCP 提供者配置""" @@ -348,4 +343,4 @@ class ToolsRoute(Route): return Response().ok(message="同步成功").__dict__ except Exception as e: logger.error(traceback.format_exc()) - return Response().error(f"同步失败: {str(e)}").__dict__ + return Response().error(f"同步失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index 426deb38a..b0520c315 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -1,13 +1,15 @@ import traceback -from .route import Route, Response, RouteContext + from quart import request -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.updator import AstrBotUpdator -from astrbot.core import logger, pip_installer -from astrbot.core.utils.io import download_dashboard, get_dashboard_version + +from astrbot.core import DEMO_MODE, logger, pip_installer from astrbot.core.config.default import VERSION -from astrbot.core import DEMO_MODE -from astrbot.core.db.migration.helper import do_migration_v4, check_migration_needed_v4 +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db.migration.helper import check_migration_needed_v4, do_migration_v4 +from astrbot.core.updator import AstrBotUpdator +from astrbot.core.utils.io import download_dashboard, get_dashboard_version + +from .route import Response, Route, RouteContext CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'} @@ -40,12 +42,14 @@ class UpdateRoute(Route): data = await request.json pim = data.get("platform_id_map", {}) await do_migration_v4( - self.core_lifecycle.db, pim, self.core_lifecycle.astrbot_config + self.core_lifecycle.db, + pim, + self.core_lifecycle.astrbot_config, ) return Response().ok(None, "迁移成功。").__dict__ except Exception as e: logger.error(f"迁移失败: {traceback.format_exc()}") - return Response().error(f"迁移失败: {str(e)}").__dict__ + return Response().error(f"迁移失败: {e!s}").__dict__ async def check_update(self): type_ = request.args.get("type", None) @@ -58,20 +62,19 @@ class UpdateRoute(Route): .ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv}) .__dict__ ) - else: - ret = await self.astrbot_updator.check_update(None, None, False) - return Response( - status="success", - message=str(ret) if ret is not None else "已经是最新版本了。", - data={ - "version": f"v{VERSION}", - "has_new_version": ret is not None, - "dashboard_version": dv, - "dashboard_has_new_version": bool(dv and dv != f"v{VERSION}"), - }, - ).__dict__ + ret = await self.astrbot_updator.check_update(None, None, False) + return Response( + status="success", + message=str(ret) if ret is not None else "已经是最新版本了。", + data={ + "version": f"v{VERSION}", + "has_new_version": ret is not None, + "dashboard_version": dv, + "dashboard_has_new_version": bool(dv and dv != f"v{VERSION}"), + }, + ).__dict__ except Exception as e: - logger.warning(f"检查更新失败: {str(e)} (不影响除项目更新外的正常使用)") + logger.warning(f"检查更新失败: {e!s} (不影响除项目更新外的正常使用)") return Response().error(e.__str__()).__dict__ async def get_releases(self): @@ -98,7 +101,9 @@ class UpdateRoute(Route): try: await self.astrbot_updator.update( - latest=latest, version=version, proxy=proxy + latest=latest, + version=version, + proxy=proxy, ) try: @@ -121,13 +126,12 @@ class UpdateRoute(Route): .__dict__ ) return ret, 200, CLEAR_SITE_DATA_HEADERS - else: - ret = ( - Response() - .ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。") - .__dict__ - ) - return ret, 200, CLEAR_SITE_DATA_HEADERS + ret = ( + Response() + .ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。") + .__dict__ + ) + return ret, 200, CLEAR_SITE_DATA_HEADERS except Exception as e: logger.error(f"/api/update_project: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 31507e2ce..775983ef8 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -39,7 +39,7 @@ class AstrBotDashboard: self.data_path = os.path.abspath(webui_dir) else: self.data_path = os.path.abspath( - os.path.join(get_astrbot_data_path(), "dist") + os.path.join(get_astrbot_data_path(), "dist"), ) self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") @@ -53,11 +53,15 @@ class AstrBotDashboard: logging.getLogger(self.app.name).removeHandler(default_handler) self.context = RouteContext(self.config, self.app) self.ur = UpdateRoute( - self.context, core_lifecycle.astrbot_updator, core_lifecycle + self.context, + core_lifecycle.astrbot_updator, + core_lifecycle, ) self.sr = StatRoute(self.context, db, core_lifecycle) self.pr = PluginRoute( - self.context, core_lifecycle, core_lifecycle.plugin_manager + self.context, + core_lifecycle, + core_lifecycle.plugin_manager, ) self.cr = ConfigRoute(self.context, core_lifecycle) self.lr = LogRoute(self.context, core_lifecycle.log_broker) @@ -68,7 +72,9 @@ class AstrBotDashboard: self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) self.file_route = FileRoute(self.context) self.session_management_route = SessionManagementRoute( - self.context, db, core_lifecycle + self.context, + db, + core_lifecycle, ) self.persona_route = PersonaRoute(self.context, db, core_lifecycle) self.t2i_route = T2iRoute(self.context, core_lifecycle) @@ -85,9 +91,7 @@ class AstrBotDashboard: self._init_jwt_secret() async def srv_plug_route(self, subpath, *args, **kwargs): - """ - 插件路由 - """ + """插件路由""" registered_web_apis = self.core_lifecycle.star_context.registered_web_apis for api in registered_web_apis: route, view_handler, methods, _ = api @@ -97,18 +101,17 @@ class AstrBotDashboard: async def auth_middleware(self): if not request.path.startswith("/api"): - return + return None allowed_endpoints = ["/api/auth/login", "/api/file"] if any(request.path.startswith(prefix) for prefix in allowed_endpoints): - return + return None # claim jwt token = request.headers.get("Authorization") if not token: r = jsonify(Response().error("未授权").__dict__) r.status_code = 401 return r - if token.startswith("Bearer "): - token = token[7:] + token = token.removeprefix("Bearer ") try: payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) g.username = payload["username"] @@ -122,9 +125,7 @@ class AstrBotDashboard: return r def check_port_in_use(self, port: int) -> bool: - """ - 跨平台检测端口是否被占用 - """ + """跨平台检测端口是否被占用""" try: # 创建 IPv4 TCP Socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -136,7 +137,7 @@ class AstrBotDashboard: # result 为 0 表示端口被占用 return result == 0 except Exception as e: - logger.warning(f"检查端口 {port} 时发生错误: {str(e)}") + logger.warning(f"检查端口 {port} 时发生错误: {e!s}") # 如果出现异常,保守起见认为端口可能被占用 return True @@ -157,10 +158,10 @@ class AstrBotDashboard: ] return "\n ".join(proc_info) except (psutil.NoSuchProcess, psutil.AccessDenied) as e: - return f"无法获取进程详细信息(可能需要管理员权限): {str(e)}" + return f"无法获取进程详细信息(可能需要管理员权限): {e!s}" return "未找到占用进程" except Exception as e: - return f"获取进程信息失败: {str(e)}" + return f"获取进程信息失败: {e!s}" def _init_jwt_secret(self): if not self.config.get("dashboard", {}).get("jwt_secret", None): @@ -182,13 +183,13 @@ class AstrBotDashboard: if not enable: logger.info("WebUI 已被禁用") - return + return None logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}") if host == "0.0.0.0": logger.info( - "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)" + "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)", ) if host not in ["localhost", "127.0.0.1"]: @@ -207,7 +208,7 @@ class AstrBotDashboard: f"请确保:\n" f"1. 没有其他 AstrBot 实例正在运行\n" f"2. 端口 {port} 没有被其他程序占用\n" - f"3. 如需使用其他端口,请修改配置文件" + f"3. 如需使用其他端口,请修改配置文件", ) raise Exception(f"端口 {port} 已被占用") @@ -226,7 +227,9 @@ class AstrBotDashboard: logger.info(display) return self.app.run_task( - host=host, port=port, shutdown_trigger=self.shutdown_trigger + host=host, + port=port, + shutdown_trigger=self.shutdown_trigger, ) async def shutdown_trigger(self): diff --git a/astrbot/dashboard/utils.py b/astrbot/dashboard/utils.py index 4bdaf43c4..b81faad06 100644 --- a/astrbot/dashboard/utils.py +++ b/astrbot/dashboard/utils.py @@ -2,14 +2,17 @@ import base64 import os import traceback from io import BytesIO + from astrbot.api import logger +from astrbot.core.db.vec_db.faiss_impl import FaissVecDB from astrbot.core.knowledge_base.kb_helper import KBHelper from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager -from astrbot.core.db.vec_db.faiss_impl import FaissVecDB async def generate_tsne_visualization( - query: str, kb_names: list[str], kb_manager: KnowledgeBaseManager + query: str, + kb_names: list[str], + kb_manager: KnowledgeBaseManager, ) -> str | None: """生成 t-SNE 可视化图片 @@ -20,18 +23,19 @@ async def generate_tsne_visualization( Returns: 图片路径或 None + """ try: import faiss - import numpy as np import matplotlib + import numpy as np matplotlib.use("Agg") # 使用非交互式后端 import matplotlib.pyplot as plt from sklearn.manifold import TSNE except ImportError as e: raise Exception( - "缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}" + "缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}", ) from e try: diff --git a/main.py b/main.py index be0d4f307..b453cdfb5 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,18 @@ -import os -import asyncio -import sys -import mimetypes import argparse -from astrbot.core.initial_loader import InitialLoader -from astrbot.core import db_helper -from astrbot.core import logger, LogManager, LogBroker +import asyncio +import mimetypes +import os +import sys +from pathlib import Path + +from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.config.default import VERSION -from astrbot.core.utils.io import download_dashboard, get_dashboard_version +from astrbot.core.initial_loader import InitialLoader from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_dashboard, get_dashboard_version # add parent path to sys.path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(Path(__file__).parent.as_posix()) logo_tmpl = r""" ___ _______.___________..______ .______ ______ .___________. @@ -46,8 +47,7 @@ async def check_dashboard_files(webui_dir: str | None = None): if os.path.exists(webui_dir): logger.info(f"使用指定的 WebUI 目录: {webui_dir}") return webui_dir - else: - logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。") + logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。") data_dist_path = os.path.join(get_astrbot_data_path(), "dist") if os.path.exists(data_dist_path): @@ -58,12 +58,12 @@ async def check_dashboard_files(webui_dir: str | None = None): logger.info("WebUI 版本已是最新。") else: logger.warning( - f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。" + f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。", ) return data_dist_path logger.info( - "开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。" + "开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。", ) try: @@ -79,7 +79,10 @@ async def check_dashboard_files(webui_dir: str | None = None): if __name__ == "__main__": parser = argparse.ArgumentParser(description="AstrBot") parser.add_argument( - "--webui-dir", type=str, help="指定 WebUI 静态文件目录路径", default=None + "--webui-dir", + type=str, + help="指定 WebUI 静态文件目录路径", + default=None, ) args = parser.parse_args() diff --git a/packages/astrbot/commands/__init__.py b/packages/astrbot/commands/__init__.py index 995022a14..8f1f9bafa 100644 --- a/packages/astrbot/commands/__init__.py +++ b/packages/astrbot/commands/__init__.py @@ -1,31 +1,31 @@ # Commands module +from .admin import AdminCommands +from .alter_cmd import AlterCmdCommands +from .conversation import ConversationCommands from .help import HelpCommand from .llm import LLMCommands -from .tool import ToolCommands -from .plugin import PluginCommands -from .admin import AdminCommands -from .conversation import ConversationCommands -from .provider import ProviderCommands from .persona import PersonaCommands -from .alter_cmd import AlterCmdCommands +from .plugin import PluginCommands +from .provider import ProviderCommands from .setunset import SetUnsetCommands -from .t2i import T2ICommand -from .tts import TTSCommand from .sid import SIDCommand +from .t2i import T2ICommand +from .tool import ToolCommands +from .tts import TTSCommand __all__ = [ + "AdminCommands", + "AlterCmdCommands", + "ConversationCommands", "HelpCommand", "LLMCommands", - "ToolCommands", - "PluginCommands", - "AdminCommands", - "ConversationCommands", - "ProviderCommands", "PersonaCommands", - "AlterCmdCommands", + "PluginCommands", + "ProviderCommands", + "SIDCommand", "SetUnsetCommands", "T2ICommand", "TTSCommand", - "SIDCommand", + "ToolCommands", ] diff --git a/packages/astrbot/commands/admin.py b/packages/astrbot/commands/admin.py index 4ea3188f1..2073f45a2 100644 --- a/packages/astrbot/commands/admin.py +++ b/packages/astrbot/commands/admin.py @@ -1,7 +1,7 @@ -import astrbot.api.star as star -from astrbot.api.event import AstrMessageEvent, MessageEventResult, MessageChain -from astrbot.core.utils.io import download_dashboard +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, MessageChain, MessageEventResult from astrbot.core.config.default import VERSION +from astrbot.core.utils.io import download_dashboard class AdminCommands: @@ -13,8 +13,8 @@ class AdminCommands: if not admin_id: event.set_result( MessageEventResult().message( - "使用方法: /op 授权管理员;/deop 取消管理员。可通过 /sid 获取 ID。" - ) + "使用方法: /op 授权管理员;/deop 取消管理员。可通过 /sid 获取 ID。", + ), ) return self.context.get_config()["admins_id"].append(str(admin_id)) @@ -26,8 +26,8 @@ class AdminCommands: if not admin_id: event.set_result( MessageEventResult().message( - "使用方法: /deop 取消管理员。可通过 /sid 获取 ID。" - ) + "使用方法: /deop 取消管理员。可通过 /sid 获取 ID。", + ), ) return try: @@ -36,7 +36,7 @@ class AdminCommands: event.set_result(MessageEventResult().message("取消授权成功。")) except ValueError: event.set_result( - MessageEventResult().message("此用户 ID 不在管理员名单内。") + MessageEventResult().message("此用户 ID 不在管理员名单内。"), ) async def wl(self, event: AstrMessageEvent, sid: str = ""): @@ -44,8 +44,8 @@ class AdminCommands: if not sid: event.set_result( MessageEventResult().message( - "使用方法: /wl 添加白名单;/dwl 删除白名单。可通过 /sid 获取 ID。" - ) + "使用方法: /wl 添加白名单;/dwl 删除白名单。可通过 /sid 获取 ID。", + ), ) return cfg = self.context.get_config(umo=event.unified_msg_origin) @@ -58,8 +58,8 @@ class AdminCommands: if not sid: event.set_result( MessageEventResult().message( - "使用方法: /dwl 删除白名单。可通过 /sid 获取 ID。" - ) + "使用方法: /dwl 删除白名单。可通过 /sid 获取 ID。", + ), ) return try: diff --git a/packages/astrbot/commands/alter_cmd.py b/packages/astrbot/commands/alter_cmd.py index 18d6c1305..50007f6c0 100644 --- a/packages/astrbot/commands/alter_cmd.py +++ b/packages/astrbot/commands/alter_cmd.py @@ -1,11 +1,12 @@ -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.core.utils.command_parser import CommandParserMixin -from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata -from astrbot.core.star.star import star_map from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.filter.permission import PermissionTypeFilter +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry +from astrbot.core.utils.command_parser import CommandParserMixin + from .utils.rst_scene import RstScene @@ -34,8 +35,8 @@ class AlterCmdCommands(CommandParserMixin): "格式: /alter_cmd \n" "例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n" "例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n" - "/alter_cmd reset config 打开 reset 权限配置" - ) + "/alter_cmd reset config 打开 reset 权限配置", + ), ) return @@ -75,13 +76,13 @@ class AlterCmdCommands(CommandParserMixin): if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3: await event.send( - MessageChain().message("场景编号必须是 1-3 之间的数字") + MessageChain().message("场景编号必须是 1-3 之间的数字"), ) return if perm_type not in ["admin", "member"]: await event.send( - MessageChain().message("权限类型错误,只能是 admin 或 member") + MessageChain().message("权限类型错误,只能是 admin 或 member"), ) return @@ -93,14 +94,14 @@ class AlterCmdCommands(CommandParserMixin): await event.send( MessageChain().message( - f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}" - ) + f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}", + ), ) return if cmd_type not in ["admin", "member"]: await event.send( - MessageChain().message("指令类型错误,可选类型有 admin, member") + MessageChain().message("指令类型错误,可选类型有 admin, member"), ) return @@ -144,29 +145,29 @@ class AlterCmdCommands(CommandParserMixin): for filter_ in found_command.event_filters: if isinstance(filter_, PermissionTypeFilter): if cmd_type == "admin": - import astrbot.api.event.filter as filter + from astrbot.api.event import filter filter_.permission_type = filter.PermissionType.ADMIN else: - import astrbot.api.event.filter as filter + from astrbot.api.event import filter filter_.permission_type = filter.PermissionType.MEMBER found_permission_filter = True break if not found_permission_filter: - import astrbot.api.event.filter as filter + from astrbot.api.event import filter found_command.event_filters.insert( 0, PermissionTypeFilter( filter.PermissionType.ADMIN if cmd_type == "admin" - else filter.PermissionType.MEMBER + else filter.PermissionType.MEMBER, ), ) cmd_group_str = "指令组" if cmd_group else "指令" await event.send( MessageChain().message( - f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。" - ) + f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。", + ), ) diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py index 1a8ce746b..82b661773 100644 --- a/packages/astrbot/commands/conversation.py +++ b/packages/astrbot/commands/conversation.py @@ -1,14 +1,14 @@ import datetime -import astrbot.api.star as star + +from astrbot.api import logger, sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.sources.dify_source import ProviderDify from astrbot.core.provider.sources.coze_source import ProviderCoze -from astrbot.api import sp, logger +from astrbot.core.provider.sources.dify_source import ProviderDify + from ..long_term_memory import LongTermMemory from .utils.rst_scene import RstScene -from typing import Union class ConversationCommands: @@ -18,12 +18,13 @@ class ConversationCommands: async def _get_current_persona_id(self, session_id): curr = await self.context.conversation_manager.get_curr_conversation_id( - session_id + session_id, ) if not curr: return None conv = await self.context.conversation_manager.get_conversation( - session_id, curr + session_id, + curr, ) return conv.persona_id @@ -37,7 +38,6 @@ class ConversationCommands: async def reset(self, message: AstrMessageEvent): """重置 LLM 会话""" - is_unique_session = self.context.get_config()["platform_settings"][ "unique_session" ] @@ -50,21 +50,22 @@ class ConversationCommands: reset_cfg = plugin_config.get("reset", {}) required_perm = reset_cfg.get( - scene.key, "admin" if is_group and not is_unique_session else "member" + scene.key, + "admin" if is_group and not is_unique_session else "member", ) if required_perm == "admin" and message.role != "admin": message.set_result( MessageEventResult().message( f"在{scene.name}场景下,reset命令需要管理员权限," - f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。" - ) + f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。", + ), ) return if not self.context.get_using_provider(message.unified_msg_origin): message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return @@ -76,25 +77,27 @@ class ConversationCommands: await provider.forget(message.unified_msg_origin) message.set_result( MessageEventResult().message( - "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。" - ) + "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。", + ), ) return cid = await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin + message.unified_msg_origin, ) if not cid: message.set_result( MessageEventResult().message( - "当前未处于对话状态,请 /switch 切换或者 /new 创建。" - ) + "当前未处于对话状态,请 /switch 切换或者 /new 创建。", + ), ) return await self.context.conversation_manager.update_conversation( - message.unified_msg_origin, cid, [] + message.unified_msg_origin, + cid, + [], ) ret = "清除会话 LLM 聊天历史成功。" @@ -108,7 +111,7 @@ class ConversationCommands: """查看对话记录""" if not self.context.get_using_provider(message.unified_msg_origin): message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return @@ -120,11 +123,15 @@ class ConversationCommands: if not session_curr_cid: session_curr_cid = await conv_mgr.new_conversation( - umo, message.get_platform_id() + umo, + message.get_platform_id(), ) contexts, total_pages = await conv_mgr.get_human_readable_context( - umo, session_curr_cid, page, size_per_page + umo, + session_curr_cid, + page, + size_per_page, ) history = "" @@ -144,7 +151,6 @@ class ConversationCommands: async def convs(self, message: AstrMessageEvent, page: int = 1): """查看对话列表""" - provider = self.context.get_using_provider(message.unified_msg_origin) if provider and provider.meta().type == "dify": """原有的Dify处理逻辑保持不变""" @@ -154,7 +160,7 @@ class ConversationCommands: idx = 1 for conv in data["data"]: ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime( - "%m-%d %H:%M" + "%m-%d %H:%M", ) ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" idx += 1 @@ -168,7 +174,7 @@ class ConversationCommands: size_per_page = 6 """获取所有对话列表""" conversations_all = await self.context.conversation_manager.get_conversations( - message.unified_msg_origin + message.unified_msg_origin, ) """计算总页数""" total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page @@ -194,7 +200,7 @@ class ConversationCommands: persona_id = conv.persona_id if not persona_id or persona_id == "[%None]": persona = await self.context.persona_manager.get_default_persona_v3( - umo=message.unified_msg_origin + umo=message.unified_msg_origin, ) persona_id = persona["name"] title = _titles.get(conv.cid, "新对话") @@ -203,7 +209,7 @@ class ConversationCommands: ret += "---\n" curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin + message.unified_msg_origin, ) if curr_cid: """从所有对话的标题字典中获取标题""" @@ -227,9 +233,7 @@ class ConversationCommands: return async def new_conv(self, message: AstrMessageEvent): - """ - 创建新对话 - """ + """创建新对话""" provider = self.context.get_using_provider(message.unified_msg_origin) if provider and provider.meta().type in ["dify", "coze"]: assert isinstance(provider, (ProviderDify, ProviderCoze)), ( @@ -237,13 +241,15 @@ class ConversationCommands: ) await provider.forget(message.unified_msg_origin) message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。") + MessageEventResult().message("成功,下次聊天将是新对话。"), ) return cpersona = await self._get_current_persona_id(message.unified_msg_origin) cid = await self.context.conversation_manager.new_conversation( - message.unified_msg_origin, message.get_platform_id(), persona_id=cpersona + message.unified_msg_origin, + message.get_platform_id(), + persona_id=cpersona, ) # 长期记忆 @@ -254,7 +260,7 @@ class ConversationCommands: logger.error(f"清理聊天增强记录失败: {e}") message.set_result( - MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。") + MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"), ) async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): @@ -266,7 +272,7 @@ class ConversationCommands: ) await provider.forget(message.unified_msg_origin) message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。") + MessageEventResult().message("成功,下次聊天将是新对话。"), ) return if sid: @@ -275,31 +281,34 @@ class ConversationCommands: platform_name=message.platform_meta.id, message_type=MessageType("GroupMessage"), session_id=sid, - ) + ), ) cpersona = await self._get_current_persona_id(session) cid = await self.context.conversation_manager.new_conversation( - session, message.get_platform_id(), persona_id=cpersona + session, + message.get_platform_id(), + persona_id=cpersona, ) message.set_result( MessageEventResult().message( - f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。" - ) + f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。", + ), ) else: message.set_result( - MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。") + MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。"), ) async def switch_conv( - self, message: AstrMessageEvent, index: Union[int, None] = None + self, + message: AstrMessageEvent, + index: int | None = None, ): """通过 /ls 前面的序号切换对话""" - if not isinstance(index, int): message.set_result( - MessageEventResult().message("类型错误,请输入数字对话序号。") + MessageEventResult().message("类型错误,请输入数字对话序号。"), ) return @@ -316,7 +325,7 @@ class ConversationCommands: selected_conv = data["data"][index - 1] except IndexError: message.set_result( - MessageEventResult().message("对话序号错误,请使用 /ls 查看") + MessageEventResult().message("对话序号错误,请使用 /ls 查看"), ) return else: @@ -331,27 +340,28 @@ class ConversationCommands: if index is None: message.set_result( MessageEventResult().message( - "请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话" - ) + "请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话", + ), ) return conversations = await self.context.conversation_manager.get_conversations( - message.unified_msg_origin + message.unified_msg_origin, ) if index > len(conversations) or index < 1: message.set_result( - MessageEventResult().message("对话序号错误,请使用 /ls 查看") + MessageEventResult().message("对话序号错误,请使用 /ls 查看"), ) else: conversation = conversations[index - 1] title = conversation.title if conversation.title else "新对话" await self.context.conversation_manager.switch_conversation( - message.unified_msg_origin, conversation.cid + message.unified_msg_origin, + conversation.cid, ) message.set_result( MessageEventResult().message( - f"切换到对话: {title}({conversation.cid[:4]})。" - ) + f"切换到对话: {title}({conversation.cid[:4]})。", + ), ) async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): @@ -373,7 +383,8 @@ class ConversationCommands: return await self.context.conversation_manager.update_conversation_title( - message.unified_msg_origin, new_name + message.unified_msg_origin, + new_name, ) message.set_result(MessageEventResult().message("重命名对话成功。")) @@ -386,8 +397,8 @@ class ConversationCommands: # 群聊,没开独立会话,发送人不是管理员 message.set_result( MessageEventResult().message( - f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。" - ) + f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。", + ), ) return @@ -397,31 +408,33 @@ class ConversationCommands: dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None) if dify_cid: await provider.api_client.delete_chat_conv( - message.unified_msg_origin, dify_cid + message.unified_msg_origin, + dify_cid, ) message.set_result( MessageEventResult().message( - "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" - ) + "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。", + ), ) return session_curr_cid = ( await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin + message.unified_msg_origin, ) ) if not session_curr_cid: message.set_result( MessageEventResult().message( - "当前未处于对话状态,请 /switch 序号 切换或 /new 创建。" - ) + "当前未处于对话状态,请 /switch 序号 切换或 /new 创建。", + ), ) return await self.context.conversation_manager.delete_conversation( - message.unified_msg_origin, session_curr_cid + message.unified_msg_origin, + session_curr_cid, ) ret = "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" diff --git a/packages/astrbot/commands/help.py b/packages/astrbot/commands/help.py index de192ce3d..7f5b6c170 100644 --- a/packages/astrbot/commands/help.py +++ b/packages/astrbot/commands/help.py @@ -1,5 +1,6 @@ import aiohttp -import astrbot.api.star as star + +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.config.default import VERSION from astrbot.core.utils.io import get_dashboard_version @@ -13,7 +14,8 @@ class HelpCommand: try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - "https://astrbot.app/notice.json", timeout=2 + "https://astrbot.app/notice.json", + timeout=2, ) as resp: return (await resp.json())["notice"] except BaseException: diff --git a/packages/astrbot/commands/llm.py b/packages/astrbot/commands/llm.py index 51f8d9923..85977df40 100644 --- a/packages/astrbot/commands/llm.py +++ b/packages/astrbot/commands/llm.py @@ -1,4 +1,4 @@ -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageChain diff --git a/packages/astrbot/commands/persona.py b/packages/astrbot/commands/persona.py index 9971df6f0..53582ce8e 100644 --- a/packages/astrbot/commands/persona.py +++ b/packages/astrbot/commands/persona.py @@ -1,5 +1,6 @@ import builtins -import astrbot.api.star as star + +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult @@ -14,7 +15,7 @@ class PersonaCommands: curr_persona_name = "无" cid = await self.context.conversation_manager.get_curr_conversation_id(umo) default_persona = await self.context.persona_manager.get_default_persona_v3( - umo=umo + umo=umo, ) curr_cid_title = "无" if cid: @@ -26,8 +27,8 @@ class PersonaCommands: if conv is None: message.set_result( MessageEventResult().message( - "当前对话不存在,请先使用 /new 新建一个对话。" - ) + "当前对话不存在,请先使用 /new 新建一个对话。", + ), ) return if not conv.persona_id and conv.persona_id != "[%None]": @@ -53,9 +54,9 @@ class PersonaCommands: 当前对话 {curr_cid_title} 的人格情景: {curr_persona_name} 配置人格情景请前往管理面板-配置页 -""" +""", ) - .use_t2i(False) + .use_t2i(False), ) elif l[1] == "list": msg = "人格列表:\n" @@ -83,11 +84,12 @@ class PersonaCommands: elif l[1] == "unset": if not cid: message.set_result( - MessageEventResult().message("当前没有对话,无法取消人格。") + MessageEventResult().message("当前没有对话,无法取消人格。"), ) return await self.context.conversation_manager.update_conversation_persona_id( - message.unified_msg_origin, "[%None]" + message.unified_msg_origin, + "[%None]", ) message.set_result(MessageEventResult().message("取消人格成功。")) else: @@ -95,8 +97,8 @@ class PersonaCommands: if not cid: message.set_result( MessageEventResult().message( - "当前没有对话,请先开始对话或使用 /new 创建一个对话。" - ) + "当前没有对话,请先开始对话或使用 /new 创建一个对话。", + ), ) return if persona := next( @@ -107,16 +109,17 @@ class PersonaCommands: None, ): await self.context.conversation_manager.update_conversation_persona_id( - message.unified_msg_origin, ps + message.unified_msg_origin, + ps, ) message.set_result( MessageEventResult().message( - "设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。" - ) + "设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。", + ), ) else: message.set_result( MessageEventResult().message( - "不存在该人格情景。使用 /persona list 查看所有。" - ) + "不存在该人格情景。使用 /persona list 查看所有。", + ), ) diff --git a/packages/astrbot/commands/plugin.py b/packages/astrbot/commands/plugin.py index 8f705b417..f9092ff97 100644 --- a/packages/astrbot/commands/plugin.py +++ b/packages/astrbot/commands/plugin.py @@ -1,10 +1,10 @@ -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata +from astrbot.core import DEMO_MODE, logger from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry from astrbot.core.star.star_manager import PluginManager -from astrbot.core import DEMO_MODE, logger class PluginCommands: @@ -24,7 +24,7 @@ class PluginCommands: plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。" event.set_result( - MessageEventResult().message(f"{plugin_list_info}").use_t2i(False) + MessageEventResult().message(f"{plugin_list_info}").use_t2i(False), ) async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): @@ -34,7 +34,7 @@ class PluginCommands: return if not plugin_name: event.set_result( - MessageEventResult().message("/plugin off <插件名> 禁用插件。") + MessageEventResult().message("/plugin off <插件名> 禁用插件。"), ) return await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore @@ -47,7 +47,7 @@ class PluginCommands: return if not plugin_name: event.set_result( - MessageEventResult().message("/plugin on <插件名> 启用插件。") + MessageEventResult().message("/plugin on <插件名> 启用插件。"), ) return await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore @@ -60,7 +60,7 @@ class PluginCommands: return if not plugin_repo: event.set_result( - MessageEventResult().message("/plugin get <插件仓库地址> 安装插件") + MessageEventResult().message("/plugin get <插件仓库地址> 安装插件"), ) return logger.info(f"准备从 {plugin_repo} 安装插件。") @@ -78,7 +78,7 @@ class PluginCommands: """获取插件帮助""" if not plugin_name: event.set_result( - MessageEventResult().message("/plugin help <插件名> 查看插件信息。") + MessageEventResult().message("/plugin help <插件名> 查看插件信息。"), ) return plugin = self.context.get_registered_star(plugin_name) @@ -98,7 +98,7 @@ class PluginCommands: command_handlers.append(handler) command_names.append(filter_.command_name) break - elif isinstance(filter_, CommandGroupFilter): + if isinstance(filter_, CommandGroupFilter): command_handlers.append(handler) command_names.append(filter_.group_name) diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py index 85754d0b3..750e9de5a 100644 --- a/packages/astrbot/commands/provider.py +++ b/packages/astrbot/commands/provider.py @@ -1,6 +1,6 @@ import re -from typing import Union -import astrbot.api.star as star + +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType @@ -12,8 +12,8 @@ class ProviderCommands: async def provider( self, event: AstrMessageEvent, - idx: Union[str, int, None] = None, - idx2: Union[int, None] = None, + idx: str | int | None = None, + idx2: int | None = None, ): """查看或者切换 LLM Provider""" umo = event.unified_msg_origin @@ -62,32 +62,30 @@ class ProviderCommands: if idx2 is None: event.set_result(MessageEventResult().message("请输入序号。")) return - else: - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_tts_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.TEXT_TO_SPEECH, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_tts_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.TEXT_TO_SPEECH, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) elif idx == "stt": if idx2 is None: event.set_result(MessageEventResult().message("请输入序号。")) return - else: - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_stt_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) elif isinstance(idx, int): if idx > len(self.context.get_all_providers()) or idx < 1: event.set_result(MessageEventResult().message("无效的序号。")) @@ -104,13 +102,15 @@ class ProviderCommands: event.set_result(MessageEventResult().message("无效的参数。")) async def model_ls( - self, message: AstrMessageEvent, idx_or_name: Union[int, str, None] = None + self, + message: AstrMessageEvent, + idx_or_name: int | str | None = None, ): """查看或者切换模型""" prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return # 定义正则表达式匹配 API 密钥 @@ -125,7 +125,7 @@ class ProviderCommands: message.set_result( MessageEventResult() .message("获取模型列表失败: " + err_msg) - .use_t2i(False) + .use_t2i(False), ) return i = 1 @@ -139,42 +139,41 @@ class ProviderCommands: ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" message.set_result(MessageEventResult().message(ret).use_t2i(False)) - else: - if isinstance(idx_or_name, int): - models = [] + elif isinstance(idx_or_name, int): + models = [] + try: + models = await prov.get_models() + except BaseException as e: + message.set_result( + MessageEventResult().message("获取模型列表失败: " + str(e)), + ) + return + if idx_or_name > len(models) or idx_or_name < 1: + message.set_result(MessageEventResult().message("模型序号错误。")) + else: try: - models = await prov.get_models() + new_model = models[idx_or_name - 1] + prov.set_model(new_model) except BaseException as e: message.set_result( - MessageEventResult().message("获取模型列表失败: " + str(e)) + MessageEventResult().message("切换模型未知错误: " + str(e)), ) - return - if idx_or_name > len(models) or idx_or_name < 1: - message.set_result(MessageEventResult().message("模型序号错误。")) - else: - try: - new_model = models[idx_or_name - 1] - prov.set_model(new_model) - except BaseException as e: - message.set_result( - MessageEventResult().message("切换模型未知错误: " + str(e)) - ) - message.set_result( - MessageEventResult().message( - f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]" - ) - ) - else: - prov.set_model(idx_or_name) message.set_result( - MessageEventResult().message(f"切换模型到 {prov.get_model()}。") + MessageEventResult().message( + f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", + ), ) + else: + prov.set_model(idx_or_name) + message.set_result( + MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), + ) - async def key(self, message: AstrMessageEvent, index: Union[int, None] = None): + async def key(self, message: AstrMessageEvent, index: int | None = None): prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return @@ -200,6 +199,6 @@ class ProviderCommands: prov.set_key(new_key) except BaseException as e: message.set_result( - MessageEventResult().message(f"切换 Key 未知错误: {str(e)}") + MessageEventResult().message(f"切换 Key 未知错误: {e!s}"), ) message.set_result(MessageEventResult().message("切换 Key 成功。")) diff --git a/packages/astrbot/commands/setunset.py b/packages/astrbot/commands/setunset.py index a82fcdca3..79e5d5d1c 100644 --- a/packages/astrbot/commands/setunset.py +++ b/packages/astrbot/commands/setunset.py @@ -1,6 +1,5 @@ -import astrbot.api.star as star +from astrbot.api import sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.api import sp class SetUnsetCommands: @@ -16,8 +15,8 @@ class SetUnsetCommands: event.set_result( MessageEventResult().message( - f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。" - ) + f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。", + ), ) async def unset_variable(self, event: AstrMessageEvent, key: str): @@ -27,11 +26,11 @@ class SetUnsetCommands: if key not in session_var: event.set_result( - MessageEventResult().message("没有那个变量名。格式 /unset 变量名。") + MessageEventResult().message("没有那个变量名。格式 /unset 变量名。"), ) else: del session_var[key] await sp.session_put(uid, "session_variables", session_var) event.set_result( - MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。") + MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。"), ) diff --git a/packages/astrbot/commands/sid.py b/packages/astrbot/commands/sid.py index 101b22134..4d95c5a60 100644 --- a/packages/astrbot/commands/sid.py +++ b/packages/astrbot/commands/sid.py @@ -1,6 +1,6 @@ """会话ID命令""" -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult diff --git a/packages/astrbot/commands/t2i.py b/packages/astrbot/commands/t2i.py index 28c1d4eb6..7766b342f 100644 --- a/packages/astrbot/commands/t2i.py +++ b/packages/astrbot/commands/t2i.py @@ -1,6 +1,6 @@ """文本转图片命令""" -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult diff --git a/packages/astrbot/commands/tool.py b/packages/astrbot/commands/tool.py index 335ed5580..9a6c507e6 100644 --- a/packages/astrbot/commands/tool.py +++ b/packages/astrbot/commands/tool.py @@ -1,4 +1,4 @@ -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult @@ -9,23 +9,23 @@ class ToolCommands: async def tool_ls(self, event: AstrMessageEvent): """查看函数工具列表""" event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""): """启用一个函数工具""" event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""): """停用一个函数工具""" event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) async def tool_all_off(self, event: AstrMessageEvent): """停用所有函数工具""" event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) diff --git a/packages/astrbot/commands/tts.py b/packages/astrbot/commands/tts.py index a0102fb76..d733ba1ea 100644 --- a/packages/astrbot/commands/tts.py +++ b/packages/astrbot/commands/tts.py @@ -1,6 +1,6 @@ """文本转语音命令""" -import astrbot.api.star as star +from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.star.session_llm_manager import SessionServiceManager @@ -27,10 +27,10 @@ class TTSCommand: if new_status and not tts_enable: event.set_result( MessageEventResult().message( - f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。" - ) + f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。", + ), ) else: event.set_result( - MessageEventResult().message(f"{status_text}当前会话的文本转语音。") + MessageEventResult().message(f"{status_text}当前会话的文本转语音。"), ) diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index dc2484860..a686d35b2 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -1,13 +1,14 @@ import datetime -import uuid import random -import astrbot.api.star as star -from astrbot.api.event import AstrMessageEvent -from astrbot.api.platform import MessageType -from astrbot.api.provider import ProviderRequest, Provider -from astrbot.api.message_components import Plain, Image -from astrbot import logger +import uuid from collections import defaultdict + +from astrbot import logger +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent +from astrbot.api.message_components import Image, Plain +from astrbot.api.platform import MessageType +from astrbot.api.provider import Provider, ProviderRequest from astrbot.core.astrbot_config_mgr import AstrBotConfigManager """ @@ -66,7 +67,10 @@ class LongTermMemory: return cnt async def get_image_caption( - self, image_url: str, image_caption_provider_id: str, image_caption_prompt: str + self, + image_url: str, + image_caption_provider_id: str, + image_caption_prompt: str, ) -> str: if not image_caption_provider_id: provider = self.context.get_using_provider() diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 6fd0b0e5a..8b33b887d 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1,29 +1,28 @@ import traceback -import astrbot.api.star as star -import astrbot.api.event.filter as filter -from astrbot.api.event import AstrMessageEvent + +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, filter +from astrbot.api.message_components import Image, Plain from astrbot.api.provider import ProviderRequest -from astrbot.core.provider.sources.dify_source import ProviderDify -from .long_term_memory import LongTermMemory from astrbot.core import logger -from astrbot.api.message_components import Plain, Image -from typing import Union +from astrbot.core.provider.sources.dify_source import ProviderDify from .commands import ( + AdminCommands, + AlterCmdCommands, + ConversationCommands, HelpCommand, LLMCommands, - ToolCommands, - PluginCommands, - AdminCommands, - ConversationCommands, - ProviderCommands, PersonaCommands, - AlterCmdCommands, + PluginCommands, + ProviderCommands, SetUnsetCommands, - T2ICommand, - TTSCommand, SIDCommand, + T2ICommand, + ToolCommands, + TTSCommand, ) +from .long_term_memory import LongTermMemory from .process_llm_request import ProcessLLMRequest @@ -182,7 +181,9 @@ class Main(star.Star): @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("model") async def model_ls( - self, message: AstrMessageEvent, idx_or_name: Union[int, str, None] = None + self, + message: AstrMessageEvent, + idx_or_name: int | str | None = None, ): """查看或者切换模型""" await self.provider_c.model_ls(message, idx_or_name) @@ -199,9 +200,7 @@ class Main(star.Star): @filter.command("new") async def new_conv(self, message: AstrMessageEvent): - """ - 创建新对话 - """ + """创建新对话""" await self.conversation_c.new_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @@ -253,7 +252,6 @@ class Main(star.Star): @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) async def on_message(self, event: AstrMessageEvent): """群聊记忆增强""" - has_image_or_plain = False for comp in event.message_obj.message: if isinstance(comp, Plain) or isinstance(comp, Image): @@ -283,27 +281,29 @@ class Main(star.Star): conv = None if provider.meta().type != "dify": session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - event.unified_msg_origin + event.unified_msg_origin, ) if not session_curr_cid: logger.error( - "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。" + "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", ) return conv = await self.context.conversation_manager.get_conversation( - event.unified_msg_origin, session_curr_cid + event.unified_msg_origin, + session_curr_cid, ) else: # Dify 自己有维护对话,不需要 bot 端维护。 assert isinstance(provider, ProviderDify) cid = provider.conversation_ids.get( - event.unified_msg_origin, None + event.unified_msg_origin, + None, ) if cid is None: logger.error( - "[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。" + "[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", ) return diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index 8f17dd0dc..6d8c896f4 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -1,14 +1,13 @@ -import copy -import astrbot.api.star as star import builtins +import copy import datetime import zoneinfo -from astrbot.api import logger + +from astrbot.api import logger, star from astrbot.api.event import AstrMessageEvent -from astrbot.api.provider import Provider -from astrbot.api.provider import ProviderRequest -from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.api.message_components import Image, Reply +from astrbot.api.provider import Provider, ProviderRequest +from astrbot.core.provider.func_tool_manager import ToolSet class ProcessLLMRequest: @@ -64,11 +63,16 @@ class ProcessLLMRequest: logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}") async def _ensure_img_caption( - self, req: ProviderRequest, cfg: dict, img_cap_prov_id: str + self, + req: ProviderRequest, + cfg: dict, + img_cap_prov_id: str, ): try: caption = await self._request_img_caption( - img_cap_prov_id, cfg, req.image_urls + img_cap_prov_id, + cfg, + req.image_urls, ) if caption: req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}" @@ -77,12 +81,16 @@ class ProcessLLMRequest: logger.error(f"处理图片描述失败: {e}") async def _request_img_caption( - self, provider_id: str, cfg: dict, image_urls: list[str] + self, + provider_id: str, + cfg: dict, + image_urls: list[str], ) -> str: if prov := self.ctx.get_provider_by_id(provider_id): if isinstance(prov, Provider): img_cap_prompt = cfg.get( - "image_caption_prompt", "Please describe the image." + "image_caption_prompt", + "Please describe the image.", ) logger.debug(f"Processing image caption with provider: {provider_id}") llm_resp = await prov.text_chat( @@ -90,14 +98,12 @@ class ProcessLLMRequest: image_urls=image_urls, ) return llm_resp.completion_text - else: - raise ValueError( - f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}." - ) - else: raise ValueError( - f"Cannot get image caption because provider `{provider_id}` is not exist." + f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.", ) + raise ValueError( + f"Cannot get image caption because provider `{provider_id}` is not exist.", + ) async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest): """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 505b57903..35a2f2698 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -1,21 +1,21 @@ -import os -import json -import shutil -import aiohttp -import uuid import asyncio +import json +import os import re -import aiodocker +import shutil import time -import astrbot.api.star as star +import uuid from collections import defaultdict -from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.api import llm_tool, logger -from astrbot.api.event import filter + +import aiodocker +import aiohttp + +from astrbot.api import llm_tool, logger, star +from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter +from astrbot.api.message_components import File, Image from astrbot.api.provider import ProviderRequest -from astrbot.api.message_components import Image, File -from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_file, download_image_by_url PROMPT = """ ## Task @@ -120,23 +120,21 @@ class Main(star.Star): self.config = DEFAULT_CONFIG self._save_config() else: - with open(PATH, "r") as f: + with open(PATH) as f: self.config = json.load(f) async def initialize(self): ok = await self.is_docker_available() if not ok: logger.info( - "Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。" + "Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。", ) # await self.context._star_manager.turn_off_plugin( # "astrbot-python-interpreter" # ) async def file_upload(self, file_path: str): - """ - 上传图像文件到 S3 - """ + """上传图像文件到 S3""" ext = os.path.splitext(file_path)[1] S3_URL = "https://s3.neko.soulter.top/astrbot-s3" with open(file_path, "rb") as f: @@ -144,13 +142,16 @@ class Main(star.Star): s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}" - async with aiohttp.ClientSession( - headers={"Accept": "application/json"}, trust_env=True - ) as session: - async with session.put(s3_file_url, data=file) as resp: - if resp.status != 200: - raise Exception(f"Failed to upload image: {resp.status}") - return s3_file_url + async with ( + aiohttp.ClientSession( + headers={"Accept": "application/json"}, + trust_env=True, + ) as session, + session.put(s3_file_url, data=file) as resp, + ): + if resp.status != 200: + raise Exception(f"Failed to upload image: {resp.status}") + return s3_file_url async def is_docker_available(self) -> bool: """Check if docker is available""" @@ -177,7 +178,10 @@ class Main(star.Star): return uuid.uuid4().hex[:8] async def download_image( - self, image_url: str, workplace_path: str, filename: str + self, + image_url: str, + workplace_path: str, + filename: str, ) -> str: """Download image from url to workplace_path""" async with aiohttp.ClientSession(trust_env=True) as session: @@ -247,7 +251,7 @@ class Main(star.Star): """设置 Docker 宿主机绝对路径""" if not path: yield event.plain_result( - f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}" + f"当前 Docker 宿主机绝对路径: {self.config.get('docker_host_astrbot_abs_path', '')}", ) else: self.config["docker_host_astrbot_abs_path"] = path @@ -290,7 +294,7 @@ class Main(star.Star): await asyncio.sleep(60) if uid in self.user_waiting: yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。" + f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。", ) self.user_waiting.pop(uid) @@ -301,11 +305,11 @@ class Main(star.Star): if uid in self.user_waiting: self.user_waiting.pop(uid) yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 已清理。" + f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 已清理。", ) else: yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有等待上传文件。" + f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有等待上传文件。", ) @pi.command("list") @@ -315,11 +319,11 @@ class Main(star.Star): if uid in self.user_file_msg_buffer: files = self.user_file_msg_buffer[uid] yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 上传的文件: {files}" + f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 上传的文件: {files}", ) else: yield event.plain_result( - f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有上传文件。" + f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 没有上传文件。", ) @llm_tool("python_interpreter") @@ -373,11 +377,12 @@ class Main(star.Star): ) provider = self.context.get_using_provider() llm_response = await provider.text_chat( - prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}" + prompt=PROMPT_, + session_id=f"{event.session_id}_{magic_code}_{i!s}", ) logger.debug( - "code interpreter llm gened code:" + llm_response.completion_text + "code interpreter llm gened code:" + llm_response.completion_text, ) # 整理代码并保存 @@ -398,21 +403,25 @@ class Main(star.Star): await docker.images.pull(image_name) yield event.plain_result( - f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})" + f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})", ) self.docker_host_astrbot_abs_path = self.config.get( - "docker_host_astrbot_abs_path", "" + "docker_host_astrbot_abs_path", + "", ) if self.docker_host_astrbot_abs_path: host_shared = os.path.join( - self.docker_host_astrbot_abs_path, self.shared_path + self.docker_host_astrbot_abs_path, + self.shared_path, ) host_output = os.path.join( - self.docker_host_astrbot_abs_path, output_path + self.docker_host_astrbot_abs_path, + output_path, ) host_workplace = os.path.join( - self.docker_host_astrbot_abs_path, workplace_path + self.docker_host_astrbot_abs_path, + workplace_path, ) else: @@ -421,7 +430,7 @@ class Main(star.Star): host_workplace = os.path.abspath(workplace_path) logger.debug( - f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}" + f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}", ) container = await docker.containers.run( @@ -435,11 +444,11 @@ class Main(star.Star): f"{host_shared}:/astrbot_sandbox/shared:ro", f"{host_output}:/astrbot_sandbox/output:rw", f"{host_workplace}:/astrbot_sandbox:rw", - ] + ], }, "Env": [f"MAGIC_CODE={magic_code}"], "AutoRemove": True, - } + }, ) logger.debug(f"Container {container.id} created.") @@ -479,7 +488,7 @@ class Main(star.Star): obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code." else: logger.warning( - f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}" + f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}", ) break else: @@ -488,7 +497,7 @@ class Main(star.Star): return yield event.plain_result( - "经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。" + "经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。", ) @pi.command("cleanfile") @@ -504,7 +513,9 @@ class Main(star.Star): yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。") async def run_container( - self, container: aiodocker.docker.DockerContainer, timeout: int = 20 + self, + container: aiodocker.docker.DockerContainer, + timeout: int = 20, ) -> list[str]: """Run the container and get the output""" try: diff --git a/packages/reminder/main.py b/packages/reminder/main.py index e5fb1c864..0349b9eb4 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -1,13 +1,13 @@ -import os -import json import datetime +import json +import os import uuid import zoneinfo -import astrbot.api.star as star -from astrbot.api.event import filter + from apscheduler.schedulers.asyncio import AsyncIOScheduler -from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.api import llm_tool, logger + +from astrbot.api import llm_tool, logger, star +from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -31,7 +31,7 @@ class Main(star.Star): if not os.path.exists(reminder_file): with open(reminder_file, "w", encoding="utf-8") as f: f.write("{}") - with open(reminder_file, "r", encoding="utf-8") as f: + with open(reminder_file, encoding="utf-8") as f: self.reminder_data = json.load(f) self._init_scheduler() @@ -56,7 +56,8 @@ class Main(star.Star): trigger="date", args=[group, reminder], run_date=datetime.datetime.strptime( - reminder["datetime"], "%Y-%m-%d %H:%M" + reminder["datetime"], + "%Y-%m-%d %H:%M", ), misfire_grace_time=60, ) @@ -74,7 +75,8 @@ class Main(star.Star): """Check if the reminder is outdated.""" if "datetime" in reminder: reminder_time = datetime.datetime.strptime( - reminder["datetime"], "%Y-%m-%d %H:%M" + reminder["datetime"], + "%Y-%m-%d %H:%M", ).replace(tzinfo=self.timezone) return reminder_time < datetime.datetime.now(self.timezone) return False @@ -111,6 +113,7 @@ class Main(star.Star): datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder. Monday is 0 and Sunday is 6. human_readable_cron(string): Optional. The human readable cron expression of the reminder. + """ if event.get_platform_name() == "qq_official": yield event.plain_result("reminder 暂不支持 QQ 官方机器人。") @@ -121,7 +124,7 @@ class Main(star.Star): if not cron_expression and not datetime_str: raise ValueError( - "The cron_expression and datetime_str cannot be both None." + "The cron_expression and datetime_str cannot be both None.", ) reminder_time = "" @@ -150,7 +153,8 @@ class Main(star.Star): d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())} self.reminder_data[event.unified_msg_origin].append(d) datetime_scheduled = datetime.datetime.strptime( - datetime_str, "%Y-%m-%d %H:%M" + datetime_str, + "%Y-%m-%d %H:%M", ) self.scheduler.add_job( self._reminder_callback, @@ -167,13 +171,12 @@ class Main(star.Star): + text + "\n时间: " + reminder_time - + "\n\n使用 /reminder ls 查看所有待办事项。\n使用 /tool off reminder 关闭此功能。" + + "\n\n使用 /reminder ls 查看所有待办事项。\n使用 /tool off reminder 关闭此功能。", ) @filter.command_group("reminder") def reminder(self): """The command group of the reminder.""" - pass async def get_upcoming_reminders(self, unified_msg_origin: str): """Get upcoming reminders.""" @@ -186,7 +189,8 @@ class Main(star.Star): for reminder in reminders if "datetime" not in reminder or datetime.datetime.strptime( - reminder["datetime"], "%Y-%m-%d %H:%M" + reminder["datetime"], + "%Y-%m-%d %H:%M", ).replace(tzinfo=self.timezone) >= now ] @@ -233,7 +237,7 @@ class Main(star.Star): except Exception as e: logger.error(f"Remove job error: {e}") yield event.plain_result( - f"成功移除对应的待办事项。删除定时任务失败: {str(e)} 可能需要重启 AstrBot 以取消该提醒任务。" + f"成功移除对应的待办事项。删除定时任务失败: {e!s} 可能需要重启 AstrBot 以取消该提醒任务。", ) await self._save_data() yield event.plain_result("成功删除待办事项:\n" + reminder["text"]) @@ -248,7 +252,7 @@ class Main(star.Star): + d["text"] + "\n时间: " + d.get("datetime", "") - + d.get("cron_h", "") + + d.get("cron_h", ""), ), ) diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 86c8a24fb..4d4a42528 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -1,16 +1,17 @@ -import astrbot.api.message_components as Comp import copy +from sys import maxsize + +import astrbot.api.message_components as Comp from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, filter from astrbot.api.star import Context, Star from astrbot.core.utils.session_waiter import ( - SessionWaiter, - USER_SESSIONS, FILTERS, - session_waiter, + USER_SESSIONS, SessionController, + SessionWaiter, + session_waiter, ) -from sys import maxsize class Waiter(Star): @@ -52,13 +53,14 @@ class Waiter(Star): # 获取用户当前的对话信息 curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - event.unified_msg_origin + event.unified_msg_origin, ) conversation = None if curr_cid: conversation = await self.context.conversation_manager.get_conversation( - event.unified_msg_origin, curr_cid + event.unified_msg_origin, + curr_cid, ) else: # 创建新对话 @@ -81,16 +83,18 @@ class Waiter(Star): conversation=conversation, ) except Exception as e: - logger.error(f"LLM response failed: {str(e)}") + logger.error(f"LLM response failed: {e!s}") # LLM 回复失败,使用原始预设回复 yield event.plain_result("想要问什么呢?😄") @session_waiter(60) async def empty_mention_waiter( - controller: SessionController, event: AstrMessageEvent + controller: SessionController, + event: AstrMessageEvent, ): event.message_obj.message.insert( - 0, Comp.At(qq=event.get_self_id(), name=event.get_self_id()) + 0, + Comp.At(qq=event.get_self_id(), name=event.get_self_id()), ) new_event = copy.copy(event) # 重新推入事件队列 diff --git a/packages/thinking_filter/main.py b/packages/thinking_filter/main.py index 3d2729669..a3bc65d20 100644 --- a/packages/thinking_filter/main.py +++ b/packages/thinking_filter/main.py @@ -1,13 +1,14 @@ -import re import json import logging -from typing import Any, Tuple +import re +from typing import Any -from astrbot.api.event import filter, AstrMessageEvent -from astrbot.api.star import Context, Star -from astrbot.api.provider import LLMResponse from openai.types.chat.chat_completion import ChatCompletion +from astrbot.api.event import AstrMessageEvent, filter +from astrbot.api.provider import LLMResponse +from astrbot.api.star import Context, Star + try: # 谨慎引入,避免在未安装 google-genai 的环境下报错 from google.genai.types import GenerateContentResponse @@ -22,7 +23,8 @@ class R1Filter(Star): @filter.on_llm_response() async def resp(self, event: AstrMessageEvent, response: LLMResponse): cfg = self.context.get_config(umo=event.unified_msg_origin).get( - "provider_settings", {} + "provider_settings", + {}, ) show_reasoning = cfg.get("display_reasoning_text", False) @@ -30,10 +32,11 @@ class R1Filter(Star): # Gemini 可能在 parts 中注入 {"thought": true, "text": "..."} # 官方 SDK 默认不会返回此字段。 if GenerateContentResponse is not None and isinstance( - response.raw_completion, GenerateContentResponse + response.raw_completion, + GenerateContentResponse, ): thought_text, answer_text = self._extract_gemini_texts( - response.raw_completion + response.raw_completion, ) if thought_text or answer_text: @@ -46,11 +49,10 @@ class R1Filter(Star): if merged: response.completion_text = merged return - else: - # 默认隐藏思考内容,仅保留正文 - if answer_text: - response.completion_text = answer_text - return + # 默认隐藏思考内容,仅保留正文 + elif answer_text: + response.completion_text = answer_text + return # --- 非 Gemini 或无明确 thought:true 情况 --- if show_reasoning: @@ -88,7 +90,10 @@ class R1Filter(Star): if r"" in completion_text or r"" in completion_text: # 移除配对的标签及其内容 completion_text = re.sub( - r".*?", "", completion_text, flags=re.DOTALL + r".*?", + "", + completion_text, + flags=re.DOTALL, ).strip() # 移除可能残留的单个标签 @@ -126,7 +131,7 @@ class R1Filter(Star): continue except Exception as e: logging.exception( - f"Unexpected error when calling {getter} on {type(p).__name__}: {e}" + f"Unexpected error when calling {getter} on {type(p).__name__}: {e}", ) continue try: @@ -137,7 +142,7 @@ class R1Filter(Star): pass except Exception as e: logging.exception( - f"Unexpected error when accessing __dict__ on {type(p).__name__}: {e}" + f"Unexpected error when accessing __dict__ on {type(p).__name__}: {e}", ) return {} @@ -175,7 +180,7 @@ class R1Filter(Star): continue return False - def _extract_gemini_texts(self, resp: Any) -> Tuple[str, str]: + def _extract_gemini_texts(self, resp: Any) -> tuple[str, str]: """从 GenerateContentResponse 中提取 (思考文本, 正文文本)。""" try: cand0 = next(iter(getattr(resp, "candidates", []) or []), None) diff --git a/packages/web_searcher/engines/__init__.py b/packages/web_searcher/engines/__init__.py index 38b3ede10..706cfa87b 100644 --- a/packages/web_searcher/engines/__init__.py +++ b/packages/web_searcher/engines/__init__.py @@ -1,9 +1,9 @@ import random -from bs4 import BeautifulSoup -from aiohttp import ClientSession -from dataclasses import dataclass -from typing import List import urllib.parse +from dataclasses import dataclass + +from aiohttp import ClientSession +from bs4 import BeautifulSoup HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0", @@ -38,9 +38,7 @@ class SearchResult: class SearchEngine: - """ - 搜索引擎爬虫基类 - """ + """搜索引擎爬虫基类""" def __init__(self) -> None: self.TIMEOUT = 10 @@ -48,37 +46,44 @@ class SearchEngine: self.headers = HEADERS def _set_selector(self, selector: str) -> None: - raise NotImplementedError() + raise NotImplementedError def _get_next_page(self): - raise NotImplementedError() + raise NotImplementedError async def _get_html(self, url: str, data: dict = None) -> str: headers = self.headers headers["Referer"] = url headers["User-Agent"] = random.choice(USER_AGENTS) if data: - async with ClientSession() as session: - async with session.post( - url, headers=headers, data=data, timeout=self.TIMEOUT - ) as resp: - ret = await resp.text(encoding="utf-8") - return ret + async with ( + ClientSession() as session, + session.post( + url, + headers=headers, + data=data, + timeout=self.TIMEOUT, + ) as resp, + ): + ret = await resp.text(encoding="utf-8") + return ret else: - async with ClientSession() as session: - async with session.get( - url, headers=headers, timeout=self.TIMEOUT - ) as resp: - ret = await resp.text(encoding="utf-8") - return ret + async with ( + ClientSession() as session, + session.get( + url, + headers=headers, + timeout=self.TIMEOUT, + ) as resp, + ): + ret = await resp.text(encoding="utf-8") + return ret def tidy_text(self, text: str) -> str: - """ - 清理文本,去除空格、换行符等 - """ + """清理文本,去除空格、换行符等""" return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") - async def search(self, query: str, num_results: int) -> List[SearchResult]: + async def search(self, query: str, num_results: int) -> list[SearchResult]: query = urllib.parse.quote(query) try: @@ -88,7 +93,7 @@ class SearchEngine: results = [] for link in links: title = self.tidy_text( - link.select_one(self._set_selector("title")).text + link.select_one(self._set_selector("title")).text, ) url = link.select_one(self._set_selector("url")) snippet = "" diff --git a/packages/web_searcher/engines/bing.py b/packages/web_searcher/engines/bing.py index 01bec4d45..4c2ec319d 100644 --- a/packages/web_searcher/engines/bing.py +++ b/packages/web_searcher/engines/bing.py @@ -1,6 +1,4 @@ -from typing import List -from . import SearchEngine, SearchResult -from . import USER_AGENT_BING +from . import USER_AGENT_BING, SearchEngine, SearchResult class Bing(SearchEngine): @@ -31,7 +29,7 @@ class Bing(SearchEngine): continue raise Exception("Bing search failed") - async def search(self, query: str, num_results: int) -> List[SearchResult]: + async def search(self, query: str, num_results: int) -> list[SearchResult]: results = await super().search(query, num_results) for result in results: if not isinstance(result.url, str): diff --git a/packages/web_searcher/engines/sogo.py b/packages/web_searcher/engines/sogo.py index 9a505782f..382e7c937 100644 --- a/packages/web_searcher/engines/sogo.py +++ b/packages/web_searcher/engines/sogo.py @@ -1,10 +1,9 @@ import random import re -from bs4 import BeautifulSoup -from . import SearchEngine, SearchResult -from . import USER_AGENTS -from typing import List +from bs4 import BeautifulSoup + +from . import USER_AGENTS, SearchEngine, SearchResult class Sogo(SearchEngine): @@ -27,7 +26,7 @@ class Sogo(SearchEngine): url = f"{self.base_url}/web?query={query}" return await self._get_html(url, None) - async def search(self, query: str, num_results: int) -> List[SearchResult]: + async def search(self, query: str, num_results: int) -> list[SearchResult]: results = await super().search(query, num_results) for result in results: result.url = result.url.get("href") @@ -42,6 +41,6 @@ class Sogo(SearchEngine): script = soup.find("script") if script: url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group( - 1 + 1, ) return url diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 635f3ebb7..118ef2483 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -1,18 +1,18 @@ -import aiohttp import asyncio import random -import astrbot.api.star as star -import astrbot.api.event.filter as filter -from astrbot.api.event import AstrMessageEvent, MessageEventResult + +import aiohttp +from bs4 import BeautifulSoup +from readability import Document + +from astrbot.api import AstrBotConfig, llm_tool, logger, star +from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter from astrbot.api.provider import ProviderRequest -from astrbot.api import llm_tool, logger, AstrBotConfig from astrbot.core.provider.func_tool_manager import FunctionToolManager -from .engines import SearchResult + +from .engines import HEADERS, USER_AGENTS, SearchResult from .engines.bing import Bing from .engines.sogo import Sogo -from readability import Document -from bs4 import BeautifulSoup -from .engines import HEADERS, USER_AGENTS class Main(star.Star): @@ -35,7 +35,7 @@ class Main(star.Star): tavily_key = provider_settings.get("websearch_tavily_key") if isinstance(tavily_key, str): logger.info( - "检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。" + "检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。", ) if tavily_key: provider_settings["websearch_tavily_key"] = [tavily_key] @@ -65,7 +65,10 @@ class Main(star.Star): return ret async def _process_search_result( - self, result: SearchResult, idx: int, websearch_link: bool + self, + result: SearchResult, + idx: int, + websearch_link: bool, ) -> str: """处理单个搜索结果""" logger.info(f"web_searcher - scraping web: {result.title} - {result.url}") @@ -85,7 +88,9 @@ class Main(star.Star): return f"{header}\n{result.snippet}\n{site_result}\n\n" async def _web_search_default( - self, query, num_results: int = 5 + self, + query, + num_results: int = 5, ) -> list[SearchResult]: results = [] try: @@ -116,7 +121,9 @@ class Main(star.Star): return key async def _web_search_tavily( - self, cfg: AstrBotConfig, payload: dict + self, + cfg: AstrBotConfig, + payload: dict, ) -> list[SearchResult]: """使用 Tavily 搜索引擎进行搜索""" tavily_key = await self._get_tavily_key(cfg) @@ -127,12 +134,15 @@ class Main(star.Star): } async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( - url, json=payload, headers=header, timeout=6 + url, + json=payload, + headers=header, + timeout=6, ) as response: if response.status != 200: reason = await response.text() raise Exception( - f"Tavily web search failed: {reason}, status: {response.status}" + f"Tavily web search failed: {reason}, status: {response.status}", ) data = await response.json() results = [] @@ -155,18 +165,21 @@ class Main(star.Star): } async with aiohttp.ClientSession(trust_env=True) as session: async with session.post( - url, json=payload, headers=header, timeout=6 + url, + json=payload, + headers=header, + timeout=6, ) as response: if response.status != 200: reason = await response.text() raise Exception( - f"Tavily web search failed: {reason}, status: {response.status}" + f"Tavily web search failed: {reason}, status: {response.status}", ) data = await response.json() results: list[dict] = data.get("results", []) if not results: raise ValueError( - "Error: Tavily web searcher does not return any results." + "Error: Tavily web searcher does not return any results.", ) return results @@ -174,19 +187,23 @@ class Main(star.Star): async def websearch(self, event: AstrMessageEvent, oper: str | None = None): event.set_result( MessageEventResult().message( - "此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。" - ) + "此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。", + ), ) @llm_tool(name="web_search") async def search_from_search_engine( - self, event: AstrMessageEvent, query: str, max_results: int = 5 + self, + event: AstrMessageEvent, + query: str, + max_results: int = 5, ) -> str: """搜索网络以回答用户的问题。当用户需要搜索网络以获取即时性的信息时调用此工具。 Args: query(string): 和用户的问题最相关的搜索关键词,用于在 Google 上搜索。 max_results(number): 返回的最大搜索结果数量,默认为 5。 + """ logger.info(f"web_searcher - search_from_search_engine: {query}") cfg = self.context.get_config(umo=event.unified_msg_origin) @@ -218,11 +235,12 @@ class Main(star.Star): return cfg = self.context.get_config(umo=umo) key = cfg.get("provider_settings", {}).get( - "websearch_baidu_app_builder_key", "" + "websearch_baidu_app_builder_key", + "", ) if not key: raise ValueError( - "Error: Baidu AI Search API key is not configured in AstrBot." + "Error: Baidu AI Search API key is not configured in AstrBot.", ) func_tool_mgr = self.context.get_llm_tool_manager() await func_tool_mgr.enable_mcp_server( @@ -239,10 +257,11 @@ class Main(star.Star): @llm_tool(name="fetch_url") async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str: - """fetch the content of a website with the given web url + """Fetch the content of a website with the given web url Args: url(string): The url of the website to fetch content from + """ resp = await self._get_from_url(url) return resp @@ -272,6 +291,7 @@ class Main(star.Star): time_range(string): Optional. The time range back from the current date to include in the search results. This feature is available for both 'general' and 'news' search topics. Must be one of 'day', 'week', 'month', 'year'. start_date(string): Optional. The start date for the search results in the format 'YYYY-MM-DD'. end_date(string): Optional. The end date for the search results in the format 'YYYY-MM-DD'. + """ logger.info(f"web_searcher - search_from_tavily: {query}") cfg = self.context.get_config(umo=event.unified_msg_origin) @@ -319,13 +339,17 @@ class Main(star.Star): @llm_tool("tavily_extract_web_page") async def tavily_extract_web_page( - self, event: AstrMessageEvent, url: str = "", extract_depth: str = "basic" + self, + event: AstrMessageEvent, + url: str = "", + extract_depth: str = "basic", ) -> str: """Extract the content of a web page using Tavily. Args: url(string): Required. An URl to extract content from. extract_depth(string): Optional. The depth of the extraction, must be one of 'basic', 'advanced'. Default is "basic". + """ cfg = self.context.get_config(umo=event.unified_msg_origin) if not cfg.get("provider_settings", {}).get("websearch_tavily_key", []): @@ -351,7 +375,9 @@ class Main(star.Star): @filter.on_llm_request(priority=-10000) async def edit_web_search_tools( - self, event: AstrMessageEvent, req: ProviderRequest + self, + event: AstrMessageEvent, + req: ProviderRequest, ): """Get the session conversation for the given event.""" cfg = self.context.get_config(umo=event.unified_msg_origin) diff --git a/pyproject.toml b/pyproject.toml index 5868c5c99..c83fdf2dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,16 @@ [project] name = "AstrBot" -version = "4.5.1" +dynamic = ["version"] description = "易上手的多平台 LLM 聊天机器人及开发框架" readme = "README.md" requires-python = ">=3.10" + +keywords = [ + "Astrbot", + "Astrbot Module", + "Astrbot Plugin" +] + dependencies = [ "aiocqhttp>=1.4.4", "aiodocker>=0.24.0", @@ -58,35 +65,6 @@ dependencies = [ "xinference-client", ] -[project.scripts] -astrbot = "astrbot.cli.__main__:cli" - -[build-system] -requires = ["hatchling", "uv-dynamic-versioning"] -build-backend = "hatchling.build" - -[tool.ruff] -exclude = [ - "astrbot/core/utils/t2i/local_strategy.py", - "astrbot/api/all.py", -] -line-length = 88 -lint.ignore = [ - "F403", - "F405", - "E501", - "ASYNC230" # TODO: handle ASYNC230 in AstrBot -] -lint.select = [ - "F", # Pyflakes - "W", # pycodestyle warnings - "E", # pycodestyle errors - "ASYNC", # flake8-async - "C4", # flake8-comprehensions - "Q", # flake8-quotes -] -target-version = "py310" - [dependency-groups] dev = [ "commitizen>=4.9.1", @@ -95,3 +73,54 @@ dev = [ "pytest-cov>=6.2.1", "ruff>=0.12.8", ] + +[project.scripts] +astrbot = "astrbot.cli.__main__:cli" + +[tool.ruff] +exclude = [ + "astrbot/core/utils/t2i/local_strategy.py", + "astrbot/api/all.py", +] +line-length = 88 +target-version = "py310" + +[tool.ruff.lint] +select = [ + "F", # Pyflakes + "W", # pycodestyle warnings + "E", # pycodestyle errors + "ASYNC", # flake8-async + "C4", # flake8-comprehensions + "Q", # flake8-quotes + "I", # import-order + "UP", # pyupgrade + # "SIM", # flake8-simplify +] +ignore = [ + "F403", + "F405", + "E501", + "ASYNC230" # TODO: handle ASYNC230 in AstrBot +] + +[tool.pyright] +typeCheckingMode = "basic" +pythonVersion = "3.10" +reportMissingTypeStubs = false +reportMissingImports = false +include = ["astrbot","packages"] +exclude = ["dashboard", "node_modules", "dist", "data", "tests"] + +[tool.hatch.version] +source = "uv-dynamic-versioning" + +[tool.uv-dynamic-versioning] +vcs = "git" +style = "pep440" +bump = true + +[build-system] +requires = ["hatchling", "uv-dynamic-versioning"] +build-backend = "hatchling.build" + diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 8fd8ce5f9..a2710c841 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,14 +1,16 @@ +import asyncio +import os + import pytest import pytest_asyncio -import os -import asyncio from quart import Quart -from astrbot.dashboard.server import AstrBotDashboard -from astrbot.core.db.sqlite import SQLiteDatabase -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + from astrbot.core import LogBroker -from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.star.star import star_registry +from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.dashboard.server import AstrBotDashboard @pytest_asyncio.fixture(scope="module") @@ -53,7 +55,8 @@ async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): """Tests the login functionality with both wrong and correct credentials.""" test_client = app.test_client() response = await test_client.post( - "/api/auth/login", json={"username": "wrong", "password": "password"} + "/api/auth/login", + json={"username": "wrong", "password": "password"}, ) data = await response.get_json() assert data["status"] == "error" @@ -91,7 +94,8 @@ async def test_plugins(app: Quart, authenticated_header: dict): # 插件市场 response = await test_client.get( - "/api/plugin/market_list", headers=authenticated_header + "/api/plugin/market_list", + headers=authenticated_header, ) assert response.status_code == 200 data = await response.get_json() @@ -172,7 +176,6 @@ async def test_do_update( async def mock_update(*args, **kwargs): """Mocks the update process by creating a directory in the temp path.""" os.makedirs(release_path, exist_ok=True) - return async def mock_download_dashboard(*args, **kwargs): """Mocks the dashboard download to prevent network access.""" @@ -184,10 +187,12 @@ async def test_do_update( monkeypatch.setattr(core_lifecycle_td.astrbot_updator, "update", mock_update) monkeypatch.setattr( - "astrbot.dashboard.routes.update.download_dashboard", mock_download_dashboard + "astrbot.dashboard.routes.update.download_dashboard", + mock_download_dashboard, ) monkeypatch.setattr( - "astrbot.dashboard.routes.update.pip_installer.install", mock_pip_install + "astrbot.dashboard.routes.update.pip_installer.install", + mock_pip_install, ) response = await test_client.post( diff --git a/tests/test_main.py b/tests/test_main.py index d7e45b01d..0453a51ee 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,9 +4,11 @@ import sys # 将项目根目录添加到 sys.path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -import pytest from unittest import mock -from main import check_env, check_dashboard_files + +import pytest + +from main import check_dashboard_files, check_env class _version_info: diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 44486a11d..277f8fa4d 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,19 +1,20 @@ -import pytest import os +from asyncio import Queue from unittest.mock import MagicMock -from astrbot.core.star.star_manager import PluginManager -from astrbot.core.star.star_handler import star_handlers_registry -from astrbot.core.star.star import star_registry -from astrbot.core.star.context import Context + +import pytest + from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.db.sqlite import SQLiteDatabase -from asyncio import Queue +from astrbot.core.star.context import Context +from astrbot.core.star.star import star_registry +from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.star.star_manager import PluginManager @pytest.fixture def plugin_manager_pm(tmp_path): - """ - Provides a fully isolated PluginManager instance for testing. + """Provides a fully isolated PluginManager instance for testing. - Uses a temporary directory for plugins. - Uses a temporary database. - Creates a fresh context for each test. @@ -53,7 +54,7 @@ def plugin_manager_pm(tmp_path): # Create the PluginManager instance manager = PluginManager(star_context, config) - yield manager + return manager def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): @@ -75,7 +76,8 @@ async def test_install_plugin(plugin_manager_pm: PluginManager): test_repo = "https://github.com/Soulter/astrbot_plugin_essential" plugin_info = await plugin_manager_pm.install_plugin(test_repo) plugin_path = os.path.join( - plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential" + plugin_manager_pm.plugin_store_path, + "astrbot_plugin_essential", ) assert plugin_info is not None @@ -90,7 +92,7 @@ async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.install_plugin( - "https://github.com/Soulter/non_existent_repo" + "https://github.com/Soulter/non_existent_repo", ) @@ -119,7 +121,8 @@ async def test_uninstall_plugin(plugin_manager_pm: PluginManager): test_repo = "https://github.com/Soulter/astrbot_plugin_essential" await plugin_manager_pm.install_plugin(test_repo) plugin_path = os.path.join( - plugin_manager_pm.plugin_store_path, "astrbot_plugin_essential" + plugin_manager_pm.plugin_store_path, + "astrbot_plugin_essential", ) assert os.path.exists(plugin_path) # Pre-condition