From bdfc77d3490089c87099f5e9409d600d05c137b3 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 9 Dec 2024 22:38:42 +0800 Subject: [PATCH] refactor: im so tired :) --- .gitignore | 7 +- astrbot/api/__init__.py | 14 +- astrbot/api/all.py | 40 ++ astrbot/api/event/__init__.py | 5 + astrbot/api/event/filter/__init__.py | 10 + astrbot/api/platform/__init__.py | 5 + astrbot/api/provider/__init__.py | 1 + astrbot/api/star/__init__.py | 6 + astrbot/core/config/__init__.py | 4 +- astrbot/core/config/astrbot_config.py | 349 ++++-------------- astrbot/core/config/default.py | 238 ++++++------ astrbot/core/core_lifecycle.py | 68 +++- astrbot/core/db/__init__.py | 21 +- astrbot/core/db/po.py | 15 +- astrbot/core/db/sqlite.py | 101 +++-- astrbot/core/db/sqlite_init.sql | 14 + astrbot/core/event_bus.py | 8 +- astrbot/core/message/components.py | 1 + astrbot/core/message/message_event_handler.py | 177 --------- astrbot/core/message/message_event_result.py | 104 +++++- astrbot/core/pipeline/__init__.py | 18 + .../pipeline/content_safety_check/stage.py | 31 ++ .../strategies/__init__.py | 8 + .../strategies/baidu_aip.py | 31 ++ .../strategies/keywords.py | 21 ++ .../strategies/strategy.py | 27 ++ .../strategies}/unfit_words | 0 astrbot/core/pipeline/context.py | 8 + .../process_stage/method/llm_request.py | 71 ++++ .../process_stage/method/star_request.py | 48 +++ astrbot/core/pipeline/process_stage/stage.py | 36 ++ .../core/pipeline/rate_limit_check/stage.py | 87 +++++ astrbot/core/pipeline/respond/stage.py | 25 ++ .../core/pipeline/result_decorate/stage.py | 45 +++ astrbot/core/pipeline/scheduler.py | 44 +++ astrbot/core/pipeline/stage.py | 32 ++ astrbot/core/pipeline/waking_check/stage.py | 96 +++++ .../core/pipeline/whitelist_check/stage.py | 19 + astrbot/core/platform/astr_message_event.py | 100 ++++- astrbot/core/platform/astrbot_message.py | 3 +- astrbot/core/platform/manager.py | 42 +++ astrbot/core/platform/platform_metadata.py | 2 +- astrbot/core/platform/register.py | 25 ++ .../aiocqhttp}/aiocqhttp_message_event.py | 2 +- .../aiocqhttp}/aiocqhttp_platform_adapter.py | 46 +-- .../qqofficial}/qqofficial_message_event.py | 3 +- .../qqofficial_platform_adapter.py | 24 +- .../sources/vchat/vchat_message_event.py | 8 +- .../sources/vchat/vchat_platform_adapter.py | 32 +- astrbot/core/plugin/__init__.py | 4 - astrbot/core/plugin/context.py | 217 ----------- astrbot/core/plugin/plugin.py | 43 --- astrbot/core/provider/__init__.py | 2 +- astrbot/core/provider/llm_response.py | 13 + astrbot/core/provider/manager.py | 49 +++ astrbot/core/provider/provider.py | 157 +++++--- astrbot/core/provider/provider_metadata.py | 6 + astrbot/core/provider/register.py | 25 ++ .../core/provider/sources/openai_source.py | 216 +++++++++++ .../{utils/func_call.py => provider/tool.py} | 55 +-- .../provider/tools/websearch/engines}/bing.py | 0 .../tools/websearch/engines}/config.py | 0 .../tools/websearch/engines}/engine.py | 0 .../tools/websearch/engines}/google.py | 0 .../provider/tools/websearch/engines}/sogo.py | 0 .../provider/tools/websearch}/web_searcher.py | 13 +- astrbot/core/star/README.md | 5 + astrbot/core/star/__init__.py | 4 + astrbot/core/{plugin => star}/config.py | 0 astrbot/core/star/context.py | 174 +++++++++ astrbot/core/star/filter/__init__.py | 10 + astrbot/core/star/filter/command.py | 67 ++++ astrbot/core/star/filter/command_group.py | 70 ++++ .../core/star/filter/event_message_type.py | 28 ++ .../core/star/filter/platform_adapter_type.py | 27 ++ astrbot/core/star/filter/regex.py | 14 + astrbot/core/star/register/__init__.py | 8 + astrbot/core/star/register/star.py | 18 + astrbot/core/star/register/star_handler.py | 115 ++++++ astrbot/core/star/star.py | 43 +++ astrbot/core/star/star_handler.py | 31 ++ .../star_manager.py} | 198 +++++----- astrbot/core/{plugin => star}/updator.py | 36 +- astrbot/core/utils/command_parser.py | 9 +- astrbot/core/utils/image_uploader.py | 21 -- astrbot/core/utils/param_validation_mixin.py | 31 ++ astrbot/dashboard/dashboard_lifecycle.py | 2 +- astrbot/dashboard/routes/auth.py | 2 +- astrbot/dashboard/routes/config.py | 19 +- astrbot/dashboard/routes/plugin.py | 16 +- astrbot/dashboard/routes/stat.py | 4 +- astrbot/dashboard/server.py | 4 +- main.py | 3 +- packages/astrbot/main.py | 303 +++++++++++---- packages/astrbot/metadata.yaml | 6 - packages/astrbot_adapter_aiocqhttp/main.py | 13 - .../astrbot_adapter_aiocqhttp/metadata.yaml | 6 - packages/astrbot_adapter_qqofficial/main.py | 18 - .../astrbot_adapter_qqofficial/metadata.yaml | 6 - packages/astrbot_adapter_wechat/main.py | 18 - packages/astrbot_adapter_wechat/metadata.yaml | 6 - packages/astrbot_plugin_openai/__init__.py | 1 - packages/astrbot_plugin_openai/commands.py | 169 --------- packages/astrbot_plugin_openai/main.py | 253 ------------- packages/astrbot_plugin_openai/metadata.yaml | 6 - .../astrbot_plugin_openai/openai_adapter.py | 254 ------------- requirements.txt | 3 +- requirements_atri_base.txt | 2 + requirements_atri_ft.txt | 2 + 109 files changed, 2843 insertions(+), 2104 deletions(-) create mode 100644 astrbot/api/all.py create mode 100644 astrbot/api/event/__init__.py create mode 100644 astrbot/api/event/filter/__init__.py create mode 100644 astrbot/api/platform/__init__.py create mode 100644 astrbot/api/provider/__init__.py create mode 100644 astrbot/api/star/__init__.py delete mode 100644 astrbot/core/message/message_event_handler.py create mode 100644 astrbot/core/pipeline/__init__.py create mode 100644 astrbot/core/pipeline/content_safety_check/stage.py create mode 100644 astrbot/core/pipeline/content_safety_check/strategies/__init__.py create mode 100644 astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py create mode 100644 astrbot/core/pipeline/content_safety_check/strategies/keywords.py create mode 100644 astrbot/core/pipeline/content_safety_check/strategies/strategy.py rename {packages/astrbot => astrbot/core/pipeline/content_safety_check/strategies}/unfit_words (100%) create mode 100644 astrbot/core/pipeline/context.py create mode 100644 astrbot/core/pipeline/process_stage/method/llm_request.py create mode 100644 astrbot/core/pipeline/process_stage/method/star_request.py create mode 100644 astrbot/core/pipeline/process_stage/stage.py create mode 100644 astrbot/core/pipeline/rate_limit_check/stage.py create mode 100644 astrbot/core/pipeline/respond/stage.py create mode 100644 astrbot/core/pipeline/result_decorate/stage.py create mode 100644 astrbot/core/pipeline/scheduler.py create mode 100644 astrbot/core/pipeline/stage.py create mode 100644 astrbot/core/pipeline/waking_check/stage.py create mode 100644 astrbot/core/pipeline/whitelist_check/stage.py create mode 100644 astrbot/core/platform/manager.py create mode 100644 astrbot/core/platform/register.py rename {packages/astrbot_adapter_aiocqhttp => astrbot/core/platform/sources/aiocqhttp}/aiocqhttp_message_event.py (96%) rename {packages/astrbot_adapter_aiocqhttp => astrbot/core/platform/sources/aiocqhttp}/aiocqhttp_platform_adapter.py (71%) rename {packages/astrbot_adapter_qqofficial => astrbot/core/platform/sources/qqofficial}/qqofficial_message_event.py (96%) rename {packages/astrbot_adapter_qqofficial => astrbot/core/platform/sources/qqofficial}/qqofficial_platform_adapter.py (87%) rename packages/astrbot_adapter_wechat/wechat_message_event.py => astrbot/core/platform/sources/vchat/vchat_message_event.py (84%) rename packages/astrbot_adapter_wechat/wechat_platform_adapter.py => astrbot/core/platform/sources/vchat/vchat_platform_adapter.py (77%) delete mode 100644 astrbot/core/plugin/__init__.py delete mode 100644 astrbot/core/plugin/context.py delete mode 100644 astrbot/core/plugin/plugin.py create mode 100644 astrbot/core/provider/llm_response.py create mode 100644 astrbot/core/provider/manager.py create mode 100644 astrbot/core/provider/provider_metadata.py create mode 100644 astrbot/core/provider/register.py create mode 100644 astrbot/core/provider/sources/openai_source.py rename astrbot/core/{utils/func_call.py => provider/tool.py} (83%) rename {packages/astrbot_plugin_openai/websearch => astrbot/core/provider/tools/websearch/engines}/bing.py (100%) rename {packages/astrbot_plugin_openai/websearch => astrbot/core/provider/tools/websearch/engines}/config.py (100%) rename {packages/astrbot_plugin_openai/websearch => astrbot/core/provider/tools/websearch/engines}/engine.py (100%) rename {packages/astrbot_plugin_openai/websearch => astrbot/core/provider/tools/websearch/engines}/google.py (100%) rename {packages/astrbot_plugin_openai/websearch => astrbot/core/provider/tools/websearch/engines}/sogo.py (100%) rename {packages/astrbot_plugin_openai => astrbot/core/provider/tools/websearch}/web_searcher.py (91%) create mode 100644 astrbot/core/star/README.md create mode 100644 astrbot/core/star/__init__.py rename astrbot/core/{plugin => star}/config.py (100%) create mode 100644 astrbot/core/star/context.py create mode 100644 astrbot/core/star/filter/__init__.py create mode 100644 astrbot/core/star/filter/command.py create mode 100644 astrbot/core/star/filter/command_group.py create mode 100644 astrbot/core/star/filter/event_message_type.py create mode 100644 astrbot/core/star/filter/platform_adapter_type.py create mode 100644 astrbot/core/star/filter/regex.py create mode 100644 astrbot/core/star/register/__init__.py create mode 100644 astrbot/core/star/register/star.py create mode 100644 astrbot/core/star/register/star_handler.py create mode 100644 astrbot/core/star/star.py create mode 100644 astrbot/core/star/star_handler.py rename astrbot/core/{plugin/plugin_manager.py => star/star_manager.py} (57%) rename astrbot/core/{plugin => star}/updator.py (69%) delete mode 100644 astrbot/core/utils/image_uploader.py create mode 100644 astrbot/core/utils/param_validation_mixin.py delete mode 100644 packages/astrbot/metadata.yaml delete mode 100644 packages/astrbot_adapter_aiocqhttp/main.py delete mode 100644 packages/astrbot_adapter_aiocqhttp/metadata.yaml delete mode 100644 packages/astrbot_adapter_qqofficial/main.py delete mode 100644 packages/astrbot_adapter_qqofficial/metadata.yaml delete mode 100644 packages/astrbot_adapter_wechat/main.py delete mode 100644 packages/astrbot_adapter_wechat/metadata.yaml delete mode 100644 packages/astrbot_plugin_openai/__init__.py delete mode 100644 packages/astrbot_plugin_openai/commands.py delete mode 100644 packages/astrbot_plugin_openai/main.py delete mode 100644 packages/astrbot_plugin_openai/metadata.yaml delete mode 100644 packages/astrbot_plugin_openai/openai_adapter.py create mode 100644 requirements_atri_base.txt create mode 100644 requirements_atri_ft.txt diff --git a/.gitignore b/.gitignore index 19e9c7060..77409966d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__ botpy.log .vscode data_v2.db +data_v3.db configs/session configs/config.yaml **/.DS_Store @@ -11,4 +12,8 @@ data cookies.json logs/ addons/plugins -.coverage \ No newline at end of file +.coverage + + +tests/astrbot_plugin_openai +chroma \ No newline at end of file diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index 40a31adbb..feb3c162b 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -1,16 +1,4 @@ - -from astrbot.core.plugin import Context -from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, CommandResult -from astrbot.core.provider import Provider, Personality from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger from astrbot.core.utils.personality import personalities - -from astrbot.core.utils.command_parser import CommandParser, CommandTokens -from astrbot.core.utils.func_call import FuncCall -from astrbot.core import html_renderer - -from astrbot.core.plugin.config import * - -command_parser = CommandParser() \ No newline at end of file +from astrbot.core import html_renderer \ No newline at end of file diff --git a/astrbot/api/all.py b/astrbot/api/all.py new file mode 100644 index 000000000..fd44f6481 --- /dev/null +++ b/astrbot/api/all.py @@ -0,0 +1,40 @@ + +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot import logger +from astrbot.core.utils.personality import personalities +from astrbot.core import html_renderer + +# event +from astrbot.core.message.message_event_result import ( + MessageEventResult, MessageChain, CommandResult, EventResultType +) +from astrbot.core.platform import AstrMessageEvent + +# star register +from astrbot.core.star.register import ( + register_command as command, + register_command_group as command_group, + register_event_message_type as event_message_type, + register_regex as regex, + register_platform_adapter_type as platform_adapter_type, +) +from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType +from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType +from astrbot.core.star.register import ( + register_star as register # 注册插件(Star) +) +from astrbot.core.star import Context, Star +from astrbot.core.star.config import * + + +# provider +from astrbot.core.provider import Provider, Personality, ProviderMetaData + +# platform +from astrbot.core.platform import ( + AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +) + +from astrbot.core.platform.register import register_platform_adapter + +from .message_components import * \ No newline at end of file diff --git a/astrbot/api/event/__init__.py b/astrbot/api/event/__init__.py new file mode 100644 index 000000000..21129f645 --- /dev/null +++ b/astrbot/api/event/__init__.py @@ -0,0 +1,5 @@ +from astrbot.core.message.message_event_result import ( + MessageEventResult, MessageChain, CommandResult, EventResultType +) + +from astrbot.core.platform import AstrMessageEvent \ No newline at end of file diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py new file mode 100644 index 000000000..430968feb --- /dev/null +++ b/astrbot/api/event/filter/__init__.py @@ -0,0 +1,10 @@ +from astrbot.core.star.register import ( + register_command as command, + register_command_group as command_group, + register_event_message_type as event_message_type, + register_regex as regex, + register_platform_adapter_type as platform_adapter_type, +) + +from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType +from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType \ No newline at end of file diff --git a/astrbot/api/platform/__init__.py b/astrbot/api/platform/__init__.py new file mode 100644 index 000000000..7947671ac --- /dev/null +++ b/astrbot/api/platform/__init__.py @@ -0,0 +1,5 @@ +from astrbot.core.platform import ( + AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +) + +from astrbot.core.platform.register import register_platform_adapter \ No newline at end of file diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py new file mode 100644 index 000000000..52c5c59d4 --- /dev/null +++ b/astrbot/api/provider/__init__.py @@ -0,0 +1 @@ +from astrbot.core.provider import Provider, Personality, ProviderMetaData diff --git a/astrbot/api/star/__init__.py b/astrbot/api/star/__init__.py new file mode 100644 index 000000000..dfe1eb77d --- /dev/null +++ b/astrbot/api/star/__init__.py @@ -0,0 +1,6 @@ +from astrbot.core.star.register import ( + register_star as register # 注册插件(Star) +) + +from astrbot.core.star import Context, Star +from astrbot.core.star.config import * diff --git a/astrbot/core/config/__init__.py b/astrbot/core/config/__init__.py index a781e688a..095d8c773 100644 --- a/astrbot/core/config/__init__.py +++ b/astrbot/core/config/__init__.py @@ -1,2 +1,2 @@ -from .default import DEFAULT_CONFIG_VERSION_2, VERSION, DB_PATH -from .astrbot_config import AstrBotConfig \ No newline at end of file +from .default import DEFAULT_CONFIG, VERSION, DB_PATH +from .astrbot_config import * \ No newline at end of file diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 85e3a6252..3113e468d 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -1,295 +1,82 @@ import os import json -import shutil import logging -from . import DEFAULT_CONFIG_VERSION_2 -from dataclasses import dataclass, field, asdict -from typing import List, Dict, Optional +import enum +from .default import DEFAULT_CONFIG +from typing import List, Dict ASTRBOT_CONFIG_PATH = "data/cmd_config.json" logger = logging.getLogger("astrbot") -@dataclass -class RateLimit: - time: int = 60 - count: int = 30 +class RateLimitStrategy(enum.Enum): + STALL = "stall" + DISCARD = "discard" -@dataclass -class PlatformSettings: - unique_session: bool = False - rate_limit: RateLimit = field(default_factory=RateLimit) - reply_prefix: str = "" - forward_threshold: int = 200 +class AstrBotConfig(dict): + '''从配置文件中加载的配置,支持直接通过点号操作符访问配置项''' - def __post_init__(self): - self.rate_limit = RateLimit(**self.rate_limit) - -@dataclass -class PlatformConfig(): - id: str = "" - name: str = "" - enable: bool = False - -@dataclass -class QQOfficialPlatformConfig(PlatformConfig): - appid: str = "" - secret: str = "" - enable_group_c2c: bool = True - enable_guild_direct_message: bool = True - -@dataclass -class AiocqhttpPlatformConfig(PlatformConfig): - ws_reverse_host: str = "" - ws_reverse_port: int = 6199 - qq_id_whitelist: List[str] = field(default_factory=list) - qq_group_id_whitelist: List[str] = field(default_factory=list) - -@dataclass -class WechatPlatformConfig(PlatformConfig): - wechat_id_whitelist: List[str] = field(default_factory=list) - -@dataclass -class ModelConfig: - model: str = "gpt-4o" - max_tokens: int = 6000 - temperature: float = 0.9 - top_p: float = 1 - - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - -@dataclass -class ImageGenerationModelConfig: - enable: bool = True - model: str = "dall-e-3" - size: str = "1024x1024" - style: str = "vivid" - quality: str = "standard" - -@dataclass -class EmbeddingModel: - enable: bool = False - model: str = "" - -@dataclass -class LLMConfig: - id: str = "" - name: str = "openai" - enable: bool = True - key: List[str] = field(default_factory=list) - api_base: str = "" - prompt_prefix: str = "" - default_personality: str = "" - model_config: ModelConfig = field(default_factory=ModelConfig) - image_generation_model_config: Optional[ImageGenerationModelConfig] = field(default_factory=ImageGenerationModelConfig) - embedding_model: Optional[EmbeddingModel] = field(default_factory=EmbeddingModel) - - def __post_init__(self): - if isinstance(self.model_config, dict): - self.model_config = ModelConfig(**self.model_config) - if isinstance(self.image_generation_model_config, dict): - self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) if self.image_generation_model_config else None - if isinstance(self.embedding_model, dict): - self.embedding_model = EmbeddingModel(**self.embedding_model) if self.embedding_model else None - -@dataclass -class LLMSettings: - wake_prefix: str = "" - web_search: bool = False - identifier: bool = False - -@dataclass -class BaiduAIPConfig: - enable: bool = False - app_id: str = "" - api_key: str = "" - secret_key: str = "" - -@dataclass -class InternalKeywordsConfig: - enable: bool = True - extra_keywords: List[str] = field(default_factory=list) - -@dataclass -class ContentSafetyConfig: - baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig) - internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig) - - def __post_init__(self): - self.baidu_aip = BaiduAIPConfig(**self.baidu_aip) - self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords) - -@dataclass -class DashboardConfig: - enable: bool = True - username: str = "" - password: str = "" - - -@dataclass -class ATRILongTermMemory: - enable: bool = False - summary_threshold_cnt: int = 5 - -@dataclass -class ATRIActiveMessage: - enable: bool = False - -@dataclass -class ProjectATRI: - enable: bool = False - long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory) - active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage) - persona: str = "" - split_response: bool = True - embedding_provider_id: str = "" - summarize_provider_id: str = "" - chat_provider_id: str = "" - chat_base_model_path: str = "" - chat_adapter_model_path: str = "" - quantization_bit: int = 4 - - def __post_init__(self): - if isinstance(self.long_term_memory, dict): - self.long_term_memory = ATRILongTermMemory(**self.long_term_memory) - if isinstance(self.active_message, dict): - self.active_message = ATRIActiveMessage(**self.active_message) - -@dataclass -class AstrBotConfig(): - config_version: int = 2 - platform_settings: PlatformSettings = field(default_factory=PlatformSettings) - llm: List[LLMConfig] = field(default_factory=list) - llm_settings: LLMSettings = field(default_factory=LLMSettings) - content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig) - t2i: bool = True - admins_id: List[str] = field(default_factory=list) - https_proxy: str = "" - http_proxy: str = "" - dashboard: DashboardConfig = field(default_factory=DashboardConfig) - platform: List[PlatformConfig] = field(default_factory=list) - wake_prefix: List[str] = field(default_factory=list) - log_level: str = "INFO" - t2i_endpoint: str = "" - pip_install_arg: str = "" - plugin_repo_mirror: str = "" - project_atri: ProjectATRI = field(default_factory=ProjectATRI) - - def __init__(self) -> None: - self.init_configs() - - # compability - if isinstance(self.wake_prefix, str): - self.wake_prefix = [self.wake_prefix] - - if len(self.wake_prefix) == 0: - self.wake_prefix.append("/") - - def load_from_dict(self, data: Dict): - '''从字典中加载配置到对象。 - - @note: 适用于 version 2 配置文件。 - ''' - self.config_version=data.get("version", 2) - self.platform=[] - - left_platforms = ["qq_official", "aiocqhttp", "wechat"] - for p in data.get("platform", []): - if 'name' not in p: - logger.warning("A platform config missing name, skipping.") - continue - if p["name"] == "qq_official": - self.platform.append(QQOfficialPlatformConfig(**p)) - left_platforms.remove(p["name"]) - elif p["name"] == "aiocqhttp": - self.platform.append(AiocqhttpPlatformConfig(**p)) - left_platforms.remove(p["name"]) - elif p["name"] == "wechat": - self.platform.append(WechatPlatformConfig(**p)) - left_platforms.remove(p["name"]) - # 注入默认配置 - for p in left_platforms: - if p == "qq_official": - self.platform.append(QQOfficialPlatformConfig(id="default", name=p)) - elif p == "aiocqhttp": - self.platform.append(AiocqhttpPlatformConfig(id="default", name=p)) - elif p == "wechat": - self.platform.append(WechatPlatformConfig(id="default", name=p)) - - self.platform_settings=PlatformSettings(**data.get("platform_settings", {})) - self.llm=[LLMConfig(**l) for l in data.get("llm", [])] - self.llm_settings=LLMSettings(**data.get("llm_settings", {})) - self.content_safety=ContentSafetyConfig(**data.get("content_safety", {})) - self.t2i=data.get("t2i", True) - self.admins_id=data.get("admins_id", []) - self.https_proxy=data.get("https_proxy", "") - self.http_proxy=data.get("http_proxy", "") - self.dashboard=DashboardConfig(**data.get("dashboard", {})) - self.wake_prefix=data.get("wake_prefix", ["/"]) - self.log_level=data.get("log_level", "INFO") - self.t2i_endpoint=data.get("t2i_endpoint", "") - self.pip_install_arg=data.get("pip_install_arg", "") - self.plugin_repo_mirror=data.get("plugin_repo_mirror", "") - self.project_atri=ProjectATRI(**data.get("project_atri", {})) - - def flush_config(self, config: dict = None): - '''将配置写入文件, 如果没有传入配置,则写入默认配置''' - with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: - json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False) - f.flush() - - def save_config(self): - '''将现存配置写入文件''' - self.flush_config(self.to_dict()) - - def init_configs(self): - '''初始化必需的配置项''' - config = None - + def __init__(self): + super().__init__() if not self.check_exist(): - self.flush_config() - config = DEFAULT_CONFIG_VERSION_2 - else: - config = self.get_all() - - # 加载配置到对象 - self.load_from_dict(config) - # 保存到文件 - # 这一步操作是为了保证配置文件中的字段的完整性。 - # 在版本变动新增配置项时,将对象中新增的配置项的默认值写入文件。 - self.save_config() - - def get(self, key: str, default=None): - '''从文件系统中直接获取配置''' - with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: - d = json.load(f) - if key in d: - return d[key] - else: - return default - - def get_all(self): - '''从文件系统中获取所有配置''' - with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: - conf_str = f.read() - if conf_str.startswith(u'/ufeff'): # remove BOM - conf_str = conf_str.encode('utf8')[3:].decode('utf8') - if not conf_str: - return {} - conf = json.loads(conf_str) - return conf - - def put(self, key, value): - with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: - d = json.load(f) - d[key] = value + '''不存在时载入默认配置''' with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: - json.dump(d, f, indent=2, ensure_ascii=False) - f.flush() - - def to_dict(self) -> Dict: - return asdict(self) + json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False) + + with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: + conf_str = f.read() + if conf_str.startswith(u'/ufeff'): # remove BOM + conf_str = conf_str.encode('utf8')[3:].decode('utf8') + conf = json.loads(conf_str) + + # 检查配置完整性,并插入 + has_new = self.check_config_integrity(DEFAULT_CONFIG, conf) + self.update(conf) + if has_new: + self.save_config() + + self.update(conf) + + def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): + '''检查配置完整性,如果有新的配置项则返回 True''' + has_new = False + for key, value in refer_conf.items(): + if key not in conf: + logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}") + conf[key] = value + has_new = True + else: + if conf[key] == None: + conf[key] = value + has_new = True + elif isinstance(value, dict): + has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key) + return has_new + + def save_config(self, replace_config: Dict = None): + '''将配置写入文件 + + 如果传入 replace_config,则将配置替换为 replace_config + ''' + if replace_config: + self.update(replace_config) + with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: + json.dump(self, f, indent=2, ensure_ascii=False) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + return None + + def __delattr__(self, key): + try: + del self[key] + self.save_config() + except KeyError: + raise AttributeError(f"没有找到 Key: '{key}'") + + def __setattr__(self, key, value): + self[key] = value def check_exist(self) -> bool: return os.path.exists(ASTRBOT_CONFIG_PATH) \ No newline at end of file diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index c47df8d9e..98feee9ad 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1,177 +1,160 @@ ''' -这里定义了一些默认配置文件,请不要修改这个文件。如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 +如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 ''' VERSION = '3.4.0' DB_PATH = 'data/data_v3.db' +# 默认配置 +DEFAULT_CONFIG = { + "config_version": 2, + "platform_settings": { + "unique_session": False, + "rate_limit": { + "time": 60, + "count": 30, + "strategy": "stall" # stall, discard + }, + "reply_prefix": "", + "forward_threshold": 200, + "id_whitelist": [] + }, + "provider": [], + "provider_settings": { + "wake_prefix": "", + "web_search": False, + "identifier": False, + "default_personality": "", + "prompt_prefix": "" + }, + "content_safety": { + "internal_keywords": { + "enable": True, + "extra_keywords": [] + }, + "baidu_aip": { + "enable": False, + "app_id": "", + "api_key": "", + "secret_key": "" + } + }, + "admins_id": [], + "t2i": False, + "http_proxy": "", + "dashboard": { + "enable": True, + "username": "astrbot", + "password": "77b90590a8945a7d36c963981a307dc9" + }, + "platform": [], + "wake_prefix": ["/"], + "log_level": "INFO", + "t2i_endpoint": "", + "pip_install_arg": "", + "plugin_repo_mirror": "", + "project_atri": { + "enable": False, + "long_term_memory": { + "enable": False, + "summary_threshold_cnt": 5, + "embedding_provider_id": "", + "summarize_provider_id": "" + }, + "active_message": { + "enable": False + }, + "vision": { + "enable": False, + "provider_id_or_ofa_model_path": "", + "reply_meme_prob": 0.4, + "reply_meme_similar_threshold": 0.7 + }, + "persona": "", + "split_response": True, + "chat_provider_id": "", + "chat_base_model_path": "", + "chat_adapter_model_path": "", + "quantization_bit": 4 + } +} + # LLM 提供商配置模板 PROVIDER_CONFIG_TEMPLATE = { "openai": { "id": "default", - "name": "openai", + "type": "openai_chat_completion", "enable": True, "key": [], "api_base": "", - "prompt_prefix": "", - "default_personality": "", "model_config": { - "model": "gpt-4o", - "max_tokens": 6000, - "temperature": 0.9, - "top_p": 1, - }, - "image_generation_model_config": { - "enable": False, - "model": "dall-e-3", - "size": "1024x1024", - "style": "vivid", - "quality": "standard", - }, - "embedding_model": { - "enable": False, - "model": "text-embedding-3-small" + "model": "gpt-4o-mini", } }, "ollama": { "id": "ollama_default", - "name": "ollama", + "type": "openai_chat_completion", "enable": True, "key": ["ollama"], # ollama 的 key 默认是 ollama "api_base": "http://localhost:11434", - "prompt_prefix": "", - "default_personality": "", "model_config": { "model": "llama3.1-8b", - "temperature": 0.9, - "top_p": 1, } }, "gemini": { "id": "gemini_default", - "name": "gemini", + "type": "openai_chat_completion", "enable": True, "key": [], "api_base": "https://generativelanguage.googleapis.com/v1beta/openai/", - "prompt_prefix": "", - "default_personality": "", "model_config": { "model": "gemini-1.5-flash", } }, "deepseek": { "id": "deepseek_default", - "name": "deepseek", + "type": "openai_chat_completion", "enable": True, "key": [], "api_base": "https://api.deepseek.com/v1", - "prompt_prefix": "", - "default_personality": "", "model_config": { "model": "deepseek-chat", } }, "zhipu": { "id": "zhipu_default", - "name": "zhipu(glm)", + "type": "openai_chat_completion", "enable": True, "key": [], "api_base": "https://open.bigmodel.cn/api/paas/v4/", - "prompt_prefix": "", - "default_personality": "", "model_config": { "model": "glm-4-flash", } }, } -# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D -DEFAULT_CONFIG_VERSION_2 = { - "config_version": 2, - "platform": [ - { - "id": "default", - "name": "qq_official", - "enable": False, - "appid": "", - "secret": "", - "enable_group_c2c": True, - "enable_guild_direct_message": True, - }, - { - "id": "default", - "name": "aiocqhttp", - "enable": False, - "ws_reverse_host": "", - "ws_reverse_port": 6199, - "qq_id_whitelist": [], - "qq_group_id_whitelist": [] - }, - { - "id": "default", - "name": "wechat", - "enable": False, - "wechat_id_whitelist": [] - } - ], - "platform_settings": { - "unique_session": False, - "rate_limit": { - "time": 60, - "count": 30, - }, - "reply_prefix": "", - "forward_threshold": 200, # 转发消息的阈值 - }, - "llm": [ - PROVIDER_CONFIG_TEMPLATE["openai"] - ], - "llm_settings": { - "wake_prefix": "", - "web_search": False, - "identifier": False, - }, - "content_safety": { - "internal_keywords": { - "enable": True, - "extra_keywords": [], - } - }, - "wake_prefix": ["/"], - "t2i": True, - "admins_id": [], - "https_proxy": "", - "http_proxy": "", - "dashboard": { - "enable": True, - "username": "astrbot", - "password": "77b90590a8945a7d36c963981a307dc9", - }, - "log_level": "INFO", - "t2i_endpoint": "", - "pip_install_arg": "", - "plugin_repo_mirror": "default", - "project_atri": { +# 平台适配器配置模板 +ADAPTER_CONFIG_TEMPLATE = { + "qq_official": { + "id": "default", + "name": "qq_official", "enable": False, - "long_term_memory": { - "enable": False, - "summary_threshold_cnt": 6, - }, - "active_message": { - "enable": False, - }, - "vision": { - "enable": False, - "provider_id_or_ofa_model_path": "", - }, - "persona": "", - "split_response": True, - "embedding_provider_id": "", - "summarize_provider_id": "", - "chat_provider_id": "", - "chat_base_model_path": "", - "chat_adapter_model_path": "", - "quantization_bit": 4 + "appid": "", + "secret": "", + "enable_group_c2c": True, + "enable_guild_direct_message": True, + }, + "aiocqhtp": { + "id": "default", + "name": "aiocqhttp", + "enable": False, + "ws_reverse_host": "", + "ws_reverse_port": 6199 + }, + "wechat": { + "id": "default", + "name": "vchat", + "enable": False } } @@ -183,7 +166,7 @@ CONFIG_METADATA_2 = { "type": "list", "items": { "id": {"description": "ID", "type": "string", "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。"}, - "name": {"description": "适配器类型", "type": "string", "hint": "当前版本下,内置支持 `qq_official`(QQ 官方机器人), `aiocqhttp`(Onebot 适用) 适配器类型。", "options": ["qq_official", "aiocqhttp", "wechat"], "readonly": True}, + "name": {"description": "适配器类型", "type": "string", "invisible": True}, "enable": {"description": "启用", "type": "bool", "hint": "是否启用该适配器。未启用的适配器对应的消息平台将不会接收到消息。"}, "appid": {"description": "appid", "type": "string", "hint": "必填项。QQ 官方机器人平台的 appid。如何获取请参考文档。"}, "secret": {"description": "secret", "type": "string", "hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。"}, @@ -208,18 +191,20 @@ CONFIG_METADATA_2 = { "items": { "time": {"description": "消息速率限制时间", "type": "int"}, "count": {"description": "消息速率限制计数", "type": "int"}, + "strategy": {"description": "速率限制策略", "type": "string", "options": ["stall", "discard"], "hint": "当消息速率超过限制时的处理策略。stall 为等待,discard 为丢弃。"} } }, "reply_prefix": {"description": "回复前缀", "type": "string", "hint": "机器人回复消息时带有的前缀。"}, "forward_threshold": {"description": "转发消息的字数阈值", "type": "int", "hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。"}, + "id_whitelist": {"description": "ID 白名单", "type": "list", "items": {"type": "int"}, "hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的 ID。"}, } }, - "llm": { + "provider": { "description": "大语言模型配置", "type": "list", "items": { "id": {"description": "ID", "type": "string", "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。"}, - "name": {"description": "模型提供商类型", "type": "string", "hint": "如需变更模型提供商,请点击上面的 + 新建一个。如果没有找到你想要接入的提供商,可以前往你的提供商的官网查看是否兼容 OpenAI API,如兼容,可以选择 `openai`。大多数提供商都是兼容的。", "options": list(PROVIDER_CONFIG_TEMPLATE.keys()), "obvious_hint": True, "readonly": True}, + "type": {"description": "模型提供商类型", "type": "string", "invisible": True}, "enable": {"description": "启用", "type": "bool", "hint": "是否启用该模型。未启用的模型将不会被使用。"}, "key": {"description": "API Key", "type": "list", "items": {"type": "string"}, "hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。"}, "api_base": {"description": "API Base URL", "type": "string", "hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。"}, @@ -256,7 +241,7 @@ CONFIG_METADATA_2 = { } } }, - "llm_settings": { + "provider_settings": { "description": "大语言模型设置", "type": "object", "items": { @@ -292,7 +277,6 @@ CONFIG_METADATA_2 = { "wake_prefix": {"description": "机器人唤醒前缀", "type": "list", "items": {"type": "string"}, "hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。"}, "t2i": {"description": "文本转图像", "type": "bool", "hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。"}, "admins_id": {"description": "管理员 ID", "type": "list", "items": {"type": "int"}, "hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。"}, - "https_proxy": {"description": "HTTPS 代理", "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`"}, "http_proxy": {"description": "HTTP 代理", "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`"}, "dashboard": { "description": "管理面板配置", @@ -318,6 +302,8 @@ CONFIG_METADATA_2 = { "items": { "enable": {"description": "启用", "type": "bool"}, "summary_threshold_cnt": {"description": "摘要阈值", "type": "int", "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。"}, + "embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True}, + "summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True}, } }, "active_message": { @@ -335,10 +321,8 @@ CONFIG_METADATA_2 = { "provider_id_or_ofa_model_path": {"description": "提供商 ID 或 OFA 模型路径", "type": "string", "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。"}, } }, - "split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的事件间隔。默认启用。"}, + "split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。"}, "persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True}, - "embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True}, - "summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True}, "chat_provider_id": {"description": "Chat provider ID", "type": "string", "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True}, "chat_base_model_path": {"description": "用于聊天的基座模型路径", "type": "string", "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。", "obvious_hint": True}, "chat_adapter_model_path": {"description": "用于聊天的 Lora 模型路径", "type": "string", "hint": "Lora 模型路径。", "obvious_hint": True}, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index b8deef970..58adaa2e0 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -1,10 +1,13 @@ -import asyncio, time, threading +import asyncio, time, threading, os from .event_bus import EventBus from asyncio import Queue from typing import List from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.message.message_event_handler import MessageEventHandler -from astrbot.core.plugin import PluginManager +from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext +from astrbot.core.star import PluginManager +from astrbot.core.platform.manager import PlatformManager +from astrbot.core.star.context import Context +from astrbot.core.provider.manager import ProviderManager from astrbot.core import LogBroker from astrbot.core.db import BaseDatabase from astrbot.core.updator import AstrBotUpdator @@ -15,21 +18,46 @@ class AstrBotCoreLifecycle: def __init__(self, log_broker: LogBroker, db: BaseDatabase): self.log_broker = log_broker self.astrbot_config = AstrBotConfig() + self.db = db + + if self.astrbot_config['http_proxy']: + os.environ['https_proxy'] = self.astrbot_config['http_proxy'] + os.environ['http_proxy'] = self.astrbot_config['http_proxy'] + + async def initialize(self): logger.info("AstrBot v"+ VERSION) - logger.setLevel(self.astrbot_config.log_level) + logger.setLevel(self.astrbot_config['log_level']) self.event_queue = Queue() self.event_queue.closed = False - self.plugin_manager = PluginManager(self.astrbot_config, self.event_queue, db) - self.message_event_handler = MessageEventHandler(self.astrbot_config, self.plugin_manager) - self.astrbot_updator = AstrBotUpdator(self.astrbot_config.plugin_repo_mirror) - self.event_bus = EventBus(self.event_queue, self.message_event_handler) - self.stop_flag = False - self.start_time = int(time.time()) + self.provider_manager = ProviderManager(self.astrbot_config, self.db) + + self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) + + self.star_context = Context(self.event_queue, self.astrbot_config, self.db) + self.star_context.platform_manager = self.platform_manager + self.star_context.provider_manager = self.provider_manager + self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) + + self.plugin_manager.reload() + '''扫描、注册插件、实例化插件类''' + + await self.provider_manager.initialize() + '''根据配置实例化各个 Provider''' + + await self.platform_manager.initialize() + '''根据配置实例化各个平台适配器''' + + self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager)) + await self.pipeline_scheduler.initialize() + '''初始化消息事件流水线调度器''' + + self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror']) + self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) + self.start_time = int(time.time()) self.curr_tasks: List[asyncio.Task] = [] def _load(self): - self.plugin_manager.reload() platform_tasks = self.load_platform() event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus") @@ -41,16 +69,26 @@ class AstrBotCoreLifecycle: self._load() await asyncio.gather(*self.curr_tasks, return_exceptions=True) - def stop(self): - self.stop_flag = True - + async def stop(self): + self.event_queue.closed = True + for task in self.curr_tasks: + task.cancel() + + for task in self.curr_tasks: + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"任务 {task.get_name()} 发生错误: {e}") + def restart(self): self.event_queue.closed = True threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start() def load_platform(self) -> List[asyncio.Task]: tasks = [] - platform_insts = self.plugin_manager.get_platform_insts() + platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)) return tasks \ No newline at end of file diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index f53bb3478..d993cd6bc 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,7 +1,7 @@ import abc from dataclasses import dataclass from typing import List -from astrbot.core.db.po import Stats, LLMHistory +from astrbot.core.db.po import Stats, LLMHistory, ATRIVision @dataclass class BaseDatabase(abc.ABC): @@ -39,12 +39,12 @@ class BaseDatabase(abc.ABC): raise NotImplementedError @abc.abstractmethod - def update_llm_history(self, session_id: str, content: str): + def update_llm_history(self, session_id: str, content: str, provider_type: str): '''更新 LLM 历史记录。当不存在 session_id 时插入''' raise NotImplementedError @abc.abstractmethod - def get_llm_history(self, session_id: str = None) -> List[LLMHistory]: + def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]: '''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有''' raise NotImplementedError @@ -62,3 +62,18 @@ class BaseDatabase(abc.ABC): def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: '''获取基础统计数据(合并)''' raise NotImplementedError + + @abc.abstractmethod + def insert_atri_vision_data(self, vision_data: ATRIVision): + '''插入 ATRI 视觉数据''' + raise NotImplementedError + + @abc.abstractmethod + def get_atri_vision_data(self) -> List[ATRIVision]: + '''获取 ATRI 视觉数据''' + raise NotImplementedError + + @abc.abstractmethod + def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision: + '''通过 url 或 path 获取 ATRI 视觉数据''' + raise NotImplementedError \ No newline at end of file diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 96dc3d7e9..0e6b2e92c 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -37,5 +37,18 @@ class Stats(): @dataclass class LLMHistory(): + provider_type: str session_id: str - content: str \ No newline at end of file + content: str + +@dataclass +class ATRIVision(): + id: str + url_or_path: str + caption: str + is_meme: bool + keywords: List[str] + platform_name: str + session_id: str + sender_nickname: str + timestamp: int = -1 \ No newline at end of file diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index ed2f0c81e..2ecee276e 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -6,7 +6,8 @@ from astrbot.core.db.po import ( Command, Provider, Stats, - LLMHistory + LLMHistory, + ATRIVision ) from . import BaseDatabase from typing import Tuple @@ -75,28 +76,39 @@ class SQLiteDatabase(BaseDatabase): ''', (k, v, int(time.time())) ) - def update_llm_history(self, session_id: str, content: str): - res = self.get_llm_history(session_id) + def update_llm_history(self, session_id: str, content: str, provider_type: str): + res = self.get_llm_history(session_id, provider_type) if res: self._exec_sql( ''' - UPDATE llm_history SET content = ? WHERE session_id = ? - ''', (content, session_id) + UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ? + ''', (content, session_id, provider_type) ) else: self._exec_sql( ''' - INSERT INTO llm_history(session_id, content) VALUES (?, ?) - ''', (session_id, content) + INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?) + ''', (provider_type, session_id, content) ) - def get_llm_history(self, session_id: str = None) -> Tuple: + def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: c = self._get_conn(self.db_path).cursor() - - where_clause = "" if session_id is None else f"WHERE session_id = '{session_id}'" + + where_clause = "" + if session_id or provider_type: + where_clause += " WHERE " + has = False + if session_id: + where_clause += f"session_id = '{session_id}'" + has = True + if provider_type: + if has: + where_clause += " AND " + where_clause += f"provider_type = '{provider_type}'" + c.execute( ''' SELECT * FROM llm_history @@ -186,26 +198,53 @@ class SQLiteDatabase(BaseDatabase): for row in c.fetchall(): platform.append(Platform(*row)) - # c.execute( - # ''' - # SELECT name, SUM(count), timestamp FROM command - # ''' + where_clause + " GROUP BY name" - # ) - - # command = [] - # for row in c.fetchall(): - # command.append(Command(*row)) - - # c.execute( - # ''' - # SELECT name, SUM(count), timestamp FROM llm - # ''' + where_clause + " GROUP BY name" - # ) - - # llm = [] - # for row in c.fetchall(): - # llm.append(Provider(*row)) - c.close() - return Stats(platform, [], []) \ No newline at end of file + return Stats(platform, [], []) + + + def insert_atri_vision_data(self, vision: ATRIVision): + ts = int(time.time()) + keywords = ",".join(vision.keywords) + self._exec_sql( + ''' + INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts) + ) + + def get_atri_vision_data(self) -> Tuple: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + ''' + SELECT * FROM atri_vision + ''' + ) + + res = c.fetchall() + visions = [] + for row in res: + visions.append(ATRIVision(*row)) + c.close() + return visions + + def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision: + try: + c = self.conn.cursor() + except sqlite3.ProgrammingError: + c = self._get_conn(self.db_path).cursor() + + c.execute( + ''' + SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ? + ''', (url_or_path, id) + ) + + res = c.fetchone() + c.close() + if res: + return ATRIVision(*res) + return None \ No newline at end of file diff --git a/astrbot/core/db/sqlite_init.sql b/astrbot/core/db/sqlite_init.sql index 924a9d5ac..2cd7e77b3 100644 --- a/astrbot/core/db/sqlite_init.sql +++ b/astrbot/core/db/sqlite_init.sql @@ -19,6 +19,20 @@ CREATE TABLE IF NOT EXISTS command( timestamp INTEGER ); CREATE TABLE IF NOT EXISTS llm_history( + provider_type VARCHAR(32), session_id VARCHAR(32), content TEXT +); + +-- ATRI +CREATE TABLE IF NOT EXISTS atri_vision( + id TEXT, + url_or_path TEXT, + caption TEXT, + is_meme BOOLEAN, + keywords TEXT, + platform_name VARCHAR(32), + session_id VARCHAR(32), + sender_nickname VARCHAR(32), + timestamp INTEGER ); \ No newline at end of file diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 1ab5edb1e..75f937974 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -2,22 +2,22 @@ import asyncio from asyncio import Queue from collections import defaultdict from typing import List -from astrbot.core.message.message_event_handler import MessageEventHandler +from astrbot.core.pipeline.scheduler import PipelineScheduler from astrbot.core import logger from .platform import AstrMessageEvent from astrbot.core.message.components import Image, Plain class EventBus: - def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler): + def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler): self.event_queue = event_queue - self.message_event_handler = message_event_handler + self.pipeline_scheduler = pipeline_scheduler async def dispatch(self): logger.info("事件总线已打开。") while True: event: AstrMessageEvent = await self.event_queue.get() self._print_event(event) - asyncio.create_task(self.message_event_handler.handle(event)) + asyncio.create_task(self.pipeline_scheduler.execute(event)) def _print_event(self, event: AstrMessageEvent): if event.get_sender_name(): diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index b34e14c1a..717a0f81b 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -271,6 +271,7 @@ class Image(BaseMessageComponent): c: T.Optional[int] = 2 # 额外 path: T.Optional[str] = "" + file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: T.Optional[str], **_): # for k in _.keys(): diff --git a/astrbot/core/message/message_event_handler.py b/astrbot/core/message/message_event_handler.py deleted file mode 100644 index 45e2ecb1f..000000000 --- a/astrbot/core/message/message_event_handler.py +++ /dev/null @@ -1,177 +0,0 @@ -import asyncio, re, time -import inspect -import traceback -from typing import List, Union -from astrbot.core.platform import AstrMessageEvent -from astrbot.core.config.astrbot_config import AstrBotConfig -from .message_event_result import MessageEventResult, CommandResult, MessageChain -from astrbot.core.plugin import PluginManager, Context, CommandMetadata -from .components import * -from astrbot.core import logger -from astrbot.core import html_renderer - -class CommandTokens(): - def __init__(self) -> None: - self.tokens = [] - self.len = 0 - - def get(self, idx: int): - if idx >= self.len: - return None - return self.tokens[idx].strip() - -class CommandParser(): - def __init__(self): - pass - - def parse(self, message: str): - cmd_tokens = CommandTokens() - cmd_tokens.tokens = message.split(" ") - cmd_tokens.len = len(cmd_tokens.tokens) - return cmd_tokens - - def regex_match(self, message: str, command: str) -> bool: - return re.search(command, message, re.MULTILINE) is not None - - -class MessageEventHandler(): - ''' - 处理消息事件。 - ''' - def __init__(self, config: AstrBotConfig, plugin_manager: PluginManager): - self.config = config - self.plugin_manager = plugin_manager - self.command_parser = CommandParser() - - async def handle(self, event: AstrMessageEvent): - ''' - 处理消息事件。 - ''' - event.message_str = event.message_str.strip() - for admin_id in self.config.admins_id: - if event.get_sender_id() == admin_id: - event.role = "admin" - break - - # 检查 wake - wake_prefixes = self.config.wake_prefix - messages = event.get_messages() - is_wake = False - for wake_prefix in wake_prefixes: - if event.message_str.startswith(wake_prefix): - is_wake = True - break - if not is_wake: - # 检查是否有 at 消息 - for message in messages: - if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"): - is_wake = True - wake_prefix = "" - break - # 检查是否是私聊 - if event.is_private_chat(): - is_wake = True - wake_prefix = "" - event.is_wake = is_wake - - # 处理事件监听器(在指令扫描之前) - listeners = self.plugin_manager.context.registered_listeners - listeners_handler = self.plugin_manager.context.listeners_handler - for name in listeners: - if listeners_handler[name].after_commands: - continue - ret = await listeners_handler[name].handler(event) - if ret: - event.set_result(ret) - if event.get_result(): - return await self.post_handle(event) - - # 处理指令,指令带有指定过的前缀 - commands = self.plugin_manager.context.registered_commands - commands_handler = self.plugin_manager.context.commands_handler - - # 扫描指令 - for command in commands: - command = command[1] - trig = False - pre_ = "" - if not commands_handler[command].ignore_prefix: - pre_ = wake_prefix - - if commands_handler[command].use_regex: - trig = self.command_parser.regex_match(event.message_str, pre_ + command) - else: - trig = event.message_str.startswith(pre_ + command) - if trig: - ret = await self.execute_handler(command, commands_handler[command], event) - if ret: - event.set_result(ret) - if event.get_result(): - return await self.post_handle(event) - - # 处理事件监听器(在指令扫描之后) - for name in listeners: - if not listeners_handler[name].after_commands: - continue - ret = await listeners_handler[name].handler(event) - if ret: - event.set_result(ret) - if event.get_result(): - return await self.post_handle(event) - - async def post_handle(self, event: AstrMessageEvent): - result = event.get_result() - if result.callback: - await result.callback(event) - - # prefix - if self.config.platform_settings.reply_prefix: - result.chain.insert(0, Plain(self.config.platform_settings.reply_prefix)) - - # t2i - if (result.use_t2i_ is None and self.config.t2i) or result.use_t2i_: - plain_str = "" - for comp in result.chain: - if isinstance(comp, Plain): - plain_str += "\n\n" + comp.text - else: - break - if plain_str and len(plain_str) > 150: - render_start = time.time() - url = await html_renderer.render_t2i(plain_str, return_url=True) - if time.time() - render_start > 3: - logger.warning(f"图片转文本耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。") - if url: - result.chain = [Image.fromURL(url)] - - logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}") - - await event.send(result) - - async def execute_handler(self, - command: str, - command_metadata: CommandMetadata, - message_event: AstrMessageEvent): - logger.info(f"触发 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令。") - handler = command_metadata.handler - try: - if inspect.iscoroutinefunction(handler): - command_result = await handler(message_event) - else: - command_result = handler(message_event) - - if command_result is not None: - message_event.set_result(command_result) - except TypeError as e: - # 兼容旧版本插件 - if inspect.iscoroutinefunction(handler): - command_result = await handler(message_event, self.plugin_manager.context) - else: - command_result = handler(message_event, self.plugin_manager.context) - - if command_result is not None: - message_event.set_result(command_result) - except BaseException as e: - logger.error(traceback.format_exc()) - text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}" - message_event.set_result(MessageEventResult().message(text)) \ No newline at end of file diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index a7629cbe0..1583e1850 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,43 +1,68 @@ -from typing import List, Union, Optional +import enum, logging + +from typing import List, Optional from dataclasses import dataclass, field from astrbot.core.message.components import * +from typing_extensions import deprecated @dataclass class MessageChain(): + '''MessageChain 描述了一整条消息中带有的所有组件。 + 现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。 + + Attributes: + `chain` (list): 用于顺序存储各个组件。 + `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 + `is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 + ''' + chain: List[BaseMessageComponent] = field(default_factory=list) use_t2i_: Optional[bool] = None # None 为跟随用户设置 is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。 def message(self, message: str): - ''' - 快捷回复消息。 + '''添加一条文本消息到消息链 `chain` 中。 - CommandResult().message("Hello, world!") + Example: + + CommandResult().message("Hello ").message("world!") + # 输出 Hello world! + ''' self.chain.append(Plain(message)) return self + @deprecated("请使用 message 方法代替。") def error(self, message: str): - ''' - 快捷回复消息。 + '''添加一条错误消息到消息链 `chain` 中 - CommandResult().error("Hello, world!") + Example: + + CommandResult().error("解析失败") + ''' self.chain.append(Plain(message)) return self def url_image(self, url: str): - ''' - 快捷回复图片(网络url的格式)。 + '''添加一条图片消息(https 链接)到消息链 `chain` 中。 - CommandResult().image("https://example.com/image.jpg") + Note: + 如果需要发送本地图片,请使用 `file_image` 方法。 + + Example: + + CommandResult().image("https://example.com/image.jpg") + ''' self.chain.append(Image.fromURL(url)) return self def file_image(self, path: str): - ''' - 快捷回复图片(本地文件路径的格式)。 + '''添加一条图片消息(本地文件路径)到消息链 `chain` 中。 + + Note: + 如果需要发送网络图片,请使用 `url_image` 方法。 CommandResult().image("image.jpg") ''' @@ -45,24 +70,65 @@ class MessageChain(): return self def use_t2i(self, use_t2i: bool): - ''' - 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 + '''设置是否使用文本转图片服务。 + + Args: + use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 ''' self.use_t2i_ = use_t2i return self def is_split(self, is_split: bool): - ''' - 设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 + '''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 - 具体的效果以各适配器实现为准。 + Note: + 具体的效果以各适配器实现为准。 + ''' self.is_split_ = is_split return self +class EventResultType(enum.Enum): + '''用于描述事件处理的结果类型。 + + Attributes: + CONTINUE: 事件将会继续传播 + STOP: 事件将会终止传播 + ''' + CONTINUE = enum.auto() + STOP = enum.auto() + @dataclass class MessageEventResult(MessageChain): - is_command_call: Optional[bool] = False - callback: Optional[callable] = None + '''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。 + 现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。 + + Attributes: + `chain` (list): 用于顺序存储各个组件。 + `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 + `is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。 + `result_type` (EventResultType): 事件处理的结果类型。 + ''' + + result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE) + + def stop_event(self) -> 'MessageEventResult': + '''终止事件传播。 + ''' + self.result_type = EventResultType.STOP + return self + + def continue_event(self) -> 'MessageEventResult': + '''继续事件传播。 + ''' + self.result_type = EventResultType.CONTINUE + return self + + def is_stopped(self) -> bool: + ''' + 是否终止事件传播。 + ''' + return self.result_type == EventResultType.STOP + CommandResult = MessageEventResult \ No newline at end of file diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py new file mode 100644 index 000000000..39251e8b2 --- /dev/null +++ b/astrbot/core/pipeline/__init__.py @@ -0,0 +1,18 @@ +from astrbot.core.message.message_event_result import MessageEventResult, EventResultType + +STAGES_ORDER = [ + "WakingCheckStage", # 检查是否需要唤醒 + "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 + "RateLimitCheckStage", # 检查会话是否超过频率限制 + "ContentSafetyCheckStage", # 检查内容安全 + "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 + "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 + "RespondStage" # 发送消息 +] + +from .waking_check.stage import WakingCheckStage +from .whitelist_check.stage import WhitelistCheckStage +from .content_safety_check.stage import ContentSafetyCheckStage +from .process_stage.stage import ProcessStage +from .result_decorate.stage import ResultDecorateStage +from .respond.stage import RespondStage \ No newline at end of file diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py new file mode 100644 index 000000000..16cc77337 --- /dev/null +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -0,0 +1,31 @@ +import asyncio +from datetime import datetime, timedelta +from collections import defaultdict, deque +from typing import DefaultDict, Deque, List, Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core import logger +from .strategies.strategy import StrategySelector + +@register_stage +class ContentSafetyCheckStage(Stage): + '''检查内容安全 + + 当前只会检查文本的。 + ''' + + async def initialize(self, ctx: PipelineContext): + config = ctx.astrbot_config['content_safety'] + self.strategy_selector = StrategySelector(config) + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + '''检查内容安全''' + ok, info = self.strategy_selector.check(event.get_message_str()) + if not ok: + event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。")) + event.stop_event() + logger.info(f"内容安全检查不通过,原因:{info}") + return + event.continue_event() diff --git a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py new file mode 100644 index 000000000..5962f27d8 --- /dev/null +++ b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py @@ -0,0 +1,8 @@ +import abc +from typing import Tuple + +class ContentSafetyStrategy(abc.ABC): + + @abc.abstractmethod + def check(self, content: str) -> Tuple[bool, str]: + raise NotImplementedError \ No newline at end of file diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py new file mode 100644 index 000000000..adf6a039e --- /dev/null +++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py @@ -0,0 +1,31 @@ +''' +使用此功能应该先 pip install baidu-aip +''' +from . import ContentSafetyStrategy +from aip import AipContentCensor +from astrbot.core import logger + +class BaiduAipStrategy(ContentSafetyStrategy): + def __init__(self, appid: str, ak: str, sk: str) -> None: + self.app_id = appid + self.api_key = ak + self.secret_key = sk + self.client = AipContentCensor(self.app_id, + self.api_key, + self.secret_key) + + def check(self, content: str): + res = self.client.textCensorUserDefined(content) + if 'conclusionType' not in res: + return False, "" + if res['conclusionType'] == 1: + return True, "" + else: + if 'data' not in res: + return False, "" + count = len(res['data']) + info = f"百度审核服务发现 {count} 处违规:\n" + for i in res['data']: + info += f"{i['msg']};\n" + info += "\n判断结果:"+res['conclusion'] + return False, info \ No newline at end of file diff --git a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py new file mode 100644 index 000000000..433d93de3 --- /dev/null +++ b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py @@ -0,0 +1,21 @@ +import re, os, json, base64 +from . import ContentSafetyStrategy +from astrbot.core import logger + +class KeywordsStrategy(ContentSafetyStrategy): + def __init__(self, extra_keywords: list) -> None: + self.keywords = [] + if extra_keywords is None: + extra_keywords = [] + self.keywords.extend(extra_keywords) + keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words') + # internal keywords + if os.path.exists(keywords_path): + with open(keywords_path, "r", encoding="utf-8") as f: + self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords']) + + def check(self, content: str) -> bool: + for keyword in self.keywords: + if re.search(keyword, content): + return False, f"内容安全检查不通过,匹配到敏感词。" + return True, "" \ No newline at end of file diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py new file mode 100644 index 000000000..21bb58535 --- /dev/null +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -0,0 +1,27 @@ +from . import ContentSafetyStrategy +from typing import List, Tuple + + +class StrategySelector(): + def __init__(self, config: dict) -> None: + self.enabled_strategies: List[ContentSafetyStrategy] = [] + if config['internal_keywords']['enable']: + from .keywords import KeywordsStrategy + self.enabled_strategies.append(KeywordsStrategy( + config['internal_keywords']['extra_keywords'])) + if config['baidu_aip']['enable']: + try: + from .baidu_aip import BaiduAipStrategy + except ImportError: + raise ImportError("使用百度内容审核应该先 pip install baidu-aip") + self.enabled_strategies.append(BaiduAipStrategy(config['baidu_aip']['app_id'], + config['baidu_aip']['api_key'], + config['baidu_aip']['secret_key'] + )) + + def check(self, content: str) -> Tuple[bool, str]: + for strategy in self.enabled_strategies: + ok, info = strategy.check(content) + if not ok: + return False, info + return True, "" diff --git a/packages/astrbot/unfit_words b/astrbot/core/pipeline/content_safety_check/strategies/unfit_words similarity index 100% rename from packages/astrbot/unfit_words rename to astrbot/core/pipeline/content_safety_check/strategies/unfit_words diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py new file mode 100644 index 000000000..a6b41f8bf --- /dev/null +++ b/astrbot/core/pipeline/context.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.star import PluginManager + +@dataclass +class PipelineContext: + astrbot_config: AstrBotConfig + plugin_manager: PluginManager \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py new file mode 100644 index 000000000..e3f553c6d --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -0,0 +1,71 @@ +import asyncio, traceback, json +from typing import DefaultDict, Deque, List, Union, AsyncGenerator +from ...context import PipelineContext +from ..stage import Stage +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult +from astrbot.core.message.components import Image +from astrbot.core import logger +from astrbot.core.utils.metrics import Metric +from astrbot.core.provider.llm_response import LLMResponse + + +class LLMRequestSubStage(Stage): + + async def initialize(self, ctx: PipelineContext) -> None: + self.curr_provider = ctx.plugin_manager.context.get_using_provider() + self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix'] + self.identifier = ctx.astrbot_config['provider_settings']['identifier'] + self.ctx = ctx + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + if self.prompt_prefix: + event.message_str = self.prompt_prefix + event.message_str + if self.identifier: + user_id = event.message_obj.sender.user_id + user_nickname = event.message_obj.sender.nickname + user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n" + event.message_str = user_info + event.message_str + + image_urls = [] + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_url = comp.url if comp.url else comp.file + image_urls.append(image_url) + + tools = self.ctx.plugin_manager.context.get_llm_tools() + + try: + llm_response = await self.curr_provider.text_chat( + prompt=event.message_str, + session_id=event.session_id, + image_urls=image_urls, + tools=tools + ) + await Metric.upload(llm_tick=1, model_name=self.curr_provider.get_model(), provider_type=self.curr_provider.meta().type) + + if llm_response.role == 'assistant': + # text completion + event.set_result(MessageEventResult().message(llm_response.completion_text)) + elif llm_response.role == 'tool': + # function calling + for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args): + func_tool = tools.get_func(func_tool_name) + logger.debug(f"调用工具函数:{func_tool_name},参数:{func_tool_args}") + try: + ret = await func_tool(event=event, *func_tool_args) + + if ret: + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。" + event.stop_event() + event.set_result(ret) + # 执行后续步骤来发送消息 + yield + + except BaseException as e: + logger.error(traceback.format_exc()) + + except BaseException as e: + logger.error(traceback.format_exc()) + event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) + return \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py new file mode 100644 index 000000000..ca06e8f0e --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -0,0 +1,48 @@ +from ...context import PipelineContext +from ..stage import Stage +from typing import Dict, Any, List, AsyncGenerator, Union +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult, EventResultType +from astrbot.core import logger +from astrbot.core.star.star_handler import StarHandlerMetadata +from astrbot.core.star.star import star_map +class StarRequestSubStage(Stage): + + async def initialize(self, ctx: PipelineContext) -> None: + self.curr_provider = ctx.plugin_manager.context.get_using_provider() + self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix'] + self.identifier = ctx.astrbot_config['provider_settings']['identifier'] + self.ctx = ctx + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") + handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params") + if not handlers_parsed_params: + handlers_parsed_params = {} + for handler in activated_handlers: + params = handlers_parsed_params.get(handler.handler_full_name, {}) + try: + if handler.handler_module_str not in star_map: + # 孤立无援的 star handler + continue + star_cls_obj = star_map.get(handler.handler_module_str).star_cls + + # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) + if hasattr(handler.handler, '__self__'): + # 猜测没有通过装饰器去注册 + try: + ret = await handler.handler(event, **params) + except TypeError: + # 向下兼容 + ret = await handler.handler(event, self.ctx.plugin_manager.context, **params) + else: + ret = await handler.handler(star_cls_obj, event, **params) + if ret: + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。" + event.stop_event() + event.set_result(ret) + # 执行后续步骤来发送消息 + yield + event.clear_result() # 清除上一个 handler 的结果 + except Exception as e: + logger.error(f"Star {handler.handler_full_name} handle error: {e}") \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py new file mode 100644 index 000000000..6f1ff7b70 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -0,0 +1,36 @@ +from typing import List, Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from .method.llm_request import LLMRequestSubStage +from .method.star_request import StarRequestSubStage +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult, EventResultType +from astrbot.core import logger +from astrbot.core.star.star_handler import StarHandlerMetadata +from astrbot.core.message.components import * +from astrbot.core import html_renderer + +@register_stage +class ProcessStage(Stage): + + async def initialize(self, ctx: PipelineContext) -> None: + self.config = ctx.astrbot_config + self.plugin_manager = ctx.plugin_manager + self.llm_request_sub_stage = LLMRequestSubStage() + await self.llm_request_sub_stage.initialize(ctx) + + self.star_request_sub_stage = StarRequestSubStage() + await self.star_request_sub_stage.initialize(ctx) + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + '''处理事件 + ''' + activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") + + if not activated_handlers: + async for _ in self.llm_request_sub_stage.process(event): + yield + else: + async for _ in self.star_request_sub_stage.process(event): + yield + \ No newline at end of file diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py new file mode 100644 index 000000000..033baf357 --- /dev/null +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -0,0 +1,87 @@ +import asyncio +from datetime import datetime, timedelta +from collections import defaultdict, deque +from typing import DefaultDict, Deque, List, Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core import logger +from astrbot.core.config.astrbot_config import RateLimitStrategy + + +@register_stage +class RateLimitStage(Stage): + """ + 检查是否需要限制消息发送的限流器。 + + 使用 Fixed Window 算法。 + 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 + """ + + def __init__(self): + # 存储每个会话的请求时间队列 + self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque) + # 为每个会话设置一个锁,避免并发冲突 + self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + # 限流参数 + self.rate_limit_count: int = 0 + self.rate_limit_time: timedelta = timedelta(0) + + async def initialize(self, ctx: PipelineContext) -> None: + """ + 初始化限流器,根据配置设置限流参数。 + """ + self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count'] + self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time']) + self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + """ + 检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 + + Args: + event (AstrMessageEvent): 当前消息事件。 + ctx (PipelineContext): 流水线上下文。 + + Returns: + MessageEventResult: 继续或停止事件处理的结果。 + """ + session_id = event.session_id + now = datetime.now() + + async with self.locks[session_id]: # 确保同一会话不会并发修改队列 + timestamps = self.event_timestamps[session_id] + + self._remove_expired_timestamps(timestamps, now) + + if len(timestamps) >= self.rate_limit_count: + # 达到限流阈值,计算下一个窗口的时间 + next_window_time = timestamps[0] + self.rate_limit_time + stall_duration = (next_window_time - now).total_seconds() + + match self.rl_strategy: + case RateLimitStrategy.STALL: + logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。") + await asyncio.sleep(stall_duration) + case RateLimitStrategy.DISCARD: + event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。")) + return event.stop_event() + + self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration)) + + timestamps.append(now) + + return event.continue_event() + + def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None: + """ + 移除时间窗口外的时间戳。 + + Args: + timestamps (Deque[datetime]): 当前会话的时间戳队列。 + now (datetime): 当前时间,用于计算过期时间。 + """ + expiry_threshold: datetime = now - self.rate_limit_time + while timestamps and timestamps[0] < expiry_threshold: + timestamps.popleft() diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py new file mode 100644 index 000000000..0fbc07df0 --- /dev/null +++ b/astrbot/core/pipeline/respond/stage.py @@ -0,0 +1,25 @@ +import asyncio +from datetime import datetime, timedelta +from collections import defaultdict, deque +from typing import DefaultDict, Deque, List, Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core import logger +from astrbot.core.config.astrbot_config import RateLimitStrategy + +@register_stage +class RespondStage: + async def initialize(self, ctx: PipelineContext): + self.ctx = ctx + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + result = event.get_result() + if result is None: + return + + if len(result.chain) > 0: + await event.send(result) + logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}") + \ No newline at end of file diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py new file mode 100644 index 000000000..a34bf7ac8 --- /dev/null +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -0,0 +1,45 @@ +import asyncio, time +from datetime import datetime, timedelta +from collections import defaultdict, deque +from typing import DefaultDict, Deque, List, Union, AsyncGenerator +from ..stage import Stage, register_stage +from ..context import PipelineContext +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core import logger +from astrbot.core.config.astrbot_config import RateLimitStrategy +from astrbot.core.message.components import Plain, Image +from astrbot.core import html_renderer + +@register_stage +class ResultDecorateStage: + async def initialize(self, ctx: PipelineContext): + self.ctx = ctx + self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix'] + self.t2i = ctx.astrbot_config['t2i'] + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + result = event.get_result() + if result is None: + return + + if len(result.chain) > 0: + # 回复前缀 + if self.reply_prefix: + result.chain.insert(0, Plain(self.reply_prefix)) + + # 文本转图片 + if (result.use_t2i_ is None and self.t2i) or result.use_t2i_: + plain_str = "" + for comp in result.chain: + if isinstance(comp, Plain): + plain_str += "\n\n" + comp.text + else: + break + if plain_str and len(plain_str) > 150: + render_start = time.time() + url = await html_renderer.render_t2i(plain_str, return_url=True) + if time.time() - render_start > 3: + logger.warning(f"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。") + if url: + result.chain = [Image.fromURL(url)] \ No newline at end of file diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py new file mode 100644 index 000000000..8590db25e --- /dev/null +++ b/astrbot/core/pipeline/scheduler.py @@ -0,0 +1,44 @@ +from . import STAGES_ORDER +from .stage import registered_stages, Stage +from .context import PipelineContext +from typing import AsyncGenerator +from astrbot.core.platform import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult, EventResultType +from astrbot.core import logger + +class PipelineScheduler(): + def __init__(self, context: PipelineContext): + registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__ .__name__)) + self.ctx = context + + async def initialize(self): + for stage in registered_stages: + logger.debug(f"初始化阶段 {stage.__class__ .__name__}") + + await stage.initialize(self.ctx) + + async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + for i in range(from_stage, len(registered_stages)): + stage = registered_stages[i] + logger.debug(f"执行阶段 {stage.__class__ .__name__}") + coro = stage.process(event) + if isinstance(coro, AsyncGenerator): + async for _ in coro: + if event.is_stopped(): + logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + break + await self._process_stages(event, i + 1) + else: + await coro + + if event.is_stopped(): + logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + break + + if event.is_stopped(): + logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。") + break + + async def execute(self, event: AstrMessageEvent): + '''执行 pipeline''' + await self._process_stages(event) \ No newline at end of file diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py new file mode 100644 index 000000000..4459ddc44 --- /dev/null +++ b/astrbot/core/pipeline/stage.py @@ -0,0 +1,32 @@ +from __future__ import annotations +import abc +from typing import List, Dict, AsyncGenerator, Union +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from .context import PipelineContext + +registered_stages: List[Stage] = [] +'''维护了所有已注册的 Stage 实现类''' + +def register_stage(cls): + '''一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类 + ''' + registered_stages.append(cls()) + return cls + +class Stage(abc.ABC): + '''描述一个 Pipeline 的某个阶段 + ''' + + @abc.abstractmethod + async def initialize(self, ctx: PipelineContext) -> None: + '''初始化阶段 + ''' + raise NotImplementedError + + @abc.abstractmethod + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + '''处理事件 + ''' + raise NotImplementedError + + \ No newline at end of file diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py new file mode 100644 index 000000000..53c212761 --- /dev/null +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -0,0 +1,96 @@ +from ..stage import Stage, register_stage +from ..context import PipelineContext +from typing import Union, AsyncGenerator +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult, EventResultType +from astrbot.core.message.components import At, Plain +from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.star.filter.command_group import CommandGroupFilter + +@register_stage +class WakingCheckStage(Stage): + '''检查是否需要唤醒。唤醒机器人有如下几点条件: + + 1. 机器人被 @ 了 + 2. 机器人的消息被提到了 + 3. 以 wake_prefix 前缀开头 + 4. 插件(Star)的 handler filter 通过 + ''' + + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + # 设置 sender 身份 + event.message_str = event.message_str.strip() + for admin_id in self.ctx.astrbot_config['admins_id']: + if event.get_sender_id() == admin_id: + event.role = "admin" + break + + # 检查 wake + wake_prefixes = self.ctx.astrbot_config['wake_prefix'] + messages = event.get_messages() + is_wake = False + for wake_prefix in wake_prefixes: + if event.message_str.startswith(wake_prefix): + is_wake = True + event.is_wake = True + event.message_str = event.message_str[len(wake_prefix):].strip() + break + if not is_wake: + # 检查是否有 at 消息 + for message in messages: + if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"): + is_wake = True + event.is_wake = True + wake_prefix = "" + break + # 检查是否是私聊 + if event.is_private_chat(): + is_wake = True + event.is_wake = True + wake_prefix = "" + + # 检查插件的 handler filter + activated_handlers = [] + handlers_parsed_params = {} # 注册了指令的 handler + for handler in star_handlers_registry: + # filter 需要满足 AND 的逻辑关系 + passed = False + child_command_handler_md = None + for filter in handler.event_filters: + try: + if isinstance(filter, CommandGroupFilter): + '''如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata''' + ok, child_command_handler_md = filter.filter(event, self.ctx.astrbot_config) + if ok: + passed = True + handler = child_command_handler_md # handler 覆盖 + break + else: + if filter.filter(event, self.ctx.astrbot_config): + passed = True + break + except Exception as e: + # event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}")) + # yield + await event.send(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}")) + event.stop_event() + passed = False + break + + if passed: + is_wake = True + event.is_wake = True + + activated_handlers.append(handler) + if 'parsed_params' in event.get_extra(): + handlers_parsed_params[handler.handler_full_name] = event.get_extra('parsed_params') + event.clear_extra() + + event.set_extra('activated_handlers', activated_handlers) + event.set_extra('handlers_parsed_params', handlers_parsed_params) + + if not is_wake: + event.stop_event() \ No newline at end of file diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py new file mode 100644 index 000000000..a7d00f194 --- /dev/null +++ b/astrbot/core/pipeline/whitelist_check/stage.py @@ -0,0 +1,19 @@ +from ..stage import Stage, register_stage +from ..context import PipelineContext +from typing import List, Dict, AsyncGenerator, Union +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.message_type import MessageType +from astrbot.core import logger + +@register_stage +class WhitelistCheckStage(Stage): + '''检查是否在群聊/私聊白名单 + ''' + async def initialize(self, ctx: PipelineContext) -> None: + self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist'] + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + # 检查是否在白名单 + if event.unified_msg_origin not in self.whitelist: + logger.info(f"会话 {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。") + event.stop_event() \ No newline at end of file diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 8b319462a..9bc65531f 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,11 +1,11 @@ -import abc +import abc, logging from dataclasses import dataclass from .astrbot_message import AstrBotMessage from .platform_metadata import PlatformMetadata -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain +from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, EventResultType from astrbot.core.platform.message_type import MessageType from typing import List -from astrbot.core.message.components import BaseMessageComponent, Plain, Image +from astrbot.core.message.components import * from astrbot.core.utils.metrics import Metric @dataclass @@ -34,8 +34,6 @@ class AstrMessageEvent(abc.ABC): self.session_id = session_id self.role = "member" self.is_wake = False - - self._result: MessageEventResult = None self._extras = {} self.session = MessageSesion( platform_name=platform_meta.name, @@ -43,6 +41,9 @@ class AstrMessageEvent(abc.ABC): session_id=session_id ) self.unified_msg_origin = str(self.session) + + self._result: MessageEventResult = None + '''消息事件的结果''' def get_platform_name(self): return self.platform_meta.name @@ -58,8 +59,19 @@ class AstrMessageEvent(abc.ABC): for i in chain: if isinstance(i, Plain): outline += i.text - if isinstance(i, Image): + elif isinstance(i, Image): outline += "[图片]" + elif isinstance(i, Face): + outline += f"[表情:{i.id}]" + elif isinstance(i, At): + outline += f"[At:{i.qq}]" + elif isinstance(i, AtAll): + outline += "[At:全体成员]" + elif isinstance(i, Forward): + # 转发消息 + outline += f"[转发消息]" + else: + outline += f"[{i.type}]" return outline def get_message_outline(self) -> str: @@ -76,12 +88,24 @@ class AstrMessageEvent(abc.ABC): ''' return self.message_obj.message + def get_message_type(self) -> MessageType: + ''' + 获取消息类型。 + ''' + return self.message_obj.type + def get_session_id(self) -> str: ''' 获取会话id。 ''' return self.session_id + def get_group_id(self) -> str: + ''' + 获取群组id。如果不是群组消息,返回空字符串。 + ''' + return self.message_obj.group_id + def get_self_id(self) -> str: ''' 获取机器人自身的id。 @@ -101,16 +125,62 @@ class AstrMessageEvent(abc.ABC): return self.message_obj.sender.nickname def set_result(self, result: MessageEventResult): - ''' - 设置消息事件的结果。当设置了结果后,消息事件将不再继续传递。 + '''设置消息事件的结果。 + + Note: + 事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。 + + 如果没有设置 `MessageEventResult` 中的 result_type,默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。 + + Example: + + async def ban_handler(self, event: AstrMessageEvent): + if event.get_sender_id() in self.blacklist: + event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP) + return + + async def check_count(self, event: AstrMessageEvent): + self.count += 1 + event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE)) + return ''' self._result = result + def stop_event(self): + '''终止事件传播。 + ''' + if self._result is None: + self.set_result(MessageEventResult().stop_event()) + else: + self._result.stop_event() + + def continue_event(self): + '''继续事件传播。 + ''' + if self._result is None: + self.set_result(MessageEventResult().continue_event()) + else: + self._result.continue_event() + + def is_stopped(self) -> bool: + ''' + 是否终止事件传播。 + ''' + if self._result is None: + return False # 默认是继续传播 + return self._result.is_stopped() + def get_result(self) -> MessageEventResult: ''' 获取消息事件的结果。 ''' return self._result + + def clear_result(self): + ''' + 清除消息事件的结果。 + ''' + self._result = None def set_extra(self, key, value): ''' @@ -118,6 +188,20 @@ class AstrMessageEvent(abc.ABC): ''' self._extras[key] = value + def get_extra(self, key = None): + ''' + 获取额外的信息。 + ''' + if key is None: + return self._extras + return self._extras.get(key, None) + + def clear_extra(self): + ''' + 清除额外的信息。 + ''' + self._extras.clear() + def is_private_chat(self) -> bool: ''' 是否是私聊。 diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 7c8c9d5d0..1ca4f3109 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -15,8 +15,9 @@ class AstrBotMessage: ''' type: MessageType # 消息类型 self_id: str # 机器人的识别id - session_id: str # 会话id + session_id: str # 会话id。取决于 unique_session 的设置。 message_id: str # 消息id + group_id: str = "" # 群组id,如果为私聊,则为空 sender: MessageMember # 发送者 message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py new file mode 100644 index 000000000..3e155d108 --- /dev/null +++ b/astrbot/core/platform/manager.py @@ -0,0 +1,42 @@ +from astrbot.core.config.astrbot_config import AstrBotConfig +from .platform import PlatformMetadata, Platform +from typing import List +from asyncio import Queue +from .register import platform_registry, platform_cls_map +from astrbot.core import logger + + +class PlatformManager(): + def __init__(self, config: AstrBotConfig, event_queue: Queue): + self.platform_insts: List[Platform] = [] + '''加载的 Platform 的实例''' + + self.platforms_config = config['platform'] + self.settings = config['platform_settings'] + self.event_queue = event_queue + + for platform in self.platforms_config: + if not platform['enable']: + continue + match platform['name']: + case "aiocqhttp": + from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter + case "qqofficial": + from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialAdapter + case "vchat": + from .sources.vchat.vchat_platform_adapter import VChatAdapter + + async def initialize(self): + for platform in self.platforms_config: + if not platform['enable']: + continue + if platform['name'] not in platform_cls_map: + logger.error(f"未找到适用于 {platform['name']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。") + continue + cls_type = platform_cls_map[platform['name']] + logger.info(f"尝试实例化 {platform['name']}({platform['id']}) 平台适配器 ...") + inst = cls_type(platform, self.settings, self.event_queue) + self.platform_insts.append(inst) + + def get_insts(self): + return self.platform_insts \ No newline at end of file diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 9edff89c9..07f66f794 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass - +from typing import Type @dataclass class PlatformMetadata(): name: str # 平台的名称 diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py new file mode 100644 index 000000000..2db6026cd --- /dev/null +++ b/astrbot/core/platform/register.py @@ -0,0 +1,25 @@ +from typing import List, Dict, Type +from .platform_metadata import PlatformMetadata +from astrbot.core import logger + +platform_registry: List[PlatformMetadata] = [] +'''维护了通过装饰器注册的平台适配器''' +platform_cls_map: Dict[str, Type] = {} +'''维护了平台适配器名称和适配器类的映射''' + +def register_platform_adapter(adapter_name: str, desc: str): + '''用于注册平台适配器的带参装饰器''' + def decorator(cls): + if adapter_name in platform_cls_map: + raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。") + + pm = PlatformMetadata( + name=adapter_name, + description=desc, + ) + platform_registry.append(pm) + platform_cls_map[adapter_name] = cls + logger.debug(f"平台适配器 {adapter_name} 已注册") + return cls + + return decorator diff --git a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py similarity index 96% rename from packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py rename to astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index dadcc51a7..682d80316 100644 --- a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,6 +1,6 @@ import os, traceback, random, asyncio -from astrbot.api import AstrMessageEvent, MessageChain, logger +from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image from aiocqhttp import CQHttp from astrbot.core.utils.io import file_to_base64, download_image_by_url diff --git a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py similarity index 71% rename from packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py rename to astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 95b7eec71..95054dd95 100644 --- a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -4,24 +4,26 @@ import traceback import logging from typing import Awaitable, Any from aiocqhttp import CQHttp, Event -from astrbot.api import Platform -from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.event import MessageChain, MessageEventResult from .aiocqhttp_message_event import * from astrbot.api.message_components import * from astrbot.api import logger from .aiocqhttp_message_event import AiocqhttpMessageEvent -from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings +from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.platform.astr_message_event import MessageSesion +from ...register import register_platform_adapter +@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。") class AiocqhttpAdapter(Platform): - def __init__(self, platform_config: AiocqhttpPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: + def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: super().__init__(event_queue) self.config = platform_config self.settings = platform_settings - self.unique_session = platform_settings.unique_session - self.host = platform_config.ws_reverse_host - self.port = platform_config.ws_reverse_port + self.unique_session = platform_settings['unique_session'] + self.host = platform_config['ws_reverse_host'] + self.port = platform_config['ws_reverse_port'] self.metadata = PlatformMetadata( "aiocqhttp", @@ -51,6 +53,7 @@ class AiocqhttpAdapter(Platform): if event['message_type'] == 'group': abm.type = MessageType.GROUP_MESSAGE + abm.group_id = str(event.group_id) elif event['message_type'] == 'private': abm.type = MessageType.FRIEND_MESSAGE @@ -71,35 +74,18 @@ class AiocqhttpAdapter(Platform): except BaseException as e: logger.error(f"回复消息失败: {e}") return + logger.debug(f"aiocqhttp: 收到消息: {event.message}") for m in event.message: t = m['type'] a = None - if t == 'at': - a = At(**m['data']) - abm.message.append(a) if t == 'text': - a = Plain(text=m['data']['text']) message_str += m['data']['text'].strip() - abm.message.append(a) - if t == 'image': - file = m['data']['file'] if 'file' in m['data'] else None - url = m['data']['url'] if 'url' in m['data'] else None - a = Image(file=file, url=url) - abm.message.append(a) + a = ComponentTypes[t](**m['data']) + abm.message.append(a) abm.timestamp = int(time.time()) abm.message_str = message_str abm.raw_message = event return abm - - def handle_whitelist(self, event: Event) -> bool: - match event['message_type']: - case "group": - if self.config.qq_group_id_whitelist and str(event.group_id) in self.config.qq_group_id_whitelist: - return True - case "private": - if self.config.qq_id_whitelist and str(event.sender['user_id']) in self.config.qq_id_whitelist: - return True - return False def run(self) -> Awaitable[Any]: if not self.host or not self.port: @@ -107,18 +93,12 @@ class AiocqhttpAdapter(Platform): self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) @self.bot.on_message('group') async def group(event: Event): - if not self.handle_whitelist(event): - logger.debug(f"一个群消息({event.group_id})事件由于不在白名单而被过滤。") - return abm = self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message('private') async def private(event: Event): - if not self.handle_whitelist(event): - logger.debug(f"一个私聊消息({event.sender['nickname']}/{event.sender['user_id']})事件由于不在白名单而被过滤。") - return abm = self.convert_message(event) if abm: await self.handle_msg(abm) diff --git a/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py similarity index 96% rename from packages/astrbot_adapter_qqofficial/qqofficial_message_event.py rename to astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2b02e3455..e3acc0878 100644 --- a/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -3,7 +3,8 @@ import botpy.message import botpy.types import botpy.types.message from astrbot.core.utils.io import file_to_base64, download_image_by_url -from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image from botpy import Client from botpy.http import Route diff --git a/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py similarity index 87% rename from packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py rename to astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 11ba4f53b..8f973ba0c 100644 --- a/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -6,17 +6,17 @@ import botpy.types import botpy.types.message from botpy import Client -from astrbot.api import Platform -from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.event import MessageChain from typing import Union, List, Dict from astrbot.api.message_components import * from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion from .qqofficial_message_event import QQOfficialMessageEvent -from astrbot.core.config.astrbot_config import PlatformConfig, QQOfficialPlatformConfig, PlatformSettings -from astrbot.core.utils.io import save_temp_img, download_image_by_url +from ...register import register_platform_adapter # QQ 机器人官方框架 +@register_platform_adapter("qqofficial", "QQ 机器人官方 API 适配器") class botClient(Client): def set_platform(self, platform: 'QQOfficialPlatformAdapter'): self.platform = platform @@ -56,18 +56,18 @@ class botClient(Client): class QQOfficialPlatformAdapter(Platform): - def __init__(self, platform_config: QQOfficialPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: + def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: super().__init__(event_queue) self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.config = platform_config - self.appid = platform_config.appid - self.secret = platform_config.secret - self.unique_session = platform_settings.unique_session - qq_group = platform_config.enable_group_c2c - guild_dm = platform_config.enable_guild_direct_message + self.appid = platform_config['appid'] + self.secret = platform_config['secret'] + self.unique_session = platform_settings['unique_session'] + qq_group = platform_config['enable_group_c2c'] + guild_dm = platform_config['enable_guild_direct_message'] if qq_group: self.intents = botpy.Intents( @@ -115,6 +115,7 @@ class QQOfficialPlatformAdapter(Platform): message.author.member_openid, "" ) + abm.group_id = message.group_openid else: abm.sender = MessageMember( message.author.user_openid, @@ -157,6 +158,9 @@ class QQOfficialPlatformAdapter(Platform): str(message.author.id), str(message.author.username) ) + + if isinstance(message, botpy.message.Message): + abm.group_id = message.channel_id else: raise ValueError(f"Unknown message type: {message_type}") return abm diff --git a/packages/astrbot_adapter_wechat/wechat_message_event.py b/astrbot/core/platform/sources/vchat/vchat_message_event.py similarity index 84% rename from packages/astrbot_adapter_wechat/wechat_message_event.py rename to astrbot/core/platform/sources/vchat/vchat_message_event.py index 9451e043c..6aab78a62 100644 --- a/packages/astrbot_adapter_wechat/wechat_message_event.py +++ b/astrbot/core/platform/sources/vchat/vchat_message_event.py @@ -1,10 +1,12 @@ import random, asyncio from astrbot.core.utils.io import download_image_by_url -from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata +from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.platform import AstrBotMessage, PlatformMetadata from astrbot.api.message_components import Plain, Image from vchat import Core -class WechatPlatformEvent(AstrMessageEvent): +class VChatPlatformEvent(AstrMessageEvent): def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -36,6 +38,6 @@ class WechatPlatformEvent(AstrMessageEvent): async def send(self, message: MessageChain): - await WechatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username) + await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username) await super().send(message) \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/wechat_platform_adapter.py b/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py similarity index 77% rename from packages/astrbot_adapter_wechat/wechat_platform_adapter.py rename to astrbot/core/platform/sources/vchat/vchat_platform_adapter.py index d743ddf45..0fd48bb9c 100644 --- a/packages/astrbot_adapter_wechat/wechat_platform_adapter.py +++ b/astrbot/core/platform/sources/vchat/vchat_platform_adapter.py @@ -1,15 +1,13 @@ -import sys, time, datetime, uuid +import sys, time, uuid import asyncio -from astrbot.api import Platform -from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata -from typing import Union, List, Dict +from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.api.event import MessageChain from astrbot.api.message_components import * from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion -from .wechat_message_event import WechatPlatformEvent -from astrbot.core.config.astrbot_config import PlatformConfig, WechatPlatformConfig, PlatformSettings -from astrbot.core.utils.io import save_temp_img, download_image_by_url +from .vchat_message_event import VChatPlatformEvent +from ...register import register_platform_adapter from vchat import Core from vchat import model @@ -18,10 +16,11 @@ if sys.version_info >= (3, 12): from typing import override else: from typing_extensions import override - -class WechatPlatformAdapter(Platform): - def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: +@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器") +class VChatPlatformAdapter(Platform): + + def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None: super().__init__(event_queue) self.config = platform_config self.settingss = platform_settings @@ -31,13 +30,13 @@ class WechatPlatformAdapter(Platform): @override async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): from_username = session.session_id.split('$$')[0] - await WechatPlatformEvent.send_with_client(self.client, message_chain, from_username) + await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username) await super().send_by_session(session, message_chain) @override def meta(self) -> PlatformMetadata: return PlatformMetadata( - "wechat", + "vchat", "基于 VChat 的 Wechat 适配器", ) @@ -53,10 +52,6 @@ class WechatPlatformAdapter(Platform): logger.debug(f"忽略旧消息: {msg}") return logger.debug(f"收到消息: {msg.todict()}") - if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist: - logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}") - return - logger.info(f"收到消息: {msg.todict()}") abmsg = self.convert_message(msg) # await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞 asyncio.create_task(self.handle_msg(abmsg)) @@ -92,12 +87,13 @@ class WechatPlatformAdapter(Platform): amsg.type = MessageType.FRIEND_MESSAGE elif isinstance(msg.from_, model.Chatroom): amsg.type = MessageType.GROUP_MESSAGE + amsg.group_id = msg.from_.username else: logger.error(f"不支持的 Wechat 消息类型: {msg.from_}") amsg.raw_message = msg - if self.settingss.unique_session: + if self.settingss['unique_session']: session_id = msg.from_.username + "$$" + msg.to.username if msg.chatroom_sender is not None: session_id += '$$' + msg.chatroom_sender.username @@ -108,7 +104,7 @@ class WechatPlatformAdapter(Platform): return amsg async def handle_msg(self, message: AstrBotMessage): - message_event = WechatPlatformEvent( + message_event = VChatPlatformEvent( message_str=message.message_str, message_obj=message, platform_meta=self.meta(), diff --git a/astrbot/core/plugin/__init__.py b/astrbot/core/plugin/__init__.py deleted file mode 100644 index 61e1860d7..000000000 --- a/astrbot/core/plugin/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .plugin import Plugin, RegisteredPlugin, PluginMetadata -from .plugin_manager import PluginManager -from .context import CommandMetadata, Context -from astrbot.core.provider import Provider \ No newline at end of file diff --git a/astrbot/core/plugin/context.py b/astrbot/core/plugin/context.py deleted file mode 100644 index a1d341fbe..000000000 --- a/astrbot/core/plugin/context.py +++ /dev/null @@ -1,217 +0,0 @@ -import heapq -from asyncio import Queue -from . import RegisteredPlugin, PluginMetadata -from typing import List, Dict, Awaitable, Union -from dataclasses import dataclass - -from astrbot.core.platform import Platform -from astrbot.core.db import BaseDatabase -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.utils.func_call import FuncCall -from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.core.message.message_event_result import MessageChain - -@dataclass -class CommandMetadata(): - ''' - 显式指令 - ''' - plugin_name: str - plugin_metadata: PluginMetadata - handler: Awaitable - use_regex: bool = False - ignore_prefix: bool = False - description: str = "" - -@dataclass -class EventListenerMetadata(): - ''' - 事件监听器 - ''' - plugin_name: str - plugin_metadata: PluginMetadata - handler: Awaitable - description: str = "" - after_commands: bool = False - - -class Context: - ''' - 暴露给插件的接口上下文,用于注册指令、事件监听器、消息平台、模型提供商等。 - ''' - # 事件队列。消息平台通过事件队列传递消息事件。 - _event_queue: Queue = None - - # AstrBot 配置信息 - _config: AstrBotConfig = None - - # AstrBot 数据库 - _db: BaseDatabase = None - - # 维护了注册的插件的信息 - registered_plugins: List[RegisteredPlugin] = [] - - # 维护了插件注册的指令的信息的名字列表,用于优先级排序 - registered_commands: List[str] = [] - # 维护了插件注册的指令的信息 - commands_handler: Dict[str, CommandMetadata] = {} - - # 维护了插件注册的中间件的名字列表,用于优先级排序 - registered_listeners: List[str] = [] - # 维护了插件注册的中间件的信息 - listeners_handler: Dict[str, EventListenerMetadata] = {} - - # 维护了注册的平台的信息 - registered_platforms: List[Platform] = [] - - # 维护了 LLM Tools 信息 - llm_tools: FuncCall = FuncCall() - - # 维护插件存储的数据 - plugin_data: Dict[str, Dict[str, any]] = {} - - def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase): - self._event_queue = event_queue - self._config = config - self._db = db - - def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin: - for plugin in self.registered_plugins: - if plugin.metadata.plugin_name == plugin_name: - return plugin - return None - - def register_listener(self, - plugin_name: str, - name: str, - handler: Awaitable, - description: str = None, - after_commands: bool = False): - ''' - 注册一个事件监听器。 - - after_commands: 是否在指令处理后执行。 - ''' - if name in self.registered_listeners: - raise ValueError(f"Middleware {name} already exists.") - self.registered_listeners.append(name) - self.listeners_handler[name] = EventListenerMetadata( - plugin_name=plugin_name, - plugin_metadata=None, - handler=handler, - description=description, - after_commands=after_commands - ) - - def register_commands(self, - plugin_name: str, - command_name: str, - description: str, - priority: int, - handler: Awaitable, - use_regex: bool = False, - ignore_prefix: bool = False): - ''' - 注册插件指令。 - - @param plugin_name: 插件名,注意需要和你的 metadata 中的一致。 - @param command_name: 指令名,如 "help"。不需要带前缀。 - @param description: 指令描述。 - @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 - @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context - @param use_regex: 是否使用正则表达式匹配指令名。 - @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。 - - .. Example:: - - ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。 - ''' - for command in self.registered_commands: - if command_name in command[1]: - raise ValueError(f"Command {command_name} already exists.") - if not handler: - raise ValueError(f"Handler of {command_name} is None.") - - heapq.heappush(self.registered_commands, (-priority, command_name)) - self.commands_handler[command_name] = CommandMetadata( - plugin_name=plugin_name, - plugin_metadata=None, - handler=handler, - use_regex=use_regex, - ignore_prefix=ignore_prefix, - description=description - ) - heapq.heapify(self.registered_commands) - - def register_platform(self, platform: Platform): - ''' - 注册一个消息平台。 - ''' - self.registered_platforms.append(platform) - - def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: - ''' - 为函数调用(function-calling / tools-use)添加工具。 - - @param name: 函数名 - @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] - @param desc: 函数描述 - @param func_obj: 异步处理函数。 - - 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 - ''' - self.llm_tools.add_func(name, func_args, desc, func_obj) - - def unregister_llm_tool(self, name: str) -> None: - ''' - 删除一个函数调用工具。 - ''' - self.llm_tools.remove_func(name) - - def get_config(self) -> AstrBotConfig: - ''' - 获取 AstrBot 配置信息。 - ''' - return self._config - - def get_db(self) -> BaseDatabase: - ''' - 获取 AstrBot 数据库。 - ''' - return self._db - - def get_event_queue(self) -> Queue: - ''' - 获取事件队列。 - ''' - return self._event_queue - - async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool: - ''' - 根据 session(unified_msg_origin) 发送消息。 - - @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 - @param message_chain: 消息链。 - - @return: 是否找到匹配的平台。 - - 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 - ''' - - if isinstance(session, str): - try: - session = MessageSesion.from_str(session) - except BaseException as e: - raise ValueError("不合法的 session 字符串: " + str(e)) - - for platform in self.registered_platforms: - if platform.meta().name == session.platform_name: - await platform.send_by_session(session, message_chain) - return True - return False - - def set_data(self, plugin_name: str, key: str, value: any): - ''' - 设置插件数据。 - ''' - self.plugin_data[plugin_name][key] = value \ No newline at end of file diff --git a/astrbot/core/plugin/plugin.py b/astrbot/core/plugin/plugin.py deleted file mode 100644 index c27fb9925..000000000 --- a/astrbot/core/plugin/plugin.py +++ /dev/null @@ -1,43 +0,0 @@ -from enum import Enum -from types import ModuleType -from typing import List -from dataclasses import dataclass - -@dataclass -class PluginMetadata: - ''' - 插件的元数据。 - ''' - # required - plugin_name: str - author: str # 插件作者 - desc: str # 插件简介 - version: str # 插件版本 - - # optional - repo: str = None # 插件仓库地址 - - def __str__(self) -> str: - return f"PluginMetadata({self.plugin_name}, {self.desc}, {self.version}, {self.repo})" - - -@dataclass -class RegisteredPlugin: - ''' - 注册在 AstrBot 中的插件。 - ''' - metadata: PluginMetadata - plugin_instance: object - module_path: str - module: ModuleType - root_dir_name: str - reserved: bool # 是否是 AstrBot 的保留插件 - - def __str__(self) -> str: - return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" - - - -class Plugin: - def __init__(self): - pass \ No newline at end of file diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index d7f09bd65..665452cdb 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1 +1 @@ -from .provider import Provider, Personality \ No newline at end of file +from .provider import Provider, Personality, ProviderMetaData \ No newline at end of file diff --git a/astrbot/core/provider/llm_response.py b/astrbot/core/provider/llm_response.py new file mode 100644 index 000000000..89fbf4045 --- /dev/null +++ b/astrbot/core/provider/llm_response.py @@ -0,0 +1,13 @@ +from typing import Dict, List +from dataclasses import dataclass + +@dataclass +class LLMResponse: + role: str + '''角色''' + completion_text: str = None + '''LLM 返回的文本''' + tools_call_args: List[Dict[str, any]] = None + '''工具调用参数''' + tools_call_name: List[str] = None + '''工具调用名称''' diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py new file mode 100644 index 000000000..92cb470ea --- /dev/null +++ b/astrbot/core/provider/manager.py @@ -0,0 +1,49 @@ +from astrbot.core.config.astrbot_config import AstrBotConfig +from .provider import Provider +from typing import List +from astrbot.core.db import BaseDatabase +from collections import defaultdict +from astrbot.core.provider.tool import FuncCall +from .register import provider_cls_map, provider_registry +from astrbot.core import logger + +class ProviderManager(): + def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase): + self.providers_config: List = config['provider'] + self.provider_settings: dict = config['provider_settings'] + self.provider_insts: List[Provider] = [] + '''加载的 Provider 的实例''' + self.llm_tools: FuncCall = FuncCall() + self.curr_provider_inst: Provider = None + self.loaded_ids = defaultdict(bool) + self.db_helper = db_helper + + for provider_cfg in self.providers_config: + if not provider_cfg['enable']: + continue + + if provider_cfg['id'] in self.loaded_ids: + raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。") + self.loaded_ids[provider_cfg['id']] = True + + match provider_cfg['type']: + case "openai_chat_completion": + from .sources.openai_source import ProviderOpenAIOfficial + + async def initialize(self): + for provider_config in self.providers_config: + if not provider_config['enable']: + continue + if provider_config['type'] not in provider_cls_map: + logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。") + continue + cls_type = provider_cls_map[provider_config['type']] + logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...") + inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) + self.provider_insts.append(inst) + + if len(self.provider_insts) > 0: + self.curr_provider_inst = self.provider_insts[0] + + def get_insts(self): + return self.provider_insts \ No newline at end of file diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 6e3c3c9bf..066a9a2e3 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,91 +1,134 @@ -import abc, json, threading, time +import abc, json from collections import defaultdict from typing import List from astrbot.core.db import BaseDatabase from astrbot.core import logger from typing import TypedDict - +from .provider_metadata import ProviderMetaData +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.tool import FuncCall +from astrbot.core.provider.llm_response import LLMResponse +from dataclasses import dataclass class Personality(TypedDict): - prompt: str - name: str + prompt: str = "" + name: str = "" + +@dataclass +class ProviderMeta(): + id: str + model: str + type: str + class Provider(abc.ABC): - def __init__(self, db_helper: BaseDatabase, default_personality: str = None, persistant_history: bool = True) -> None: - self.model_name = "unknown" - # 维护了 session_id 的上下文,不包含 system 指令 + def __init__( + self, + provider_config: dict, + provider_settings: dict, + persistant_history: bool = True, + db_helper: BaseDatabase = None + ) -> None: + self.model_name = "" + '''当前使用的模型名称''' + self.session_memory = defaultdict(list) - self.curr_personality = Personality(prompt=default_personality, name="") + '''维护了 session_id 的上下文,**不包含 system 指令**。''' + + self.provider_config = provider_config + + self.provider_settings = provider_settings + + self.curr_personality = Personality(prompt=provider_settings['default_personality']) + '''维护了当前的使用的 persona,即人格。''' + self.db_helper = db_helper + '''用于持久化的数据库操作对象。''' + if persistant_history: # 读取历史记录 try: - for history in db_helper.get_llm_history(): + for history in db_helper.get_llm_history(provider_type=provider_config['type']): self.session_memory[history.session_id] = json.loads(history.content) except BaseException as e: logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。") - def set_model(self, model_name: str): + '''设置当前使用的模型名称''' self.model_name = model_name - def get_model(self): + def get_model(self) -> str: + '''获得当前使用的模型名称''' return self.model_name - async def get_human_readable_context(self, session_id: str) -> List[str]: + @abc.abstractmethod + def get_current_key(self) -> str: + raise NotImplementedError() + + def get_keys(self) -> List[str]: + '''获得提供商 Key''' + return self.provider_config['key'] + + @abc.abstractmethod + def set_key(self, key: str): + raise NotImplementedError() + + @abc.abstractmethod + def get_models(self) -> List[str]: + '''获得支持的模型列表''' + raise NotImplementedError() + + @abc.abstractmethod + async def get_human_readable_context(self, session_id: str, page: int, page_size: int): + '''获取人类可读的上下文 + + Example: + + ["User: 你好", "Assistant: 你好!"] + + Return: + contexts: List[str]: 上下文列表 + total_pages: int: 总页数 ''' - 获取人类可读的上下文 - - example: - ["User: 你好", "Assistant: 你好"] - ''' - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - - contexts = [] - for record in self.session_memory[session_id]: - if record['role'] == "user": - contexts.append(f"User: {record['content']}") - elif record['role'] == "assistant": - contexts.append(f"Assistant: {record['content']}") - - return contexts + raise NotImplementedError() @abc.abstractmethod async def text_chat(self, prompt: str, - session_id: str, - image_urls: List[str] = None, - tools = None, - contexts=None, - **kwargs) -> str: - ''' - prompt: 提示词 - session_id: 会话id + session_id: str=None, + image_urls: List[str]=None, + func_tool: FuncCall=None, + contexts: List=None, + **kwargs) -> LLMResponse: + '''获得 LLM 的文本对话结果。会使用当前的模型进行对话。 + + Args: + prompt: 提示词 + session_id: 会话 ID + image_urls: 图片 URL 列表 + tools: Function-calling 工具 + contexts: 上下文 + kwargs: 其他参数 + + Notes: + - 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话, + 并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。 + - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 + - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 + - 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。 + 传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。 - [optional] - image_url: 图片url(识图) - tools: 函数调用工具 - ''' - raise NotImplementedError() - - @abc.abstractmethod - async def image_generate(self, prompt: str, session_id: str, **kwargs) -> str: - ''' - prompt: 提示词 - session_id: 会话id - ''' - raise NotImplementedError() - - @abc.abstractmethod - async def get_embedding(self, text: str) -> List[float]: - ''' - 获取文本的嵌入 ''' raise NotImplementedError() @abc.abstractmethod async def forget(self, session_id: str) -> bool: - ''' - 重置会话 - ''' + '''重置某一个 session_id 的上下文''' raise NotImplementedError() + + def meta(self) -> ProviderMeta: + '''获取 Provider 的元数据''' + return ProviderMeta( + id=self.provider_config['id'], + model=self.get_model(), + type=self.provider_config['type'] + ) \ No newline at end of file diff --git a/astrbot/core/provider/provider_metadata.py b/astrbot/core/provider/provider_metadata.py new file mode 100644 index 000000000..34299a934 --- /dev/null +++ b/astrbot/core/provider/provider_metadata.py @@ -0,0 +1,6 @@ +from dataclasses import dataclass + +@dataclass +class ProviderMetaData(): + type: str # 提供商适配器名称,如 openai, ollama + desc: str = "" # 提供商适配器描述. \ No newline at end of file diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py new file mode 100644 index 000000000..1ed812b68 --- /dev/null +++ b/astrbot/core/provider/register.py @@ -0,0 +1,25 @@ +from typing import List, Dict, Type +from .provider_metadata import ProviderMetaData +from astrbot.core import logger + +provider_registry: List[ProviderMetaData] = [] +'''维护了通过装饰器注册的 Provider''' +provider_cls_map: Dict[str, Type] = {} +'''维护了 Provider 类型名称和 Provider 类的映射''' + +def register_provider_adapter(provider_type_name: str, desc: str): + '''用于注册平台适配器的带参装饰器''' + def decorator(cls): + if provider_type_name in provider_cls_map: + raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。") + + pm = ProviderMetaData( + type=provider_type_name, + desc=desc, + ) + provider_registry.append(pm) + provider_cls_map[provider_type_name] = cls + logger.debug(f"Provider {provider_type_name} 已注册") + return cls + + return decorator diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py new file mode 100644 index 000000000..0852867c3 --- /dev/null +++ b/astrbot/core/provider/sources/openai_source.py @@ -0,0 +1,216 @@ +import asyncio +import traceback +import base64 +import json + +from openai import AsyncOpenAI, NOT_GIVEN +from openai.types.chat.chat_completion import ChatCompletion +from openai._exceptions import * +from astrbot.core.utils.io import download_image_by_url + +from astrbot.core.db import BaseDatabase +from astrbot.api.provider import Provider +from astrbot import logger +from astrbot.core.provider.tool import FuncCall +from typing import List +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from ..register import register_provider_adapter +from astrbot.core.provider.llm_response import LLMResponse + +@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器") +class ProviderOpenAIOfficial(Provider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + db_helper: BaseDatabase, + persistant_history = True + ) -> None: + super().__init__(provider_config, provider_settings, persistant_history, db_helper) + self.chosen_api_key = None + self.api_keys: List = provider_config.get("key", []) + self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None + + self.client = AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=provider_config.get("api_base", None), + timeout=provider_config.get("timeout", NOT_GIVEN), + ) + self.set_model(provider_config['model_config']['model']) + + async def get_human_readable_context(self, session_id, page, page_size): + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + contexts = [] + for record in self.session_memory[session_id]: + if record['role'] == "user": + contexts.append(f"User: {record['content']}") + elif record['role'] == "assistant": + contexts.append(f"Assistant: {record['content']}") + + # 计算分页 + paged_contexts = contexts[(page-1)*page_size:page*page_size] + total_pages = len(contexts) // page_size + if len(contexts) % page_size != 0: + total_pages += 1 + + return paged_contexts, total_pages + + async def get_models(self): + try: + models_str = [] + models = await self.client.models.list() + for model in models: + models_str.append(model['id']) + return models_str + except NotFoundError as e: + raise Exception(f"获取模型列表失败:{e}") + + async def pop_record(self, session_id: str, pop_system_prompt: bool = False): + ''' + 弹出第一条记录 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + if len(self.session_memory[session_id]) == 0: + return None + + for i in range(len(self.session_memory[session_id])): + # 检查是否是 system prompt + if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": + # 如果只有一个 system prompt,才不删掉 + f = False + for j in range(i+1, len(self.session_memory[session_id])): + if self.session_memory[session_id][j]['user']['role'] == "system": + f = True + break + if not f: + continue + record = self.session_memory[session_id].pop(i) + break + + return record + + async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: + if tools: + payloads["tools"] = tools.get_func_desc_openai_style() + + completion = await self.client.chat.completions.create( + **payloads, + stream=False + ) + + assert isinstance(completion, ChatCompletion) + logger.debug(f"completion: {completion.usage}") + + if len(completion.choices) == 0: + raise Exception("API 返回的 completion 为空。") + choice = completion.choices[0] + + if choice.message.content: + # text completion + completion_text = str(choice.message.content).strip() + return LLMResponse("assistant", completion_text) + elif choice.message.tool_calls: + # tools call (function calling) + args_ls = [] + func_name_ls = [] + for tool_call in choice.message.tool_calls: + for tool in tools.func_list: + if tool['name'] == tool_call.function.name: + args = json.loads(tool_call.function.arguments) + args_ls.append(args) + func_name_ls.append(tool_call.function.name) + return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls) + else: + raise Exception("Internal Error") + + async def text_chat(self, + prompt: str, + session_id: str, + image_urls: List[str]=None, + func_tool: FuncCall=None, + contexts=None, + **kwargs + ) -> LLMResponse: + new_record = await self.assemble_context(prompt, image_urls) + + context_query = [] + + if not contexts: + context_query = [*self.session_memory[session_id], new_record] + if self.curr_personality["prompt"]: + context_query.insert(0, {"role": "system", "content": self.curr_personality["prompt"]}) + else: + context_query = contexts + + logger.debug(f"请求上下文:{context_query}") + + payloads = { + "messages": context_query, + **self.provider_config.get("model_config", {}) + } + + try: + llm_response = await self._query(payloads, func_tool) + except Exception as e: + if "maximum context length" in str(e): + logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") + self.pop_record(session_id) + logger.warning(traceback.format_exc()) + + if llm_response.role == "assistant": + # 文本回复 + if not contexts: + # 添加用户 record + self.session_memory[session_id].append(new_record) + # 添加 assistant record + self.session_memory[session_id].append({ + "role": "assistant", + "content": llm_response.completion_text + }) + self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type']) + + return llm_response + + async def forget(self, session_id: str) -> bool: + self.session_memory[session_id] = [] + return True + + def get_current_key(self) -> str: + return self.client.api_key + + def get_keys(self) -> List[str]: + return self.api_keys + + def set_key(self, key): + self.client.api_key = key + + async def assemble_context(self, text: str, image_urls: List[str] = None): + ''' + 组装上下文。 + ''' + if image_urls: + user_content = {"role": "user","content": [{"type": "text", "text": text}]} + for image_url in image_urls: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + image_data = await self.encode_image_bs64(image_path) + else: + image_data = await self.encode_image_bs64(image_url) + user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) + return user_content + else: + return {"role": "user","content": text} + + async def encode_image_bs64(self, image_url: str) -> str: + ''' + 将图片转换为 base64 + ''' + if image_url.startswith("base64://"): + return image_url.replace("base64://", "data:image/jpeg;base64,") + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode('utf-8') + return "data:image/jpeg;base64," + image_bs64 + return '' \ No newline at end of file diff --git a/astrbot/core/utils/func_call.py b/astrbot/core/provider/tool.py similarity index 83% rename from astrbot/core/utils/func_call.py rename to astrbot/core/provider/tool.py index 581e2cc3e..429038ee3 100644 --- a/astrbot/core/utils/func_call.py +++ b/astrbot/core/provider/tool.py @@ -1,7 +1,7 @@ -from astrbot.core.provider import Provider -from typing import Awaitable import json import textwrap +from typing import Awaitable, Dict, List +from typing_extensions import TypedDict class FuncCallJsonFormatError(Exception): @@ -11,7 +11,6 @@ class FuncCallJsonFormatError(Exception): def __str__(self): return self.msg - class FuncNotFoundError(Exception): def __init__(self, msg): self.msg = msg @@ -19,10 +18,19 @@ class FuncNotFoundError(Exception): def __str__(self): return self.msg +class FuncTool(TypedDict): + ''' + 用于描述一个函数调用工具。 + ''' + name: str + parameters: Dict + description: str + func_obj: Awaitable + class FuncCall(): def __init__(self) -> None: - self.func_list = [] + self.func_list: List[FuncTool] = [] def empty(self) -> bool: return len(self.func_list) == 0 @@ -45,12 +53,7 @@ class FuncCall(): "type": param['type'], "description": param['description'] } - _func = { - "name": name, - "parameters": params, - "description": desc, - "func_obj": func_obj, - } + _func = FuncTool(name=name, parameters=params, description=desc, func_obj=func_obj) self.func_list.append(_func) def remove_func(self, name: str) -> None: @@ -62,17 +65,16 @@ class FuncCall(): self.func_list.pop(i) break - def func_dump(self) -> str: - _l = [] + def get_func(self, name) -> FuncTool: for f in self.func_list: - _l.append({ - "name": f["name"], - "parameters": f["parameters"], - "description": f["description"], - }) - return json.dumps(_l, ensure_ascii=False) - - def get_func(self) -> list: + if f["name"] == name: + return f + return None + + def get_func_desc_openai_style(self) -> list: + ''' + 获得 OpenAI API 风格的工具描述 + ''' _l = [] for f in self.func_list: _l.append({ @@ -85,7 +87,17 @@ class FuncCall(): }) return _l - async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider) -> tuple: + async def func_call(self, question: str, session_id: str, provider) -> tuple: + + _l = [] + for f in self.func_list: + _l.append({ + "name": f["name"], + "parameters": f["parameters"], + "description": f["description"], + }) + func_definition = json.dumps(_l, ensure_ascii=False) + prompt = textwrap.dedent(f""" ROLE: 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。 @@ -111,7 +123,6 @@ class FuncCall(): while _c < 3: try: res = await provider.text_chat(prompt, session_id) - print(res) if res.find('```') != -1: res = res[res.find('```json') + 7: res.rfind('```')] res = json.loads(res) diff --git a/packages/astrbot_plugin_openai/websearch/bing.py b/astrbot/core/provider/tools/websearch/engines/bing.py similarity index 100% rename from packages/astrbot_plugin_openai/websearch/bing.py rename to astrbot/core/provider/tools/websearch/engines/bing.py diff --git a/packages/astrbot_plugin_openai/websearch/config.py b/astrbot/core/provider/tools/websearch/engines/config.py similarity index 100% rename from packages/astrbot_plugin_openai/websearch/config.py rename to astrbot/core/provider/tools/websearch/engines/config.py diff --git a/packages/astrbot_plugin_openai/websearch/engine.py b/astrbot/core/provider/tools/websearch/engines/engine.py similarity index 100% rename from packages/astrbot_plugin_openai/websearch/engine.py rename to astrbot/core/provider/tools/websearch/engines/engine.py diff --git a/packages/astrbot_plugin_openai/websearch/google.py b/astrbot/core/provider/tools/websearch/engines/google.py similarity index 100% rename from packages/astrbot_plugin_openai/websearch/google.py rename to astrbot/core/provider/tools/websearch/engines/google.py diff --git a/packages/astrbot_plugin_openai/websearch/sogo.py b/astrbot/core/provider/tools/websearch/engines/sogo.py similarity index 100% rename from packages/astrbot_plugin_openai/websearch/sogo.py rename to astrbot/core/provider/tools/websearch/engines/sogo.py diff --git a/packages/astrbot_plugin_openai/web_searcher.py b/astrbot/core/provider/tools/websearch/web_searcher.py similarity index 91% rename from packages/astrbot_plugin_openai/web_searcher.py rename to astrbot/core/provider/tools/websearch/web_searcher.py index 92982d1c8..fa341d1a8 100644 --- a/packages/astrbot_plugin_openai/web_searcher.py +++ b/astrbot/core/provider/tools/websearch/web_searcher.py @@ -5,11 +5,13 @@ import os from readability import Document from bs4 import BeautifulSoup from openai._exceptions import * -from .websearch.config import HEADERS, USER_AGENTS -from .websearch.bing import Bing -from .websearch.sogo import Sogo -from .websearch.google import Google -from astrbot.api import logger, AstrMessageEvent, Provider, MessageChain, MessageEventResult +from engines.config import HEADERS, USER_AGENTS +from engines.bing import Bing +from engines.sogo import Sogo +from engines.google import Google +from astrbot.api.event import AstrMessageEvent, MessageChain, MessageEventResult +from astrbot.api.provider import Provider +from astrbot.api import logger bing_search = Bing() sogo_search = Sogo() @@ -61,7 +63,6 @@ async def search_from_bing(keyword: str, event: AstrMessageEvent = None, provide return await summarize(ret, event, provider) - async def fetch_website_content(url: str, event: AstrMessageEvent = None, provider: Provider = None) -> str: header = HEADERS header.update({'User-Agent': random.choice(USER_AGENTS)}) diff --git a/astrbot/core/star/README.md b/astrbot/core/star/README.md new file mode 100644 index 000000000..badb7fab0 --- /dev/null +++ b/astrbot/core/star/README.md @@ -0,0 +1,5 @@ +# AstrBot Star + +`AstrBot Star` 就是插件。 + +在 AstrBot v4.0 版本后,AstrBot 内部将插件命名为 `star`。插件的 handler 称作 `star_handler`。 \ No newline at end of file diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py new file mode 100644 index 000000000..2574ba6dc --- /dev/null +++ b/astrbot/core/star/__init__.py @@ -0,0 +1,4 @@ +from .star import Star, StarMetadata +from .star_manager import PluginManager +from .context import Context +from astrbot.core.provider import Provider \ No newline at end of file diff --git a/astrbot/core/plugin/config.py b/astrbot/core/star/config.py similarity index 100% rename from astrbot/core/plugin/config.py rename to astrbot/core/star/config.py diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py new file mode 100644 index 000000000..5a4a57ca2 --- /dev/null +++ b/astrbot/core/star/context.py @@ -0,0 +1,174 @@ +import heapq +from asyncio import Queue +from . import StarMetadata +from typing import List, Dict, TypedDict, Union + +from astrbot.core.platform import Platform +from astrbot.core.provider import Provider +from astrbot.core.db import BaseDatabase +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.provider.tool import FuncCall +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.manager import ProviderManager +from astrbot.core.platform.manager import PlatformManager +from .star import star_registry, star_map, StarMetadata +from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata +from .filter.command import CommandFilter +from .filter.regex import RegexFilter +from typing import Awaitable + +class StarCommand(TypedDict): + full_command_name: str + command_name: str + +class Context: + ''' + 暴露给插件的接口上下文。 + ''' + _event_queue: Queue = None + '''事件队列。消息平台通过事件队列传递消息事件。''' + + _config: AstrBotConfig = None + '''AstrBot 配置信息''' + + _db: BaseDatabase = None + '''AstrBot 数据库''' + + provider_manager: ProviderManager = None + + platform_manager: PlatformManager = None + + def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase): + self._event_queue = event_queue + self._config = config + self._db = db + + def get_registered_star(self, star_name: str) -> StarMetadata: + return star_map.get(star_name, None) + + def get_all_stars(self) -> List[StarMetadata]: + return star_registry + + def get_llm_tools(self) -> FuncCall: + ''' + 获取 LLM Tools。 + ''' + return self.provider_manager.llm_tools + + + # def get_star_commands(self, star_name: str) -> List[]: + # '''获得一个''' + + # def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: + # ''' + # 为函数调用(function-calling / tools-use)添加工具。 + + # @param name: 函数名 + # @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + # @param desc: 函数描述 + # @param func_obj: 异步处理函数。 + + # 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 + # ''' + # self.llm_tools.add_func(name, func_args, desc, func_obj) + + # def unregister_llm_tool(self, name: str) -> None: + # ''' + # 删除一个函数调用工具。 + # ''' + # self.llm_tools.remove_func(name) + + def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False): + ''' + 注册一个命令。 + + [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 + + @param star_name: 插件(Star)名称。 + @param command_name: 命令名称。 + @param desc: 命令描述。 + @param priority: 优先级。1-10。 + @param awaitable: 异步处理函数。 + + ''' + md = StarHandlerMetadata( + handler_full_name=awaitable.__module__ + "_" + awaitable.__name__, + handler_name=awaitable.__name__, + handler_module_str=awaitable.__module__, + handler=awaitable, + event_filters=[], + desc=desc + ) + if use_regex: + md.event_filters.append(RegexFilter( + regex=command_name + )) + else: + md.event_filters.append(CommandFilter( + command_name=command_name, + handler_md=md + )) + star_handlers_registry.append(md) + + def register_provider(self, provider: Provider): + ''' + 注册一个 LLM Provider。 + ''' + self.provider_manager.provider_insts.append(provider) + + def get_all_providers(self) -> List[Provider]: + ''' + 获取所有 LLM Provider。 + ''' + return self.provider_manager.provider_insts + + def get_using_provider(self) -> Provider: + ''' + 获取当前使用的 LLM Provider。 + + 通过 /provider 指令切换。 + ''' + return self.provider_manager.curr_provider_inst + + def get_config(self) -> AstrBotConfig: + ''' + 获取 AstrBot 配置信息。 + ''' + return self._config + + def get_db(self) -> BaseDatabase: + ''' + 获取 AstrBot 数据库。 + ''' + return self._db + + def get_event_queue(self) -> Queue: + ''' + 获取事件队列。 + ''' + return self._event_queue + + async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool: + ''' + 根据 session(unified_msg_origin) 发送消息。 + + @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 + @param message_chain: 消息链。 + + @return: 是否找到匹配的平台。 + + 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 + ''' + + if isinstance(session, str): + try: + session = MessageSesion.from_str(session) + except BaseException as e: + raise ValueError("不合法的 session 字符串: " + str(e)) + + for platform in self.registered_platforms: + if platform.meta().name == session.platform_name: + await platform.send_by_session(session, message_chain) + return True + return False diff --git a/astrbot/core/star/filter/__init__.py b/astrbot/core/star/filter/__init__.py new file mode 100644 index 000000000..bada25d6f --- /dev/null +++ b/astrbot/core/star/filter/__init__.py @@ -0,0 +1,10 @@ +import abc +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig + +class HandlerFilter(abc.ABC): + @abc.abstractmethod + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + '''是否应当被过滤''' + raise NotImplementedError diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py new file mode 100644 index 000000000..ec71d9b47 --- /dev/null +++ b/astrbot/core/star/filter/command.py @@ -0,0 +1,67 @@ + +import re, inspect +from . import HandlerFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig +from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin +from typing import Awaitable +from ..star_handler import StarHandlerMetadata + +# 标准指令受到 wake_prefix 的制约。 +class CommandFilter(HandlerFilter, ParameterValidationMixin): + '''标准指令过滤器''' + def __init__(self, command_name: str, handler_md: StarHandlerMetadata = None): + self.command_name = command_name + if handler_md: + self.init_handler_md(handler_md) + + def print_types(self): + result = "" + print(self.handler_params) + for k, v in self.handler_params.items(): + if isinstance(v, type): + result += f"{k}({v.__name__})," + else: + result += f"{k}({type(v).__name__})={v}," + return result + + def init_handler_md(self, handle_md: StarHandlerMetadata): + self.handler_md = handle_md + signature = inspect.signature(self.handler_md.handler) + self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 + idx = 0 + for k, v in signature.parameters.items(): + if idx < 2: + # 忽略前两个参数,即 self 和 event + idx += 1 + continue + if v.default == inspect.Parameter.empty: + self.handler_params[k] = v.annotation + else: + self.handler_params[k] = v.default + + def get_handler_md(self) -> StarHandlerMetadata: + return self.handler_md + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + if not event.is_wake_up(): + return False + + message_str = event.get_message_str().strip() + # 分割为列表(每个参数之间可能会有多个空格) + ls = re.split(r"\s+", message_str) + if self.command_name != ls[0]: + return False + # params_str = message_str[len(self.command_name):].strip() + ls = ls[1:] + # 去除空字符串 + ls = [param for param in ls if param] + params = {} + try: + params = self.validate_and_convert_params(ls, self.handler_params) + # 解析完成咱也不能丢掉呀,留着给后面的用 + except ValueError as e: + raise e + event.set_extra("parsed_params", params) + + return True \ No newline at end of file diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py new file mode 100644 index 000000000..be9a4ddae --- /dev/null +++ b/astrbot/core/star/filter/command_group.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import re +from typing import Awaitable, List, Union, Tuple +from . import HandlerFilter +from .command import CommandFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig +from ..star_handler import StarHandlerMetadata + +# 指令组受到 wake_prefix 的制约。 +class CommandGroupFilter(HandlerFilter): + def __init__(self, group_name: str): + self.group_name = group_name + self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = [] + + def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]): + self.sub_command_filters.append(sub_command_filter) + + # 以树的形式打印出来 + def print_cmd_tree(self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "") -> str: + result = "" + for sub_filter in sub_command_filters: + if isinstance(sub_filter, CommandFilter): + cmd_th = sub_filter.print_types() + result += f"{prefix}├── {sub_filter.command_name}" + if cmd_th: + result += f" ({cmd_th})" + else: + result += f" (无参数指令)" + + result += "\n" + elif isinstance(sub_filter, CommandGroupFilter): + result += f"{prefix}├── {sub_filter.group_name}" + result += "\n" + result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ") + return result + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]: + if not event.is_wake_up(): + return False, None + + message_str = event.get_message_str().strip() + ls = re.split(r"\s+", message_str) + + if ls[0] != self.group_name: + return False, None + # 改写 message_str + ls = ls[1:] + event.message_str = " ".join(ls) + event.message_str = event.message_str.strip() + + if event.message_str == "": + # 当前还是指令组 + tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters) + raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree) + + child_command_handler_md = None + for sub_filter in self.sub_command_filters: + if isinstance(sub_filter, CommandFilter): + if sub_filter.filter(event, cfg): + child_command_handler_md = sub_filter.get_handler_md() + return True, child_command_handler_md + elif isinstance(sub_filter, CommandGroupFilter): + ok, handler = sub_filter.filter(event, cfg) + if ok: + child_command_handler_md = handler + return True, child_command_handler_md + tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters) + raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree) diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py new file mode 100644 index 000000000..5e16e2e75 --- /dev/null +++ b/astrbot/core/star/filter/event_message_type.py @@ -0,0 +1,28 @@ +import enum +from . import HandlerFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig +from astrbot.core.platform.message_type import MessageType + +class EventMessageType(enum.Flag): + GROUP_MESSAGE = enum.auto() + PRIVATE_MESSAGE = enum.auto() + OTHER_MESSAGE = enum.auto() + ALL = GROUP_MESSAGE | PRIVATE_MESSAGE | OTHER_MESSAGE + +MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE = { + MessageType.GROUP_MESSAGE: EventMessageType.GROUP_MESSAGE, + MessageType.FRIEND_MESSAGE: EventMessageType.PRIVATE_MESSAGE, + MessageType.OTHER_MESSAGE: EventMessageType.OTHER_MESSAGE +} + +class EventMessageTypeFilter(HandlerFilter): + def __init__(self, event_message_type: EventMessageType): + self.event_message_type = event_message_type + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + message_type = event.get_message_type() + if message_type in MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE: + event_message_type = MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE[message_type] + return bool(event_message_type & self.event_message_type) + return False \ No newline at end of file diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py new file mode 100644 index 000000000..0da89bc6c --- /dev/null +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -0,0 +1,27 @@ +import enum +from . import HandlerFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig +from typing import Union + +class PlatformAdapterType(enum.Flag): + AIOCQHTTP = enum.auto() + QQOFFICIAL = enum.auto() + VCHAT = enum.auto() + ALL = AIOCQHTTP | QQOFFICIAL | VCHAT + +ADAPTER_NAME_2_TYPE = { + "aiocqhttp": PlatformAdapterType.AIOCQHTTP, + "qq_official": PlatformAdapterType.QQOFFICIAL, + "vchat": PlatformAdapterType.VCHAT +} + +class PlatformAdapterTypeFilter(HandlerFilter): + def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]): + self.type_or_str = platform_adapter_type_or_str + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + adapter_name = event.get_platform_name() + if adapter_name in ADAPTER_NAME_2_TYPE: + return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str + return False \ No newline at end of file diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py new file mode 100644 index 000000000..c5f919ee3 --- /dev/null +++ b/astrbot/core/star/filter/regex.py @@ -0,0 +1,14 @@ + +import re +from . import HandlerFilter +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.config import AstrBotConfig + +# 正则表达式过滤器不会受到 wake_prefix 的制约。 +class RegexFilter(HandlerFilter): + '''正则表达式过滤器''' + def __init__(self, regex: str): + self.regex = re.compile(regex) + + def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: + return bool(self.regex.match(event.get_message_str().strip())) \ No newline at end of file diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py new file mode 100644 index 000000000..0823582c4 --- /dev/null +++ b/astrbot/core/star/register/__init__.py @@ -0,0 +1,8 @@ +from .star import register_star +from .star_handler import ( + register_command, + register_command_group, + register_event_message_type, + register_platform_adapter_type, + register_regex +) \ No newline at end of file diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py new file mode 100644 index 000000000..53b7c950d --- /dev/null +++ b/astrbot/core/star/register/star.py @@ -0,0 +1,18 @@ +from ..star import star_registry, StarMetadata, star_map + +def register_star(name: str, author: str, desc: str, version: str, repo: str = None): + def decorator(cls): + star_metadata = StarMetadata( + name=name, + author=author, + desc=desc, + version=version, + repo=repo, + star_cls_type=cls, + module_path=cls.__module__ + ) + star_registry.append(star_metadata) + star_map[cls.__module__] = star_metadata + return cls + + return decorator diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py new file mode 100644 index 000000000..f066ef5d4 --- /dev/null +++ b/astrbot/core/star/register/star_handler.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata +from ..filter.command import CommandFilter +from ..filter.command_group import CommandGroupFilter +from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType +from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType +from ..filter.regex import RegexFilter +from typing import Awaitable, List, Dict + + +def get_handler_full_name(awatable: Awaitable) -> str: + '''获取 Handler 的全名''' + return f"{awatable.__module__}_{awatable.__name__}" + +def get_handler_or_create(handler: Awaitable) -> StarHandlerMetadata: + '''获取 Handler 或者创建一个新的 Handler''' + handler_full_name = get_handler_full_name(handler) + if handler_full_name in star_handlers_map: + return star_handlers_map[handler_full_name] + else: + md = StarHandlerMetadata( + handler_full_name=handler_full_name, + handler_name=handler.__name__, + handler_module_str=handler.__module__, + handler=handler, + event_filters=[] + ) + star_handlers_registry.append(md) + star_handlers_map[handler_full_name] = md + return md + +def register_command(command_name: str = None, *args): + '''注册一个 Command''' + + new_command = None + add_to_event_filters = False + if isinstance(command_name, RegisteringCommandable): + # 子指令 + new_command = CommandFilter(args[0], None) + command_name.parent_group.add_sub_command_filter(new_command) + else: + # 裸指令 + new_command = CommandFilter(command_name, None) + add_to_event_filters = True + + def decorator(awaitable): + handler_md = get_handler_or_create(awaitable) + new_command.init_handler_md(handler_md) + if add_to_event_filters: + # 裸指令 + handler_md.event_filters.append(new_command) + + return awaitable + + return decorator + +def register_command_group(command_group_name: str = None, *args): + '''注册一个 CommandGroup''' + + new_group = None + add_to_event_filters = False + if isinstance(command_group_name, RegisteringCommandable): + # 子指令组 + new_group = CommandGroupFilter(args[0]) + command_group_name.parent_group.add_sub_command_filter(new_group) + else: + # 根指令组 + new_group = CommandGroupFilter(command_group_name) + add_to_event_filters = True + + def decorator(obj): + if add_to_event_filters: + # 根指令组 + handler_md = get_handler_or_create(obj) + handler_md.event_filters.append(new_group) + + return RegisteringCommandable(new_group) + + return decorator + +class RegisteringCommandable(): + '''用于指令组级联注册''' + group = register_command_group + command = register_command + + def __init__(self, parent_group: CommandGroupFilter): + self.parent_group = parent_group + +def register_event_message_type(event_message_type: EventMessageType): + '''注册一个 EventMessageType''' + def decorator(awatable): + handler_md = get_handler_or_create(awatable) + handler_md.event_filters.append(EventMessageTypeFilter(event_message_type)) + return awatable + + return decorator + +def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType): + '''注册一个 PlatformAdapterType''' + def decorator(awatable): + handler_md = get_handler_or_create(awatable) + handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type)) + return awatable + + return decorator + +def register_regex(regex: str): + '''注册一个 Regex''' + def decorator(awatable): + handler_md = get_handler_or_create(awatable) + handler_md.event_filters.append(RegexFilter(regex)) + return awatable + + return decorator \ No newline at end of file diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py new file mode 100644 index 000000000..ddae6d7fb --- /dev/null +++ b/astrbot/core/star/star.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from types import ModuleType +from typing import List, Dict +from dataclasses import dataclass +from astrbot.core.utils.command_parser import CommandParserMixin + +star_registry: List[StarMetadata] = [] +star_map: Dict[str, StarMetadata] = {} +'''key 是模块路径,__module__''' + +class Star(CommandParserMixin): + '''所有插件(Star)的父类,所有插件都应该继承于这个类''' + def __init__(self): + pass + +@dataclass +class StarMetadata: + ''' + Star 的元数据。 + ''' + name: str + author: str # 插件作者 + desc: str # 插件简介 + version: str # 插件版本 + repo: str = None # 插件仓库地址 + + star_cls_type: type = None + '''Star 的类对象的类型''' + module_path: str = None + '''Star 的模块路径''' + + star_cls: object = None + '''Star 的类对象''' + module: ModuleType = None + '''Star 的模块对象''' + root_dir_name: str = None + '''Star 的根目录名''' + reserved: bool = False + '''是否是 AstrBot 的保留 Star''' + + def __str__(self) -> str: + return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})" \ No newline at end of file diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py new file mode 100644 index 000000000..5ba6429de --- /dev/null +++ b/astrbot/core/star/star_handler.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Awaitable, List, Dict +from .filter import HandlerFilter + +star_handlers_registry: List[StarHandlerMetadata] = [] + +star_handlers_map: Dict[str, StarHandlerMetadata] = {} +'''用于快速查找。key 是 handler_full_name''' + +@dataclass +class StarHandlerMetadata(): + '''描述一个 Star 所注册的某一个 Handler。''' + + handler_full_name: str + '''格式为 f"{handler.__module__}_{handler.__name__}"''' + + handler_name: str + '''Handler 的名字,也就是方法名''' + + handler_module_str: str + '''Handler 所在的模块路径。''' + + handler: Awaitable + '''Handler 的函数对象,应当是一个异步函数''' + + event_filters: List[HandlerFilter] + '''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件''' + + desc: str = "" + '''Handler 的描述信息''' diff --git a/astrbot/core/plugin/plugin_manager.py b/astrbot/core/star/star_manager.py similarity index 57% rename from astrbot/core/plugin/plugin_manager.py rename to astrbot/core/star/star_manager.py index 34fdbc930..49ad4cc9a 100644 --- a/astrbot/core/plugin/plugin_manager.py +++ b/astrbot/core/star/star_manager.py @@ -1,31 +1,35 @@ import inspect import os -import sys import traceback -import uuid -import shutil import yaml import logging -from asyncio import Queue from types import ModuleType from typing import List, Awaitable from pip import main as pip_main from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core import logger from .context import Context -from . import RegisteredPlugin, PluginMetadata +from . import StarMetadata from .updator import PluginUpdator -from astrbot.core.db import BaseDatabase from astrbot.core.utils.io import remove_dir +from .star import star_registry, star_map + +from .star_handler import star_handlers_registry class PluginManager: - def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase): - self.updator = PluginUpdator(config.plugin_repo_mirror) - self.context = Context(event_queue, config, db) + def __init__( + self, + context: Context, + config: AstrBotConfig + ): + self.updator = PluginUpdator(config['plugin_repo_mirror']) + + self.context = context + self.config = config self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins")) self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages")) - + def _get_classes(self, arg: ModuleType): classes = [] clsmembers = inspect.getmembers(arg, inspect.isclass) @@ -69,6 +73,9 @@ class PluginManager: return plugins def _check_plugin_dept_update(self, target_plugin: str = None): + '''检查插件的依赖 + 如果 target_plugin 为 None,则检查所有插件的依赖 + ''' plugin_dir = self.plugin_store_path if not os.path.exists(plugin_dir): return False @@ -76,7 +83,7 @@ class PluginManager: if target_plugin: to_update.append(target_plugin) else: - for p in self.context.registered_plugins: + for p in self.context.get_all_stars(): to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) @@ -89,6 +96,7 @@ class PluginManager: logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") def _update_plugin_dept(self, path): + '''更新插件的依赖''' args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/'] if self.config.pip_install_arg: args.extend(self.config.pip_install_arg) @@ -96,7 +104,11 @@ class PluginManager: if result_code != 0: raise Exception(str(result_code)) - def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata: + def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata: + '''v3.4.0 以前的方式载入插件元数据 + + 先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。 + ''' metadata = None if not os.path.exists(plugin_path): @@ -112,8 +124,8 @@ class PluginManager: if isinstance(metadata, dict): if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata: raise Exception("插件元数据信息不完整。") - metadata = PluginMetadata( - plugin_name=metadata['name'], + metadata = StarMetadata( + name=metadata['name'], author=metadata['author'], desc=metadata['desc'], version=metadata['version'], @@ -123,72 +135,68 @@ class PluginManager: return metadata def reload(self): - ''' - 加载插件类 - ''' - registered_plugins = self.context.registered_plugins - plugins = self._get_plugin_modules() - if plugins is None: + '''扫描并加载所有的 Star''' + star_handlers_registry.clear() + + plugin_modules = self._get_plugin_modules() + if plugin_modules is None: return False, "未找到任何插件模块" fail_rec = "" - - registered_map = {} - for p in registered_plugins: - registered_map[p.module_path] = None - - for plugin in plugins: + + # 导入 Star 模块,并尝试实例化 Star 类 + for plugin_module in plugin_modules: try: - p = plugin['module'] - module_path = plugin['module_path'] - root_dir_name = plugin['pname'] - reserved = plugin.get('reserved', False) + module_str = plugin_module['module'] + module_path = plugin_module['module_path'] + root_dir_name = plugin_module['pname'] + reserved = plugin_module.get('reserved', False) - logger.info(f"正在加载插件 {root_dir_name} ...") + logger.info(f"正在载入插件 {root_dir_name} ...") - pre = "data.plugins." if not reserved else "packages." - - # 尝试导入插件模块 + # 尝试导入模块 + path = "data.plugins." if not reserved else "packages." + path += root_dir_name + "." + module_str try: - module = __import__(pre + root_dir_name + "." + p, fromlist=[p]) + module = __import__(path, fromlist=[module_str]) except (ModuleNotFoundError, ImportError) as e: - # 尝试安装插件依赖 + # 尝试安装依赖 self._check_plugin_dept_update(target_plugin=root_dir_name) - module = __import__(pre + root_dir_name + "." + p, fromlist=[p]) + module = __import__(path, fromlist=[module_str]) + except Exception as e: + logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}") + continue - cls = self._get_classes(module) - - # 实例化插件类 - try: - obj = getattr(module, cls[0])(context=self.context) - except BaseException as e: - logger.error(f"插件 {root_dir_name} 实例化失败。") - raise e - - # 解析插件元数据,加入注册列表 - metadata = None - plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name) - metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj) - if module_path not in registered_map: - registered_plugins.append(RegisteredPlugin( - metadata=metadata, - plugin_instance=obj, - module=module, - module_path=module_path, - root_dir_name=root_dir_name, - reserved=reserved - )) + if path in star_map: + # 通过装饰器的方式注册插件 + star_metadata = star_map[path] + star_metadata.star_cls = star_metadata.star_cls_type(context=self.context) + star_metadata.module = module + star_metadata.root_dir_name = root_dir_name + star_metadata.reserved = reserved + else: + # v3.4.0 以前的方式注册插件 + logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。") + classes = self._get_classes(module) + try: + obj = getattr(module, classes[0])(context=self.context) + except BaseException as e: + logger.error(f"插件 {root_dir_name} 实例化失败。") + raise e + + metadata = None + plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name) + metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj) + metadata.star_cls = obj + metadata.module = module + metadata.root_dir_name = root_dir_name + metadata.reserved = reserved + metadata.star_cls_type = obj.__class__ + metadata.module_path = path + star_map[path] = metadata - for command in self.context.commands_handler: - if self.context.commands_handler[command].plugin_name == metadata.plugin_name: - self.context.commands_handler[command].plugin_metadata = metadata - for listener in self.context.listeners_handler: - if self.context.listeners_handler[listener].plugin_name == metadata.plugin_name: - self.context.listeners_handler[listener].plugin_metadata = metadata - - except BaseException as e: traceback.print_exc() - fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n" + fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n" # 清除 pip.main 导致的多余的 logging handlers for handler in logging.root.handlers[:]: @@ -200,26 +208,26 @@ class PluginManager: return False, fail_rec async def install_plugin(self, repo_url: str): - plugin_path = await self.updator.update(repo_url) - with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: - f.write(repo_url) + plugin_path = await self.updator.install(repo_url) self._check_plugin_dept_update() return plugin_path def uninstall_plugin(self, plugin_name: str): - plugin = self.context.get_registered_plugin(plugin_name) + plugin = self.context.get_registered_star(plugin_name) if not plugin: raise Exception("插件不存在。") if plugin.reserved: raise Exception("该插件是 AstrBot 保留插件,无法卸载。") root_dir_name = plugin.root_dir_name ppath = self.plugin_store_path - self.context.registered_plugins.remove(plugin) + + del star_map[plugin.module_path] + if not remove_dir(os.path.join(ppath, root_dir_name)): raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") async def update_plugin(self, plugin_name: str): - plugin = self.context.get_registered_plugin(plugin_name) + plugin = self.context.get_registered_star(plugin_name) if not plugin: raise Exception("插件不存在。") if plugin.reserved: @@ -228,41 +236,13 @@ class PluginManager: await self.updator.update(plugin) def install_plugin_from_file(self, zip_file_path: str): - # try to unzip - temp_dir = os.path.join(os.path.dirname(zip_file_path), str(uuid.uuid4())) - self.updator.unzip_file(zip_file_path, temp_dir) - # check if the plugin has metadata.yaml - if not os.path.exists(os.path.join(temp_dir, "metadata.yaml")): - remove_dir(temp_dir) - raise Exception("插件缺少 metadata.yaml 文件。") - - metadata = self._load_plugin_metadata(temp_dir) - plugin_name = metadata.plugin_name - if not plugin_name: - remove_dir(temp_dir) - raise Exception("插件 metadata.yaml 文件中 name 字段为空。") - plugin_name = self.updator.format_name(plugin_name) + desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path)) + self.updator.unzip_file(zip_file_path, desti_dir) - ppath = self.plugin_store_path - plugin_path = os.path.join(ppath, plugin_name) - if os.path.exists(plugin_path): - remove_dir(plugin_path) - - # move to the target path - shutil.move(temp_dir, plugin_path) - - if metadata.repo: - with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: - f.write(metadata.repo) - - # remove the temp dir - remove_dir(temp_dir) + # remove the zip + try: + os.remove(zip_file_path) + except BaseException as e: + logger.warning(f"删除插件压缩包失败: {str(e)}") self._check_plugin_dept_update() - - def get_platform_insts(self): - return self.context.registered_platforms - - def get_loaded_plugins(self): - return self.context.registered_plugins - \ No newline at end of file diff --git a/astrbot/core/plugin/updator.py b/astrbot/core/star/updator.py similarity index 69% rename from astrbot/core/plugin/updator.py rename to astrbot/core/star/updator.py index 3d358f834..4564ef5c9 100644 --- a/astrbot/core/plugin/updator.py +++ b/astrbot/core/star/updator.py @@ -2,7 +2,7 @@ import os, zipfile, shutil from ..updator import RepoZipUpdator from astrbot.core.utils.io import remove_dir, on_error -from ..plugin import RegisteredPlugin +from ..star.star import StarMetadata from typing import Union from astrbot.core import logger @@ -13,20 +13,22 @@ class PluginUpdator(RepoZipUpdator): def get_plugin_store_path(self) -> str: return self.plugin_store_path - - async def update(self, plugin: Union[RegisteredPlugin, str]) -> str: - repo_url = None + + async def install(self, repo_url: str) -> str: + repo_name = self.format_repo_name(repo_url) + plugin_path = os.path.join(self.plugin_store_path, repo_name) + await self.download_from_repo_url(plugin_path, repo_url) + self.unzip_file(plugin_path + ".zip", plugin_path) - if not isinstance(plugin, str): - plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) - if not os.path.exists(os.path.join(plugin_path, "REPO")): - raise Exception("插件更新信息文件 `REPO` 不存在,请手动升级,或者先卸载然后重新安装该插件。") - - with open(os.path.join(plugin_path, "REPO"), "r", encoding='utf-8') as f: - repo_url = f.read() - else: - repo_url = plugin - plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url)) + return plugin_path + + async def update(self, plugin: StarMetadata) -> str: + repo_url = plugin.repo + + if not repo_url: + raise Exception(f"插件 {plugin.name} 没有指定仓库地址。") + + plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}") await self.download_from_repo_url(plugin_path, repo_url) @@ -34,7 +36,7 @@ class PluginUpdator(RepoZipUpdator): try: remove_dir(plugin_path) except BaseException as e: - logger.error(f"删除旧版本插件 {plugin.metadata.plugin_name} 文件夹失败: {str(e)},使用覆盖安装。") + logger.error(f"删除旧版本插件 {plugin_path} 文件夹失败: {str(e)},使用覆盖安装。") self.unzip_file(plugin_path + ".zip", plugin_path) @@ -48,13 +50,10 @@ class PluginUpdator(RepoZipUpdator): update_dir = z.namelist()[0] z.extractall(target_dir) - avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir] - files = os.listdir(os.path.join(target_dir, update_dir)) for f in files: logger.info(f"移动更新文件/目录: {f}") if os.path.isdir(os.path.join(target_dir, update_dir, f)): - if f in avoid_dirs: continue if os.path.exists(os.path.join(target_dir, f)): shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) else: @@ -68,3 +67,4 @@ class PluginUpdator(RepoZipUpdator): os.remove(zip_path) except: logger.warning(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") + diff --git a/astrbot/core/utils/command_parser.py b/astrbot/core/utils/command_parser.py index ff1e4b767..f454a00f9 100644 --- a/astrbot/core/utils/command_parser.py +++ b/astrbot/core/utils/command_parser.py @@ -10,13 +10,10 @@ class CommandTokens(): return None return self.tokens[idx].strip() -class CommandParser(): - def __init__(self): - pass - - def parse(self, message: str): +class CommandParserMixin(): + def parse_commands(self, message: str): cmd_tokens = CommandTokens() - cmd_tokens.tokens = message.split(" ") + cmd_tokens.tokens = re.split(r"\s+", message) cmd_tokens.len = len(cmd_tokens.tokens) return cmd_tokens diff --git a/astrbot/core/utils/image_uploader.py b/astrbot/core/utils/image_uploader.py deleted file mode 100644 index a9a76d08a..000000000 --- a/astrbot/core/utils/image_uploader.py +++ /dev/null @@ -1,21 +0,0 @@ -import aiohttp -import uuid - -class ImageUploader(): - def __init__(self) -> None: - self.S3_URL = "https://s3.neko.soulter.top/astrbot-s3" - - async def upload_image(self, image_path: str) -> str: - ''' - 上传图像文件到S3 - ''' - with open(image_path, "rb") as f: - image = f.read() - - image_url = f"{self.S3_URL}/{uuid.uuid4().hex}.jpg" - - async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session: - async with session.put(image_url, data=image) as resp: - if resp.status != 200: - raise Exception(f"Failed to upload image: {resp.status}") - return image_url diff --git a/astrbot/core/utils/param_validation_mixin.py b/astrbot/core/utils/param_validation_mixin.py new file mode 100644 index 000000000..f3d999978 --- /dev/null +++ b/astrbot/core/utils/param_validation_mixin.py @@ -0,0 +1,31 @@ +import inspect +from typing import Awaitable, List, Union, Dict, Any, Type + +class ParameterValidationMixin: + def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]: + '''将参数列表 params 根据 param_type 转换为参数字典。 + ''' + result = {} + print(params, param_type) + for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()): + if i >= len(params): + if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty: + # 是类型 + raise ValueError(f"参数 {param_name} 缺失") + else: + # 是默认值 + result[param_name] = param_type_or_default_val + else: + # 尝试强制转换 + try: + if param_type_or_default_val == None: + if params[i].isdigit(): + result[param_name] = int(params[i]) + else: + result[param_name] = params[i] + else: + result[param_name] = type(param_type_or_default_val)(params[i]) + except ValueError: + raise ValueError(f"参数 {param_name} 类型错误") + print(result) + return result \ No newline at end of file diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py index dc0be1a35..13bbd3963 100644 --- a/astrbot/dashboard/dashboard_lifecycle.py +++ b/astrbot/dashboard/dashboard_lifecycle.py @@ -21,4 +21,4 @@ class AstrBotDashBoardLifecycle: await task except asyncio.CancelledError as e: logger.info("🌈 正在关闭 AstrBot...") - core_lifecycle.stop() \ No newline at end of file + await core_lifecycle.stop() \ No newline at end of file diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index b54b3aebc..dd295a68d 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -41,7 +41,7 @@ class AuthRoute(Route): if new_username: self.config.dashboard.username = new_username - self.config.flush_config() + self.config.save_config() return Response().ok(None, "修改成功").__dict__ diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 2a0508a27..0d08902d0 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,11 +1,10 @@ -import os, json +import os, json, traceback from .route import Route, Response, RouteContext from quart import Quart, request -from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE +from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE, ADAPTER_CONFIG_TEMPLATE from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.plugin.config import update_config +from astrbot.core.star.config import update_config from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from dataclasses import asdict def try_cast(value: str, type_: str): if type_ == "int" and value.isdigit(): @@ -55,10 +54,6 @@ def validate_config(data, config: AstrBotConfig): validate(value, meta["items"], path=f"{path}{key}.") validate(data) - # hardcode warning - data['config_version'] = config.config_version - data['dashboard'] = asdict(config.dashboard) - return errors def save_astrbot_config(post_config: dict, config: AstrBotConfig): @@ -66,7 +61,7 @@ def save_astrbot_config(post_config: dict, config: AstrBotConfig): errors = validate_config(post_config, config) if errors: raise ValueError(f"格式校验未通过: {errors}") - config.flush_config(post_config) + config.save_config(post_config) def save_extension_config(post_config: dict): if 'namespace' not in post_config: @@ -112,6 +107,7 @@ class ConfigRoute(Route): await self._save_astrbot_configs(post_configs) return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ except Exception as e: + traceback.print_exc() return Response().error(str(e)).__dict__ async def post_extension_configs(self): @@ -123,14 +119,15 @@ class ConfigRoute(Route): return Response().error(str(e)).__dict__ async def _get_astrbot_config(self): - config = self.config.to_dict() + config = self.config for key in self.config_key_dont_show: if key in config: del config[key] return { "metadata": CONFIG_METADATA_2, "config": config, - "provider_config_tmpl": PROVIDER_CONFIG_TEMPLATE + "provider_config_tmpl": PROVIDER_CONFIG_TEMPLATE, + "adapter_config_tmpl": ADAPTER_CONFIG_TEMPLATE } async def _get_extension_config(self, namespace: str): diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 3977d906b..e1eb3ea83 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -2,7 +2,7 @@ import threading, traceback, uuid from .route import Route, Response, RouteContext from astrbot.core import logger from quart import Quart, request -from astrbot.core.plugin.plugin_manager import PluginManager +from astrbot.core.star.star_manager import PluginManager from astrbot.core.core_lifecycle import AstrBotCoreLifecycle class PluginRoute(Route): @@ -21,14 +21,14 @@ class PluginRoute(Route): async def get_plugins(self): _plugin_resp = [] - for plugin in self.plugin_manager.context.registered_plugins: - _p = plugin.metadata + for plugin in self.plugin_manager.context.get_all_stars(): _t = { - "name": _p.plugin_name, - "repo": '' if _p.repo is None else _p.repo, - "author": _p.author, - "desc": _p.desc, - "version": _p.version + "name": plugin.name, + "repo": '' if plugin.repo is None else plugin.repo, + "author": plugin.author, + "desc": plugin.desc, + "version": plugin.version, + "reserved": plugin.reserved } _plugin_resp.append(_t) return Response().ok(_plugin_resp).__dict__ diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index d0b4d62d5..1812fd7c2 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -72,8 +72,8 @@ class StatRoute(Route): stat_dict.update({ "platform": self.db_helper.get_grouped_base_stats(offset_sec).platform, "message_count": self.db_helper.get_total_message_count() or 0, - "platform_count": len(self.core_lifecycle.plugin_manager.get_platform_insts()), - "plugin_count": len(self.core_lifecycle.plugin_manager.get_loaded_plugins()), + "platform_count": len(self.core_lifecycle.platform_manager.get_insts()), + "plugin_count": len(self.core_lifecycle.star_context.get_all_stars()), "message_time_series": message_time_based_stats, "running": self.format_sec(int(time.time()) - self.core_lifecycle.start_time), "memory": { diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 9f5c772ec..576a35136 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -39,10 +39,10 @@ class AstrBotDashboard(): return # claim jwt token = request.headers.get("Authorization") - if token.startswith("Bearer "): - token = token[7:] if not token: return Response().error("未授权").__dict__ + if token.startswith("Bearer "): + token = token[7:] try: jwt.decode(token, WEBUI_SK, algorithms=["HS256"]) except jwt.ExpiredSignatureError: diff --git a/main.py b/main.py index 8b1b2cf30..78500193d 100644 --- a/main.py +++ b/main.py @@ -96,7 +96,8 @@ if __name__ == "__main__": # print logo logger.info(logo_tmpl) - dashboard_lifecycle = AstrBotDashBoardLifecycle(db) core_lifecycle = AstrBotCoreLifecycle(log_broker, db) + asyncio.run(core_lifecycle.initialize()) + dashboard_lifecycle = AstrBotDashBoardLifecycle(db) asyncio.run(dashboard_lifecycle.start(core_lifecycle)) \ No newline at end of file diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index d74650932..6d69e3359 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1,59 +1,21 @@ -import aiohttp, base64, os, json, re, time +import aiohttp +import astrbot.api.star as star +import astrbot.api.event.filter as filter from typing import Dict -from astrbot.api import Context, AstrMessageEvent, MessageEventResult -from astrbot.api import logger, command_parser +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.api.platform import MessageType +from astrbot.api import logger +from astrbot.api import personalities +from astrbot.api.provider import Personality -class Main: - def __init__(self, context: Context) -> None: +from typing import Union + +@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0") +class Main(star.Star): + def __init__(self, context: star.Context) -> None: self.context = context - context.register_commands("astrbot", "help", "查看 AstrBot 帮助", 10, self.help) - context.register_commands("astrbot", "plugin", "AstrBot 插件管理", 10, self.plugin) - context.register_commands("astrbot", "t2i", "关闭/启动文本转图片", 10, self.t2i) - context.register_commands("astrbot", "myid", "查看自己在该平台上的 ID", 10, self.myid) - context.register_listener("astrbot", "keywords_ban_rate_limit", self.keywords_ban, "关键词屏蔽和发言频率监听器") - # keywords - with open(os.path.join(os.path.dirname(__file__), "unfit_words"), "r", encoding="utf-8") as f: - self.keywords: list = json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'] - internal_keywords_cfg = context.get_config().content_safety.internal_keywords - if internal_keywords_cfg.enable: - self.keywords.extend(internal_keywords_cfg.extra_keywords) - - # rate limit - self.user_rate_limit: Dict[int, int] = {} - rl_cfg = context.get_config().platform_settings.rate_limit - self.rate_limit_time: int = rl_cfg.time - self.rate_limit_count: int = rl_cfg.count - self.user_frequency = {} - - async def keywords_ban(self, event: AstrMessageEvent): - if not event.is_wake_up(): - return - - # keywords 检测 - for i in self.keywords: - matches = re.match(i, event.get_message_str().strip(), re.I | re.M) - if matches: - event.set_result(MessageEventResult().message("你的消息中包含不适当的关键词,已被屏蔽。")) - return - - # rate limit 检测 - ts = int(time.time()) - if event.session_id in self.user_frequency: - if ts-self.user_frequency[event.session_id]['time'] > self.rate_limit_time: - # reset - self.user_frequency[event.session_id]['time'] = ts - self.user_frequency[event.session_id]['count'] = 1 - return - if self.user_frequency[event.session_id]['count'] >= self.rate_limit_count: - event.set_result(MessageEventResult().message("你发送消息的频率过快,请稍后再试。")) - return - self.user_frequency[event.session_id]['count'] += 1 - else: - t = {'time': ts, 'count': 1} - self.user_frequency[event.session_id] = t - - + @filter.command("help") async def help(self, event: AstrMessageEvent): notice = "" try: @@ -63,26 +25,41 @@ class Main: except BaseException as e: pass - msg = "# AstrBot 帮助\n## 已注册的指令\n" - for key, value in self.context.commands_handler.items(): - if value.plugin_metadata: - msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n" - else: msg += f"- `{key}`: {value.description}\n" + msg = "已注册的 AstrBot 内置指令:" + msg += f"""[System] +/plugin: 插件管理 +/t2i: 开启/关闭文本转图片模式 +/sid: 获取当前会话的 ID +/op : 授权管理员 +/deop : 取消管理员 +/wl : 添加会话白名单 +/dwl : 删除会话白名单 - msg += "\n> 提示:使用 /plugin 查看已加载的插件\n" - msg += notice +[大模型] +/provider: 查看、切换大模型提供商 +/model: 查看、切换提供商模型列表 +/key: 查看、切换 API Key +/reset: 重置 LLM 会话 +/history: 获取会话历史记录 +/persona: 情境人格设置 - event.set_result(MessageEventResult().message(msg)) - +提示:如果要查看插件指令,请输入 /plugin 查看具体信息。 +{notice}""" + + event.set_result(MessageEventResult().message(msg).use_t2i(False)) + + @filter.command("plugin") async def plugin(self, event: AstrMessageEvent): plugin_list_info = "已加载的插件:\n" for plugin in self.context.registered_plugins: - plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {plugin.metadata.author}: {plugin.metadata.desc}\n" + plugin_list_info += f"- `{plugin.metadata.plugin_name}` By { + plugin.metadata.author}: {plugin.metadata.desc}\n" if plugin_list_info.strip() == "": plugin_list_info = "没有加载任何插件。" - + event.set_result(MessageEventResult().message(f"{plugin_list_info}")) - + + @filter.command("t2i") async def t2i(self, event: AstrMessageEvent): config = self.context.get_config() if config.t2i: @@ -93,7 +70,199 @@ class Main: config.t2i = True config.save_config() event.set_result(MessageEventResult().message("已开启文本转图片模式。")) - - async def myid(self, event: AstrMessageEvent): + + @filter.command("sid") + async def sid(self, event: AstrMessageEvent): + sid = event.unified_msg_origin user_id = str(event.get_sender_id()) - event.set_result(MessageEventResult().message(f"你的 ID 是 {user_id}。此 ID 可用于设置 AstrBot 管理员。")) \ No newline at end of file + ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl 添加白名单, /dwl 删除白名单。 +UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deop 取消管理员。""" + event.set_result(MessageEventResult().message(ret)) + + @filter.command("op") + async def op(self, event: AstrMessageEvent, admin_id: str): + self.context.get_config()['admins_id'].append(admin_id) + self.context.get_config().save_config() + event.set_result(MessageEventResult().message("授权成功。")) + + @filter.command("deop") + async def deop(self, event: AstrMessageEvent, admin_id: str): + try: + self.context.get_config()['admins_id'].remove(admin_id) + self.context.get_config().save_config() + event.set_result(MessageEventResult().message("取消授权成功。")) + except ValueError: + event.set_result(MessageEventResult().message("此用户 ID 不在管理员名单内。")) + + @filter.command("wl") + async def wl(self, event: AstrMessageEvent, sid: str): + self.context.get_config()['platform_settings']['id_whitelist'].append(sid) + self.context.get_config().save_config() + event.set_result(MessageEventResult().message("添加白名单成功。")) + + @filter.command("dwl") + async def dwl(self, event: AstrMessageEvent, sid: str): + try: + self.context.get_config()['platform_settings']['id_whitelist'].remove(sid) + self.context.get_config().save_config() + event.set_result(MessageEventResult().message("删除白名单成功。")) + except ValueError: + event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) + + @filter.command("provider") + async def provider(self, event: AstrMessageEvent, idx: int = None): + '''查看或者切换 LLM Provider''' + + if idx is None: + ret = "## 当前载入的 LLM 提供商\n" + for idx, llm in enumerate(self.context.get_all_providers()): + ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})" + if self.provider == llm: + ret += " (当前使用)" + ret += "\n" + + ret += "\n使用 /provider <序号> 切换提供商。" + event.set_result(MessageEventResult().message(ret)) + else: + if idx > len(self.context.get_all_providers()) or idx < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + + self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1] + + event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}。")) + + @filter.command("reset") + async def reset(self, message: AstrMessageEvent): + await self.context.get_using_provider().forget(message.session_id) + message.set_result(MessageEventResult().message("重置成功")) + + @filter.command("model") + async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None): + if idx_or_name is None: + models = [] + try: + models = await self.context.get_using_provider().get_models() + except BaseException as e: + message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e))) + return + i = 1 + ret = "下面列出了此服务提供商可用模型:" + for model in models: + ret += f"\n{i}. {model}" + i += 1 + ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + else: + if isinstance(idx_or_name, int): + models = [] + try: + models = await self.context.get_using_provider().get_models() + except BaseException as e: + message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e))) + return + if idx_or_name > len(models) or idx_or_name < 1: + message.set_result(MessageEventResult().message("模型序号错误。")) + else: + try: + new_model = models[idx_or_name-1] + self.context.get_using_provider().set_model(new_model) + except BaseException as e: + message.set_result( + MessageEventResult().message("切换模型未知错误: "+str(e))) + message.set_result(MessageEventResult().message("切换模型成功。")) + else: + self.context.get_using_provider().set_model(idx_or_name) + message.set_result( + MessageEventResult().message(f"切换模型成功。 \n模型信息: {idx_or_name}")) + + + @filter.command("history") + async def his(self, message: AstrMessageEvent, page: int = 1): + size_per_page = 3 + contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page) + + history = "" + for context in contexts: + history += f"{context}\n" + + ret = f"""历史记录: +{history} +第 {page} 页 | 共 {total_pages} 页 + +*输入 /history 2 跳转到第 2 页 +""" + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + + @filter.command("key") + async def key(self, message: AstrMessageEvent, index: int=None): + + if index == None: + keys_data = self.context.get_using_provider().get_keys() + curr_key = self.context.get_using_provider().get_current_key() + ret = "Key:" + for i, k in enumerate(keys_data): + ret += f"\n{i+1}. {k[:8]}" + + ret += f"\n当前 Key: {curr_key[:8]}" + ret += "\n当前模型: " + self.context.get_using_provider().get_model() + ret += "\n使用 /key 切换 Key。" + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + else: + keys_data = self.context.get_using_provider().get_keys() + if index > len(keys_data) or index < 1: + message.set_result(MessageEventResult().message("Key 序号错误。")) + else: + try: + new_key = keys_data[index-1] + self.context.get_using_provider().set_key(new_key) + except BaseException as e: + message.set_result( + MessageEventResult().message("切换 Key 未知错误: "+str(e))) + message.set_result(MessageEventResult().message("切换 Key 成功。")) + + @filter.command("persona") + async def persona(self, message: AstrMessageEvent): + l = message.message_str.split(" ") + if len(l) == 1: + message.set_result( + MessageEventResult().message(f"""[Persona] + +- 设置人格: `/persona 人格名`, 如 /persona 编剧 +- 人格列表: `/persona list` +- 人格详细信息: `/persona view 人格名` +- 自定义人格: /persona 人格文本 +- 重置 LLM 会话(清除人格): /reset +- 重置 LLM 会话(保留人格): /reset p + +【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])} +""")) + elif l[1] == "list": + msg = "人格列表:\n" + for key in personalities.keys(): + msg += f"- {key}\n" + msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息' + message.set_result(MessageEventResult().message(msg)) + elif l[1] == "view": + if len(l) == 2: + message.set_result(MessageEventResult().message("请输入人格名")) + ps = l[2].strip() + if ps in personalities: + msg = f"人格{ps}的详细信息:\n" + msg += f"{personalities[ps]}\n" + else: + msg = f"人格{ps}不存在" + message.set_result(MessageEventResult().message(msg)) + else: + ps = "".join(l[1:]).strip() + if ps in personalities: + self.context.get_using_provider().curr_personality = Personality( + name=ps, prompt=personalities[ps]) + message.set_result( + MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) + else: + self.context.get_using_provider().curr_personality = Personality( + name="自定义人格", prompt=ps) + message.set_result( + MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) diff --git a/packages/astrbot/metadata.yaml b/packages/astrbot/metadata.yaml deleted file mode 100644 index af88b41ad..000000000 --- a/packages/astrbot/metadata.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: astrbot # 插件名称 -desc: AstrBot 内置指令集 -help: -version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1 -author: AstrBot # 作者 -repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/main.py b/packages/astrbot_adapter_aiocqhttp/main.py deleted file mode 100644 index 01ecbd9e6..000000000 --- a/packages/astrbot_adapter_aiocqhttp/main.py +++ /dev/null @@ -1,13 +0,0 @@ -from astrbot.api import Context -from .aiocqhttp_platform_adapter import AiocqhttpAdapter -from astrbot.api import logger - -class Main: - def __init__(self, context: Context) -> None: - self.context = context - platforms_config = context.get_config().platform - settings = context.get_config().platform_settings - for platform in platforms_config: - if platform.name == "aiocqhttp" and platform.enable: - self.context.register_platform(AiocqhttpAdapter(platform, settings, context.get_event_queue())) - logger.info(f"已注册 aiocqhttp({platform.id}) 消息适配器。") \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/metadata.yaml b/packages/astrbot_adapter_aiocqhttp/metadata.yaml deleted file mode 100644 index 4269c4e78..000000000 --- a/packages/astrbot_adapter_aiocqhttp/metadata.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: astrbot_adapter_aiocqhttp # 插件名称 -desc: 支持 OneBot 协议的消息平台适配器(反向 Websockets) -help: -version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1 -author: Soulter # 作者 -repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_adapter_qqofficial/main.py b/packages/astrbot_adapter_qqofficial/main.py deleted file mode 100644 index b7e9fd88a..000000000 --- a/packages/astrbot_adapter_qqofficial/main.py +++ /dev/null @@ -1,18 +0,0 @@ -import botpy, logging -# delete qqbotpy's logger -for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - -from astrbot.api import Context -from .qqofficial_platform_adapter import QQOfficialPlatformAdapter -from astrbot.api import logger - -class Main: - def __init__(self, context: Context) -> None: - self.context = context - platforms_config = context.get_config().platform - settings = context.get_config().platform_settings - for platform in platforms_config: - if platform.name == "qq_official" and platform.enable: - self.context.register_platform(QQOfficialPlatformAdapter(platform, settings, context.get_event_queue())) - logger.info(f"已注册 qq_official({platform.id}) 消息适配器。") \ No newline at end of file diff --git a/packages/astrbot_adapter_qqofficial/metadata.yaml b/packages/astrbot_adapter_qqofficial/metadata.yaml deleted file mode 100644 index 263cbafc4..000000000 --- a/packages/astrbot_adapter_qqofficial/metadata.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: astrbot_adapter_qqofficial # 插件名称 -desc: 支持 QQ 官方机器人平台的消息平台适配器 -help: -version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1 -author: Soulter # 作者 -repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/main.py b/packages/astrbot_adapter_wechat/main.py deleted file mode 100644 index b6eaf25e2..000000000 --- a/packages/astrbot_adapter_wechat/main.py +++ /dev/null @@ -1,18 +0,0 @@ -from astrbot.api import Context, AstrMessageEvent, MessageEventResult -from .wechat_platform_adapter import WechatPlatformAdapter -from astrbot.api import logger - -class Main: - def __init__(self, context: Context) -> None: - self.context = context - platforms_config = context.get_config().platform - settings = context.get_config().platform_settings - for platform in platforms_config: - if platform.name == "wechat" and platform.enable: - self.context.register_platform(WechatPlatformAdapter(platform, settings, context.get_event_queue())) - logger.info(f"已注册 wechat({platform.id}) 消息适配器。") - - self.context.register_commands("astrbot_adapter_wechat", "wechatid", "查看微信ID", 1, self.get_wechat_id) - - async def get_wechat_id(self, event: AstrMessageEvent): - event.set_result(MessageEventResult().message("这个会话的微信ID是" + event.message_obj.raw_message.from_.username)) \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/metadata.yaml b/packages/astrbot_adapter_wechat/metadata.yaml deleted file mode 100644 index 16c8db775..000000000 --- a/packages/astrbot_adapter_wechat/metadata.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: astrbot_adapter_wechat # 插件名称 -desc: 支持 Wechat(UOS) 的消息平台适配器 -help: -version: v1.0.0 # 插件版本号。格式:v1.1.1 或者 v1.1 -author: Soulter # 作者 -repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/__init__.py b/packages/astrbot_plugin_openai/__init__.py deleted file mode 100644 index 490c10d4b..000000000 --- a/packages/astrbot_plugin_openai/__init__.py +++ /dev/null @@ -1 +0,0 @@ -PLUGIN_NAME = "astrbot_plugin_openai" \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/commands.py b/packages/astrbot_plugin_openai/commands.py deleted file mode 100644 index 885b570fe..000000000 --- a/packages/astrbot_plugin_openai/commands.py +++ /dev/null @@ -1,169 +0,0 @@ -from astrbot.api import Context, AstrMessageEvent, MessageEventResult, MessageChain -from . import PLUGIN_NAME -from astrbot.api import logger -from astrbot.api.message_components import Image, Plain -from astrbot.api import personalities -from astrbot.api import command_parser -from astrbot.api import Provider, Personality - - -class OpenAIAdapterCommand: - def __init__(self, context: Context) -> None: - self.provider: Provider = None - self.context = context - context.register_commands(PLUGIN_NAME, "reset", "重置会话", 10, self.reset) - context.register_commands(PLUGIN_NAME, "his", "查看历史记录", 10, self.his) - context.register_commands(PLUGIN_NAME, "status", "查看当前状态", 10, self.status) - context.register_commands(PLUGIN_NAME, "switch", "切换账号", 10, self.switch) - context.register_commands(PLUGIN_NAME, "persona", "设置个性化人格", 10, self.persona) - context.register_commands(PLUGIN_NAME, "draw", "调用 DallE 模型画图", 10, self.draw) - context.register_commands(PLUGIN_NAME, "model", "切换 LLM 模型", 10, self.model) - context.register_commands(PLUGIN_NAME, "画", "调用 DallE 模型画图", 10, self.draw) - - def set_provider(self, provider: Provider): - self.provider = provider - - async def reset(self, message: AstrMessageEvent): - tokens = command_parser.parse(message.message_str) - if tokens.len == 1: - await self.provider.forget(message.session_id) - message.set_result(MessageEventResult().message("重置成功")) - elif tokens.get(1) == 'p': - await self.provider.forget(message.session_id) - - async def model(self, message: AstrMessageEvent): - tokens = command_parser.parse(message.message_str) - if tokens.len == 1: - ret = await self._print_models() - message.set_result(MessageEventResult().message(ret).use_t2i(False)) - return - model = tokens.get(1) - if model.isdigit(): - try: - models = await self.provider.get_models() - except BaseException as e: - logger.error(f"获取模型列表失败: {str(e)}。如果出现 404,可能与服务提供商未提供模型列表有关。") - message.set_result(MessageEventResult().message("获取模型列表失败,无法使用编号切换模型。可以尝试直接输入模型名来切换,如 gpt-4o。")) - models = list(models) - if int(model) <= len(models) and int(model) >= 1: - model = models[int(model)-1] - self.provider.set_model(model.id) - message.set_result(MessageEventResult().message(f"模型已设置为 {model.id}")) - else: - self.provider.set_model(model) - message.set_result(MessageEventResult().message(f"模型已设置为 {model} (自定义)")) - - async def _print_models(self): - models = [] - try: - models = await self.provider.get_models() - except BaseException as e: - return "获取模型列表失败: " + str(e) - i = 1 - ret = "下面列出了此服务提供商可用模型:" - for model in models: - ret += f"\n{i}. {model.id}" - i += 1 - ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" - logger.debug(ret) - return ret - - def his(self, message: AstrMessageEvent): - tokens = command_parser.parse(message.message_str) - size_per_page = 3 - page = 1 - if tokens.len == 2: - try: - page = int(tokens.get(1)) - except BaseException as e: - message.set_result(MessageEventResult().message("页码格式错误")) - contexts, total_num = self.provider.dump_contexts_page(message.session_id, size_per_page, page=page) - t_pages = total_num // size_per_page + 1 - message.set_result(MessageEventResult().message(f"历史记录:\n\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n\n*输入 /his 2 跳转到第 2 页")) - - def status(self, message: AstrMessageEvent): - keys_data = self.provider.get_all_keys() - ret = "{} Key" - for k in keys_data: - ret += "\n|- " + k[:8] - - ret += "\n当前模型: " + self.provider.get_model() - - message.set_result(MessageEventResult().message(ret).use_t2i(False)) - - async def switch(self, message: AstrMessageEvent): - ''' - 切换账号 - ''' - tokens = command_parser.parse(message.message_str) - if tokens.len == 1: - ret = "" - curr_ = self.provider.get_curr_key() - if curr_ is None: - ret += "当前您未选择账号。输入/switch <账号序号>切换账号。使用 /status 查看账号列表。" - else: - ret += f"当前您选择的账号为:{curr_[:8]}。输入/switch <账号序号>切换账号。使用 /status 查看账号列表。" - message.set_result(MessageEventResult().message(ret)) - elif tokens.len == 2: - try: - key_stat = self.provider.get_keys_data() - index = int(tokens.get(1)) - if index > len(key_stat) or index < 1: - message.set_result(MessageEventResult().message("账号序号错误。")) - else: - try: - new_key = list(key_stat.keys())[index-1] - self.provider.set_key(new_key) - except BaseException as e: - message.set_result(MessageEventResult().message("切换账号未知错误: "+str(e))) - message.set_result(MessageEventResult().message("切换账号成功。") ) - except BaseException as e: - message.set_result(MessageEventResult().message("切换账号错误。")) - else: - message.set_result(MessageEventResult().message("参数过多。")) - - - def persona(self, message: AstrMessageEvent): - l = message.message_str.split(" ") - if len(l) == 1: - message.set_result( - MessageEventResult().message(f"""[Persona] - -- 设置人格: `/persona 人格名`, 如 /persona 编剧 -- 人格列表: `/persona list` -- 人格详细信息: `/persona view 人格名` -- 自定义人格: /persona 人格文本 -- 重置 LLM 会话(清除人格): /reset -- 重置 LLM 会话(保留人格): /reset p - -【当前人格】: {str(self.provider.curr_personality['prompt'])} -""")) - elif l[1] == "list": - msg = "人格列表:\n" - for key in personalities.keys(): - msg += f"- {key}\n" - msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息' - message.set_result(MessageEventResult().message(msg)) - elif l[1] == "view": - if len(l) == 2: - message.set_result(MessageEventResult().message("请输入人格名")) - ps = l[2].strip() - if ps in personalities: - msg = f"人格{ps}的详细信息:\n" - msg += f"{personalities[ps]}\n" - else: - msg = f"人格{ps}不存在" - message.set_result(MessageEventResult().message(msg)) - else: - ps = "".join(l[1:]).strip() - if ps in personalities: - self.provider.curr_personality = Personality(name=ps, prompt=personalities[ps]) - message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) - else: - self.provider.curr_personality = Personality(name="自定义人格", prompt=ps) - message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) - - async def draw(self, message: AstrMessageEvent): - prompt = message.message_str.removeprefix("画") - img_url = await self.provider.image_generate(prompt) - message.set_result(MessageEventResult().url_image(img_url)) \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/main.py b/packages/astrbot_plugin_openai/main.py deleted file mode 100644 index 990b534da..000000000 --- a/packages/astrbot_plugin_openai/main.py +++ /dev/null @@ -1,253 +0,0 @@ -import json, traceback -from typing import List, Dict -from astrbot.api import Context, AstrMessageEvent, MessageEventResult -from .openai_adapter import ProviderOpenAIOfficial -from .commands import OpenAIAdapterCommand -from astrbot.api import logger -from . import PLUGIN_NAME -from astrbot.api import MessageChain -from astrbot.api.message_components import Image, Plain -from openai._exceptions import * -from openai.types.chat.chat_completion_message_tool_call import Function -from astrbot.api import command_parser -from .web_searcher import search_from_bing, fetch_website_content -from astrbot.core.utils.metrics import Metric -from astrbot.core.config.astrbot_config import LLMConfig - -class Main: - def __init__(self, context: Context) -> None: - supported_provider_names = ["openai", "ollama", "gemini", "deepseek", "zhipu"] - self.context = context - - # 各 Provider 实例 - self.provider_insts: List[ProviderOpenAIOfficial] = [] - # Provider 的配置 - self.provider_llm_configs: List[LLMConfig] = [] - # 当前使用的 Provider - self.provider = None - # 当前使用的 Provider 的配置 - self.provider_config = None - - atri_config = self.context.get_config().project_atri - - loaded = False - for llm in self.context.get_config().llm: - if llm.enable: - if llm.name in supported_provider_names: - if not llm.key or not llm.enable: - logger.warning("没有开启 LLM Provider 或 API Key 未填写。") - continue - self.provider_insts.append(ProviderOpenAIOfficial(llm, self.context.get_db())) - self.provider_llm_configs.append(llm) - loaded = True - logger.info(f"已启用 LLM Provider(OpenAI API 适配器): {llm.id}({llm.name})。") - if loaded: - self.command_handler = OpenAIAdapterCommand(self.context) - self.command_handler.set_provider(self.provider_insts[0]) - self.context.register_listener(PLUGIN_NAME, "llm_chat_listener", self.chat, "llm_chat_listener", after_commands=True) - self.provider = self.command_handler.provider - self.provider_config = self.provider_llm_configs[0] - self.context.register_commands(PLUGIN_NAME, "provider", "查看当前 LLM Provider", 10, self.provider_info) - self.context.register_commands(PLUGIN_NAME, "websearch", "启用/关闭网页搜索", 10, self.web_search) - - if self.context.get_config().llm_settings.web_search: - self.add_web_search_tools() - - # load atri - self.atri = None - if atri_config.enable: - try: - from .atri import ATRI - self.atri = ATRI(self.provider_llm_configs, atri_config, self.context) - self.command_handler.provider = self.atri.atri_chat_provider - except ImportError as e: - logger.error(traceback.format_exc()) - logger.error("载入 ATRI 失败。请确保使用 pip 安装了 requirements_atri.txt 下的库。") - self.atri = None - except BaseException as e: - logger.error(traceback.format_exc()) - logger.error("载入 ATRI 失败。") - self.atri = None - - def add_web_search_tools(self): - self.context.register_llm_tool("web_search", [{ - "type": "string", - "name": "keyword", - "description": "搜索关键词" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - search_from_bing - ) - self.context.register_llm_tool("fetch_website_content", [{ - "type": "string", - "name": "url", - "description": "要获取内容的网页链接" - }], - "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content - ) - - def remove_web_search_tools(self): - self.context.unregister_llm_tool("web_search") - self.context.unregister_llm_tool("fetch_website_content") - - async def provider_info(self, event: AstrMessageEvent): - if len(self.provider_insts) == 0: - event.set_result(MessageEventResult().message("未启用任何 LLM Provider。")) - - tokens = command_parser.parse(event.get_message_str()) - - if tokens.len == 1: - ret = "## 当前载入的 LLM 接入源\n" - for idx, llm in enumerate(self.provider_insts): - ret += f"{idx}. {llm.llm_config.id} ({llm.llm_config.model_config.model})" - if self.provider == llm: - ret += " (当前使用)" - ret += "\n" - - ret += "\n使用 /provider <序号> 切换 LLM 接入源。" - event.set_result(MessageEventResult().message(ret)) - return - else: - try: - idx = int(tokens.get(1)) - if idx >= len(self.provider_insts): - event.set_result(MessageEventResult().message("无效的序号。")) - self.provider = self.provider_insts[idx] - self.provider_config = self.provider_llm_configs[idx] - self.command_handler.set_provider(self.provider) - event.set_result(MessageEventResult().message(f"已经成功切换到 LLM 接入源 {self.provider.llm_config.id}。")) - return - except BaseException as e: - event.set_result(MessageEventResult().message("provider: 参数错误。")) - return - - async def web_search(self, event: AstrMessageEvent): - websearch = self.context.get_config().llm_settings.web_search - if websearch: - # turn off - self.context.get_config().llm_settings.web_search = False - self.context.get_config().save_config() - self.remove_web_search_tools() - event.set_result(MessageEventResult().message("已关闭网页搜索。")) - return - # turn on - self.context.get_config().llm_settings.web_search = True - self.context.get_config().save_config() - self.add_web_search_tools() - event.set_result(MessageEventResult().message("已开启网页搜索。")) - - async def chat(self, event: AstrMessageEvent): - if not event.is_wake_up(): - return - if self.atri: - await self.atri.chat(event) - return - - # prompt 前缀 - if self.provider_config.prompt_prefix: - event.message_str = self.provider_config.prompt_prefix + event.message_str - - image_urls = [] - for comp in event.message_obj.message: - if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - image_urls.append(image_url) - - tool_use_flag = False - llm_result = None - try: - if not self.context.llm_tools.empty(): - # tools-use - tool_use_flag = True - llm_result = await self.provider.text_chat( - prompt=event.message_str, - session_id=event.session_id, - tools=self.context.llm_tools.get_func() - ) - await Metric.upload(llm_tick=1, llm_name=self.provider.get_model(), llm_api_base=self.provider.base_url) - - if isinstance(llm_result, Function): - logger.debug(f"function-calling: {llm_result}") - func_obj = None - for i in self.context.llm_tools.func_list: - if i["name"] == llm_result.name: - func_obj = i["func_obj"] - break - if not func_obj: - event.set_result(MessageEventResult().message("AstrBot Function-calling 异常:未找到请求的函数调用。")) - return - try: - args = json.loads(llm_result.arguments) - args['event'] = event - args['provider'] = self.provider - try: - func_result = await func_obj(**args) - except TypeError as e: - args.pop('event') - args.pop('provider') - func_result = await func_obj(**args) - if func_result: - logger.warning(f"function-calling: 工具函数 {llm_result.name} 返回了非空值,该值将被忽略。请使用 event.set_result() 设置返回值。") - return - if event.get_result(): - return - except BaseException as e: - traceback.print_exc() - event.set_result(MessageEventResult().message("AstrBot Function-calling 异常:" + str(e))) - return - else: - event.set_result(MessageEventResult().message(llm_result)) - return - else: - # normal chat - # add user info to the prompt - if self.context.get_config().llm_settings.identifier: - user_id = event.message_obj.sender.user_id - user_nickname = event.message_obj.sender.nickname - user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n" - event.message_str = user_info + event.message_str - llm_result = await self.provider.text_chat( - prompt=event.message_str, - session_id=event.session_id, - image_urls=image_urls - ) - await Metric.upload(llm_tick=1, llm_name=self.provider.get_model(), llm_api_base=self.provider.base_url) - except BadRequestError as e: - if tool_use_flag: - # seems like the model don't support function-calling - logger.error(f"error: {e}. Using local function-calling implementation") - - try: - # use local function-calling implementation - args = { - 'question': llm_result, - 'func_definition': self.context.llm_tools.func_dump(), - } - _, has_func = await self.context.llm_tools.func_call(**args) - - if not has_func: - # normal chat - llm_result = await self.provider.text_chat( - prompt=event.message_str, - session_id=event.session_id, - image_urls=image_urls - ) - except BaseException as e: - logger.error(traceback.format_exc()) - event.set_result(MessageEventResult().message("AstrBot Function-calling 异常:" + str(e))) - return - else: - logger.error(traceback.format_exc()) - logger.error(f"LLM 调用失败。") - event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) - return - except BaseException as e: - logger.error(traceback.format_exc()) - logger.error(f"LLM 调用失败。") - event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) - return - - if llm_result: - event.set_result(MessageEventResult().message(llm_result)) - return diff --git a/packages/astrbot_plugin_openai/metadata.yaml b/packages/astrbot_plugin_openai/metadata.yaml deleted file mode 100644 index 8af419fb6..000000000 --- a/packages/astrbot_plugin_openai/metadata.yaml +++ /dev/null @@ -1,6 +0,0 @@ -name: astrbot_plugin_openai # 插件名称 -desc: 支持 OpenAI API -help: -version: v1.5.0 # 插件版本号。格式:v1.1.1 或者 v1.1 -author: Soulter # 作者 -repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/openai_adapter.py b/packages/astrbot_plugin_openai/openai_adapter.py deleted file mode 100644 index 7fe8cfa8d..000000000 --- a/packages/astrbot_plugin_openai/openai_adapter.py +++ /dev/null @@ -1,254 +0,0 @@ -import os -import asyncio -import traceback -import base64 -import json - -from openai import AsyncOpenAI -from openai.types.chat.chat_completion import ChatCompletion -from openai._exceptions import * -from astrbot.core.utils.io import download_image_by_url - -from astrbot.core.db import BaseDatabase -from astrbot.api import Provider -from astrbot.core.config.astrbot_config import LLMConfig -from astrbot import logger -from typing import List, Dict -from dataclasses import asdict - -class ProviderOpenAIOfficial(Provider): - def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase, persistant_history = True) -> None: - super().__init__(db_helper, llm_config.default_personality, persistant_history) - self.api_keys = [] - self.chosen_api_key = None - self.base_url = None - self.llm_config = llm_config - self.api_keys = llm_config.key - if llm_config.api_base: - self.base_url = llm_config.api_base - self.chosen_api_key = self.api_keys[0] - - self.client = AsyncOpenAI( - api_key=self.chosen_api_key, - base_url=self.base_url - ) - self.set_model(llm_config.model_config.model) - - # 各类模型的配置 - self.image_generator_model_configs = None - self.embedding_model_configs = None - if llm_config.image_generation_model_config and llm_config.image_generation_model_config.enable: - self.image_generator_model_configs: Dict = asdict( - llm_config.image_generation_model_config) - self.image_generator_model_configs.pop("enable") - if llm_config.embedding_model and llm_config.embedding_model.enable: - self.embedding_model_configs: Dict = asdict( - llm_config.embedding_model) - self.embedding_model_configs.pop("enable") - - async def encode_image_bs64(self, image_url: str) -> str: - ''' - 将图片转换为 base64 - ''' - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode('utf-8') - return "data:image/jpeg;base64," + image_bs64 - return '' - - async def get_models(self): - models = [] - try: - models = await self.client.models.list() - except NotFoundError as e: - bu = str(self.client.base_url) - self.client.base_url = bu + "/v1" - models = await self.client.models.list() - return models - - async def pop_record(self, session_id: str, pop_system_prompt: bool = False): - ''' - 弹出第一条记录 - ''' - if session_id not in self.session_memory: - raise Exception("会话 ID 不存在") - - if len(self.session_memory[session_id]) == 0: - return None - - for i in range(len(self.session_memory[session_id])): - # 检查是否是 system prompt - if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": - # 如果只有一个 system prompt,才不删掉 - f = False - for j in range(i+1, len(self.session_memory[session_id])): - if self.session_memory[session_id][j]['user']['role'] == "system": - f = True - break - if not f: - continue - record = self.session_memory[session_id].pop(i) - break - - return record - - async def assemble_context(self, text: str, image_urls: List[str] = None): - ''' - 组装上下文。 - ''' - if image_urls: - user_content = {"role": "user","content": [{"type": "text", "text": text}]} - for image_url in image_urls: - if image_url.startswith("http"): - image_data = image_url - else: - image_data = await self.encode_image_bs64(image_url) - user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) - return user_content - else: - return {"role": "user","content": text} - - - async def text_chat(self, - prompt: str, - session_id: str, - image_urls=None, - tools=None, - contexts=None, - **kwargs - ) -> str: - ''' - 调用 LLM 进行文本对话。 - - @param tools: LLM Function-calling 的工具函数 - @param contexts: 如果不为 None,则会原封不动地使用这个上下文进行对话。 - ''' - if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on": - return "这是一个测试消息。" - - new_record = await self.assemble_context(prompt, image_urls) - if not contexts: - contexts = [*self.session_memory[session_id], new_record] - if self.curr_personality["prompt"]: - contexts.insert(0, {"role": "system", "content": self.curr_personality["prompt"]}) - - logger.debug(f"请求上下文:{contexts}") - conf = asdict(self.llm_config.model_config) - if tools: - conf['tools'] = tools - - # start request - retry = 0 - while retry < 3: - completion_coro = self.client.chat.completions.create( - messages=contexts, - stream=False, - **conf - ) - try: - completion = await completion_coro - break - except Exception as e: - retry += 1 - if retry >= 3: - logger.error(traceback.format_exc()) - raise Exception(f"请求失败:{e}。重试次数已达到上限。") - if "maximum context length" in str(e): - logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") - self.pop_record(session_id) - - logger.warning(traceback.format_exc()) - logger.warning(f"请求失败:{e}。重试第 {retry} 次。") - await asyncio.sleep(1) - - assert isinstance(completion, ChatCompletion) - logger.debug(f"completion: {completion.usage}") - - if len(completion.choices) == 0: - raise Exception("API 返回的 completion 为空。") - choice = completion.choices[0] - - if choice.message.content: - # 返回文本 - completion_text = str(choice.message.content).strip() - # 添加用户 record - self.session_memory[session_id].append(new_record) - # 添加 assistant record - self.session_memory[session_id].append({ - "role": "assistant", - "content": completion_text - }) - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id])) - return completion_text - elif choice.message.tool_calls and choice.message.tool_calls: - # tools call (function calling) - return choice.message.tool_calls[0].function - else: - raise Exception("Internal Error") - - async def image_generate(self, prompt: str, session_id: str = None, **kwargs) -> str: - ''' - 生成图片 - ''' - retry = 0 - if not self.image_generator_model_configs: - return - while retry < 3: - try: - images_response = await self.client.images.generate( - prompt=prompt, - **self.image_generator_model_configs - ) - image_url = images_response.data[0].url - return image_url - except Exception as e: - retry += 1 - if retry >= 3: - logger.error(traceback.format_exc()) - raise Exception(f"图片生成请求失败:{e}。重试次数已达到上限。") - logger.warning(f"图片生成请求失败:{e}。重试第 {retry} 次。") - await asyncio.sleep(1) - - async def get_embedding(self, text) -> List[float]: - ''' - 获取文本的嵌入 - ''' - if not self.embedding_model_configs: - return - try: - embedding = await self.client.embeddings.create( - input=text, - **self.embedding_model_configs - ) - return embedding.data[0].embedding - except Exception as e: - logger.error(f"获取文本嵌入失败:{e}") - - async def forget(self, session_id: str) -> bool: - self.session_memory[session_id] = [] - return True - - def dump_contexts_page(self, session_id: str, size=5, page=1,): - ''' - 获取缓存的会话 - ''' - contexts_str = "" - if session_id in self.session_memory: - for record in self.session_memory[session_id]: - if record['role'] == "user": - text = record['content'][:100] + "..." if len( - record['content']) > 100 else record['content'] - contexts_str += f"User: {text}\n\n" - elif record['role'] == "assistant": - text = record['content'][:100] + "..." if len( - record['content']) > 100 else record['content'] - contexts_str += f"Assistant: {text}\n\n" - else: - contexts_str = "会话 ID 不存在。" - - return contexts_str, len(self.session_memory[session_id]) - - def get_curr_key(self): - return self.chosen_api_key - - def get_all_keys(self): - return self.api_keys \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1d5bfe9ae..89f0f0edc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ psutil lxml_html_clean colorlog aiocqhttp -pyjwt \ No newline at end of file +pyjwt +apscheduler \ No newline at end of file diff --git a/requirements_atri_base.txt b/requirements_atri_base.txt new file mode 100644 index 000000000..ea90da291 --- /dev/null +++ b/requirements_atri_base.txt @@ -0,0 +1,2 @@ +chromadb +openai \ No newline at end of file diff --git a/requirements_atri_ft.txt b/requirements_atri_ft.txt new file mode 100644 index 000000000..d2e4234f2 --- /dev/null +++ b/requirements_atri_ft.txt @@ -0,0 +1,2 @@ +llmtuner +bitsandbytes \ No newline at end of file