diff --git a/astrbot/__init__.py b/astrbot/__init__.py new file mode 100644 index 000000000..5dff017b0 --- /dev/null +++ b/astrbot/__init__.py @@ -0,0 +1,2 @@ +from .core.log import LogManager +logger = LogManager.GetLogger(log_name='astrbot') \ No newline at end of file diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py new file mode 100644 index 000000000..25aa68641 --- /dev/null +++ b/astrbot/api/__init__.py @@ -0,0 +1,15 @@ + +from astrbot.core.plugin import Context +from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from astrbot.core.message_event_result import MessageEventResult, MessageChain, CommandResult +from astrbot.core.provider import Provider +from astrbot.core.config.astrbot_config import AstrBotConfig +from nakuru.entities.components import * +from astrbot import logger +from astrbot.core.utils.personality import personalities + +from astrbot.core.utils.command_parser import CommandParser, CommandTokens +from astrbot.core.utils.func_call import FuncCall +from astrbot.core import html_renderer + +command_parser = CommandParser() \ No newline at end of file diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py deleted file mode 100644 index ceeff0877..000000000 --- a/astrbot/bootstrap.py +++ /dev/null @@ -1,148 +0,0 @@ -import asyncio -import traceback -import os -import threading -from astrbot.message.handler import MessageHandler -from astrbot.db.sqlite import SQLiteDatabase -from dashboard.server import AstrBotDashboard -from model.command.manager import CommandManager -from model.command.internal_handler import InternalCommandHandler -from model.plugin.manager import PluginManager -from model.platform.manager import PlatformManager -from typing import Union -from type.types import Context -from type.config import VERSION, DB_PATH -from logging import Logger -from util.cmd_config import AstrBotConfig, try_migrate -from util.metrics import MetricUploader -from util.updator.astrbot_updator import AstrBotUpdator -from util.log import LogManager -from util.cmd_config import LLMConfig - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -class AstrBotBootstrap(): - def __init__(self) -> None: - self.context = Context() - # load configs and ensure the backward compatibility - try_migrate() - self.config_helper = AstrBotConfig() - self.context.config_helper = self.config_helper - # set log queue handler - LogManager.set_queue_handler(logger, self.context.log_broker) - logger.info("AstrBot v" + VERSION) - # set log level - logger.setLevel(self.config_helper.log_level) - # apply proxy settings - http_proxy = self.context.config_helper.http_proxy - https_proxy = self.context.config_helper.https_proxy - if http_proxy: - os.environ['HTTP_PROXY'] = http_proxy - if https_proxy: - os.environ['HTTPS_PROXY'] = https_proxy - os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com' - - if http_proxy and https_proxy: - logger.info(f"使用代理: {http_proxy}, {https_proxy}") - else: - logger.info("未使用代理。") - - self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' - - # set t2i endpoint - if self.context.config_helper.t2i_endpoint: - self.context.image_renderer.set_network_endpoint( - self.context.config_helper.t2i_endpoint - ) - - async def run(self): - self.command_manager = CommandManager() - self.plugin_manager = PluginManager(self.context) - self.updator = AstrBotUpdator() - self.cmd_handler = InternalCommandHandler(self.command_manager, self.plugin_manager) - self.db_helper = SQLiteDatabase(DB_PATH) - - # load llm provider - self.load_llm() - - self.message_handler = MessageHandler(self.context, self.command_manager, self.db_helper) - self.platfrom_manager = PlatformManager(self.context, self.message_handler) - self.dashboard = AstrBotDashboard(self.context, - plugin_manager=self.plugin_manager, - astrbot_updator=self.updator, - db_helper=self.db_helper) - self.metrics_uploader = MetricUploader(self.db_helper) - - self.context.metrics_uploader = self.metrics_uploader - self.context.updator = self.updator - self.context.plugin_updator = self.plugin_manager.updator - self.context.message_handler = self.message_handler - self.context.command_manager = self.command_manager - - # load dashboard - dashboard_server_task = asyncio.create_task(self.dashboard.run(), name="dashboard") - - if self.test_mode: - def run_dashboard(): - asyncio.run(self.dashboard.run()) - dashboard_thread = threading.Thread(target=run_dashboard, name="dashboard_thread", daemon=False) - dashboard_thread.start() - return - - # load plugins, plugins' commands. - self.load_plugins() - self.command_manager.register_from_pcb(self.context.plugin_command_bridge) - - # load platforms - platform_tasks = self.load_platform() - # load metrics uploader - metrics_upload_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics-uploader") - - tasks = [metrics_upload_task, dashboard_server_task, *platform_tasks, *self.context.ext_tasks] - tasks = [self.handle_task(task) for task in tasks] - await asyncio.gather(*tasks) - - async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]): - while True: - try: - result = await task - return result - except asyncio.CancelledError: - logger.info(f"{task.get_name()} 任务已取消。") - return - except Exception as e: - logger.error(traceback.format_exc()) - logger.error(f"{task.get_name()} 任务发生错误。") - return - - def load_llm(self): - f = False - llms = self.context.config_helper.llm - logger.info(f"加载 {len(llms)} 个 LLM Provider...") - for llm in llms: - if llm.enable: - if llm.name == "openai": - if not llm.key or not llm.enable: - logger.warning("没有开启 LLM Provider 或 API Key 未填写。") - continue - self.load_openai(llm) - f = True - logger.info(f"已启用 LLM Provider(OpenAI API): {llm.id}({llm.name})。") - if f: - from model.command.openai_official_handler import OpenAIOfficialCommandHandler - self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) - self.openai_command_handler.set_provider(self.context.llms[0].llm_instance) - - def load_openai(self, llm_config: LLMConfig): - from model.provider.openai_official import ProviderOpenAIOfficial - inst = ProviderOpenAIOfficial(llm_config, self.db_helper) - self.context.register_provider(llm_config.id, inst) - - def load_plugins(self): - self.plugin_manager.plugin_reload() - - def load_platform(self): - platforms = self.platfrom_manager.load_platforms() - if not platforms: - logger.warning("未启用任何消息平台。") - return platforms \ No newline at end of file diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py new file mode 100644 index 000000000..d289c6e99 --- /dev/null +++ b/astrbot/core/__init__.py @@ -0,0 +1,5 @@ +from .log import LogManager, LogBroker +from core.utils.t2i.renderer import HtmlRenderer + +html_renderer = HtmlRenderer() +logger = LogManager.GetLogger(log_name='astrbot') \ No newline at end of file diff --git a/astrbot/core/config/__init__.py b/astrbot/core/config/__init__.py new file mode 100644 index 000000000..a781e688a --- /dev/null +++ b/astrbot/core/config/__init__.py @@ -0,0 +1,2 @@ +from .default import DEFAULT_CONFIG_VERSION_2, VERSION, DB_PATH +from .astrbot_config import AstrBotConfig \ No newline at end of file diff --git a/util/cmd_config.py b/astrbot/core/config/astrbot_config.py similarity index 81% rename from util/cmd_config.py rename to astrbot/core/config/astrbot_config.py index 423287129..6a61193de 100644 --- a/util/cmd_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -2,8 +2,7 @@ import os import json import shutil import logging -from util.io import on_error -from type.config import DEFAULT_CONFIG_VERSION_2, MAPPINGS_1_2 +from . import DEFAULT_CONFIG_VERSION_2 from dataclasses import dataclass, field, asdict from typing import List, Dict, Optional @@ -52,7 +51,7 @@ class NakuruPlatformConfig(PlatformConfig): class AiocqhttpPlatformConfig(PlatformConfig): ws_reverse_host: str = "" ws_reverse_port: int = 6199 - + @dataclass class ModelConfig: model: str = "gpt-4o" @@ -184,32 +183,6 @@ class AstrBotConfig(): self.pip_install_arg=data.get("pip_install_arg", "") self.plugin_repo_mirror=data.get("plugin_repo_mirror", "") - - def migrate_config_1_2(self, old: dict) -> dict: - '''将配置文件从版本 1 迁移至版本 2''' - logger.info("正在更新配置文件到 version 2...") - new_config = DEFAULT_CONFIG_VERSION_2 - mappings = MAPPINGS_1_2 - - def set_nested_value(d, keys, value): - cursor = d - for key in keys[:-1]: - cursor = cursor[key] - cursor[keys[-1]] = value - - for old_path, new_path in mappings: - value = old - try: - for key in old_path: - value = value[key] # soooooo convenient!! - set_nested_value(new_config, new_path, value) - except KeyError: - # 如果旧配置中没有这个键,跳过,即使用新配置的默认值 - continue - - logger.info("配置文件更新完成。") - return new_config - def flush_config(self, config: dict = None): '''将配置写入文件, 如果没有传入配置,则写入默认配置''' with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: @@ -273,27 +246,4 @@ class AstrBotConfig(): return asdict(self) def check_exist(self) -> bool: - return os.path.exists(ASTRBOT_CONFIG_PATH) - -def try_migrate(): - ''' - - 将 cmd_config.json 迁移至 data/cmd_config.json (如果存在) - - 将 addons/plugins 迁移至 data/plugins (如果存在) - ''' - if os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"): - try: - shutil.move("cmd_config.json", "data/cmd_config.json") - except: - logger.error("迁移 cmd_config.json 失败。") - - if os.path.exists("addons/plugins"): - if os.path.exists("data/plugins"): - try: - shutil.rmtree("data/plugins", onerror=on_error) - except: - logger.error("删除 data/plugins 失败。") - try: - shutil.move("addons/plugins", "data/") - shutil.rmtree("addons", onerror=on_error) - except: - logger.error("迁移 addons/plugins 失败。") \ No newline at end of file + return os.path.exists(ASTRBOT_CONFIG_PATH) \ No newline at end of file diff --git a/type/config.py b/astrbot/core/config/default.py similarity index 77% rename from type/config.py rename to astrbot/core/config/default.py index 5b5bff593..9854d4b06 100644 --- a/type/config.py +++ b/astrbot/core/config/default.py @@ -1,4 +1,4 @@ -VERSION = '3.3.19' +VERSION = '3.4.0' DB_PATH = 'data/data_v2.db' # 新版本配置文件,摈弃旧版本令人困惑的配置项 :D @@ -14,18 +14,6 @@ DEFAULT_CONFIG_VERSION_2 = { "enable_group_c2c": True, "enable_guild_direct_message": True, }, - { - "id": "default", - "name": "nakuru", - "enable": False, - "host": "172.0.0.1", - "port": 5700, - "websocket_port": 6700, - "enable_group": True, - "enable_guild": True, - "enable_direct_message": True, - "enable_group_increase": True, - }, { "id": "default", "name": "aiocqhttp", @@ -75,12 +63,6 @@ DEFAULT_CONFIG_VERSION_2 = { "identifier": False, }, "content_safety": { - "baidu_aip": { - "enable": False, - "app_id": "", - "api_key": "", - "secret_key": "", - }, "internal_keywords": { "enable": True, "extra_keywords": [], @@ -103,63 +85,6 @@ DEFAULT_CONFIG_VERSION_2 = { "plugin_repo_mirror": "default", } -# 这个是用于迁移旧版本配置文件的映射表 -MAPPINGS_1_2 = [ - [["qqbot", "enable"], ["platform", 0, "enable"]], - [["qqbot", "appid"], ["platform", 0, "appid"]], - [["qqbot", "token"], ["platform", 0, "secret"]], - [["qqofficial_enable_group_message"], ["platform", 0, "enable_group_c2c"]], - [["direct_message_mode"], ["platform", 0, "enable_guild_direct_message"]], - [["gocqbot", "enable"], ["platform", 1, "enable"]], - [["gocq_host"], ["platform", 1, "host"]], - [["gocq_http_port"], ["platform", 1, "port"]], - [["gocq_websocket_port"], ["platform", 1, "websocket_port"]], - [["gocq_react_group"], ["platform", 1, "enable_group"]], - [["gocq_react_guild"], ["platform", 1, "enable_guild"]], - [["gocq_react_friend"], ["platform", 1, "enable_direct_message"]], - [["gocq_react_group_increase"], ["platform", 1, "enable_group_increase"]], - [["aiocqhttp", "enable"], ["platform", 2, "enable"]], - [["aiocqhttp", "ws_reverse_host"], ["platform", 2, "ws_reverse_host"]], - [["aiocqhttp", "ws_reverse_port"], ["platform", 2, "ws_reverse_port"]], - [["uniqueSessionMode"], ["platform_settings", "unique_session"]], - [["limit", "time"], ["platform_settings", "rate_limit", "time"]], - [["limit", "count"], ["platform_settings", "rate_limit", "count"]], - [["reply_prefix"], ["platform_settings", "reply_prefix"]], - [["qq_forward_threshold"], ["platform_settings", "forward_threshold"]], - - [["openai", "key"], ["llm", 0, "key"]], - [["openai", "api_base"], ["llm", 0, "api_base"]], - [["openai", "chatGPTConfigs", "model"], ["llm", 0, "model_config", "model"]], - [["openai", "chatGPTConfigs", "max_tokens"], ["llm", 0, "model_config", "max_tokens"]], - [["openai", "chatGPTConfigs", "temperature"], ["llm", 0, "model_config", "temperature"]], - [["openai", "chatGPTConfigs", "top_p"], ["llm", 0, "model_config", "top_p"]], - [["openai", "chatGPTConfigs", "frequency_penalty"], ["llm", 0, "model_config", "frequency_penalty"]], - [["openai", "chatGPTConfigs", "presence_penalty"], ["llm", 0, "model_config", "presence_penalty"]], - - [["default_personality_str"], ["llm", 0, "default_personality"]], - [["llm_env_prompt"], ["llm", 0, "prompt_prefix"]], - [["openai_image_generate", "model"], ["llm", 0, "image_generation_model_config", "model"]], - [["openai_image_generate", "size"], ["llm", 0, "image_generation_model_config", "size"]], - [["openai_image_generate", "style"], ["llm", 0, "image_generation_model_config", "style"]], - [["openai_image_generate", "quality"], ["llm", 0, "image_generation_model_config", "quality"]], - - [["llm_wake_prefix"], ["llm_settings", "wake_prefix"]], - - [["baidu_aip", "enable"], ["content_safety", "baidu_aip", "enable"]], - [["baidu_aip", "app_id"], ["content_safety", "baidu_aip", "app_id"]], - [["baidu_aip", "api_key"], ["content_safety", "baidu_aip", "api_key"]], - [["baidu_aip", "secret_key"], ["content_safety", "baidu_aip", "secret_key"]], - - [["qq_pic_mode"], ["t2i"]], - [["dump_history_interval"], ["dump_history_interval"]], - [["other_admins"], ["admins_id"]], - [["http_proxy"], ["http_proxy"]], - [["https_proxy"], ["https_proxy"]], - [["dashboard_username"], ["dashboard", "username"]], - [["dashboard_password"], ["dashboard", "password"]], - [["nick_qq"], ["wake_prefix"]], -] - # 配置项的中文描述、值类型 CONFIG_METADATA_2 = { "config_version": {"description": "配置版本", "type": "int"}, @@ -291,3 +216,13 @@ CONFIG_METADATA_2 = { "pip_install_arg": {"description": "pip 安装参数", "type": "string", "hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。"}, "plugin_repo_mirror": {"description": "插件仓库镜像", "type": "string", "hint": "插件仓库的镜像地址,用于加速插件的下载。", "options": ["default", "https://github-mirror.us.kg/"]}, } + +DEFAULT_VALUE_MAP = { + "int": 0, + "float": 0.0, + "bool": False, + "string": "", + "text": "", + "list": [], + "object": {}, +} \ No newline at end of file diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py new file mode 100644 index 000000000..350ab0b61 --- /dev/null +++ b/astrbot/core/core_lifecycle.py @@ -0,0 +1,59 @@ +import asyncio, time, threading +from .event_bus import EventBus +from asyncio import Queue +from typing import List +from core.config.astrbot_config import AstrBotConfig +from core.message_event_handler import MessageEventHandler +from core.plugin import PluginManager +from core import LogBroker +from core.utils.metrics import MetricUploader +from core.db import BaseDatabase +from core.updator import AstrBotUpdator +from core import logger +from core.config.default import VERSION + +class AstrBotCoreLifecycle: + def __init__(self, log_broker: LogBroker, db: BaseDatabase): + self.log_broker = log_broker + self.astrbot_config = AstrBotConfig() + logger.info("AstrBot v"+ VERSION) + logger.setLevel(self.astrbot_config.log_level) + self.event_queue = Queue() + self.event_queue.closed = False + self.plugin_manager = PluginManager(self.astrbot_config, self.event_queue, db) + self.message_event_handler = MessageEventHandler(self.astrbot_config, self.plugin_manager) + self.metrics_uploader = MetricUploader(db) + self.astrbot_updator = AstrBotUpdator(self.astrbot_config.plugin_repo_mirror) + self.event_bus = EventBus(self.event_queue, self.message_event_handler) + self.stop_flag = False + self.start_time = int(time.time()) + + self.curr_tasks: List[asyncio.Task] = [] + + def _load(self): + self.plugin_manager.reload() + + platform_tasks = self.load_platform() + event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus") + metrics_uploader_task = asyncio.create_task(self.metrics_uploader.upload_metrics(), name="metrics") + + self.curr_tasks = [event_bus_task, metrics_uploader_task, *platform_tasks] + self.start_time = int(time.time()) + + async def start(self): + self._load() + await asyncio.gather(*self.curr_tasks, return_exceptions=True) + + def stop(self): + self.stop_flag = True + + def restart(self): + self.event_queue.closed = True + threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start() + + def load_platform(self) -> List[asyncio.Task]: + tasks = [] + platform_insts = self.plugin_manager.get_platform_insts() + for platform_inst in platform_insts: + tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)) + return tasks \ No newline at end of file diff --git a/astrbot/db/__init__.py b/astrbot/core/db/__init__.py similarity index 97% rename from astrbot/db/__init__.py rename to astrbot/core/db/__init__.py index 6e3b6d8c2..e536598d9 100644 --- a/astrbot/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -1,7 +1,7 @@ import abc from dataclasses import dataclass from typing import List -from astrbot.db.po import Stats, LLMHistory +from core.db.po import Stats, LLMHistory @dataclass class BaseDatabase(abc.ABC): diff --git a/astrbot/db/po.py b/astrbot/core/db/po.py similarity index 97% rename from astrbot/db/po.py rename to astrbot/core/db/po.py index b09e12994..96dc3d7e9 100644 --- a/astrbot/db/po.py +++ b/astrbot/core/db/po.py @@ -1,7 +1,6 @@ '''指标数据''' from dataclasses import dataclass, field -# default_factory from typing import List @dataclass diff --git a/astrbot/db/sqlite.py b/astrbot/core/db/sqlite.py similarity index 99% rename from astrbot/db/sqlite.py rename to astrbot/core/db/sqlite.py index 1d065afe3..c9a4699fc 100644 --- a/astrbot/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,7 +1,7 @@ import sqlite3 import os import time -from astrbot.db.po import ( +from core.db.po import ( Platform, Command, Provider, diff --git a/astrbot/db/sqlite_init.sql b/astrbot/core/db/sqlite_init.sql similarity index 100% rename from astrbot/db/sqlite_init.sql rename to astrbot/core/db/sqlite_init.sql diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py new file mode 100644 index 000000000..d6da1d467 --- /dev/null +++ b/astrbot/core/event_bus.py @@ -0,0 +1,26 @@ +import asyncio +from asyncio import Queue +from collections import defaultdict +from typing import List +from .message_event_handler import MessageEventHandler +from core import logger +from .platform import AstrMessageEvent +from nakuru.entities.components import Plain, Image + +class EventBus: + def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler): + self.event_queue = event_queue + self.message_event_handler = message_event_handler + + async def dispatch(self): + logger.info("事件总线已打开。") + while True: + event: AstrMessageEvent = await self.event_queue.get() + self._print_event(event) + asyncio.create_task(self.message_event_handler.handle(event)) + + def _print_event(self, event: AstrMessageEvent): + if event.get_sender_name(): + logger.info(f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}") + else: + logger.info(f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}") \ No newline at end of file diff --git a/util/log.py b/astrbot/core/log.py similarity index 98% rename from util/log.py rename to astrbot/core/log.py index 32e1532da..73d3c1b86 100644 --- a/util/log.py +++ b/astrbot/core/log.py @@ -1,6 +1,6 @@ import logging, colorlog, asyncio from collections import deque -from asyncio import Queue, Lock +from asyncio import Queue from typing import List CACHED_SIZE = 200 diff --git a/astrbot/core/message_event_handler.py b/astrbot/core/message_event_handler.py new file mode 100644 index 000000000..7eda2039d --- /dev/null +++ b/astrbot/core/message_event_handler.py @@ -0,0 +1,175 @@ +import asyncio, re +import inspect +import traceback +from typing import List, Union +from .platform import AstrMessageEvent +from .config.astrbot_config import AstrBotConfig +from .message_event_result import MessageEventResult, CommandResult, MessageChain +from .plugin import PluginManager, Context, CommandMetadata +from .provider import Provider +from nakuru.entities.components import * +from core import logger +from core import html_renderer + +class CommandTokens(): + def __init__(self) -> None: + self.tokens = [] + self.len = 0 + + def get(self, idx: int): + if idx >= self.len: + return None + return self.tokens[idx].strip() + +class CommandParser(): + def __init__(self): + pass + + def parse(self, message: str): + cmd_tokens = CommandTokens() + cmd_tokens.tokens = message.split(" ") + cmd_tokens.len = len(cmd_tokens.tokens) + return cmd_tokens + + def regex_match(self, message: str, command: str) -> bool: + return re.search(command, message, re.MULTILINE) is not None + + +class MessageEventHandler(): + ''' + 处理消息事件。 + ''' + def __init__(self, config: AstrBotConfig, plugin_manager: PluginManager): + self.config = config + self.plugin_manager = plugin_manager + self.command_parser = CommandParser() + + async def handle(self, event: AstrMessageEvent): + ''' + 处理消息事件。 + ''' + event.message_str = event.message_str.strip() + for admin_id in self.config.admins_id: + if event.get_sender_id() == admin_id: + event.role = "admin" + break + + # 检查 wake + wake_prefixes = self.config.wake_prefix + messages = event.get_messages() + is_wake = False + for wake_prefix in wake_prefixes: + if event.message_str.startswith(wake_prefix): + is_wake = True + break + if not is_wake: + # 检查是否有 at 消息 + for message in messages: + if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"): + is_wake = True + wake_prefix = "" + break + # 检查是否是私聊 + if event.is_private_chat(): + is_wake = True + wake_prefix = "" + event.is_wake = is_wake + + # 处理事件监听器(在指令扫描之前) + listeners = self.plugin_manager.context.registered_listeners + listeners_handler = self.plugin_manager.context.listeners_handler + for name in listeners: + if listeners_handler[name].after_commands: + continue + ret = await listeners_handler[name].handler(event) + if ret: + event.set_result(ret) + if event.get_result(): + return await self.post_handle(event) + + # 处理指令,指令带有指定过的前缀 + commands = self.plugin_manager.context.registered_commands + commands_handler = self.plugin_manager.context.commands_handler + + # 扫描指令 + for command in commands: + command = command[1] + trig = False + pre_ = "" + if not commands_handler[command].ignore_prefix: + pre_ = wake_prefix + + if commands_handler[command].use_regex: + trig = self.command_parser.regex_match(event.message_str, pre_ + command) + else: + trig = event.message_str.startswith(pre_ + command) + if trig: + ret = await self.execute_handler(command, commands_handler[command], event) + if ret: + event.set_result(ret) + if event.get_result(): + return await self.post_handle(event) + + # 处理事件监听器(在指令扫描之后) + for name in listeners: + if not listeners_handler[name].after_commands: + continue + ret = await listeners_handler[name].handler(event) + if ret: + event.set_result(ret) + if event.get_result(): + return await self.post_handle(event) + + async def post_handle(self, event: AstrMessageEvent): + result = event.get_result() + if result.callback: + await result.callback(event) + + # prefix + if self.config.platform_settings.reply_prefix: + result.chain.insert(0, Plain(self.config.platform_settings.reply_prefix)) + + # t2i + if (result.use_t2i_ is None and self.config.t2i) or result.use_t2i_: + plain_str = "" + for comp in result.chain: + if isinstance(comp, Plain): + plain_str += "\n\n" + comp.text + else: + break + if plain_str and len(plain_str) > 150: + url = await html_renderer.render_t2i(plain_str, return_url=True) + if url: + result.chain = [Image.fromURL(url)] + + logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}") + + await event.send(result) + + async def execute_handler(self, + command: str, + command_metadata: CommandMetadata, + message_event: AstrMessageEvent): + logger.info(f"触发 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令。") + handler = command_metadata.handler + try: + if inspect.iscoroutinefunction(handler): + command_result = await handler(message_event) + else: + command_result = handler(message_event) + + if command_result is not None: + message_event.set_result(command_result) + except TypeError as e: + # 兼容旧版本插件 + if inspect.iscoroutinefunction(handler): + command_result = await handler(message_event, self.plugin_manager.context) + else: + command_result = handler(message_event, self.plugin_manager.context) + + if command_result is not None: + message_event.set_result(command_result) + except BaseException as e: + logger.error(traceback.format_exc()) + text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}" + message_event.set_result(MessageEventResult().message(text)) \ No newline at end of file diff --git a/astrbot/core/message_event_result.py b/astrbot/core/message_event_result.py new file mode 100644 index 000000000..89136871b --- /dev/null +++ b/astrbot/core/message_event_result.py @@ -0,0 +1,58 @@ +from typing import List, Union, Optional +from dataclasses import dataclass, field +from nakuru.entities.components import * + +@dataclass +class MessageChain(): + chain: List[BaseMessageComponent] = field(default_factory=list) + use_t2i_: Optional[bool] = None # None 为跟随用户设置 + + def message(self, message: str): + ''' + 快捷回复消息。 + + CommandResult().message("Hello, world!") + ''' + self.chain.append(Plain(message)) + return self + + def error(self, message: str): + ''' + 快捷回复消息。 + + CommandResult().error("Hello, world!") + ''' + self.chain.append(Plain(message)) + return self + + def url_image(self, url: str): + ''' + 快捷回复图片(网络url的格式)。 + + CommandResult().image("https://example.com/image.jpg") + ''' + self.chain.append(Image.fromURL(url)) + return self + + def file_image(self, path: str): + ''' + 快捷回复图片(本地文件路径的格式)。 + + CommandResult().image("image.jpg") + ''' + self.chain.append(Image.fromFileSystem(path)) + return self + + def use_t2i(self, use_t2i: bool): + ''' + 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 + ''' + self.use_t2i_ = use_t2i + return self + +@dataclass +class MessageEventResult(MessageChain): + is_command_call: Optional[bool] = False + callback: Optional[callable] = None + +CommandResult = MessageEventResult \ No newline at end of file diff --git a/astrbot/core/platform/__init__.py b/astrbot/core/platform/__init__.py new file mode 100644 index 000000000..d7db721ef --- /dev/null +++ b/astrbot/core/platform/__init__.py @@ -0,0 +1,4 @@ +from .platform import Platform +from .astr_message_event import AstrMessageEvent +from .platform_metadata import PlatformMetadata +from .astrbot_message import AstrBotMessage, MessageMember, MessageType \ No newline at end of file diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py new file mode 100644 index 000000000..28ca834c1 --- /dev/null +++ b/astrbot/core/platform/astr_message_event.py @@ -0,0 +1,142 @@ +import abc +from dataclasses import dataclass +from .astrbot_message import AstrBotMessage +from .platform_metadata import PlatformMetadata +from core.message_event_result import MessageEventResult, MessageChain +from core.platform.message_type import MessageType +from typing import List +from nakuru.entities.components import BaseMessageComponent, Plain, Image + +@dataclass +class MessageSesion: + platform_name: str + message_type: MessageType + session_id: str + + def __str__(self): + return f"{self.platform_name}:{self.message_type.value}:{self.session_id}" + + @staticmethod + def from_str(session_str: str): + platform_name, message_type, session_id = session_str.split(":") + return MessageSesion(platform_name, MessageType(message_type), session_id) + +class AstrMessageEvent(abc.ABC): + def __init__(self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str,): + self.message_str = message_str + self.message_obj = message_obj + self.platform_meta = platform_meta + self.session_id = session_id + self.role = "member" + self.is_wake = False + + self._result: MessageEventResult = None + self._extras = {} + self.session = MessageSesion( + platform_name=platform_meta.name, + message_type=message_obj.type, + session_id=session_id + ) + self.unified_msg_origin = str(self.session) + + def get_platform_name(self): + return self.platform_meta.name + + def get_message_str(self) -> str: + ''' + 获取消息字符串。 + ''' + return self.message_str + + def _outline_chain(self, chain: List[BaseMessageComponent]) -> str: + outline = "" + for i in chain: + if isinstance(i, Plain): + outline += i.text + if isinstance(i, Image): + outline += "[图片]" + return outline + + def get_message_outline(self) -> str: + ''' + 获取消息概要。 + + 除了文本消息外,其他消息类型会被转换为对应的占位符。如图片消息会被转换为 [图片]。 + ''' + return self._outline_chain(self.message_obj.message) + + def get_messages(self) -> List[BaseMessageComponent]: + ''' + 获取消息链。 + ''' + return self.message_obj.message + + def get_session_id(self) -> str: + ''' + 获取会话id。 + ''' + return self.session_id + + def get_self_id(self) -> str: + ''' + 获取机器人自身的id。 + ''' + return self.message_obj.self_id + + def get_sender_id(self) -> str: + ''' + 获取消息发送者的id。 + ''' + return self.message_obj.sender.user_id + + def get_sender_name(self) -> str: + ''' + 获取消息发送者的名称。(可能会返回空字符串) + ''' + return self.message_obj.sender.nickname + + def set_result(self, result: MessageEventResult): + ''' + 设置消息事件的结果。当设置了结果后,消息事件将不再继续传递。 + ''' + self._result = result + + def get_result(self) -> MessageEventResult: + ''' + 获取消息事件的结果。 + ''' + return self._result + + def set_extra(self, key, value): + ''' + 设置额外的信息。 + ''' + self._extras[key] = value + + def is_private_chat(self) -> bool: + ''' + 是否是私聊。 + ''' + return self.message_obj.type.value == (MessageType.FRIEND_MESSAGE).value + + def is_wake_up(self) -> bool: + ''' + 是否是唤醒机器人的事件。 + + 机器人被唤醒的条件: + 1. 消息以用户设置的唤醒前缀开头,默认是 `/`. + 2. 消息中有 at 机器人的消息。 + 3. 是私聊。 + ''' + return self.is_wake + + @abc.abstractmethod + async def send(self, message: MessageChain): + ''' + 发送消息。 + ''' + raise NotImplementedError() \ No newline at end of file diff --git a/type/astrbot_message.py b/astrbot/core/platform/astrbot_message.py similarity index 73% rename from type/astrbot_message.py rename to astrbot/core/platform/astrbot_message.py index 61719fbbc..2b3fb6519 100644 --- a/type/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -1,25 +1,18 @@ import time -from enum import Enum from typing import List from dataclasses import dataclass from nakuru.entities.components import BaseMessageComponent - -class MessageType(Enum): - GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 - FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 - GUILD_MESSAGE = 'GuildMessage' # 频道消息 +from .message_type import MessageType @dataclass class MessageMember(): user_id: str # 发送者id nickname: str = None - -class AstrBotMessage(): +class AstrBotMessage: ''' AstrBot 的消息对象 ''' - tag: str # 消息来源标签 type: MessageType # 消息类型 self_id: str # 机器人的识别id session_id: str # 会话id diff --git a/astrbot/core/platform/message_type.py b/astrbot/core/platform/message_type.py new file mode 100644 index 000000000..b1c7721e5 --- /dev/null +++ b/astrbot/core/platform/message_type.py @@ -0,0 +1,6 @@ +from enum import Enum + +class MessageType(Enum): + GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 + FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 + \ No newline at end of file diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py new file mode 100644 index 000000000..01d767e31 --- /dev/null +++ b/astrbot/core/platform/platform.py @@ -0,0 +1,42 @@ +import abc +from typing import Awaitable, Any +from asyncio import Queue +from .platform_metadata import PlatformMetadata +from .astr_message_event import AstrMessageEvent +from core.message_event_result import MessageChain +from .astr_message_event import MessageSesion + +class Platform(abc.ABC): + def __init__(self, event_queue: Queue): + super().__init__() + # 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。 + self._event_queue = event_queue + + @abc.abstractmethod + def run(self) -> Awaitable[Any]: + ''' + 得到一个平台的运行实例,需要返回一个协程对象。 + ''' + raise NotImplementedError + + @abc.abstractmethod + def meta(self) -> PlatformMetadata: + ''' + 得到一个平台的元数据。 + ''' + raise NotImplementedError + + @abc.abstractmethod + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain) -> Awaitable[Any]: + ''' + 通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 + + 异步方法。 + ''' + raise NotImplementedError + + def commit_event(self, event: AstrMessageEvent): + ''' + 提交一个事件到事件队列。 + ''' + self._event_queue.put_nowait(event) \ No newline at end of file diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py new file mode 100644 index 000000000..9edff89c9 --- /dev/null +++ b/astrbot/core/platform/platform_metadata.py @@ -0,0 +1,6 @@ +from dataclasses import dataclass + +@dataclass +class PlatformMetadata(): + name: str # 平台的名称 + description: str # 平台的描述 \ No newline at end of file diff --git a/astrbot/core/plugin/__init__.py b/astrbot/core/plugin/__init__.py new file mode 100644 index 000000000..f054edcd3 --- /dev/null +++ b/astrbot/core/plugin/__init__.py @@ -0,0 +1,4 @@ +from .plugin import Plugin, RegisteredPlugin, PluginMetadata +from .plugin_manager import PluginManager +from .context import CommandMetadata, Context +from core.provider import Provider \ No newline at end of file diff --git a/util/plugin_dev/api/v1/config.py b/astrbot/core/plugin/config.py similarity index 100% rename from util/plugin_dev/api/v1/config.py rename to astrbot/core/plugin/config.py diff --git a/astrbot/core/plugin/context.py b/astrbot/core/plugin/context.py new file mode 100644 index 000000000..cf3d4c409 --- /dev/null +++ b/astrbot/core/plugin/context.py @@ -0,0 +1,208 @@ +import heapq +from asyncio import Queue +from . import RegisteredPlugin, PluginMetadata +from typing import List, Dict, Awaitable, Union +from dataclasses import dataclass + +from core.platform import Platform +from core.db import BaseDatabase +from core.config.astrbot_config import AstrBotConfig +from core.utils.func_call import FuncCall +from core.platform.astr_message_event import MessageSesion +from core.message_event_result import MessageChain + +@dataclass +class CommandMetadata(): + ''' + 显式指令 + ''' + plugin_name: str + plugin_metadata: PluginMetadata + handler: Awaitable + use_regex: bool = False + ignore_prefix: bool = False + description: str = "" + +@dataclass +class EventListenerMetadata(): + ''' + 事件监听器 + ''' + plugin_name: str + plugin_metadata: PluginMetadata + handler: Awaitable + description: str = "" + after_commands: bool = False + + +class Context: + ''' + 暴露给插件的接口上下文,用于注册指令、事件监听器、消息平台、模型提供商等。 + ''' + # 事件队列。消息平台通过事件队列传递消息事件。 + _event_queue: Queue = None + + # AstrBot 配置信息 + _config: AstrBotConfig = None + + # AstrBot 数据库 + _db: BaseDatabase = None + + # 维护了注册的插件的信息 + registered_plugins: List[RegisteredPlugin] = [] + + # 维护了插件注册的指令的信息的名字列表,用于优先级排序 + registered_commands: List[str] = [] + # 维护了插件注册的指令的信息 + commands_handler: Dict[str, CommandMetadata] = {} + + # 维护了插件注册的中间件的名字列表,用于优先级排序 + registered_listeners: List[str] = [] + # 维护了插件注册的中间件的信息 + listeners_handler: Dict[str, EventListenerMetadata] = {} + + # 维护了注册的平台的信息 + registered_platforms: List[Platform] = [] + + # 维护了 LLM Tools 信息 + llm_tools: FuncCall = FuncCall() + + def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase): + self._event_queue = event_queue + self._config = config + self._db = db + + def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin: + for plugin in self.registered_plugins: + if plugin.metadata.plugin_name == plugin_name: + return plugin + return None + + def register_listener(self, + plugin_name: str, + name: str, + handler: Awaitable, + description: str = None, + after_commands: bool = False): + ''' + 注册一个事件监听器。 + + after_commands: 是否在指令处理后执行。 + ''' + if name in self.registered_listeners: + raise ValueError(f"Middleware {name} already exists.") + self.registered_listeners.append(name) + self.listeners_handler[name] = EventListenerMetadata( + plugin_name=plugin_name, + plugin_metadata=None, + handler=handler, + description=description, + after_commands=after_commands + ) + + def register_commands(self, + plugin_name: str, + command_name: str, + description: str, + priority: int, + handler: Awaitable, + use_regex: bool = False, + ignore_prefix: bool = False): + ''' + 注册插件指令。 + + @param plugin_name: 插件名,注意需要和你的 metadata 中的一致。 + @param command_name: 指令名,如 "help"。不需要带前缀。 + @param description: 指令描述。 + @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 + @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context + @param use_regex: 是否使用正则表达式匹配指令名。 + @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。 + + .. Example:: + + ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。 + ''' + for command in self.registered_commands: + if command_name in command[1]: + raise ValueError(f"Command {command_name} already exists.") + if not handler: + raise ValueError(f"Handler of {command_name} is None.") + + heapq.heappush(self.registered_commands, (-priority, command_name)) + self.commands_handler[command_name] = CommandMetadata( + plugin_name=plugin_name, + plugin_metadata=None, + handler=handler, + use_regex=use_regex, + ignore_prefix=ignore_prefix, + description=description + ) + heapq.heapify(self.registered_commands) + + def register_platform(self, platform: Platform): + ''' + 注册一个消息平台。 + ''' + self.registered_platforms.append(platform) + + def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: + ''' + 为函数调用(function-calling / tools-use)添加工具。 + + @param name: 函数名 + @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + @param desc: 函数描述 + @param func_obj: 异步处理函数。 + + 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 + ''' + self.llm_tools.add_func(name, func_args, desc, func_obj) + + def unregister_llm_tool(self, name: str) -> None: + ''' + 删除一个函数调用工具。 + ''' + self.llm_tools.remove_func(name) + + def get_config(self) -> AstrBotConfig: + ''' + 获取 AstrBot 配置信息。 + ''' + return self._config + + def get_db(self) -> BaseDatabase: + ''' + 获取 AstrBot 数据库。 + ''' + return self._db + + def get_event_queue(self) -> Queue: + ''' + 获取事件队列。 + ''' + return self._event_queue + + async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool: + ''' + 根据 session(unified_msg_origin) 发送消息。 + + @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 + @param message_chain: 消息链。 + + @return: 是否找到匹配的平台。 + + 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 + ''' + + if isinstance(session, str): + try: + session = MessageSesion.from_str(session) + except BaseException as e: + raise ValueError("不合法的 session 字符串: " + str(e)) + + for platform in self.registered_platforms: + if platform.meta().name == session.platform_name: + await platform.send_by_session(session, message_chain) + return True + return False \ No newline at end of file diff --git a/type/plugin.py b/astrbot/core/plugin/plugin.py similarity index 61% rename from type/plugin.py rename to astrbot/core/plugin/plugin.py index 23dc976a2..c27fb9925 100644 --- a/type/plugin.py +++ b/astrbot/core/plugin/plugin.py @@ -3,12 +3,6 @@ from types import ModuleType from typing import List from dataclasses import dataclass -class PluginType(Enum): - PLATFORM = 'platform' # 平台类插件。 - LLM = 'llm' # 大语言模型类插件 - COMMON = 'common' # 其他插件 - - @dataclass class PluginMetadata: ''' @@ -16,7 +10,6 @@ class PluginMetadata: ''' # required plugin_name: str - plugin_type: PluginType author: str # 插件作者 desc: str # 插件简介 version: str # 插件版本 @@ -25,7 +18,7 @@ class PluginMetadata: repo: str = None # 插件仓库地址 def __str__(self) -> str: - return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" + return f"PluginMetadata({self.plugin_name}, {self.desc}, {self.version}, {self.repo})" @dataclass @@ -38,16 +31,13 @@ class RegisteredPlugin: module_path: str module: ModuleType root_dir_name: str - trig_cnt: int = 0 - - def reset_trig_cnt(self): - self.trig_cnt = 0 - - def trig(self): - self.trig_cnt += 1 + reserved: bool # 是否是 AstrBot 的保留插件 def __str__(self) -> str: return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" -RegisteredPlugins = List[RegisteredPlugin] + +class Plugin: + def __init__(self): + pass \ No newline at end of file diff --git a/model/plugin/manager.py b/astrbot/core/plugin/plugin_manager.py similarity index 60% rename from model/plugin/manager.py rename to astrbot/core/plugin/plugin_manager.py index 89c78495f..93bbb82f9 100644 --- a/model/plugin/manager.py +++ b/astrbot/core/plugin/plugin_manager.py @@ -6,26 +6,27 @@ import uuid import shutil import yaml import logging - -from util.updator.plugin_updator import PluginUpdator -from util.io import remove_dir, download_file +from asyncio import Queue from types import ModuleType -from type.types import Context -from type.plugin import * -from type.register import * -from util.log import LogManager -from logging import Logger +from typing import List, Awaitable from pip import main as pip_main +from core.config.astrbot_config import AstrBotConfig +from core import logger +from .context import Context +from . import RegisteredPlugin, PluginMetadata +from .updator import PluginUpdator +from core.db import BaseDatabase +from core.utils.io import remove_dir -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -class PluginManager(): - def __init__(self, context: Context): - self.updator = PluginUpdator(context.config_helper.plugin_repo_mirror) - self.plugin_store_path = self.updator.get_plugin_store_path() - self.context = context - - def get_classes(self, arg: ModuleType): +class PluginManager: + def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase): + self.updator = PluginUpdator(config.plugin_repo_mirror) + self.context = Context(event_queue, config, db) + self.config = config + self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins")) + self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages")) + + def _get_classes(self, arg: ModuleType): classes = [] clsmembers = inspect.getmembers(arg, inspect.isclass) for (name, _) in clsmembers: @@ -34,7 +35,7 @@ class PluginManager(): break return classes - def get_modules(self, path): + def _get_modules(self, path): modules = [] dirs = os.listdir(path) @@ -56,17 +57,18 @@ class PluginManager(): }) return modules - def get_plugin_modules(self): + def _get_plugin_modules(self) -> List[dict]: plugins = [] - try: - plugin_dir = self.plugin_store_path - if os.path.exists(plugin_dir): - plugins = self.get_modules(plugin_dir) - return plugins - except BaseException as e: - raise e + if os.path.exists(self.plugin_store_path): + plugins.extend(self._get_modules(self.plugin_store_path)) + if os.path.exists(self.reserved_plugin_path): + _p = self._get_modules(self.reserved_plugin_path) + for p in _p: + p['reserved'] = True + plugins.extend(_p) + return plugins - def check_plugin_dept_update(self, target_plugin: str = None): + def _check_plugin_dept_update(self, target_plugin: str = None): plugin_dir = self.plugin_store_path if not os.path.exists(plugin_dir): return False @@ -74,7 +76,7 @@ class PluginManager(): if target_plugin: to_update.append(target_plugin) else: - for p in self.context.cached_plugins: + for p in self.context.registered_plugins: to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) @@ -82,56 +84,56 @@ class PluginManager(): pth = os.path.join(plugin_path, "requirements.txt") logger.info(f"正在检查插件 {p} 的依赖: {pth}") try: - self.update_plugin_dept(os.path.join(plugin_path, "requirements.txt")) + self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt")) except Exception as e: logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") - def update_plugin_dept(self, path): - args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/', '--break-system-package'] - if self.context.config_helper.pip_install_arg: - args.extend(self.context.config_helper.pip_install_arg) + def _update_plugin_dept(self, path): + args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/'] + if self.config.pip_install_arg: + args.extend(self.config.pip_install_arg) result_code = pip_main(args) if result_code != 0: - raise Exception(str(result_code)) - - async def install_plugin(self, repo_url: str): - plugin_path = await self.updator.update(repo_url) - with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: - f.write(repo_url) - # self.check_plugin_dept_update() - return plugin_path + raise Exception(str(result_code)) - def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin: - for p in self.context.cached_plugins: - if p.metadata.plugin_name == plugin_name: - return p + def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata: + metadata = None + + if not os.path.exists(plugin_path): + raise Exception("插件不存在。") + + if os.path.exists(os.path.join(plugin_path, "metadata.yaml")): + with open(os.path.join(plugin_path, "metadata.yaml"), "r", encoding='utf-8') as f: + metadata = yaml.safe_load(f) + elif plugin_obj: + # 使用 info() 函数 + metadata = plugin_obj.info() + + if isinstance(metadata, dict): + if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata: + raise Exception("插件元数据信息不完整。") + metadata = PluginMetadata( + plugin_name=metadata['name'], + author=metadata['author'], + desc=metadata['desc'], + version=metadata['version'], + repo=metadata['repo'] if 'repo' in metadata else None + ) + + return metadata - def uninstall_plugin(self, plugin_name: str): - plugin = self.get_registered_plugin(plugin_name) - if not plugin: - raise Exception("插件不存在。") - root_dir_name = plugin.root_dir_name - ppath = self.plugin_store_path - self.context.cached_plugins.remove(plugin) - if not remove_dir(os.path.join(ppath, root_dir_name)): - raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") - - async def update_plugin(self, plugin_name: str): - plugin = self.get_registered_plugin(plugin_name) - if not plugin: - raise Exception("插件不存在。") - - await self.updator.update(plugin) - - def plugin_reload(self): - cached_plugins = self.context.cached_plugins - plugins = self.get_plugin_modules() + def reload(self): + ''' + 加载插件类 + ''' + registered_plugins = self.context.registered_plugins + plugins = self._get_plugin_modules() if plugins is None: return False, "未找到任何插件模块" fail_rec = "" registered_map = {} - for p in cached_plugins: + for p in registered_plugins: registered_map[p.module_path] = None for plugin in plugins: @@ -139,43 +141,51 @@ class PluginManager(): p = plugin['module'] module_path = plugin['module_path'] root_dir_name = plugin['pname'] + reserved = plugin.get('reserved', False) logger.info(f"正在加载插件 {root_dir_name} ...") - - # self.check_plugin_dept_update(target_plugin=root_dir_name) + pre = "data.plugins." if not reserved else "packages." + + # 尝试导入插件模块 try: - module = __import__("data.plugins." + - root_dir_name + "." + p, fromlist=[p]) + module = __import__(pre + root_dir_name + "." + p, fromlist=[p]) except (ModuleNotFoundError, ImportError) as e: # 尝试安装插件依赖 - self.check_plugin_dept_update(target_plugin=root_dir_name) - module = __import__("data.plugins." + - root_dir_name + "." + p, fromlist=[p]) + self._check_plugin_dept_update(target_plugin=root_dir_name) + module = __import__(pre + root_dir_name + "." + p, fromlist=[p]) - cls = self.get_classes(module) + cls = self._get_classes(module) + # 实例化插件类 try: - # 尝试传入 ctx obj = getattr(module, cls[0])(context=self.context) - except TypeError: - obj = getattr(module, cls[0])() except BaseException as e: + logger.error(f"插件 {root_dir_name} 实例化失败。") raise e - + + # 解析插件元数据,加入注册列表 metadata = None - - plugin_path = os.path.join(self.plugin_store_path, root_dir_name) - metadata = self.load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj) - + plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name) + metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj) if module_path not in registered_map: - cached_plugins.append(RegisteredPlugin( + registered_plugins.append(RegisteredPlugin( metadata=metadata, plugin_instance=obj, module=module, module_path=module_path, - root_dir_name=root_dir_name + root_dir_name=root_dir_name, + reserved=reserved )) + + for command in self.context.commands_handler: + if self.context.commands_handler[command].plugin_name == metadata.plugin_name: + self.context.commands_handler[command].plugin_metadata = metadata + for listener in self.context.listeners_handler: + if self.context.listeners_handler[listener].plugin_name == metadata.plugin_name: + self.context.listeners_handler[listener].plugin_metadata = metadata + + except BaseException as e: traceback.print_exc() fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n" @@ -184,12 +194,39 @@ class PluginManager(): for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) - if not fail_rec: return True, None else: return False, fail_rec + async def install_plugin(self, repo_url: str): + plugin_path = await self.updator.update(repo_url) + with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f: + f.write(repo_url) + self._check_plugin_dept_update() + return plugin_path + + def uninstall_plugin(self, plugin_name: str): + plugin = self.context.get_registered_plugin(plugin_name) + if not plugin: + raise Exception("插件不存在。") + if plugin.reserved: + raise Exception("该插件是 AstrBot 保留插件,无法卸载。") + root_dir_name = plugin.root_dir_name + ppath = self.plugin_store_path + self.context.registered_plugins.remove(plugin) + if not remove_dir(os.path.join(ppath, root_dir_name)): + raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。") + + async def update_plugin(self, plugin_name: str): + plugin = self.context.get_registered_plugin(plugin_name) + if not plugin: + raise Exception("插件不存在。") + if plugin.reserved: + raise Exception("该插件是 AstrBot 保留插件,无法更新。") + + await self.updator.update(plugin) + def install_plugin_from_file(self, zip_file_path: str): # try to unzip temp_dir = os.path.join(os.path.dirname(zip_file_path), str(uuid.uuid4())) @@ -199,7 +236,7 @@ class PluginManager(): remove_dir(temp_dir) raise Exception("插件缺少 metadata.yaml 文件。") - metadata = self.load_plugin_metadata(temp_dir) + metadata = self._load_plugin_metadata(temp_dir) plugin_name = metadata.plugin_name if not plugin_name: remove_dir(temp_dir) @@ -221,35 +258,11 @@ class PluginManager(): # remove the temp dir remove_dir(temp_dir) - # self.check_plugin_dept_update() - - # ok, err = self.plugin_reload() - # if not ok: - # raise Exception(err) - - def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata: - metadata = None + self._check_plugin_dept_update() - if not os.path.exists(plugin_path): - raise Exception("插件不存在。") - - if os.path.exists(os.path.join(plugin_path, "metadata.yaml")): - with open(os.path.join(plugin_path, "metadata.yaml"), "r", encoding='utf-8') as f: - metadata = yaml.safe_load(f) - elif plugin_obj: - # 使用 info() 函数 - metadata = plugin_obj.info() - - if isinstance(metadata, dict): - if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata: - raise Exception("插件元数据信息不完整。") - metadata = PluginMetadata( - plugin_name=metadata['name'], - plugin_type=PluginType.COMMON if 'plugin_type' not in metadata else PluginType(metadata['plugin_type']), - author=metadata['author'], - desc=metadata['desc'], - version=metadata['version'], - repo=metadata['repo'] if 'repo' in metadata else None - ) - - return metadata \ No newline at end of file + def get_platform_insts(self): + return self.context.registered_platforms + + def get_loaded_plugins(self): + return self.context.registered_plugins + \ No newline at end of file diff --git a/util/updator/plugin_updator.py b/astrbot/core/plugin/updator.py similarity index 71% rename from util/updator/plugin_updator.py rename to astrbot/core/plugin/updator.py index 03a41faea..33032173a 100644 --- a/util/updator/plugin_updator.py +++ b/astrbot/core/plugin/updator.py @@ -1,20 +1,15 @@ import os, zipfile, shutil -from util.updator.zip_updator import RepoZipUpdator -from util.io import remove_dir -from type.register import RegisteredPlugin +from ..updator import RepoZipUpdator +from core.utils.io import remove_dir, on_error +from ..plugin import RegisteredPlugin from typing import Union -from util.log import LogManager -from logging import Logger -from util.io import on_error - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - +from core import logger class PluginUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) - self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/plugins")) + self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins")) def get_plugin_store_path(self) -> str: return self.plugin_store_path @@ -54,10 +49,6 @@ class PluginUpdator(RepoZipUpdator): z.extractall(target_dir) avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir] - # copy addons/plugins to the target_dir temporarily - # if os.path.exists(os.path.join(target_dir, "addons/plugins")): - # logger.info("备份插件目录:从 addons/plugins 到 temp_plugins") - # shutil.copytree(os.path.join(target_dir, "addons/plugins"), "temp_plugins") files = os.listdir(os.path.join(target_dir, update_dir)) for f in files: @@ -71,15 +62,9 @@ class PluginUpdator(RepoZipUpdator): os.remove(os.path.join(target_dir, f)) shutil.move(os.path.join(target_dir, update_dir, f), target_dir) - # move back - # if os.path.exists("temp_plugins"): - # logger.info("恢复插件目录:从 temp_plugins 到 addons/plugins") - # shutil.rmtree(os.path.join(target_dir, "addons/plugins"), onerror=on_error) - # shutil.move("temp_plugins", os.path.join(target_dir, "addons/plugins")) - try: logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) os.remove(zip_path) except: - logger.warn(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") + logger.warning(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}") diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py new file mode 100644 index 000000000..bf9be469a --- /dev/null +++ b/astrbot/core/provider/__init__.py @@ -0,0 +1 @@ +from .provider import Provider \ No newline at end of file diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py new file mode 100644 index 000000000..1efb5d58a --- /dev/null +++ b/astrbot/core/provider/provider.py @@ -0,0 +1,46 @@ +import abc +from collections import defaultdict +from typing import List +# from core.utils.func_call import FuncCall + +class Provider(abc.ABC): + def __init__(self) -> None: + self.model_name = "unknown" + + def set_model(self, model_name: str): + self.model_name = model_name + + def get_model(self): + return self.model_name + + @abc.abstractmethod + async def text_chat(self, + prompt: str, + session_id: str, + image_urls: List[str] = None, + tool = None, + **kwargs) -> str: + ''' + prompt: 提示词 + session_id: 会话id + + [optional] + image_url: 图片url(识图) + tools: 函数调用工具 + ''' + raise NotImplementedError() + + @abc.abstractmethod + async def image_generate(self, prompt: str, session_id: str, **kwargs) -> str: + ''' + prompt: 提示词 + session_id: 会话id + ''' + raise NotImplementedError() + + @abc.abstractmethod + async def forget(self, session_id: str) -> bool: + ''' + 重置会话 + ''' + raise NotImplementedError() diff --git a/util/updator/astrbot_updator.py b/astrbot/core/updator.py similarity index 88% rename from util/updator/astrbot_updator.py rename to astrbot/core/updator.py index b67661421..99a31324d 100644 --- a/util/updator/astrbot_updator.py +++ b/astrbot/core/updator.py @@ -1,11 +1,8 @@ import os, psutil, sys, time -from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator -from util.log import LogManager -from logging import Logger -from type.config import VERSION -from util.io import on_error, download_file - -logger: Logger = LogManager.GetLogger(log_name='astrbot') +from .zip_updator import ReleaseInfo, RepoZipUpdator +from core import logger +from core.config.default import VERSION +from core.utils.io import download_file class AstrBotUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: @@ -31,14 +28,12 @@ class AstrBotUpdator(RepoZipUpdator): except psutil.NoSuchProcess: pass - def _reboot(self, delay: int = None, context = None): + def _reboot(self, delay: int = 3): if os.environ.get('TEST_MODE', 'off') == 'on': logger.info("测试模式下不会重启。") return - # if delay: time.sleep(delay) py = sys.executable - context.running = False - time.sleep(3) + time.sleep(delay) self.terminate_child_processes() py = py.replace(" ", "\\ ") try: diff --git a/model/command/parser.py b/astrbot/core/utils/command_parser.py similarity index 100% rename from model/command/parser.py rename to astrbot/core/utils/command_parser.py diff --git a/util/agent/func_call.py b/astrbot/core/utils/func_call.py similarity index 94% rename from util/agent/func_call.py rename to astrbot/core/utils/func_call.py index 830496cab..20e349d62 100644 --- a/util/agent/func_call.py +++ b/astrbot/core/utils/func_call.py @@ -1,7 +1,9 @@ -from model.provider.provider import Provider +from core.provider import Provider +from typing import Awaitable import json import textwrap + class FuncCallJsonFormatError(Exception): def __init__(self, msg): self.msg = msg @@ -19,14 +21,13 @@ class FuncNotFoundError(Exception): class FuncCall(): - def __init__(self, provider: Provider) -> None: + def __init__(self) -> None: self.func_list = [] - self.provider = provider def empty(self) -> bool: return len(self.func_list) == 0 - def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None: + def add_func(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: ''' 为函数调用(function-calling / tools-use)添加工具。 @@ -84,11 +85,7 @@ class FuncCall(): }) return _l - async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple: - - if not provider: - provider = self.provider - + async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider) -> tuple: prompt = textwrap.dedent(f""" ROLE: 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。 diff --git a/util/image_uploader.py b/astrbot/core/utils/image_uploader.py similarity index 100% rename from util/image_uploader.py rename to astrbot/core/utils/image_uploader.py diff --git a/util/io.py b/astrbot/core/utils/io.py similarity index 91% rename from util/io.py rename to astrbot/core/utils/io.py index f66dd14b4..bf8907176 100644 --- a/util/io.py +++ b/astrbot/core/utils/io.py @@ -4,13 +4,9 @@ import shutil import socket import time import aiohttp +import base64 from PIL import Image -from util.log import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - def on_error(func, path, exc_info): ''' @@ -30,7 +26,6 @@ def remove_dir(file_path) -> bool: shutil.rmtree(file_path, onerror=on_error) return True except BaseException as e: - logger.error(f"删除文件/文件夹 {file_path} 失败: {str(e)}") return False def port_checker(port: int, host: str = "localhost"): @@ -67,7 +62,6 @@ def save_temp_img(img: Image) -> str: else: with open(p, "wb") as f: f.write(img) - logger.info(f"保存临时图片: {p}") return p async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str: @@ -75,7 +69,6 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict = 下载图片, 返回 path ''' try: - logger.info(f"下载图片: {url}") async with aiohttp.ClientSession() as session: if post: async with session.post(url, json=post_data) as resp: @@ -102,7 +95,6 @@ async def download_file(url: str, path: str): 从指定 url 下载文件到指定路径 path ''' try: - logger.info(f"下载文件: {url}") async with aiohttp.ClientSession() as session: async with session.get(url) as resp: with open(path, 'wb') as f: @@ -114,6 +106,11 @@ async def download_file(url: str, path: str): except Exception as e: raise e +def file_to_base64(file_path: str) -> str: + with open(file_path, "rb") as f: + data_bytes = f.read() + base64_str = base64.b64encode(data_bytes).decode() + return "base64://" + base64_str def get_local_ip_addresses(): ip = '' diff --git a/util/metrics.py b/astrbot/core/utils/metrics.py similarity index 96% rename from util/metrics.py rename to astrbot/core/utils/metrics.py index f870fd0cc..38cb4a4ef 100644 --- a/util/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -4,9 +4,9 @@ import json import sys import logging -from astrbot.db import BaseDatabase +from core.db import BaseDatabase from collections import defaultdict -from type.config import VERSION +from core.config import VERSION logger = logging.getLogger("astrbot") diff --git a/util/personality.py b/astrbot/core/utils/personality.py similarity index 100% rename from util/personality.py rename to astrbot/core/utils/personality.py diff --git a/util/t2i/strategies/base_strategy.py b/astrbot/core/utils/t2i/__init__.py similarity index 100% rename from util/t2i/strategies/base_strategy.py rename to astrbot/core/utils/t2i/__init__.py diff --git a/util/t2i/strategies/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py similarity index 99% rename from util/t2i/strategies/local_strategy.py rename to astrbot/core/utils/t2i/local_strategy.py index 764ab9873..33e0088ba 100644 --- a/util/t2i/strategies/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -2,9 +2,9 @@ import re import aiohttp from io import BytesIO -from .base_strategy import RenderStrategy +from . import RenderStrategy from PIL import ImageFont, Image, ImageDraw -from util.io import save_temp_img +from core.utils.io import save_temp_img class LocalRenderStrategy(RenderStrategy): diff --git a/util/t2i/strategies/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py similarity index 92% rename from util/t2i/strategies/network_strategy.py rename to astrbot/core/utils/t2i/network_strategy.py index 7b73f1f70..b305a6fcc 100644 --- a/util/t2i/strategies/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -1,9 +1,9 @@ import aiohttp import os -from .base_strategy import RenderStrategy -from type.config import VERSION -from util.io import download_image_by_url +from . import RenderStrategy +from core.config import VERSION +from core.utils.io import download_image_by_url ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img" @@ -47,5 +47,5 @@ class NetworkRenderStrategy(RenderStrategy): with open(os.path.join(self.TEMPLATE_PATH, "base.html"), "r", encoding='utf-8') as f: tmpl_str = f.read() assert(tmpl_str) - text = text.replace("`", "\`") + text = text.replace("`", "\\`") return await self.render_custom_template(tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url) \ No newline at end of file diff --git a/util/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py similarity index 61% rename from util/t2i/renderer.py rename to astrbot/core/utils/t2i/renderer.py index 5db6be6ee..94acd82ad 100644 --- a/util/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -1,17 +1,14 @@ -from util.t2i.strategies.local_strategy import LocalRenderStrategy -from util.t2i.strategies.network_strategy import NetworkRenderStrategy -from util.t2i.context import RenderContext -from util.log import LogManager -from logging import Logger +from .network_strategy import NetworkRenderStrategy +from .local_strategy import LocalRenderStrategy +from core.log import LogManager -logger: Logger = LogManager.GetLogger(log_name='astrbot') +logger = LogManager.GetLogger(log_name='astrbot') -class TextToImageRenderer: +class HtmlRenderer: def __init__(self, endpoint_url: str = None): self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - self.context = RenderContext(self.network_strategy) - + def set_network_endpoint(self, endpoint_url: str): '''设置 t2i 的网络端点。 ''' @@ -31,16 +28,14 @@ class TextToImageRenderer: local.pop('self') return await self.network_strategy.render_custom_template(**local) - async def render(self, text: str, use_network: bool = True, return_url: bool = False): + async def render_t2i(self, text: str, use_network: bool = True, return_url: bool = False): '''使用默认文转图模板。 ''' if use_network: try: - return await self.context.render(text, return_url=return_url) + return await self.network_strategy.render(text, return_url=return_url) except BaseException as e: logger.error(f"Failed to render image via AstrBot API: {e}. Falling back to local rendering.") - self.context.set_strategy(self.local_strategy) - return await self.context.render(text) + return await self.local_strategy.render(text) else: - self.context.set_strategy(self.local_strategy) - return await self.context.render(text) + return await self.local_strategy.render(text) diff --git a/util/t2i/strategies/template/base.html b/astrbot/core/utils/t2i/template/base.html similarity index 100% rename from util/t2i/strategies/template/base.html rename to astrbot/core/utils/t2i/template/base.html diff --git a/util/updator/zip_updator.py b/astrbot/core/zip_updator.py similarity index 97% rename from util/updator/zip_updator.py rename to astrbot/core/zip_updator.py index a3cac943a..eaf2ca2c5 100644 --- a/util/updator/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -1,9 +1,6 @@ import aiohttp, os, zipfile, shutil -from util.log import LogManager -from logging import Logger -from util.io import on_error, download_file - -logger: Logger = LogManager.GetLogger(log_name='astrbot') +from core.utils.io import on_error, download_file +from core import logger class ReleaseInfo(): version: str diff --git a/astrbot/dashboard/__init__.py b/astrbot/dashboard/__init__.py new file mode 100644 index 000000000..3c420fa41 --- /dev/null +++ b/astrbot/dashboard/__init__.py @@ -0,0 +1 @@ +from .dashboard_lifecycle import AstrBotDashBoardLifecycle \ No newline at end of file diff --git a/astrbot/dashboard/dashboard_lifecycle.py b/astrbot/dashboard/dashboard_lifecycle.py new file mode 100644 index 000000000..b084d88ef --- /dev/null +++ b/astrbot/dashboard/dashboard_lifecycle.py @@ -0,0 +1,24 @@ +import asyncio +from multiprocessing import Process +from core import logger +from core.core_lifecycle import AstrBotCoreLifecycle +from .server import AstrBotDashboard +from core.db import BaseDatabase + +class AstrBotDashBoardLifecycle: + def __init__(self, db: BaseDatabase): + self.db = db + self.logger = logger + self.dashboard_server = None + + async def start(self, core_lifecycle: AstrBotCoreLifecycle): + core_task = core_lifecycle.start() + self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db) + + task = asyncio.gather(core_task, self.dashboard_server.run()) + + try: + await task + except asyncio.CancelledError as e: + logger.info("🌈 正在关闭 AstrBot...") + core_lifecycle.stop() \ No newline at end of file diff --git a/dashboard/dist/_redirects b/astrbot/dashboard/dist/_redirects similarity index 100% rename from dashboard/dist/_redirects rename to astrbot/dashboard/dist/_redirects diff --git a/dashboard/dist/assets/BlankLayout-b9300230.js b/astrbot/dashboard/dist/assets/BlankLayout-b9300230.js similarity index 100% rename from dashboard/dist/assets/BlankLayout-b9300230.js rename to astrbot/dashboard/dist/assets/BlankLayout-b9300230.js diff --git a/dashboard/dist/assets/ConfigPage-3a024ddb.js b/astrbot/dashboard/dist/assets/ConfigPage-3a024ddb.js similarity index 100% rename from dashboard/dist/assets/ConfigPage-3a024ddb.js rename to astrbot/dashboard/dist/assets/ConfigPage-3a024ddb.js diff --git a/dashboard/dist/assets/ConfigPage-f564cc69.css b/astrbot/dashboard/dist/assets/ConfigPage-f564cc69.css similarity index 100% rename from dashboard/dist/assets/ConfigPage-f564cc69.css rename to astrbot/dashboard/dist/assets/ConfigPage-f564cc69.css diff --git a/dashboard/dist/assets/ConsoleDisplayer-721c13f2.js b/astrbot/dashboard/dist/assets/ConsoleDisplayer-721c13f2.js similarity index 100% rename from dashboard/dist/assets/ConsoleDisplayer-721c13f2.js rename to astrbot/dashboard/dist/assets/ConsoleDisplayer-721c13f2.js diff --git a/dashboard/dist/assets/ConsolePage-0f6141a7.js b/astrbot/dashboard/dist/assets/ConsolePage-0f6141a7.js similarity index 100% rename from dashboard/dist/assets/ConsolePage-0f6141a7.js rename to astrbot/dashboard/dist/assets/ConsolePage-0f6141a7.js diff --git a/dashboard/dist/assets/ConsolePage-6748dc2b.css b/astrbot/dashboard/dist/assets/ConsolePage-6748dc2b.css similarity index 100% rename from dashboard/dist/assets/ConsolePage-6748dc2b.css rename to astrbot/dashboard/dist/assets/ConsolePage-6748dc2b.css diff --git a/dashboard/dist/assets/DefaultDashboard-bdafda32.js b/astrbot/dashboard/dist/assets/DefaultDashboard-bdafda32.js similarity index 100% rename from dashboard/dist/assets/DefaultDashboard-bdafda32.js rename to astrbot/dashboard/dist/assets/DefaultDashboard-bdafda32.js diff --git a/dashboard/dist/assets/ExtensionPage-0885ea13.js b/astrbot/dashboard/dist/assets/ExtensionPage-0885ea13.js similarity index 100% rename from dashboard/dist/assets/ExtensionPage-0885ea13.js rename to astrbot/dashboard/dist/assets/ExtensionPage-0885ea13.js diff --git a/dashboard/dist/assets/FineTunePage-df498a32.js b/astrbot/dashboard/dist/assets/FineTunePage-df498a32.js similarity index 100% rename from dashboard/dist/assets/FineTunePage-df498a32.js rename to astrbot/dashboard/dist/assets/FineTunePage-df498a32.js diff --git a/dashboard/dist/assets/FullLayout-d561a146.js b/astrbot/dashboard/dist/assets/FullLayout-d561a146.js similarity index 100% rename from dashboard/dist/assets/FullLayout-d561a146.js rename to astrbot/dashboard/dist/assets/FullLayout-d561a146.js diff --git a/dashboard/dist/assets/LoginPage-0e7a1264.js b/astrbot/dashboard/dist/assets/LoginPage-0e7a1264.js similarity index 100% rename from dashboard/dist/assets/LoginPage-0e7a1264.js rename to astrbot/dashboard/dist/assets/LoginPage-0e7a1264.js diff --git a/dashboard/dist/assets/LoginPage-74e85ca7.css b/astrbot/dashboard/dist/assets/LoginPage-74e85ca7.css similarity index 100% rename from dashboard/dist/assets/LoginPage-74e85ca7.css rename to astrbot/dashboard/dist/assets/LoginPage-74e85ca7.css diff --git a/dashboard/dist/assets/WaitingForRestart-cde6f809.js b/astrbot/dashboard/dist/assets/WaitingForRestart-cde6f809.js similarity index 100% rename from dashboard/dist/assets/WaitingForRestart-cde6f809.js rename to astrbot/dashboard/dist/assets/WaitingForRestart-cde6f809.js diff --git a/dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js b/astrbot/dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js similarity index 100% rename from dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js rename to astrbot/dashboard/dist/assets/_plugin-vue_export-helper-c27b6911.js diff --git a/dashboard/dist/assets/common-40607810.js b/astrbot/dashboard/dist/assets/common-40607810.js similarity index 100% rename from dashboard/dist/assets/common-40607810.js rename to astrbot/dashboard/dist/assets/common-40607810.js diff --git a/dashboard/dist/assets/index-d089162b.js b/astrbot/dashboard/dist/assets/index-d089162b.js similarity index 100% rename from dashboard/dist/assets/index-d089162b.js rename to astrbot/dashboard/dist/assets/index-d089162b.js diff --git a/dashboard/dist/assets/index-d7da5bd1.css b/astrbot/dashboard/dist/assets/index-d7da5bd1.css similarity index 100% rename from dashboard/dist/assets/index-d7da5bd1.css rename to astrbot/dashboard/dist/assets/index-d7da5bd1.css diff --git a/dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot b/astrbot/dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot similarity index 100% rename from dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot rename to astrbot/dashboard/dist/assets/materialdesignicons-webfont-67d24abe.eot diff --git a/dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff b/astrbot/dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff similarity index 100% rename from dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff rename to astrbot/dashboard/dist/assets/materialdesignicons-webfont-80bb28b3.woff diff --git a/dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf b/astrbot/dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf similarity index 100% rename from dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf rename to astrbot/dashboard/dist/assets/materialdesignicons-webfont-a58ecb54.ttf diff --git a/dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 b/astrbot/dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 similarity index 100% rename from dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 rename to astrbot/dashboard/dist/assets/materialdesignicons-webfont-c1c004a9.woff2 diff --git a/dashboard/dist/assets/md5-cf2f62a3.js b/astrbot/dashboard/dist/assets/md5-cf2f62a3.js similarity index 100% rename from dashboard/dist/assets/md5-cf2f62a3.js rename to astrbot/dashboard/dist/assets/md5-cf2f62a3.js diff --git a/dashboard/dist/favicon.svg b/astrbot/dashboard/dist/favicon.svg similarity index 100% rename from dashboard/dist/favicon.svg rename to astrbot/dashboard/dist/favicon.svg diff --git a/dashboard/dist/index.html b/astrbot/dashboard/dist/index.html similarity index 100% rename from dashboard/dist/index.html rename to astrbot/dashboard/dist/index.html diff --git a/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py similarity index 98% rename from dashboard/routes/__init__.py rename to astrbot/dashboard/routes/__init__.py index bf7edbae1..5149b4bae 100644 --- a/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -6,6 +6,7 @@ from .stat import StatRoute from .log import LogRoute from .static_file import StaticFileRoute + __all__ = [ "AuthRoute", "PluginRoute", @@ -14,4 +15,5 @@ __all__ = [ "StatRoute", "LogRoute", "StaticFileRoute" -] \ No newline at end of file +] + diff --git a/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py similarity index 66% rename from dashboard/routes/auth.py rename to astrbot/dashboard/routes/auth.py index febfb5f42..2da21ea81 100644 --- a/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -1,10 +1,10 @@ -from .. import Route, Response +from .route import Route, Response from quart import Quart, request -from type.types import Context +from core.config.astrbot_config import AstrBotConfig class AuthRoute(Route): - def __init__(self, context: Context, app: Quart) -> None: - super().__init__(context, app) + def __init__(self, config: AstrBotConfig, app: Quart) -> None: + super().__init__(config, app) self.routes = { '/auth/login': ('POST', self.login), '/auth/password/reset': ('POST', self.reset_password), @@ -12,8 +12,8 @@ class AuthRoute(Route): self.register_routes() async def login(self): - username = self.context.config_helper.dashboard.username - password = self.context.config_helper.dashboard.password + username = self.config.dashboard.username + password = self.config.dashboard.password post_data = await request.json if post_data["username"] == username and post_data["password"] == password: return Response().ok({ @@ -24,10 +24,10 @@ class AuthRoute(Route): return Response().error("用户名或密码错误").__dict__ async def reset_password(self): - password = self.context.config_helper.dashboard.password + password = self.config.dashboard.password post_data = await request.json if post_data["password"] == password: - self.context.config_helper.dashboard.password = post_data['new_password'] + self.config.dashboard.password = post_data['new_password'] return Response().ok(None).__dict__ else: return Response().error("原密码错误").__dict__ \ No newline at end of file diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py new file mode 100644 index 000000000..b8c245b68 --- /dev/null +++ b/astrbot/dashboard/routes/config.py @@ -0,0 +1,159 @@ +import os, json +from .route import Route, Response +from quart import Quart, request +from core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP +from core.config.astrbot_config import AstrBotConfig +from core.plugin.config import update_config +from core.core_lifecycle import AstrBotCoreLifecycle +from dataclasses import asdict + +def try_cast(value: str, type_: str): + if type_ == "int" and value.isdigit(): + return int(value) + elif type_ == "float" and isinstance(value, str) \ + and value.replace(".", "", 1).isdigit(): + return float(value) + elif type_ == "float" and isinstance(value, int): + return float(value) + +def validate_config(data, config: AstrBotConfig): + errors = [] + def validate(data, metadata=CONFIG_METADATA_2, path=""): + for key, meta in metadata.items(): + if key not in data: + continue + value = data[key] + # null 转换 + if value is None: + data[key] = DEFAULT_VALUE_MAP(meta["type"]) + continue + # 递归验证 + if meta["type"] == "list" and isinstance(value, list): + for item in value: + validate(item, meta["items"], path=f"{path}{key}.") + elif meta["type"] == "object" and isinstance(value, dict): + validate(value, meta["items"], path=f"{path}{key}.") + + if meta["type"] == "int" and not isinstance(value, int): + casted = try_cast(value, "int") + if casted is None: + errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}") + data[key] = casted + elif meta["type"] == "float" and not isinstance(value, float): + casted = try_cast(value, "float") + if casted is None: + errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}") + data[key] = casted + elif meta["type"] == "bool" and not isinstance(value, bool): + errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}") + elif meta["type"] in ["string", "text"] and not isinstance(value, str): + errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}") + elif meta["type"] == "list" and not isinstance(value, list): + errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}") + elif meta["type"] == "object" and not isinstance(value, dict): + errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}") + validate(value, meta["items"], path=f"{path}{key}.") + validate(data) + + # hardcode warning + data['config_version'] = config.config_version + data['dashboard'] = asdict(config.dashboard) + + return errors + +def save_astrbot_config(post_config: dict, config: AstrBotConfig): + '''验证并保存配置''' + errors = validate_config(post_config, config) + if errors: + raise ValueError(f"格式校验未通过: {errors}") + config.flush_config(post_config) + +def save_extension_config(post_config: dict): + if 'namespace' not in post_config: + raise ValueError("Missing key: namespace") + if 'config' not in post_config: + raise ValueError("Missing key: config") + + namespace = post_config['namespace'] + config: list = post_config['config'][0]['body'] + for item in config: + key = item['path'] + value = item['value'] + typ = item['val_type'] + if typ == 'int': + if not value.isdigit(): + raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}") + value = int(value) + update_config(namespace, key, value) + +class ConfigRoute(Route): + def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle) -> None: + super().__init__(config, app) + self.config_key_dont_show = ['dashboard', 'config_version'] + self.core_lifecycle = core_lifecycle + self.routes = { + '/config/get': ('GET', self.get_configs), + '/config/astrbot/update': ('POST', self.post_astrbot_configs), + '/config/plugin/update': ('POST', self.post_extension_configs), + } + self.register_routes() + + async def get_configs(self): + # namespace 为空时返回 AstrBot 配置 + # 否则返回指定 namespace 的插件配置 + namespace = "" if "namespace" not in request.args else request.args["namespace"] + if not namespace: + return Response().ok(await self._get_astrbot_config()).__dict__ + return Response().ok(await self._get_extension_config(namespace)).__dict__ + + async def post_astrbot_configs(self): + post_configs = await request.json + try: + await self._save_astrbot_configs(post_configs) + return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ + except Exception as e: + return Response().error(str(e)).__dict__ + + async def post_extension_configs(self): + post_configs = await request.json + try: + await self._save_extension_configs(post_configs) + return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ + except Exception as e: + return Response().error(str(e)).__dict__ + + async def _get_astrbot_config(self): + config = self.config.to_dict() + for key in self.config_key_dont_show: + if key in config: + del config[key] + return { + "metadata": CONFIG_METADATA_2, + "config": config, + } + + async def _get_extension_config(self, namespace: str): + path = f"data/config/{namespace}.json" + if not os.path.exists(path): + return [] + with open(path, "r", encoding="utf-8-sig") as f: + return [{ + "config_type": "group", + "name": namespace + " 插件配置", + "description": "", + "body": list(json.load(f).values()) + },] + + async def _save_astrbot_configs(self, post_configs: dict): + try: + save_astrbot_config(post_configs, self.config) + self.core_lifecycle.restart() + except Exception as e: + raise e + + async def _save_extension_configs(self, post_configs: dict): + try: + save_extension_config(post_configs) + self.core_lifecycle.restart() + except Exception as e: + raise e \ No newline at end of file diff --git a/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py similarity index 57% rename from dashboard/routes/log.py rename to astrbot/dashboard/routes/log.py index f175e08aa..bda4af55f 100644 --- a/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -1,19 +1,20 @@ import asyncio from quart import websocket from quart import Quart -from type.types import Context -from .. import logger +from core.config.astrbot_config import AstrBotConfig +from core import logger, LogBroker +from .route import Route, Response -class LogRoute: - def __init__(self, context: Context, app: Quart) -> None: - self.app = app - self.context = context +class LogRoute(Route): + def __init__(self, config: AstrBotConfig, app: Quart, log_broker: LogBroker) -> None: + super().__init__(config, app) + self.log_broker = log_broker self.app.add_url_rule('/api/live-log', view_func=self.log, methods=['GET'], websocket=True) async def log(self): queue = None try: - queue = self.context.log_broker.register() + queue = self.log_broker.register() while True: message = await queue.get() await websocket.send(message) @@ -23,4 +24,4 @@ class LogRoute: logger.error(f"WebSocket 连接错误: {e}") finally: if queue: - self.context.log_broker.unregister(queue) \ No newline at end of file + self.log_broker.unregister(queue) \ No newline at end of file diff --git a/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py similarity index 75% rename from dashboard/routes/plugin.py rename to astrbot/dashboard/routes/plugin.py index e7fda7f69..e8f82222e 100644 --- a/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,13 +1,14 @@ import threading, traceback, uuid -from .. import Route, Response, logger +from .route import Route, Response +from core import logger from quart import Quart, request -from type.types import Context -from model.plugin.manager import PluginManager -from util.updator.astrbot_updator import AstrBotUpdator +from core.config.astrbot_config import AstrBotConfig +from core.plugin.plugin_manager import PluginManager +from core.core_lifecycle import AstrBotCoreLifecycle class PluginRoute(Route): - def __init__(self, context: Context, app: Quart, astrbot_updator: AstrBotUpdator, plugin_manager: PluginManager) -> None: - super().__init__(context, app) + def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None: + super().__init__(config, app) self.routes = { '/plugin/get': ('GET', self.get_plugins), '/plugin/install': ('POST', self.install_plugin), @@ -15,13 +16,13 @@ class PluginRoute(Route): '/plugin/update': ('POST', self.update_plugin), '/plugin/uninstall': ('POST', self.uninstall_plugin), } - self.astrbot_updator = astrbot_updator + self.core_lifecycle = core_lifecycle self.plugin_manager = plugin_manager self.register_routes() async def get_plugins(self): _plugin_resp = [] - for plugin in self.context.cached_plugins: + for plugin in self.plugin_manager.context.registered_plugins: _p = plugin.metadata _t = { "name": _p.plugin_name, @@ -39,9 +40,9 @@ class PluginRoute(Route): try: logger.info(f"正在安装插件 {repo_url}") await self.plugin_manager.install_plugin(repo_url) - threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() - logger.info(f"安装插件 {repo_url} 成功, 2秒后重启") - return Response().ok(None, "安装成功,程序将在 2 秒内重启。").__dict__ + self.core_lifecycle.restart() + logger.info(f"安装插件 {repo_url} 成功。") + return Response().ok(None, "安装成功。").__dict__ except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ @@ -55,8 +56,8 @@ class PluginRoute(Route): await file.save(file_path) self.plugin_manager.install_plugin_from_file(file_path) logger.info(f"安装插件 {file.filename} 成功") - threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() - return Response().ok(None, "安装成功,程序将在 2 秒内重启。").__dict__ + self.core_lifecycle.restart() + return Response().ok(None, "安装成功。").__dict__ except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ @@ -79,7 +80,7 @@ class PluginRoute(Route): try: logger.info(f"正在更新插件 {plugin_name}") await self.plugin_manager.update_plugin(plugin_name) - threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() + self.core_lifecycle.restart() logger.info(f"更新插件 {plugin_name} 成功,2秒后重启") return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__ except Exception as e: diff --git a/dashboard/__init__.py b/astrbot/dashboard/routes/route.py similarity index 79% rename from dashboard/__init__.py rename to astrbot/dashboard/routes/route.py index 167381421..5cc8160c6 100644 --- a/dashboard/__init__.py +++ b/astrbot/dashboard/routes/route.py @@ -1,13 +1,12 @@ -import logging +from core.config.astrbot_config import AstrBotConfig from dataclasses import dataclass from quart import Quart -from type.types import Context -logger = logging.getLogger("astrbot") + class Route(): - def __init__(self, context: Context, app: Quart): - self.context = context + def __init__(self, config: AstrBotConfig, app: Quart): self.app = app + self.config = config def register_routes(self): for route, (method, func) in self.routes.items(): diff --git a/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py similarity index 71% rename from dashboard/routes/stat.py rename to astrbot/dashboard/routes/stat.py index 294053f50..cc5bd95e0 100644 --- a/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,21 +1,29 @@ import traceback, psutil, time, aiohttp -from .. import Route, Response, logger +from .route import Route, Response +from core import logger from quart import Quart, request -from type.types import Context -from astrbot.db import BaseDatabase -from type.config import VERSION +from core.config.astrbot_config import AstrBotConfig +from core.core_lifecycle import AstrBotCoreLifecycle +from core.db import BaseDatabase +from core.config import VERSION class StatRoute(Route): - def __init__(self, context: Context, app: Quart, db_helper: BaseDatabase) -> None: - super().__init__(context, app) + def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: + super().__init__(config, app) self.routes = { '/stat/get': ('GET', self.get_stat), '/stat/version': ('GET', self.get_version), '/stat/dashboard-version': ('GET', self.get_dashboard_version), - '/stat/start-time': ('GET', self.get_start_time) + '/stat/start-time': ('GET', self.get_start_time), + '/stat/restart-core': ('GET', self.restart_core) } self.db_helper = db_helper self.register_routes() + self.core_lifecycle = core_lifecycle + + async def restart_core(self): + self.core_lifecycle.restart() + return Response().ok().__dict__ def format_sec(self, sec: int): m, s = divmod(sec, 60) @@ -39,7 +47,7 @@ class StatRoute(Route): async def get_start_time(self): return Response().ok({ - "start_time": self.context._start_running, + "start_time": self.core_lifecycle.start_time }).__dict__ async def get_stat(self): @@ -64,10 +72,10 @@ class StatRoute(Route): stat_dict.update({ "platform": self.db_helper.get_grouped_base_stats(offset_sec).platform, "message_count": self.db_helper.get_total_message_count() or 0, - "platform_count": len(self.context.platforms), - "plugin_count": len(self.context.cached_plugins), + "platform_count": len(self.core_lifecycle.plugin_manager.get_platform_insts()), + "plugin_count": len(self.core_lifecycle.plugin_manager.get_loaded_plugins()), "message_time_series": message_time_based_stats, - "running": self.format_sec(int(time.time() - self.context._start_running)), + "running": self.format_sec(int(time.time()) - self.core_lifecycle.start_time), "memory": { "process": psutil.Process().memory_info().rss >> 20, "system": psutil.virtual_memory().total >> 20 diff --git a/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py similarity index 64% rename from dashboard/routes/static_file.py rename to astrbot/dashboard/routes/static_file.py index 49c6de221..53ccf907d 100644 --- a/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -1,10 +1,10 @@ -from .. import Route +from .route import Route from quart import Quart -from type.types import Context +from core.config.astrbot_config import AstrBotConfig class StaticFileRoute(Route): - def __init__(self, context: Context, app: Quart) -> None: - super().__init__(context, app) + def __init__(self, config: AstrBotConfig, app: Quart) -> None: + super().__init__(config, app) index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default'] for i in index_: diff --git a/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py similarity index 77% rename from dashboard/routes/update.py rename to astrbot/dashboard/routes/update.py index d51ed70db..c810e1163 100644 --- a/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -1,12 +1,13 @@ import threading, traceback -from .. import Route, Response, logger +from .route import Route, Response from quart import Quart, request -from type.types import Context -from util.updator.astrbot_updator import AstrBotUpdator +from core.config.astrbot_config import AstrBotConfig +from core.updator import AstrBotUpdator +from core import logger class UpdateRoute(Route): - def __init__(self, context: Context, app: Quart, astrbot_updator: AstrBotUpdator) -> None: - super().__init__(context, app) + def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None: + super().__init__(config, app) self.routes = { '/update/check': ('GET', self.check_update), '/update/do': ('POST', self.update_project), @@ -39,7 +40,7 @@ class UpdateRoute(Route): try: await self.astrbot_updator.update(latest=latest, version=version) threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start() - return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__ + return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__ except Exception as e: logger.error(f"/api/update_project: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ \ No newline at end of file diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py new file mode 100644 index 000000000..38b4faf52 --- /dev/null +++ b/astrbot/dashboard/server.py @@ -0,0 +1,40 @@ +import logging +import asyncio +from quart import Quart +from quart.logging import default_handler +from core.core_lifecycle import AstrBotCoreLifecycle +from .routes import * +from core import logger +from core.db import BaseDatabase +from core.plugin.plugin_manager import PluginManager +from core.updator import AstrBotUpdator +from core.utils.io import get_local_ip_addresses +from core.config import AstrBotConfig +from core.db import BaseDatabase + +class AstrBotDashboard(): + def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None: + self.core_lifecycle = core_lifecycle + self.config = core_lifecycle.astrbot_config + self.app = Quart("dashboard", static_folder="dist", static_url_path="/") + self.app.json.sort_keys = False + + logging.getLogger(self.app.name).removeHandler(default_handler) + + self.ar = AuthRoute(self.config, self.app) + self.ur = UpdateRoute(self.config, self.app, core_lifecycle.astrbot_updator) + self.sr = StatRoute(self.config, self.app, db, core_lifecycle) + self.pr = PluginRoute(self.config, self.app, core_lifecycle, core_lifecycle.plugin_manager) + self.cr = ConfigRoute(self.config, self.app, core_lifecycle) + self.lr = LogRoute(self.config, self.app, core_lifecycle.log_broker) + self.sfr = StaticFileRoute(self.config, self.app) + + async def shutdown_trigger_placeholder(self): + while not self.core_lifecycle.event_queue.closed: + await asyncio.sleep(1) + logger.info("管理面板已关闭。") + + def run(self): + ip_addr = get_local_ip_addresses() + logger.info(f"\n-----\n🌈 管理面板已启动,可访问 \n1. http://{ip_addr}:6185\n2. http://localhost:6185 登录。\n------") + return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder) \ No newline at end of file diff --git a/astrbot/main.py b/astrbot/main.py new file mode 100644 index 000000000..f4b7f6107 --- /dev/null +++ b/astrbot/main.py @@ -0,0 +1,57 @@ + +import os +import asyncio +import sys +import mimetypes + +from core.core_lifecycle import AstrBotCoreLifecycle +from core.db.sqlite import SQLiteDatabase +from core.config import DB_PATH +from dashboard import AstrBotDashBoardLifecycle + +from core import logger, LogManager, LogBroker + +# add parent path to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +logo_tmpl = r""" + ___ _______.___________..______ .______ ______ .___________. + / \ / | || _ \ | _ \ / __ \ | | + / ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----` + / /_\ \ \ \ | | | / | _ < | | | | | | + / _____ \ .----) | | | | |\ \----.| |_) | | `--' | | | +/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__| + +""" + +def check_env(): + if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): + logger.error("请使用 Python3.10+ 运行本项目。") + exit() + + os.makedirs("data/config", exist_ok=True) + os.makedirs("data/plugins", exist_ok=True) + os.makedirs("data/temp", exist_ok=True) + + # workaround for issue #181 + mimetypes.add_type("text/javascript", ".js") + mimetypes.add_type("text/javascript", ".mjs") + mimetypes.add_type("application/json", ".json") + +if __name__ == "__main__": + check_env() + + # start log broker + log_broker = LogBroker() + LogManager.set_queue_handler(logger, log_broker) + + # start db + db = SQLiteDatabase(DB_PATH) + + # print logo + logger.info(logo_tmpl) + + dashboard_lifecycle = AstrBotDashBoardLifecycle(db) + core_lifecycle = AstrBotCoreLifecycle(log_broker, db) + + asyncio.run(dashboard_lifecycle.start(core_lifecycle)) \ No newline at end of file diff --git a/astrbot/message/baidu_aip_judge.py b/astrbot/message/baidu_aip_judge.py deleted file mode 100644 index f8d09a860..000000000 --- a/astrbot/message/baidu_aip_judge.py +++ /dev/null @@ -1,28 +0,0 @@ -from aip import AipContentCensor -from util.cmd_config import BaiduAIPConfig - - -class BaiduJudge: - def __init__(self, baidu_configs: BaiduAIPConfig) -> None: - self.app_id = baidu_configs.app_id - self.api_key = baidu_configs.api_key - self.secret_key = baidu_configs.secret_key - self.client = AipContentCensor(self.app_id, - self.api_key, - self.secret_key) - - def judge(self, text): - res = self.client.textCensorUserDefined(text) - if 'conclusionType' not in res: - return False, "百度审核服务未知错误" - if res['conclusionType'] == 1: - return True, "合规" - else: - if 'data' not in res: - return False, "百度审核服务未知错误" - count = len(res['data']) - info = f"百度审核服务发现 {count} 处违规:\n" - for i in res['data']: - info += f"{i['msg']};\n" - info += "\n判断结果:"+res['conclusion'] - return False, info diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py deleted file mode 100644 index f08783ee5..000000000 --- a/astrbot/message/handler.py +++ /dev/null @@ -1,296 +0,0 @@ -import time, json -import re, os -import asyncio -import traceback -import astrbot.message.unfit_words as uw - -from typing import Dict -from astrbot.db import BaseDatabase -from model.provider.provider import Provider -from model.command.manager import CommandManager -from type.message_event import AstrMessageEvent, MessageResult -from type.types import Context -from type.command import CommandResult -from util.log import LogManager -from logging import Logger -from nakuru.entities.components import Image -from util.agent.func_call import FuncCall -from openai._exceptions import * -from openai.types.chat.chat_completion_message_tool_call import Function - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - - -class RateLimitHelper(): - def __init__(self, context: Context) -> None: - self.user_rate_limit: Dict[int, int] = {} - rl = context.config_helper.platform_settings.rate_limit - self.rate_limit_time: int = rl.time - self.rate_limit_count: int = rl.count - self.user_frequency = {} - - def check_frequency(self, session_id: str) -> bool: - ''' - 检查发言频率 - ''' - ts = int(time.time()) - if session_id in self.user_frequency: - if ts-self.user_frequency[session_id]['time'] > self.rate_limit_time: - self.user_frequency[session_id]['time'] = ts - self.user_frequency[session_id]['count'] = 1 - return True - else: - if self.user_frequency[session_id]['count'] >= self.rate_limit_count: - return False - else: - self.user_frequency[session_id]['count'] += 1 - return True - else: - t = {'time': ts, 'count': 1} - self.user_frequency[session_id] = t - return True - -class ContentSafetyHelper(): - def __init__(self, context: Context) -> None: - self.baidu_judge = None - aip = context.config_helper.content_safety.baidu_aip - if aip.enable: - try: - from astrbot.message.baidu_aip_judge import BaiduJudge - self.baidu_judge = BaiduJudge(aip) - logger.info("已启用百度 AI 内容审核。") - except ImportError as e: - logger.error("检测到库依赖不完整,将不会启用百度 AI 内容审核。请先使用 pip 安装 `baidu_aip` 包。") - logger.error(e) - except BaseException as e: - logger.error("百度 AI 内容审核初始化失败。") - logger.error(e) - - async def check_content(self, content: str) -> bool: - ''' - 检查文本内容是否合法 - ''' - for i in uw.unfit_words_q: - matches = re.match(i, content.strip(), re.I | re.M) - if matches: - return False - if self.baidu_judge != None: - check, msg = await asyncio.to_thread(self.baidu_judge.judge, content) - if not check: - logger.info(f"百度 AI 内容审核发现以下违规:{msg}") - return False - return True - - def filter_content(self, content: str) -> str: - ''' - 过滤文本内容 - ''' - for i in uw.unfit_words_q: - content = re.sub(i, "*", content, flags=re.I) - return content - - def baidu_check(self, content: str) -> bool: - ''' - 使用百度 AI 内容审核检查文本内容是否合法 - ''' - if self.baidu_judge != None: - check, msg = self.baidu_judge.judge(content) - if not check: - logger.info(f"百度 AI 内容审核发现以下违规:{msg}") - return False - return True - -class MessageHandler(): - def __init__(self, context: Context, - command_manager: CommandManager, - db_helper: BaseDatabase) -> None: - self.context = context - self.command_manager = command_manager - self.db_helper = db_helper - self.rate_limit_helper = RateLimitHelper(context) - self.content_safety_helper = ContentSafetyHelper(context) - self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix - self.llm_identifier = self.context.config_helper.llm_settings.identifier - if self.llm_wake_prefix: - self.llm_wake_prefix = self.llm_wake_prefix.strip() - self.provider = self.context.llms[0].llm_instance if len(self.context.llms) > 0 else None - self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix) - self.llm_tools = FuncCall(self.provider) - - def set_provider(self, provider: Provider): - self.provider = provider - - async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult: - ''' - Handle the message event, including commands, plugins, etc. - - `llm_provider`: the provider to use for LLM. If None, use the default provider - ''' - msg_plain = message.message_str.strip() - provider = llm_provider if llm_provider else self.provider - - # TODO: this should be configurable - # if not message.message_str: - # return MessageResult("Hi~") - - # check the rate limit - if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): - logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。") - return - - # remove the nick prefix - for nick in self.context.config_helper.wake_prefix: - if msg_plain.startswith(nick): - msg_plain = msg_plain.removeprefix(nick) - break - message.message_str = msg_plain - - # scan candidate commands - cmd_res = await self.command_manager.scan_command(message, self.context) - if cmd_res: - assert(isinstance(cmd_res, CommandResult)) - return MessageResult( - cmd_res.message_chain, - is_command_call=True, - use_t2i=cmd_res.is_use_t2i - ) - - # middlewares - for middleware in self.context.middlewares: - try: - logger.info(f"执行中间件 {middleware.origin}/{middleware.name}...") - await middleware.func(message, self.context) - except BaseException as e: - logger.error(f"中间件 {middleware.origin}/{middleware.name} 处理消息时发生异常:{e},跳过。") - logger.error(traceback.format_exc()) - - if message.only_command: - return - - # next is the LLM part - # check if the message is a llm-wake-up command - if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix): - logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。") - return - - if not provider: - logger.debug("没有任何 LLM 可用,忽略。") - return - - # check the content safety - if not await self.content_safety_helper.check_content(msg_plain): - return MessageResult("信息包含违规内容,由于机器人管理者开启内容安全审核,你的此条消息已被停止继续处理。") - - image_url = None - for comp in message.message_obj.message: - if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - break - llm_result = None - try: - if not self.llm_tools.empty(): - # tools-use - tool_use_flag = True - llm_result = await provider.text_chat( - prompt=msg_plain, - session_id=message.session_id, - tools=self.llm_tools.get_func() - ) - self.context.metrics_uploader.llm_stats[provider.get_curr_model()] += 1 - - if isinstance(llm_result, Function): - logger.debug(f"function-calling: {llm_result}") - func_obj = None - for i in self.llm_tools.func_list: - if i["name"] == llm_result.name: - func_obj = i["func_obj"] - break - if not func_obj: - return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。") - try: - args = json.loads(llm_result.arguments) - args['ame'] = message - args['context'] = self.context - try: - cmd_res = await func_obj(**args) - except TypeError as e: - args.pop('ame') - args.pop('context') - cmd_res = await func_obj(**args) - if isinstance(cmd_res, CommandResult): - return MessageResult( - cmd_res.message_chain, - is_command_call=True, - use_t2i=cmd_res.is_use_t2i - ) - elif isinstance(cmd_res, str): - return MessageResult(cmd_res) - elif not cmd_res: - return - else: - return MessageResult(f"AstrBot Function-calling 异常:调用:{llm_result} 时,返回了未知的返回值类型。") - except BaseException as e: - traceback.print_exc() - return MessageResult("AstrBot Function-calling 异常:" + str(e)) - else: - return MessageResult(llm_result) - - else: - # normal chat - tool_use_flag = False - # add user info to the prompt - if self.llm_identifier: - user_id = message.message_obj.sender.user_id - user_nickname = message.message_obj.sender.nickname - user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n" - msg_plain = user_info + msg_plain - - llm_result = await provider.text_chat( - prompt=msg_plain, - session_id=message.session_id, - image_url=image_url - ) - self.context.metrics_uploader.llm_stats[provider.get_curr_model()] += 1 - except BadRequestError as e: - if tool_use_flag: - # seems like the model don't support function-calling - logger.error(f"error: {e}. Using local function-calling implementation") - - try: - # use local function-calling implementation - args = { - 'question': llm_result, - 'func_definition': self.llm_tools.func_dump(), - } - _, has_func = await self.llm_tools.func_call(**args) - - if not has_func: - # normal chat - llm_result = await provider.text_chat( - prompt=msg_plain, - session_id=message.session_id, - image_url=image_url - ) - except BaseException as e: - logger.error(traceback.format_exc()) - return CommandResult("AstrBot Function-calling 异常:" + str(e)) - else: - logger.error(traceback.format_exc()) - logger.error(f"LLM 调用失败。") - return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e)) - except BaseException as e: - logger.error(traceback.format_exc()) - logger.error(f"LLM 调用失败。") - return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e)) - - # concatenate reply prefix - if self.reply_prefix: - llm_result = self.reply_prefix + llm_result - - # mask unsafe content - llm_result = self.content_safety_helper.filter_content(llm_result) - check = self.content_safety_helper.baidu_check(llm_result) - if not check: - return MessageResult("LLM 输出的信息包含违规内容,由于机器人管理者开启了内容安全审核,该条消息已拦截。") - - return MessageResult(llm_result) \ No newline at end of file diff --git a/astrbot/message/unfit_words.py b/astrbot/message/unfit_words.py deleted file mode 100644 index fbe245d30..000000000 --- a/astrbot/message/unfit_words.py +++ /dev/null @@ -1,114 +0,0 @@ -# 网上搜集的。 -# 来源词拦截 -# 支持正则 -unfit_words_q = ["习近平", - "胡锦涛", - "江泽民", - "温家宝", - "李克强", - "李长春", - "毛泽东", - "邓小平", - "周恩来", - "社会主义", - "共产党", - "共产主义", - "大陆官方", - "北京政权", - "中华帝国", - "中国政府", - "共狗", - "六四事件", - "天安门", - "六四", - "政治局常委", - "学潮", - "八九", - "二十大", - "民进党", - "台独", - "台湾独立", - "台湾国", - "国民党", - "台湾民国", - "中华民国", - "pornhub", - "Pornhub", - "作爱", - "做爱", - "性交", - "自慰", - "阴茎", - "淫妇", - "肛交", - "交配", - "性关系", - "性活动", - "色情", - "色图", - "裸体", - "小穴", - "淫荡", - "性爱", - "港独", - "法轮功", - "六四"] - -# 回复词过滤 -unfit_words = ["习近平", - "胡锦涛", - "江泽民", - "温家宝", - "李克强", - "李长春", - "毛泽东", - "邓小平", - "周恩来", - "社会主义", - "共产党", - "共产主义", - "大陆官方", - "北京政权", - "中华帝国", - "中国政府", - "共狗", - "六四事件", - "天安门", - "六四", - "政治局常委", - "学潮", - "八九", - "二十大", - "民进党", - "台独", - "台湾独立", - "台湾国", - "国民党", - "台湾民国", - "中华民国", - "pornhub", - "Pornhub", - "作爱", - "做爱", - "性交", - "自慰", - "阴茎", - "淫妇", - "肛交", - "交配", - "性关系", - "性活动", - "色情", - "色图", - "涩图", - "裸体", - "小穴", - "淫荡", - "性爱", - "中华人民共和国", - "党中央", - "中央军委主席", - "台湾", - "港独", - "法轮功", - "PRC"] diff --git a/dashboard/routes/config.py b/dashboard/routes/config.py deleted file mode 100644 index dcae5b92b..000000000 --- a/dashboard/routes/config.py +++ /dev/null @@ -1,80 +0,0 @@ -import os, json, threading -from .. import Route, Response -from ..utils.config import * -from quart import Quart, request -from type.types import Context -from type.config import CONFIG_METADATA_2 -from util.updator.astrbot_updator import AstrBotUpdator - - -class ConfigRoute(Route): - def __init__(self, context: Context, app: Quart, astrbot_updator: AstrBotUpdator) -> None: - super().__init__(context, app) - self.config_key_dont_show = ['dashboard', 'config_version'] - self.astrbot_updator = astrbot_updator - self.routes = { - '/config/get': ('GET', self.get_configs), - '/config/astrbot/update': ('POST', self.post_astrbot_configs), - '/config/plugin/update': ('POST', self.post_extension_configs), - } - self.register_routes() - - async def get_configs(self): - # namespace 为空时返回 AstrBot 配置 - # 否则返回指定 namespace 的插件配置 - namespace = "" if "namespace" not in request.args else request.args["namespace"] - if not namespace: - return Response().ok(await self._get_astrbot_config()).__dict__ - return Response().ok(await self._get_extension_config(namespace)).__dict__ - - async def post_astrbot_configs(self): - post_configs = await request.json - try: - await self._save_astrbot_configs(post_configs) - return Response().ok(None, "保存成功~ 机器人将在 3 秒内重启以应用新的配置。").__dict__ - except Exception as e: - return Response().error(str(e)).__dict__ - - async def post_extension_configs(self): - post_configs = await request.json - try: - await self._save_extension_configs(post_configs) - return Response().ok(None, "保存成功~ 机器人将在 3 秒内重启以应用新的配置。").__dict__ - except Exception as e: - return Response().error(str(e)).__dict__ - - async def _get_astrbot_config(self): - config = self.context.config_helper.to_dict() - for key in self.config_key_dont_show: - if key in config: - del config[key] - return { - "metadata": CONFIG_METADATA_2, - "config": config, - } - - async def _get_extension_config(self, namespace: str): - path = f"data/config/{namespace}.json" - if not os.path.exists(path): - return [] - with open(path, "r", encoding="utf-8-sig") as f: - return [{ - "config_type": "group", - "name": namespace + " 插件配置", - "description": "", - "body": list(json.load(f).values()) - },] - - async def _save_astrbot_configs(self, post_configs: dict): - try: - save_astrbot_config(post_configs, self.context) - threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start() - except Exception as e: - raise e - - async def _save_extension_configs(self, post_configs: dict): - try: - save_extension_config(post_configs) - threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start() - except Exception as e: - raise e \ No newline at end of file diff --git a/dashboard/server.py b/dashboard/server.py deleted file mode 100644 index dea7f4873..000000000 --- a/dashboard/server.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -import asyncio -from quart import Quart -from quart.logging import default_handler -from type.types import Context -from .routes import * -from . import logger -from astrbot.db import BaseDatabase -from model.plugin.manager import PluginManager -from util.updator.astrbot_updator import AstrBotUpdator -from util.io import get_local_ip_addresses - -class AstrBotDashboard(): - def __init__(self, context: Context, - plugin_manager: PluginManager, - astrbot_updator: AstrBotUpdator, - db_helper: BaseDatabase) -> None: - self.context = context - self.app = Quart("dashboard", static_folder="dist", static_url_path="/") - self.app.json.sort_keys = False - - logging.getLogger(self.app.name).removeHandler(default_handler) - - self.ar = AuthRoute(context, self.app) - self.ur = UpdateRoute(context, self.app, astrbot_updator) - self.sr = StatRoute(context, self.app, db_helper) - self.pr = PluginRoute(context, self.app, astrbot_updator, plugin_manager) - self.cr = ConfigRoute(context, self.app, astrbot_updator) - self.lr = LogRoute(context, self.app) - self.sfr = StaticFileRoute(context, self.app) - - async def shutdown_trigger_placeholder(self): - while self.context.running: - await asyncio.sleep(1) - - def run(self): - ip_addr = get_local_ip_addresses() - logger.info(f"管理面板已启动,可访问 http://{ip_addr}:6185 登录。") - return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder) \ No newline at end of file diff --git a/dashboard/utils/config.py b/dashboard/utils/config.py deleted file mode 100644 index e12ee61ac..000000000 --- a/dashboard/utils/config.py +++ /dev/null @@ -1,100 +0,0 @@ -from dataclasses import asdict -from util.plugin_dev.api.v1.config import update_config -from type.config import CONFIG_METADATA_2 -from type.types import Context - -def try_cast(value: str, type_: str): - if type_ == "int" and value.isdigit(): - return int(value) - elif type_ == "float" and isinstance(value, str) \ - and value.replace(".", "", 1).isdigit(): - return float(value) - elif type_ == "float" and isinstance(value, int): - return float(value) - -def get_default_val_by_type(type_: str): - if type_ == "int": - return 0 - elif type_ == "float": - return 0.0 - elif type_ == "bool": - return False - elif type_ == "string": - return "" - elif type_ == "text": - return "" - elif type_ == "list": - return [] - elif type_ == "object": - return {} - - -def validate_config(data, context: Context): - errors = [] - def validate(data, metadata=CONFIG_METADATA_2, path=""): - for key, meta in metadata.items(): - if key not in data: - continue - value = data[key] - # null 转换 - if value is None: - data[key] = get_default_val_by_type(meta["type"]) - continue - # 递归验证 - if meta["type"] == "list" and isinstance(value, list): - for item in value: - validate(item, meta["items"], path=f"{path}{key}.") - elif meta["type"] == "object" and isinstance(value, dict): - validate(value, meta["items"], path=f"{path}{key}.") - - if meta["type"] == "int" and not isinstance(value, int): - casted = try_cast(value, "int") - if casted is None: - errors.append(f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}") - data[key] = casted - elif meta["type"] == "float" and not isinstance(value, float): - casted = try_cast(value, "float") - if casted is None: - errors.append(f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}") - data[key] = casted - elif meta["type"] == "bool" and not isinstance(value, bool): - errors.append(f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}") - elif meta["type"] in ["string", "text"] and not isinstance(value, str): - errors.append(f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}") - elif meta["type"] == "list" and not isinstance(value, list): - errors.append(f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}") - elif meta["type"] == "object" and not isinstance(value, dict): - errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}") - validate(value, meta["items"], path=f"{path}{key}.") - validate(data) - - # hardcode warning - data['config_version'] = context.config_helper.config_version - data['dashboard'] = asdict(context.config_helper.dashboard) - - return errors - -def save_astrbot_config(post_config: dict, context: Context): - '''验证并保存配置''' - errors = validate_config(post_config, context) - if errors: - raise ValueError(f"格式校验未通过: {errors}") - context.config_helper.flush_config(post_config) - -def save_extension_config(post_config: dict): - if 'namespace' not in post_config: - raise ValueError("Missing key: namespace") - if 'config' not in post_config: - raise ValueError("Missing key: config") - - namespace = post_config['namespace'] - config: list = post_config['config'][0]['body'] - for item in config: - key = item['path'] - value = item['value'] - typ = item['val_type'] - if typ == 'int': - if not value.isdigit(): - raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}") - value = int(value) - update_config(namespace, key, value) diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py deleted file mode 100644 index b44c37c46..000000000 --- a/model/command/internal_handler.py +++ /dev/null @@ -1,270 +0,0 @@ -import aiohttp, os - -from model.command.manager import CommandManager -from model.plugin.manager import PluginManager -from type.message_event import AstrMessageEvent -from type.command import CommandResult -from type.types import Context -from type.config import VERSION -from util.log import LogManager -from logging import Logger -from util.agent.web_searcher import search_from_bing, fetch_website_content - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - - -class InternalCommandHandler: - def __init__(self, manager: CommandManager, plugin_manager: PluginManager) -> None: - self.manager = manager - self.plugin_manager = plugin_manager - - self.manager.register("help", "查看帮助", 10, self.help) - self.manager.register("wake", "唤醒前缀", 10, self.set_nick) - self.manager.register("update", "更新管理", 10, self.update) - self.manager.register("plugin", "插件管理", 10, self.plugin) - self.manager.register("reboot", "重启 AstrBot", 10, self.reboot) - self.manager.register("websearch", "网页搜索", 10, self.web_search) - self.manager.register("t2i", "文转图", 10, self.t2i_toggle) - self.manager.register("myid", "用户ID", 10, self.myid) - self.manager.register("provider", "LLM 接入源", 10, self.provider) - - def _check_auth(self, message: AstrMessageEvent, context: Context): - if os.environ.get("TEST_MODE", "off") == "on": - return - if message.role != "admin": - user_id = message.message_obj.sender.user_id - raise Exception(f"用户(ID: {user_id}) 没有足够的权限使用该指令。") - - def provider(self, message: AstrMessageEvent, context: Context): - if len(context.llms) == 0: - return CommandResult().message("当前没有加载任何 LLM 接入源。") - - tokens = self.manager.command_parser.parse(message.message_str) - - if tokens.len == 1: - ret = "## 当前载入的 LLM 接入源\n" - for idx, llm in enumerate(context.llms): - ret += f"{idx}. {llm.llm_name}" - if llm.origin: - ret += f" (来源: {llm.origin})" - if context.message_handler.provider == llm.llm_instance: - ret += " (当前使用)" - ret += "\n" - - ret += "\n使用 provider <序号> 切换 LLM 接入源。" - return CommandResult().message(ret) - else: - try: - idx = int(tokens.get(1)) - if idx >= len(context.llms): - return CommandResult().message("provider: 无效的序号。") - context.message_handler.set_provider(context.llms[idx].llm_instance) - return CommandResult().message(f"已经成功切换到 LLM 接入源 {context.llms[idx].llm_name}。") - except BaseException as e: - return CommandResult().message("provider: 参数错误。") - - def set_nick(self, message: AstrMessageEvent, context: Context): - self._check_auth(message, context) - message_str = message.message_str - l = message_str.split(" ") - if len(l) == 1: - return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}") - nick = l[1].strip() - if not nick: - return CommandResult().message("wake: 请指定唤醒词。") - context.config_helper.wake_prefix = [nick] - context.config_helper.save_config() - return CommandResult( - hit=True, - success=True, - message_chain=f"已经成功将唤醒前缀设定为 {nick}。", - ) - - async def update(self, message: AstrMessageEvent, context: Context): - self._check_auth(message, context) - tokens = self.manager.command_parser.parse(message.message_str) - update_info = await context.updator.check_update(None, None) - if tokens.len == 1: - ret = "" - if not update_info: - ret = f"当前已经是最新版本 v{VERSION}。" - else: - ret = f"发现新版本 {update_info.version},更新内容如下:\n---\n{update_info.body}\n---\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。" - return CommandResult().message(ret) - else: - if tokens.get(1) == "latest": - try: - await context.updator.update() - return CommandResult().message(f"已经成功更新到最新版本 v{update_info.version}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") - except BaseException as e: - return CommandResult().message(f"更新失败。原因:{str(e)}") - elif tokens.get(1).startswith("v"): - try: - await context.updator.update(version=tokens.get(1)) - return CommandResult().message(f"已经成功更新到版本 v{tokens.get(1)}。要应用更新,请重启 AstrBot。输入 /reboot 即可重启") - except BaseException as e: - return CommandResult().message(f"更新失败。原因:{str(e)}") - else: - return CommandResult().message("update: 参数错误。") - - def reboot(self, message: AstrMessageEvent, context: Context): - self._check_auth(message, context) - context.updator._reboot(3, context) - return CommandResult( - hit=True, - success=True, - message_chain="AstrBot 将在 3s 后重启。", - ) - - async def plugin(self, message: AstrMessageEvent, context: Context): - tokens = self.manager.command_parser.parse(message.message_str) - if tokens.len == 1: - ret = "# 插件指令面板 \n- 安装插件: `plugin i 插件Github地址`\n- 卸载插件: `plugin d 插件名`\n- 查看插件列表:`plugin l`\n - 更新插件: `plugin u 插件名`\n" - return CommandResult().message(ret) - - if tokens.get(1) == "l": - plugin_list_info = "" - for plugin in context.cached_plugins: - plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {plugin.metadata.author}: {plugin.metadata.desc}\n" - if plugin_list_info.strip() == "": - return CommandResult().message("plugin v: 没有找到插件。") - return CommandResult().message(plugin_list_info) - - self._check_auth(message, context) - - if tokens.get(1) == "d": - if tokens.len == 2: - return CommandResult().message("plugin d: 请指定要卸载的插件名。") - plugin_name = tokens.get(2) - try: - self.plugin_manager.uninstall_plugin(plugin_name) - except BaseException as e: - return CommandResult().message(f"plugin d: 卸载插件失败。原因:{str(e)}") - return CommandResult().message(f"plugin d: 已经成功卸载插件 {plugin_name}。") - - elif tokens.get(1) == "i": - if tokens.len == 2: - return CommandResult().message("plugin i: 请指定要安装的插件的 Github 地址,或者前往可视化面板安装。") - plugin_url = tokens.get(2) - try: - await self.plugin_manager.install_plugin(plugin_url) - except BaseException as e: - return CommandResult().message(f"plugin i: 安装插件失败。原因:{str(e)}") - return CommandResult().message("plugin i: 已经成功安装插件。") - - elif tokens.get(1) == "u": - if tokens.len == 2: - return CommandResult().message("plugin u: 请指定要更新的插件名。") - plugin_name = tokens.get(2) - try: - await context.plugin_updator.update(plugin_name) - except BaseException as e: - return CommandResult().message(f"plugin u: 更新插件失败。原因:{str(e)}") - return CommandResult().message(f"plugin u: 已经成功更新插件 {plugin_name}。") - - return CommandResult().message("plugin: 参数错误。") - - async def help(self, message: AstrMessageEvent, context: Context): - notice = "" - try: - async with aiohttp.ClientSession() as session: - async with session.get("https://soulter.top/channelbot/notice.json") as resp: - notice = (await resp.json())["notice"] - except BaseException as e: - logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.") - - msg = "# 帮助中心\n## 指令\n" - for key, value in self.manager.commands_handler.items(): - if value.plugin_metadata: - msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n" - else: msg += f"- `{key}`: {value.description}\n" - # plugins - msg += "\n> 使用 plugin l 查看已加载的插件\n" - msg += notice - - return CommandResult().message(msg) - - def web_search(self, message: AstrMessageEvent, context: Context): - l = message.message_str.split(' ') - if len(l) == 1: - return CommandResult( - hit=True, - success=True, - message_chain=f"网页搜索功能当前状态: {context.config_helper.llm_settings.web_search}", - ) - elif l[1] == 'on': - context.config_helper.llm_settings.web_search = True - context.config_helper.save_config() - context.register_llm_tool("web_search", [{ - "type": "string", - "name": "keyword", - "description": "搜索关键词" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - search_from_bing - ) - context.register_llm_tool("fetch_website_content", [{ - "type": "string", - "name": "url", - "description": "要获取内容的网页链接" - }], - "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content - ) - - return CommandResult( - hit=True, - success=True, - message_chain="已开启网页搜索", - ) - elif l[1] == 'off': - context.config_helper.llm_settings.web_search = False - context.config_helper.save_config() - context.unregister_llm_tool("web_search") - context.unregister_llm_tool("fetch_website_content") - - return CommandResult( - hit=True, - success=True, - message_chain="已关闭网页搜索", - ) - else: - return CommandResult( - hit=True, - success=False, - message_chain="参数错误", - ) - - def t2i_toggle(self, message: AstrMessageEvent, context: Context): - p = context.config_helper.t2i - if p: - context.config_helper.t2i = False - context.config_helper.save_config() - return CommandResult( - hit=True, - success=True, - message_chain="已关闭文本转图片模式。", - ) - context.config_helper.t2i = True - context.config_helper.save_config() - - return CommandResult( - hit=True, - success=True, - message_chain="已开启文本转图片模式。", - ) - - def myid(self, message: AstrMessageEvent, context: Context): - try: - user_id = str(message.message_obj.sender.user_id) - return CommandResult( - hit=True, - success=True, - message_chain=f"你在此平台上的ID:{user_id}", - ) - except BaseException as e: - return CommandResult( - hit=True, - success=False, - message_chain=f"获取失败,原因: {str(e)}", - ) diff --git a/model/command/manager.py b/model/command/manager.py deleted file mode 100644 index 08a5156df..000000000 --- a/model/command/manager.py +++ /dev/null @@ -1,145 +0,0 @@ -import heapq -import inspect -import traceback -from typing import Dict -from type.types import Context -from type.plugin import PluginMetadata -from type.message_event import AstrMessageEvent -from type.command import CommandResult -from type.register import RegisteredPlugins -from model.command.parser import CommandParser -from model.plugin.command import PluginCommandBridge -from util.log import LogManager -from logging import Logger -from dataclasses import dataclass - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -@dataclass -class CommandMetadata(): - inner_command: bool - plugin_metadata: PluginMetadata - handler: callable - use_regex: bool = False - ignore_prefix: bool = False - description: str = "" - -class CommandManager(): - def __init__(self): - self.commands = [] - self.commands_handler: Dict[str, CommandMetadata] = {} - self.command_parser = CommandParser() - - def register(self, - command: str, - description: str, - priority: int, - handler: callable, - use_regex: bool = False, - ignore_prefix: bool = False, - plugin_metadata: PluginMetadata = None, - ): - ''' - 优先级越高,越先被处理。 - - use_regex: 是否使用正则表达式匹配指令。 - ''' - if command in self.commands_handler: - raise ValueError(f"Command {command} already exists.") - if not handler: - raise ValueError(f"Handler of {command} is None.") - - heapq.heappush(self.commands, (-priority, command)) - self.commands_handler[command] = CommandMetadata( - inner_command=plugin_metadata == None, - plugin_metadata=plugin_metadata, - handler=handler, - use_regex=use_regex, - ignore_prefix=ignore_prefix, - description=description - ) - if plugin_metadata: - logger.debug(f"已注册 {plugin_metadata.author}/{plugin_metadata.plugin_name} 的指令 {command}。") - else: - logger.debug(f"已注册指令 {command}。") - - def register_from_pcb(self, pcb: PluginCommandBridge): - for request in pcb.plugin_commands_waitlist: - plugin = None - for registered_plugin in pcb.cached_plugins: - if registered_plugin.metadata.plugin_name == request.plugin_name: - plugin = registered_plugin - break - if not plugin: - logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}。") - else: - self.register(command=request.command_name, - description=request.description, - priority=request.priority, - handler=request.handler, - use_regex=request.use_regex, - ignore_prefix=request.ignore_prefix, - plugin_metadata=plugin.metadata) - self.plugin_commands_waitlist = [] - - async def check_command_ignore_prefix(self, message_str: str) -> bool: - for _, command in self.commands: - command_metadata = self.commands_handler[command] - if command_metadata.ignore_prefix: - trig = False - if self.commands_handler[command].use_regex: - trig = self.command_parser.regex_match(message_str, command) - else: - trig = message_str.startswith(command) - if trig: - return True - return False - - async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult: - message_str = message_event.message_str - for _, command in self.commands: - trig = False - if self.commands_handler[command].use_regex: - trig = self.command_parser.regex_match(message_str, command) - else: - trig = message_str.startswith(command) - if trig: - logger.info(f"触发 {command} 指令。") - command_result = await self.execute_handler(command, message_event, context) - if not command_result: - continue - if command_result.hit: - return command_result - - async def execute_handler(self, - command: str, - message_event: AstrMessageEvent, - context: Context) -> CommandResult: - command_metadata = self.commands_handler[command] - handler = command_metadata.handler - # call handler - try: - if inspect.iscoroutinefunction(handler): - command_result = await handler(message_event, context) - else: - command_result = handler(message_event, context) - - # if not isinstance(command_result, CommandResult): - # raise ValueError(f"Command {command} handler should return CommandResult.") - - if not command_result: - return - - context.metrics_uploader.command_stats[command] += 1 - - return command_result - except BaseException as e: - logger.error(traceback.format_exc()) - - if not command_metadata.inner_command: - text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}" - logger.error(text) - else: - text = f"执行 {command} 指令时发生了异常。{e}" - logger.error(text) - return CommandResult().message(text) \ No newline at end of file diff --git a/model/command/openai_official_handler.py b/model/command/openai_official_handler.py deleted file mode 100644 index 11ca1dd98..000000000 --- a/model/command/openai_official_handler.py +++ /dev/null @@ -1,183 +0,0 @@ -from model.command.manager import CommandManager -from type.message_event import AstrMessageEvent -from type.command import CommandResult -from type.types import Context -from util.log import LogManager -from logging import Logger -from nakuru.entities.components import Image -from util.personality import personalities -from util.io import download_image_by_url - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - - -class OpenAIOfficialCommandHandler(): - def __init__(self, manager: CommandManager) -> None: - self.manager = manager - - self.provider = None - - self.manager.register("reset", "重置会话", 10, self.reset) - self.manager.register("his", "查看历史记录", 10, self.his) - self.manager.register("status", "查看当前状态", 10, self.status) - self.manager.register("switch", "切换账号", 10, self.switch) - self.manager.register("unset", "清除个性化人格设置", 10, self.unset) - self.manager.register("set", "设置个性化人格", 10, self.set) - self.manager.register("draw", "调用 DallE 模型画图", 10, self.draw) - self.manager.register("model", "切换模型", 10, self.model) - self.manager.register("画", "调用 DallE 模型画图", 10, self.draw) - - def set_provider(self, provider): - self.provider = provider - - async def reset(self, message: AstrMessageEvent, context: Context): - tokens = self.manager.command_parser.parse(message.message_str) - if tokens.len == 1: - await self.provider.forget(message.session_id, keep_system_prompt=True) - return CommandResult().message("重置成功") - elif tokens.get(1) == 'p': - await self.provider.forget(message.session_id) - - async def model(self, message: AstrMessageEvent, context: Context): - tokens = self.manager.command_parser.parse(message.message_str) - if tokens.len == 1: - ret = await self._print_models() - return CommandResult().message(ret) - model = tokens.get(1) - if model.isdigit(): - try: - models = await self.provider.get_models() - except BaseException as e: - logger.error(f"获取模型列表失败: {str(e)}") - return CommandResult().message("获取模型列表失败,无法使用编号切换模型。可以尝试直接输入模型名来切换,如 gpt-4o。") - models = list(models) - if int(model) <= len(models) and int(model) >= 1: - model = models[int(model)-1] - self.provider.set_model(model.id) - return CommandResult().message(f"模型已设置为 {model.id}") - else: - self.provider.set_model(model) - return CommandResult().message(f"模型已设置为 {model} (自定义)") - - async def _print_models(self): - try: - models = await self.provider.get_models() - except BaseException as e: - return "获取模型列表失败: " + str(e) - i = 1 - ret = "OpenAI GPT 类可用模型" - for model in models: - ret += f"\n{i}. {model.id}" - i += 1 - ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" - logger.debug(ret) - return ret - - def his(self, message: AstrMessageEvent, context: Context): - tokens = self.manager.command_parser.parse(message.message_str) - size_per_page = 3 - page = 1 - if tokens.len == 2: - try: - page = int(tokens.get(1)) - except BaseException as e: - return CommandResult().message("页码格式错误") - contexts, total_num = self.provider.dump_contexts_page(message.session_id, size_per_page, page=page) - t_pages = total_num // size_per_page + 1 - return CommandResult().message(f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页") - - def status(self, message: AstrMessageEvent, context: Context): - keys_data = self.provider.get_keys_data() - ret = "OpenAI Key" - for k in keys_data: - status = "🟢" if keys_data[k] else "🔴" - ret += "\n|- " + k[:8] + " " + status - - conf = self.provider.get_configs() - ret += "\n当前模型: " + conf['model'] - - if message.session_id in self.provider.session_memory and len(self.provider.session_memory[message.session_id]): - ret += "\n你的会话上下文: " + str(self.provider.session_memory[message.session_id][-1]['usage_tokens']) + " tokens" - - return CommandResult().message(ret) - - async def switch(self, message: AstrMessageEvent, context: Context): - ''' - 切换账号 - ''' - tokens = self.manager.command_parser.parse(message.message_str) - if tokens.len == 1: - _, ret, _ = self.status() - curr_ = self.provider.get_curr_key() - if curr_ is None: - ret += "当前您未选择账号。输入/switch <账号序号>切换账号。" - else: - ret += f"当前您选择的账号为:{curr_[-8:]}。输入/switch <账号序号>切换账号。" - return CommandResult().message(ret) - elif tokens.len == 2: - try: - key_stat = self.provider.get_keys_data() - index = int(tokens.get(1)) - if index > len(key_stat) or index < 1: - return CommandResult().message("账号序号错误。") - else: - try: - new_key = list(key_stat.keys())[index-1] - self.provider.set_key(new_key) - except BaseException as e: - return CommandResult().message("切换账号未知错误: "+str(e)) - return CommandResult().message("切换账号成功。") - except BaseException as e: - return CommandResult().message("切换账号错误。") - else: - return CommandResult().message("参数过多。") - - def unset(self, message: AstrMessageEvent, context: Context): - self.provider.curr_personality = {} - self.provider.forget(message.session_id) - return CommandResult().message("已清除个性化设置。") - - - def set(self, message: AstrMessageEvent, context: Context): - l = message.message_str.split(" ") - if len(l) == 1: - return CommandResult().message("- 设置人格: \nset 人格名。例如 set 编剧\n- 人格列表: set list\n- 人格详细信息: set view 人格名\n- 自定义人格: set 人格文本\n- 重置会话(清除人格): reset\n- 重置会话(保留人格): reset p\n\n【当前人格】: " + str(self.provider.curr_personality['prompt'])) - elif l[1] == "list": - msg = "人格列表:\n" - for key in personalities.keys(): - msg += f"- {key}\n" - msg += '\n\n*输入 set view 人格名 查看人格详细信息' - return CommandResult().message(msg) - elif l[1] == "view": - if len(l) == 2: - return CommandResult().message("请输入人格名") - ps = l[2].strip() - if ps in personalities: - msg = f"人格{ps}的详细信息:\n" - msg += f"{personalities[ps]}\n" - else: - msg = f"人格{ps}不存在" - return CommandResult().message(msg) - else: - ps = "".join(l[1:]).strip() - if ps in personalities: - self.provider.curr_personality = { - 'name': ps, - 'prompt': personalities[ps] - } - self.provider.personality_set(self.provider.curr_personality, message.session_id) - return CommandResult().message(f"人格已设置。 \n人格信息: {ps}") - else: - self.provider.curr_personality = { - 'name': '自定义人格', - 'prompt': ps - } - self.provider.personality_set(self.provider.curr_personality, message.session_id) - return CommandResult().message(f"人格已设置。 \n人格信息: {ps}") - - async def draw(self, message: AstrMessageEvent, context: Context): - message = message.message_str.removeprefix("画") - img_url = await self.provider.image_generate(message) - return CommandResult( - message_chain=[Image.fromURL(img_url)], - ) \ No newline at end of file diff --git a/model/platform/__init__.py b/model/platform/__init__.py deleted file mode 100644 index 0d46c2350..000000000 --- a/model/platform/__init__.py +++ /dev/null @@ -1,94 +0,0 @@ -import abc -from typing import Union, Any, List -from nakuru.entities.components import Plain, At, Image, BaseMessageComponent -from type.astrbot_message import AstrBotMessage -from type.command import CommandResult -from type.astrbot_message import MessageType - -class T2IException(Exception): - def __init__(self, message: str = "文本转图片时发生错误") -> None: - super().__init__(message) - -class Platform(): - def __init__(self, platform_name: str, context) -> None: - self.PLATFORM_NAME = platform_name - self.context = context - - @abc.abstractmethod - async def handle_msg(self, message: AstrBotMessage): - ''' - 处理到来的消息 - ''' - pass - - @abc.abstractmethod - async def reply_msg(self, message: AstrBotMessage, - result_message: List[BaseMessageComponent]): - ''' - 回复用户唤醒机器人的消息。(被动回复) - ''' - pass - - @abc.abstractmethod - async def send_msg(self, target: Any, result_message: CommandResult): - ''' - 发送消息(主动) - ''' - pass - - @abc.abstractmethod - async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): - ''' - 发送消息(主动) - ''' - pass - - def parse_message_outline(self, message: Union[AstrBotMessage, list]) -> str: - ''' - 将消息解析成大纲消息形式,如: xxxxx[图片]xxxxx。用于输出日志等。 - ''' - ret = '' - if isinstance(message, list): - parsed = message - elif isinstance(message, AstrBotMessage): - parsed = message.message - elif isinstance(message, str): - return message - - try: - for node in parsed: - if isinstance(node, Plain): - ret += node.text.replace('\n', ' ') - elif isinstance(node, At): - ret += f'[At: {node.name}/{node.qq}]' - elif isinstance(node, Image): - ret += '[图片]' - except Exception as e: - pass - return ret[:100] if len(ret) > 100 else ret - - def check_nick(self, message_str: str) -> bool: - w = self.context.config_helper.wake_prefix - if not w: return False - for nick in w: - if nick and message_str.strip().startswith(nick): - return True - return False - - async def convert_to_t2i_chain(self, message_result: list) -> Union[List[Image], None]: - plain_str = "" - rendered_images = [] - for i in message_result: - if isinstance(i, Plain): - plain_str += i.text - if plain_str and len(plain_str) > 50: - p = await self.context.image_renderer.render(plain_str, return_url=True) - if p.startswith('http'): - rendered_images.append(Image.fromURL(p)) - else: - rendered_images.append(Image.fromFileSystem(p)) - message_result = rendered_images - return message_result - - async def record_metrics(self): - self.context.metrics_uploader.increment_platform_stat(self.PLATFORM_NAME) \ No newline at end of file diff --git a/model/platform/manager.py b/model/platform/manager.py deleted file mode 100644 index 8d09b409e..000000000 --- a/model/platform/manager.py +++ /dev/null @@ -1,96 +0,0 @@ -import asyncio - -from util.io import port_checker -from type.register import RegisteredPlatform -from type.types import Context -from util.log import LogManager -from logging import Logger -from astrbot.message.handler import MessageHandler -from util.cmd_config import ( - PlatformConfig, - AiocqhttpPlatformConfig, - NakuruPlatformConfig, - QQOfficialPlatformConfig -) - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - - -class PlatformManager(): - def __init__(self, context: Context, message_handler: MessageHandler) -> None: - self.context = context - self.msg_handler = message_handler - - def load_platforms(self): - tasks = [] - - platforms = self.context.config_helper.platform - for platform in platforms: - if not platform.enable: - continue - if platform.name == "qq_official": - assert isinstance(platform, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。" - logger.info(f"加载 QQ官方 机器人消息平台 (appid: {platform.appid})") - tasks.append(asyncio.create_task(self.qqofficial_bot(platform), name="qqofficial-adapter")) - elif platform.name == "nakuru": - assert isinstance(platform, NakuruPlatformConfig), "nakuru: 无法识别的配置类型。" - logger.info(f"加载 QQ(nakuru) 机器人消息平台 ({platform.host}, {platform.websocket_port}, {platform.port})") - tasks.append(asyncio.create_task(self.nakuru_bot(platform), name="nakuru-adapter")) - elif platform.name == "aiocqhttp": - assert isinstance(platform, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。" - logger.info("加载 QQ(aiocqhttp) 机器人消息平台") - tasks.append(asyncio.create_task(self.aiocq_bot(platform), name="aiocqhttp-adapter")) - - return tasks - - async def nakuru_bot(self, config: NakuruPlatformConfig): - ''' - 运行 QQ(nakuru 适配器) - ''' - from model.platform.qq_nakuru import QQNakuru - noticed = False - host = config.host - port = config.websocket_port - http_port = config.port - logger.info( - f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}") - while True: - if not port_checker(port=port, host=host) or not port_checker(port=http_port, host=host): - if not noticed: - noticed = True - logger.warning( - f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。") - await asyncio.sleep(5) - else: - logger.info("nakuru 适配器已连接。") - break - try: - qq_gocq = QQNakuru(self.context, self.msg_handler, config) - self.context.platforms.append(RegisteredPlatform( - platform_name="nakuru", platform_instance=qq_gocq, origin="internal")) - await qq_gocq.run() - except BaseException as e: - logger.error("启动 nakuru 适配器时出现错误: " + str(e)) - - def aiocq_bot(self, config): - ''' - 运行 QQ(aiocqhttp 适配器) - ''' - from model.platform.qq_aiocqhttp import AIOCQHTTP - qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler, config) - self.context.platforms.append(RegisteredPlatform( - platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal")) - return qq_aiocqhttp.run_aiocqhttp() - - def qqofficial_bot(self, config): - ''' - 运行 QQ 官方机器人适配器 - ''' - try: - from model.platform.qq_official import QQOfficial - qqchannel_bot = QQOfficial(self.context, self.msg_handler, config) - self.context.platforms.append(RegisteredPlatform( - platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal")) - return qqchannel_bot.run() - except BaseException as e: - logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e)) diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py deleted file mode 100644 index 64120d5dd..000000000 --- a/model/platform/qq_aiocqhttp.py +++ /dev/null @@ -1,267 +0,0 @@ -import time -import asyncio -import traceback -import logging -from aiocqhttp import CQHttp, Event -from aiocqhttp.exceptions import ActionFailed -from . import Platform, T2IException -from type.astrbot_message import * -from type.message_event import * -from type.command import * -from typing import Union, List, Dict -from nakuru.entities.components import * -from util.log import LogManager -from logging import Logger -from astrbot.message.handler import MessageHandler -from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -class AIOCQHTTP(Platform): - def __init__(self, context: Context, - message_handler: MessageHandler, - platform_config: PlatformConfig) -> None: - super().__init__("aiocqhttp", context) - assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。" - - self.message_handler = message_handler - self.context = context - self.config = platform_config - self.unique_session = context.config_helper.platform_settings.unique_session - self.host = platform_config.ws_reverse_host - self.port = platform_config.ws_reverse_port - - def convert_message(self, event: Event) -> AstrBotMessage: - - abm = AstrBotMessage() - abm.self_id = str(event.self_id) - abm.tag = "aiocqhttp" - - abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname']) - - if event['message_type'] == 'group': - abm.type = MessageType.GROUP_MESSAGE - elif event['message_type'] == 'private': - abm.type = MessageType.FRIEND_MESSAGE - - if self.unique_session: - abm.session_id = abm.sender.user_id - else: - abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id - - abm.message_id = str(event.message_id) - abm.message = [] - - message_str = "" - if not isinstance(event.message, list): - err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。" - logger.critical(err) - try: - self.bot.send(event, err) - except BaseException as e: - logger.error(f"回复消息失败: {e}") - return - for m in event.message: - t = m['type'] - a = None - if t == 'at': - a = At(**m['data']) - abm.message.append(a) - if t == 'text': - a = Plain(text=m['data']['text']) - message_str += m['data']['text'].strip() - abm.message.append(a) - if t == 'image': - file = m['data']['file'] if 'file' in m['data'] else None - url = m['data']['url'] if 'url' in m['data'] else None - a = Image(file=file, url=url) - abm.message.append(a) - abm.timestamp = int(time.time()) - abm.message_str = message_str - abm.raw_message = event - return abm - - def run_aiocqhttp(self): - if not self.host or not self.port: - return - self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) - @self.bot.on_message('group') - async def group(event: Event): - abm = self.convert_message(event) - if abm: - await self.handle_msg(abm) - - @self.bot.on_message('private') - async def private(event: Event): - abm = self.convert_message(event) - if abm: - await self.handle_msg(abm) - - bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) - - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - logging.getLogger('aiocqhttp').setLevel(logging.ERROR) - - return bot - - async def shutdown_trigger_placeholder(self): - while self.context.running: - await asyncio.sleep(1) - - async def pre_check(self, message: AstrBotMessage) -> bool: - # if message chain contains Plain components or - # At components which points to self_id, return True - if message.type == MessageType.FRIEND_MESSAGE: - return True, "friend" - for comp in message.message: - if isinstance(comp, At) and str(comp.qq) == message.self_id: - return True, "at" - # check commands which ignore prefix - if await self.context.command_manager.check_command_ignore_prefix(message.message_str): - return True, "command" - # check nicks - if self.check_nick(message.message_str): - return True, "nick" - return False, "none" - - async def handle_msg(self, message: AstrBotMessage): - logger.info( - f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") - - ok, reason = await self.pre_check(message) - if not ok: - return - - # parse unified message origin - unified_msg_origin = None - assert isinstance(message.raw_message, Event) - if message.type == MessageType.GROUP_MESSAGE: - unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}" - elif message.type == MessageType.FRIEND_MESSAGE: - unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}" - - logger.debug(f"unified_msg_origin: {unified_msg_origin}") - - # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, - self.context, - "aiocqhttp", - message.session_id, - unified_msg_origin, - reason == "command") # only_command - - # transfer control to message handler - message_result = await self.message_handler.handle(ame) - if not message_result: return - - await self.reply_msg(message, message_result.result_message, message_result.use_t2i) - if message_result.callback: - message_result.callback() - - return message_result - - async def reply_msg(self, - message: AstrBotMessage, - result_message: list, - use_t2i: bool = None): - """ - 回复用户唤醒机器人的消息。(被动回复) - """ - try: - await self._reply(message, result_message, use_t2i) - except T2IException as e: - logger.error(traceback.format_exc()) - logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") - await self._reply(message, result_message, False) - return result_message - - async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None): - await self.record_metrics() - if isinstance(message_chain, str): - message_chain = [Plain(text=message_chain), ] - - # 文转图处理 - if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list): - try: - message_chain = await self.convert_to_t2i_chain(message_chain) - except BaseException as e: - raise T2IException() - - # log - if isinstance(message, AstrBotMessage): - logger.info( - f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}") - else: - logger.info(f"回复消息: {message_chain}") - - # 解析成 OneBot json 格式并发送 - ret = [] - image_idx = [] - for idx, segment in enumerate(message_chain): - d = segment.toDict() - if isinstance(segment, Plain): - d['type'] = 'text' - if isinstance(segment, Image): - image_idx.append(idx) - ret.append(d) - if os.environ.get('TEST_MODE', 'off') == 'on': - logger.info(f"回复消息: {ret}") - return - try: - await self._reply_wrapper(message, ret) - except ActionFailed as e: - if e.retcode == 1200: - # ENOENT - if not image_idx: - raise e - logger.warning("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。") - for idx in image_idx: - if ret[idx]['data']['file'].startswith('file://'): - logger.info(f"正在上传图片: {ret[idx]['data']['path']}") - # 除了上传到图床,想不到更好的办法。 - image_url = await self.context.image_uploader.upload_image(ret[idx]['data']['path']) - logger.info(f"上传成功。") - ret[idx]['data']['file'] = image_url - ret[idx]['data']['path'] = image_url - await self._reply_wrapper(message, ret) - else: - logger.error(traceback.format_exc()) - logger.error(f"回复消息失败: {e}") - raise e - - async def _reply_wrapper(self, message: Union[AstrBotMessage, Dict], ret: List): - if isinstance(message, AstrBotMessage): - await self.bot.send(message.raw_message, ret) - if isinstance(message, dict): - if 'group_id' in message: - await self.bot.send_group_msg(group_id=message['group_id'], message=ret) - elif 'user_id' in message: - await self.bot.send_private_msg(user_id=message['user_id'], message=ret) - else: - raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。") - - async def send_msg(self, target: Dict[str, int], result_message: CommandResult): - ''' - 以主动的方式给QQ用户、QQ群发送一条消息。 - - `target` 接收一个 dict 类型的值引用。 - - - 要发给 QQ 下的某个用户,请添加 key `user_id`,值为 int 类型的 qq 号; - - 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号; - - ''' - try: - await self._reply(target, result_message.message_chain, result_message.is_use_t2i) - except T2IException as e: - logger.error(traceback.format_exc()) - logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") - await self._reply(target, result_message.message_chain, False) - - async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): - if message_type == MessageType.GROUP_MESSAGE: - await self.send_msg({'group_id': int(target)}, result_message) - elif message_type == MessageType.FRIEND_MESSAGE: - await self.send_msg({'user_id': int(target)}, result_message) - else: - raise Exception("aiocqhttp: 无法识别的消息类型。") \ No newline at end of file diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py deleted file mode 100644 index e498db793..000000000 --- a/model/platform/qq_nakuru.py +++ /dev/null @@ -1,286 +0,0 @@ -import time, asyncio, traceback - -from nakuru.entities.components import Plain, At, Image, Node, BaseMessageComponent -from nakuru import ( - CQHTTP, - GuildMessage, - GroupMessage, - FriendMessage, - GroupMemberIncrease, - MessageItemType -) -from typing import Union, List, Dict -from type.types import Context -from . import Platform, T2IException -from type.astrbot_message import * -from type.message_event import * -from type.command import * -from util.log import LogManager -from logging import Logger -from astrbot.message.handler import MessageHandler -from util.cmd_config import PlatformConfig, NakuruPlatformConfig - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - - -class FakeSource: - def __init__(self, type, group_id): - self.type = type - self.group_id = group_id - - -class QQNakuru(Platform): - def __init__(self, context: Context, - message_handler: MessageHandler, - platform_config: PlatformConfig) -> None: - super().__init__("nakuru", context) - assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。" - - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - self.message_handler = message_handler - self.context = context - self.unique_session = context.config_helper.platform_settings.unique_session - self.config = platform_config - - self.client = CQHTTP( - host=self.config.host, - port=self.config.websocket_port, - http_port=self.config.port - ) - gocq_app = self.client - - @gocq_app.receiver("GroupMessage") - async def _(app: CQHTTP, source: GroupMessage): - if self.config.enable_group: - abm = self.convert_message(source) - await self.handle_msg(abm) - - @gocq_app.receiver("FriendMessage") - async def _(app: CQHTTP, source: FriendMessage): - if self.config.enable_direct_message: - abm = self.convert_message(source) - await self.handle_msg(abm) - - @gocq_app.receiver("GuildMessage") - async def _(app: CQHTTP, source: GuildMessage): - if self.config.enable_guild: - abm = self.convert_message(source) - await self.handle_msg(abm) - - def pre_check(self, message: AstrBotMessage) -> bool: - # if message chain contains Plain components or At components which points to self_id, return True - if message.type == MessageType.FRIEND_MESSAGE: - return True, "friend" - for comp in message.message: - if isinstance(comp, At) and str(comp.qq) == message.self_id: - return True, "at" - # check commands which ignore prefix - if self.context.command_manager.check_command_ignore_prefix(message.message_str): - return True, "command" - # check nicks - if self.check_nick(message.message_str): - return True, "nick" - return False, "none" - - def run(self): - coro = self.client._run() - return coro - - async def handle_msg(self, message: AstrBotMessage): - logger.info( - f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") - - assert isinstance(message.raw_message, - (GroupMessage, FriendMessage, GuildMessage)) - - # 判断是否响应消息 - ok, reason = self.pre_check(message) - if not ok: - return - - # 解析 session_id - if self.unique_session or message.type == MessageType.FRIEND_MESSAGE: - session_id = message.raw_message.user_id - elif message.type == MessageType.GROUP_MESSAGE: - session_id = message.raw_message.group_id - elif message.type == MessageType.GUILD_MESSAGE: - session_id = message.raw_message.channel_id - else: - session_id = message.raw_message.user_id - - message.session_id = session_id - - # parse unified message origin - unified_msg_origin = None - if message.type == MessageType.GROUP_MESSAGE: - assert isinstance(message.raw_message, GroupMessage) - unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}" - elif message.type == MessageType.FRIEND_MESSAGE: - assert isinstance(message.raw_message, FriendMessage) - unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}" - elif message.type == MessageType.GUILD_MESSAGE: - assert isinstance(message.raw_message, GuildMessage) - unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}" - - logger.debug(f"unified_msg_origin: {unified_msg_origin}") - - - # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, - self.context, - "nakuru", - session_id, - unified_msg_origin, - reason == 'command') # only_command - - # transfer control to message handler - message_result = await self.message_handler.handle(ame) - if not message_result: return - - await self.reply_msg(message, message_result.result_message, message_result.use_t2i) - if message_result.callback: - message_result.callback() - - async def reply_msg(self, - message: AstrBotMessage, - result_message: List[BaseMessageComponent], - use_t2i: bool = None): - """ - 回复用户唤醒机器人的消息。(被动回复) - """ - assert isinstance(message.raw_message, (GroupMessage, FriendMessage, GuildMessage)) - - try: - await self._reply(message, result_message, use_t2i) - except T2IException as e: - logger.error(traceback.format_exc()) - logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") - await self._reply(message, result_message, False) - return result_message - - async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent], use_t2i: bool = None): - await self.record_metrics() - if isinstance(message_chain, str): - message_chain = [Plain(text=message_chain), ] - - # 文转图处理 - if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(message_chain, list): - try: - message_chain = await self.convert_to_t2i_chain(message_chain) - except BaseException as e: - raise T2IException() - - # log - if isinstance(message, AstrBotMessage): - logger.info( - f"{message.sender.nickname}/{message.sender.user_id} <- {self.parse_message_outline(message_chain)}") - else: - logger.info(f"回复消息: {message_chain}") - - source = message.raw_message - is_dict = isinstance(source, dict) - - # 发消息 - typ = None - if is_dict: - if "group_id" in source: - typ = "GroupMessage" - elif "user_id" in source: - typ = "FriendMessage" - elif "guild_id" in source: - typ = "GuildMessage" - else: - typ = source.type - - if typ == "GuildMessage": - guild_id = source['guild_id'] if is_dict else source.guild_id - chan_id = source['channel_id'] if is_dict else source.channel_id - await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain) - elif typ == "FriendMessage": - user_id = source['user_id'] if is_dict else source.user_id - await self.client.sendFriendMessage(user_id, message_chain) - elif typ == "GroupMessage": - group_id = source['group_id'] if is_dict else source.group_id - # 过长时forward发送 - plain_text_len = 0 - image_num = 0 - for i in message_chain: - if isinstance(i, Plain): - plain_text_len += len(i.text) - elif isinstance(i, Image): - image_num += 1 - if plain_text_len > self.context.config_helper.platform_settings.forward_threshold or image_num > 1: - # 删除At - for i in message_chain: - if isinstance(i, At): - message_chain.remove(i) - node = Node(message_chain) - node.uin = 123456 - node.name = f"bot" - node.time = int(time.time()) - nodes = [node] - await self.client.sendGroupForwardMessage(group_id, nodes) - return - await self.client.sendGroupMessage(group_id, message_chain) - - async def send_msg(self, target: Dict[str, int], result_message: CommandResult): - ''' - 以主动的方式给用户、群或者频道发送一条消息。 - - `target` 接收一个 dict 类型的值引用。 - - - 要发给 QQ 下的某个用户,请添加 key `user_id`,值为 int 类型的 qq 号; - - 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号; - - 要发给某个频道,请添加 key `guild_id`, `channel_id`。均为 int 类型。 - - guild_id 不是频道号。 - ''' - try: - await self._reply(target, result_message.message_chain, result_message.is_use_t2i) - except T2IException as e: - logger.error(traceback.format_exc()) - logger.warning(f"文本转图片时发生错误,将使用纯文本发送。") - await self._reply(target, result_message.message_chain, False) - return result_message - - async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): - ''' - 以主动的方式给用户、群或者频道发送一条消息。 - - `message_type` 为 MessageType 枚举类型。 - - - 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE; - - 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE; - - 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。 - ''' - if message_type == MessageType.FRIEND_MESSAGE: - await self.send_msg({"user_id": int(target)}, result_message) - elif message_type == MessageType.GROUP_MESSAGE: - await self.send_msg({"group_id": int(target)}, result_message) - elif message_type == MessageType.GUILD_MESSAGE: - await self.send_msg({"channel_id": int(target)}, result_message) - - def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: - abm = AstrBotMessage() - abm.type = MessageType(message.type) - abm.raw_message = message - abm.message_id = message.message_id - - plain_content = "" - for i in message.message: - if isinstance(i, Plain): - plain_content += i.text - abm.message_str = plain_content.strip() - if message.type == MessageItemType.GuildMessage: - abm.self_id = str(message.self_tiny_id) - else: - abm.self_id = str(message.self_id) - abm.sender = MessageMember( - str(message.sender.user_id), - str(message.sender.nickname) - ) - abm.tag = "nakuru" - abm.message = message.message - return abm \ No newline at end of file diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py deleted file mode 100644 index 2c534c590..000000000 --- a/model/platform/qq_official.py +++ /dev/null @@ -1,380 +0,0 @@ -import botpy -import re -import time -import traceback -import asyncio -import botpy.message -import botpy.types -import botpy.types.message - -from botpy.types.message import Reference, Media -from botpy import Client -from util.io import save_temp_img, download_image_by_url -from . import Platform -from type.astrbot_message import * -from type.message_event import * -from type.command import * -from typing import Union, List, Dict -from nakuru.entities.components import * -from util.log import LogManager -from logging import Logger -from astrbot.message.handler import MessageHandler -from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - -# QQ 机器人官方框架 -class botClient(Client): - def set_platform(self, platform: 'QQOfficial'): - self.platform = platform - - # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): - abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) - await self.platform.handle_msg(abm) - - # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): - # 转换层 - abm = self.platform._parse_from_qqofficial(message, MessageType.GUILD_MESSAGE) - await self.platform.handle_msg(abm) - - # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): - # 转换层 - abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) - await self.platform.handle_msg(abm) - - # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): - abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) - await self.platform.handle_msg(abm) - - -class QQOfficial(Platform): - - def __init__(self, context: Context, - message_handler: MessageHandler, - platform_config: PlatformConfig, - test_mode = False) -> None: - super().__init__("qqofficial", context) - assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。" - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - self.message_handler = message_handler - self.context = context - self.config = platform_config - - self.appid = platform_config.appid - self.secret = platform_config.secret - self.unique_session = context.config_helper.platform_settings.unique_session - qq_group = platform_config.enable_group_c2c - guild_dm = platform_config.enable_guild_direct_message - - if qq_group: - self.intents = botpy.Intents( - public_messages=True, - public_guild_messages=True, - direct_message=guild_dm - ) - else: - self.intents = botpy.Intents( - public_guild_messages=True, - direct_message=guild_dm - ) - self.client = botClient( - intents=self.intents, - bot_log=False, - timeout=20, - ) - - self.client.set_platform(self) - - self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' - - async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False): - plain_text = "" - image_path = None # only one img supported - for i in message: - if isinstance(i, Plain): - plain_text += i.text - elif isinstance(i, Image) and not image_path: - if i.path: - image_path = i.path - elif i.file and i.file.startswith("base64://"): - img_data = base64.b64decode(i.file[9:]) - image_path = save_temp_img(img_data) - elif i.file and i.file.startswith("http"): - # 如果是群消息,不需要下载 - image_path = await download_image_by_url(i.file) if not is_group else i.file - return plain_text, image_path - - def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage], - message_type: MessageType): - abm = AstrBotMessage() - abm.type = message_type - abm.timestamp = int(time.time()) - abm.raw_message = message - abm.message_id = message.id - abm.tag = "qqofficial" - msg: List[BaseMessageComponent] = [] - - if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage): - if isinstance(message, botpy.message.GroupMessage): - abm.sender = MessageMember( - message.author.member_openid, - "" - ) - else: - abm.sender = MessageMember( - message.author.user_openid, - "" - ) - abm.message_str = message.content.strip() - abm.self_id = "unknown_selfid" - - msg.append(Plain(abm.message_str)) - if message.attachments: - for i in message.attachments: - if i.content_type.startswith("image"): - url = i.url - if not url.startswith("http"): - url = "https://"+url - img = Image.fromURL(url) - msg.append(img) - abm.message = msg - - elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage): - try: - abm.self_id = str(message.mentions[0].id) - except: - abm.self_id = "" - - plain_content = message.content.replace( - "<@!"+str(abm.self_id)+">", "").strip() - msg.append(Plain(plain_content)) - if message.attachments: - for i in message.attachments: - if i.content_type.startswith("image"): - url = i.url - if not url.startswith("http"): - url = "https://"+url - img = Image.fromURL(url) - msg.append(img) - abm.message = msg - abm.message_str = plain_content - abm.sender = MessageMember( - str(message.author.id), - str(message.author.username) - ) - else: - raise ValueError(f"Unknown message type: {message_type}") - return abm - - def run(self): - return self.client.start( - appid=self.appid, - secret=self.secret - ) - - async def handle_msg(self, message: AstrBotMessage): - assert isinstance(message.raw_message, (botpy.message.Message, - botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) - is_group = message.type != MessageType.FRIEND_MESSAGE - - _t = "/私聊" if not is_group else "" - logger.info( - f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}") - - # 解析出 session_id - if self.unique_session or not is_group: - session_id = message.sender.user_id - else: - if message.type == MessageType.GUILD_MESSAGE: - session_id = message.raw_message.channel_id - elif message.type == MessageType.GROUP_MESSAGE: - session_id = str(message.raw_message.group_openid) - else: - session_id = str(message.raw_message.author.id) - message.session_id = session_id - - # construct astrbot message event - ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id) - - message_result = await self.message_handler.handle(ame) - if not message_result: - return - - ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i) - if message_result.callback: - message_result.callback() - - return ret - - async def reply_msg(self, - message: AstrBotMessage, - result_message: List[BaseMessageComponent], - use_t2i: bool = None): - ''' - 回复频道消息 - ''' - source = message.raw_message - assert isinstance(source, (botpy.message.Message, - botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) - logger.info( - f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}") - - plain_text = '' - image_path = '' - msg_ref = None - rendered_images = None - - if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): - try: - rendered_images = await self.convert_to_t2i_chain(result_message) - except BaseException as e: - logger.warning(traceback.format_exc()) - logger.warning(f"文本转图片时发生错误: {e},将尝试默认方式。") - rendered_images = None - - if isinstance(result_message, list): - plain_text, image_path = await self._parse_to_qqofficial(result_message, message.type == MessageType.GROUP_MESSAGE) - else: - plain_text = result_message - - if source and not image_path: # file_image与message_reference不能同时传入 - msg_ref = Reference(message_id=source.id, - ignore_get_message_error=False) - - # 到这里,我们得到了 plain_text,image_path,msg_ref - data = { - 'content': plain_text, - 'msg_id': message.message_id, - 'message_reference': msg_ref - } - - if isinstance(message.raw_message, botpy.message.GroupMessage): - data['group_openid'] = str(source.group_openid) - elif isinstance(message.raw_message, botpy.message.Message): - data['channel_id'] = source.channel_id - elif isinstance(message.raw_message, botpy.message.DirectMessage): - data['guild_id'] = source.guild_id - elif isinstance(message.raw_message, botpy.message.C2CMessage): - data['openid'] = source.author.user_openid - if image_path: - data['file_image'] = image_path - if rendered_images: - # 文转图 - _data = data.copy() - _data['content'] = '' - _data['file_image'] = rendered_images[0].file - _data['message_reference'] = None - - try: - return await self._reply(**_data) - except BaseException as e: - logger.warn(traceback.format_exc()) - logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") - - try: - return await self._reply(**data) - except BaseException as e: - logger.error(traceback.format_exc()) - # 分割过长的消息 - if "msg over length" in str(e): - split_res = [] - split_res.append(plain_text[:len(plain_text)//2]) - split_res.append(plain_text[len(plain_text)//2:]) - for i in split_res: - data['content'] = i - return await self._reply(**data) - else: - try: - # 防止被qq频道过滤消息 - plain_text = plain_text.replace(".", " . ") - return await self._reply(**data) - except BaseException as e: - try: - data['content'] = str.join(" ", plain_text) - return await self._reply(**data) - except BaseException as e: - plain_text = re.sub( - r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) - plain_text = plain_text.replace(".", "·") - data['content'] = plain_text - return await self._reply(**data) - - async def _reply(self, **kwargs): - await self.record_metrics() - if 'group_openid' in kwargs or 'openid' in kwargs: - # QQ群组消息 - if 'file_image' in kwargs and kwargs['file_image']: - file_image_path = kwargs['file_image'].replace("file:///", "") - if file_image_path: - - if file_image_path.startswith("http"): - image_url = file_image_path - else: - logger.debug(f"上传图片: {file_image_path}") - image_url = await self.context.image_uploader.upload_image(file_image_path) - logger.debug(f"上传成功: {image_url}") - if 'group_openid' in kwargs: - media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) - elif 'openid' in kwargs: - media = await self.client.api.post_c2c_file(kwargs['openid'], 1, image_url) - del kwargs['file_image'] - kwargs['media'] = media - logger.debug(f"发送群图片: {media}") - kwargs['msg_type'] = 7 # 富媒体 - if self.test_mode: - return kwargs - if 'group_openid' in kwargs: - await self.client.api.post_group_message(**kwargs) - elif 'openid' in kwargs: - await self.client.api.post_c2c_message(**kwargs) - elif 'channel_id' in kwargs: - # 频道消息 - if 'file_image' in kwargs and kwargs['file_image']: - kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") - # 频道消息发图只支持本地 - if kwargs['file_image'].startswith("http"): - kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) - if self.test_mode: - return kwargs - await self.client.api.post_message(**kwargs) - elif 'guild_id' in kwargs: - # 频道私聊消息 - if 'file_image' in kwargs and kwargs['file_image']: - kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") - if kwargs['file_image'].startswith("http"): - kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) - if self.test_mode: - return kwargs - await self.client.api.post_dms(**kwargs) - else: - raise ValueError("Unknown target type.") - - async def send_msg(self, target: Dict[str, str], result_message: CommandResult): - ''' - 以主动的方式给频道用户、群、频道或者消息列表用户(QQ用户)发送一条消息。 - - `target` 接收一个 dict 类型的值引用。 - - - 如果目标是 QQ 群,请添加 key `group_openid`。 - - 如果目标是 频道消息,请添加 key `channel_id`。 - - 如果目标是 频道私聊,请添加 key `guild_id`。 - - 如果目标是 QQ 用户,请添加 key `openid`。 - ''' - plain_text, image_path = await self._parse_to_qqofficial(result_message.message_chain) - - payload = { - 'content': plain_text, - **target - } - if image_path: - payload['file_image'] = image_path - await self._reply(**payload) - - async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult): - raise NotImplementedError("qqofficial 不支持此方法。") diff --git a/model/plugin/command.py b/model/plugin/command.py deleted file mode 100644 index 5104a4e6e..000000000 --- a/model/plugin/command.py +++ /dev/null @@ -1,26 +0,0 @@ -from dataclasses import dataclass -from type.register import RegisteredPlugins -from typing import List, Union, Callable -from util.log import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - - -@dataclass -class CommandRegisterRequest(): - command_name: str - description: str - priority: int - handler: Callable - use_regex: bool = False - plugin_name: str = None - ignore_prefix: bool = False - -class PluginCommandBridge(): - def __init__(self, cached_plugins: RegisteredPlugins): - self.plugin_commands_waitlist: List[CommandRegisterRequest] = [] - self.cached_plugins = cached_plugins - - def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False): - self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix)) diff --git a/model/provider/provider.py b/model/provider/provider.py deleted file mode 100644 index e5665d98d..000000000 --- a/model/provider/provider.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections import defaultdict - -class Provider: - def __init__(self) -> None: - self.curr_model_name = "unknown" - - def reset_model_stat(self): - self.model_stat.clear() - - def set_curr_model(self, model_name: str): - self.curr_model_name = model_name - - def get_curr_model(self): - return self.curr_model_name - - async def text_chat(self, - prompt: str, - session_id: str, - image_url: None = None, - tools: None = None, - extra_conf: dict = None, - default_personality: dict = None, - **kwargs) -> str: - ''' - [require] - prompt: 提示词 - session_id: 会话id - - [optional] - image_url: 图片url(识图) - tools: 函数调用工具 - extra_conf: 额外配置 - default_personality: 默认人格 - ''' - raise NotImplementedError() - - async def image_generate(self, prompt, session_id, **kwargs) -> str: - ''' - [require] - prompt: 提示词 - session_id: 会话id - ''' - raise NotImplementedError() - - async def forget(self, session_id=None) -> bool: - ''' - 重置会话 - ''' - raise NotImplementedError() diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py new file mode 100644 index 000000000..d74650932 --- /dev/null +++ b/packages/astrbot/main.py @@ -0,0 +1,99 @@ +import aiohttp, base64, os, json, re, time +from typing import Dict +from astrbot.api import Context, AstrMessageEvent, MessageEventResult +from astrbot.api import logger, command_parser + +class Main: + def __init__(self, context: Context) -> None: + self.context = context + context.register_commands("astrbot", "help", "查看 AstrBot 帮助", 10, self.help) + context.register_commands("astrbot", "plugin", "AstrBot 插件管理", 10, self.plugin) + context.register_commands("astrbot", "t2i", "关闭/启动文本转图片", 10, self.t2i) + context.register_commands("astrbot", "myid", "查看自己在该平台上的 ID", 10, self.myid) + + context.register_listener("astrbot", "keywords_ban_rate_limit", self.keywords_ban, "关键词屏蔽和发言频率监听器") + # keywords + with open(os.path.join(os.path.dirname(__file__), "unfit_words"), "r", encoding="utf-8") as f: + self.keywords: list = json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'] + internal_keywords_cfg = context.get_config().content_safety.internal_keywords + if internal_keywords_cfg.enable: + self.keywords.extend(internal_keywords_cfg.extra_keywords) + + # rate limit + self.user_rate_limit: Dict[int, int] = {} + rl_cfg = context.get_config().platform_settings.rate_limit + self.rate_limit_time: int = rl_cfg.time + self.rate_limit_count: int = rl_cfg.count + self.user_frequency = {} + + async def keywords_ban(self, event: AstrMessageEvent): + if not event.is_wake_up(): + return + + # keywords 检测 + for i in self.keywords: + matches = re.match(i, event.get_message_str().strip(), re.I | re.M) + if matches: + event.set_result(MessageEventResult().message("你的消息中包含不适当的关键词,已被屏蔽。")) + return + + # rate limit 检测 + ts = int(time.time()) + if event.session_id in self.user_frequency: + if ts-self.user_frequency[event.session_id]['time'] > self.rate_limit_time: + # reset + self.user_frequency[event.session_id]['time'] = ts + self.user_frequency[event.session_id]['count'] = 1 + return + if self.user_frequency[event.session_id]['count'] >= self.rate_limit_count: + event.set_result(MessageEventResult().message("你发送消息的频率过快,请稍后再试。")) + return + self.user_frequency[event.session_id]['count'] += 1 + else: + t = {'time': ts, 'count': 1} + self.user_frequency[event.session_id] = t + + + async def help(self, event: AstrMessageEvent): + notice = "" + try: + async with aiohttp.ClientSession() as session: + async with session.get("https://astrbot.soulter.top/notice.json") as resp: + notice = (await resp.json())["notice"] + except BaseException as e: + pass + + msg = "# AstrBot 帮助\n## 已注册的指令\n" + for key, value in self.context.commands_handler.items(): + if value.plugin_metadata: + msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n" + else: msg += f"- `{key}`: {value.description}\n" + + msg += "\n> 提示:使用 /plugin 查看已加载的插件\n" + msg += notice + + event.set_result(MessageEventResult().message(msg)) + + async def plugin(self, event: AstrMessageEvent): + plugin_list_info = "已加载的插件:\n" + for plugin in self.context.registered_plugins: + plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {plugin.metadata.author}: {plugin.metadata.desc}\n" + if plugin_list_info.strip() == "": + plugin_list_info = "没有加载任何插件。" + + event.set_result(MessageEventResult().message(f"{plugin_list_info}")) + + async def t2i(self, event: AstrMessageEvent): + config = self.context.get_config() + if config.t2i: + config.t2i = False + config.save_config() + event.set_result(MessageEventResult().message("已关闭文本转图片模式。")) + return + config.t2i = True + config.save_config() + event.set_result(MessageEventResult().message("已开启文本转图片模式。")) + + async def myid(self, event: AstrMessageEvent): + user_id = str(event.get_sender_id()) + event.set_result(MessageEventResult().message(f"你的 ID 是 {user_id}。此 ID 可用于设置 AstrBot 管理员。")) \ No newline at end of file diff --git a/packages/astrbot/metadata.yaml b/packages/astrbot/metadata.yaml new file mode 100644 index 000000000..af88b41ad --- /dev/null +++ b/packages/astrbot/metadata.yaml @@ -0,0 +1,6 @@ +name: astrbot # 插件名称 +desc: AstrBot 内置指令集 +help: +version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1 +author: AstrBot # 作者 +repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot/unfit_words b/packages/astrbot/unfit_words new file mode 100644 index 000000000..030426e54 --- /dev/null +++ b/packages/astrbot/unfit_words @@ -0,0 +1 @@ +ewogICAgImtleXdvcmRzIjogWwogICAgICAgICLkuaDov5HlubMiLAogICAgICAgICLog6HplKbmtpsiLAogICAgICAgICLmsZ/ms73msJEiLAogICAgICAgICLmuKnlrrblrp0iLAogICAgICAgICLmnY7lhYvlvLoiLAogICAgICAgICLmnY7plb/mmKUiLAogICAgICAgICLmr5vms73kuJwiLAogICAgICAgICLpgpPlsI/lubMiLAogICAgICAgICLlkajmganmnaUiLAogICAgICAgICLnpL7kvJrkuLvkuYkiLAogICAgICAgICLlhbHkuqflhZoiLAogICAgICAgICLlhbHkuqfkuLvkuYkiLAogICAgICAgICLlpKfpmYblrpjmlrkiLAogICAgICAgICLljJfkuqzmlL/mnYMiLAogICAgICAgICLkuK3ljY7luJ3lm70iLAogICAgICAgICLkuK3lm73mlL/lupwiLAogICAgICAgICLlhbHni5ciLAogICAgICAgICLlha3lm5vkuovku7YiLAogICAgICAgICLlpKnlronpl6giLAogICAgICAgICLlha3lm5siLAogICAgICAgICLmlL/msrvlsYDluLjlp5QiLAogICAgICAgICLlrabmva4iLAogICAgICAgICLlhavkuZ0iLAogICAgICAgICLkuozljYHlpKciLAogICAgICAgICLmsJHov5vlhZoiLAogICAgICAgICLlj7Dni6wiLAogICAgICAgICLlj7Dmub7ni6znq4siLAogICAgICAgICLlj7Dmub7lm70iLAogICAgICAgICLlm73msJHlhZoiLAogICAgICAgICLlj7Dmub7msJHlm70iLAogICAgICAgICLkuK3ljY7msJHlm70iLAogICAgICAgICJwb3JuaHViIiwKICAgICAgICAiUG9ybmh1YiIsCiAgICAgICAgIuS9nOeIsSIsCiAgICAgICAgIuWBmueIsSIsCiAgICAgICAgIuaAp+S6pCIsCiAgICAgICAgIuiHquaFsCIsCiAgICAgICAgIumYtOiMjiIsCiAgICAgICAgIua3q+WmhyIsCiAgICAgICAgIuiCm+S6pCIsCiAgICAgICAgIuS6pOmFjSIsCiAgICAgICAgIuaAp+WFs+ezuyIsCiAgICAgICAgIuaAp+a0u+WKqCIsCiAgICAgICAgIuiJsuaDhSIsCiAgICAgICAgIuiJsuWbviIsCiAgICAgICAgIuijuOS9kyIsCiAgICAgICAgIuWwj+eptCIsCiAgICAgICAgIua3q+iNoSIsCiAgICAgICAgIuaAp+eIsSIsCiAgICAgICAgIua4r+eLrCIsCiAgICAgICAgIuazlei9ruWKnyIsCiAgICAgICAgIuWFreWbmyIKICAgIF0KfQ== \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py new file mode 100644 index 000000000..237619889 --- /dev/null +++ b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_message_event.py @@ -0,0 +1,37 @@ +import os, traceback + +from astrbot.api import AstrMessageEvent, MessageChain, logger +from astrbot.api import Plain, Image +from aiocqhttp import CQHttp +from astrbot.core.utils.io import file_to_base64, download_image_by_url + +class AiocqhttpMessageEvent(AstrMessageEvent): + def __init__(self, message_str, message_obj, platform_meta, session_id, bot: CQHttp): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.bot = bot + + @staticmethod + async def _parse_onebot_josn(message_chain: MessageChain): + '''解析成 OneBot json 格式''' + ret = [] + for segment in message_chain.chain: + d = segment.toDict() + if isinstance(segment, Plain): + d['type'] = 'text' + if isinstance(segment, Image): + # convert to base64 + if segment.file and segment.file.startswith("file:///"): + image_base64 = file_to_base64(segment.file[8:]) + image_file_path = segment.file[8:] + elif segment.file and segment.file.startswith("http"): + image_file_path = await download_image_by_url(segment.file) + image_base64 = file_to_base64(image_file_path) + d['data']['file'] = image_base64 + ret.append(d) + return ret + + async def send(self, message: MessageChain): + ret = await AiocqhttpMessageEvent._parse_onebot_josn(message) + if os.environ.get('TEST_MODE', 'off') == 'on': + return + await self.bot.send(self.message_obj.raw_message, ret) \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py new file mode 100644 index 000000000..81491c3d4 --- /dev/null +++ b/packages/astrbot_adapter_aiocqhttp/aiocqhttp_platform_adapter.py @@ -0,0 +1,135 @@ +import time +import asyncio +import traceback +import logging +from typing import Awaitable, Any +from aiocqhttp import CQHttp, Event +from astrbot.api import Platform +from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from .aiocqhttp_message_event import * +from nakuru.entities.components import * +from astrbot.api import logger +from .aiocqhttp_message_event import AiocqhttpMessageEvent +from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings +from astrbot.core.platform.astr_message_event import MessageSesion + +class AiocqhttpAdapter(Platform): + def __init__(self, platform_config: AiocqhttpPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: + super().__init__(event_queue) + + self.config = platform_config + self.settings = platform_settings + self.unique_session = platform_settings.unique_session + self.host = platform_config.ws_reverse_host + self.port = platform_config.ws_reverse_port + + self.metadata = PlatformMetadata( + "aiocqhttp", + "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", + ) + + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + ret = await AiocqhttpMessageEvent._parse_onebot_josn(message_chain) + match session.message_type.value: + case MessageType.GROUP_MESSAGE.value: + if "_" in session.session_id: + # 独立会话 + _, group_id = session.session_id.split("_") + await self.bot.send_group_msg(group_id=group_id, message=ret) + return + await self.bot.send_group_msg(group_id=session.session_id, message=ret) + case MessageType.FRIEND_MESSAGE.value: + await self.bot.send_private_msg(user_id=session.session_id, message=ret) + + def convert_message(self, event: Event) -> AstrBotMessage: + abm = AstrBotMessage() + abm.self_id = str(event.self_id) + abm.tag = "aiocqhttp" + + abm.sender = MessageMember(str(event.sender['user_id']), event.sender['nickname']) + + if event['message_type'] == 'group': + abm.type = MessageType.GROUP_MESSAGE + elif event['message_type'] == 'private': + abm.type = MessageType.FRIEND_MESSAGE + + if self.unique_session: + abm.session_id = abm.sender.user_id + "_" + str(event.group_id) # 也保留群组 id + else: + abm.session_id = str(event.group_id) if abm.type == MessageType.GROUP_MESSAGE else abm.sender.user_id + + abm.message_id = str(event.message_id) + abm.message = [] + + message_str = "" + if not isinstance(event.message, list): + err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。" + logger.critical(err) + try: + self.bot.send(event, err) + except BaseException as e: + logger.error(f"回复消息失败: {e}") + return + for m in event.message: + t = m['type'] + a = None + if t == 'at': + a = At(**m['data']) + abm.message.append(a) + if t == 'text': + a = Plain(text=m['data']['text']) + message_str += m['data']['text'].strip() + abm.message.append(a) + if t == 'image': + file = m['data']['file'] if 'file' in m['data'] else None + url = m['data']['url'] if 'url' in m['data'] else None + a = Image(file=file, url=url) + abm.message.append(a) + abm.timestamp = int(time.time()) + abm.message_str = message_str + abm.raw_message = event + return abm + + def run(self) -> Awaitable[Any]: + if not self.host or not self.port: + return + self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180) + @self.bot.on_message('group') + async def group(event: Event): + abm = self.convert_message(event) + if abm: + await self.handle_msg(abm) + + @self.bot.on_message('private') + async def private(event: Event): + abm = self.convert_message(event) + if abm: + await self.handle_msg(abm) + + bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder) + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.getLogger('aiocqhttp').setLevel(logging.ERROR) + + return bot + + def meta(self) -> PlatformMetadata: + return self.metadata + + async def shutdown_trigger_placeholder(self): + while not self._event_queue.closed: + await asyncio.sleep(1) + logger.info("aiocqhttp 适配器已关闭。") + + async def handle_msg(self, message: AstrBotMessage): + + message_event = AiocqhttpMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + bot=self.bot + ) + + self.commit_event(message_event) \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/main.py b/packages/astrbot_adapter_aiocqhttp/main.py new file mode 100644 index 000000000..01ecbd9e6 --- /dev/null +++ b/packages/astrbot_adapter_aiocqhttp/main.py @@ -0,0 +1,13 @@ +from astrbot.api import Context +from .aiocqhttp_platform_adapter import AiocqhttpAdapter +from astrbot.api import logger + +class Main: + def __init__(self, context: Context) -> None: + self.context = context + platforms_config = context.get_config().platform + settings = context.get_config().platform_settings + for platform in platforms_config: + if platform.name == "aiocqhttp" and platform.enable: + self.context.register_platform(AiocqhttpAdapter(platform, settings, context.get_event_queue())) + logger.info(f"已注册 aiocqhttp({platform.id}) 消息适配器。") \ No newline at end of file diff --git a/packages/astrbot_adapter_aiocqhttp/metadata.yaml b/packages/astrbot_adapter_aiocqhttp/metadata.yaml new file mode 100644 index 000000000..4269c4e78 --- /dev/null +++ b/packages/astrbot_adapter_aiocqhttp/metadata.yaml @@ -0,0 +1,6 @@ +name: astrbot_adapter_aiocqhttp # 插件名称 +desc: 支持 OneBot 协议的消息平台适配器(反向 Websockets) +help: +version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1 +author: Soulter # 作者 +repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_adapter_qqofficial/main.py b/packages/astrbot_adapter_qqofficial/main.py new file mode 100644 index 000000000..b7e9fd88a --- /dev/null +++ b/packages/astrbot_adapter_qqofficial/main.py @@ -0,0 +1,18 @@ +import botpy, logging +# delete qqbotpy's logger +for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + +from astrbot.api import Context +from .qqofficial_platform_adapter import QQOfficialPlatformAdapter +from astrbot.api import logger + +class Main: + def __init__(self, context: Context) -> None: + self.context = context + platforms_config = context.get_config().platform + settings = context.get_config().platform_settings + for platform in platforms_config: + if platform.name == "qq_official" and platform.enable: + self.context.register_platform(QQOfficialPlatformAdapter(platform, settings, context.get_event_queue())) + logger.info(f"已注册 qq_official({platform.id}) 消息适配器。") \ No newline at end of file diff --git a/packages/astrbot_adapter_qqofficial/metadata.yaml b/packages/astrbot_adapter_qqofficial/metadata.yaml new file mode 100644 index 000000000..263cbafc4 --- /dev/null +++ b/packages/astrbot_adapter_qqofficial/metadata.yaml @@ -0,0 +1,6 @@ +name: astrbot_adapter_qqofficial # 插件名称 +desc: 支持 QQ 官方机器人平台的消息平台适配器 +help: +version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1 +author: Soulter # 作者 +repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py b/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py new file mode 100644 index 000000000..b9346cf7b --- /dev/null +++ b/packages/astrbot_adapter_qqofficial/qqofficial_message_event.py @@ -0,0 +1,79 @@ +import os, traceback, base64, botpy +import botpy.message +import botpy.types +import botpy.types.message +from astrbot.core.utils.io import file_to_base64, download_image_by_url +from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType +from astrbot.api import Plain, Image +from botpy import Client +from botpy.http import Route + + +class QQOfficialMessageEvent(AstrMessageEvent): + def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.bot = bot + + async def send(self, message: MessageChain): + source = self.message_obj.raw_message + assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) + + plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message) + + payload = { + 'content': plain_text, + 'msg_id': self.message_obj.message_id, + } + + match type(source): + case botpy.message.GroupMessage: + if image_base64: + media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid) + payload['media'] = media + await self.bot.api.post_group_message(group_openid=source.group_openid, **payload) + case botpy.message.C2CMessage: + if image_base64: + media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid) + payload['media'] = media + await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload) + case botpy.message.Message: + if image_path: + payload['file_image'] = image_path + await self.bot.api.post_message(channel_id=source.channel_id, **payload) + case botpy.message.DirectMessage: + if image_path: + payload['file_image'] = image_path + await self.bot.api.post_dms(guild_id=source.guild_id, **payload) + + + async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media: + payload = { + 'file_data': image_base64, + 'file_type': file_type, + "srv_send_msg": False + } + if 'openid' in kwargs: + payload['openid'] = kwargs['openid'] + route = Route("POST", "/v2/users/{openid}/files", openid=kwargs['openid']) + return await self.bot.api._http.request(route, json=payload) + elif 'group_openid' in kwargs: + payload['group_openid'] = kwargs['group_openid'] + route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs['group_openid']) + return await self.bot.api._http.request(route, json=payload) + + @staticmethod + async def _parse_to_qqofficial(message: MessageChain): + plain_text = "" + image_base64 = None # only one img supported + image_file_path = None + for i in message.chain: + if isinstance(i, Plain): + plain_text += i.text + elif isinstance(i, Image) and not image_base64: + if i.file and i.file.startswith("file:///"): + image_base64 = file_to_base64(i.file[8:]) + image_file_path = i.file[8:] + elif i.file and i.file.startswith("http"): + image_file_path = await download_image_by_url(i.file) + image_base64 = file_to_base64(image_file_path) + return plain_text, image_base64, image_file_path \ No newline at end of file diff --git a/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py b/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py new file mode 100644 index 000000000..6bcc07000 --- /dev/null +++ b/packages/astrbot_adapter_qqofficial/qqofficial_platform_adapter.py @@ -0,0 +1,168 @@ +import botpy +import time +import asyncio +import botpy.message +import botpy.types +import botpy.types.message + +from botpy import Client +from astrbot.api import Platform +from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from typing import Union, List, Dict +from nakuru.entities.components import * +from astrbot.api import logger +from astrbot.core.platform.astr_message_event import MessageSesion +from .qqofficial_message_event import QQOfficialMessageEvent +from astrbot.core.config.astrbot_config import PlatformConfig, QQOfficialPlatformConfig, PlatformSettings +from astrbot.core.utils.io import save_temp_img, download_image_by_url + +# QQ 机器人官方框架 +class botClient(Client): + def set_platform(self, platform: 'QQOfficialPlatformAdapter'): + self.platform = platform + + # 收到群消息 + async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) + abm.session_id = abm.sender.user_id if self.platform.unique_session else message.group_openid + self._commit(abm) + + # 收到频道消息 + async def on_at_message_create(self, message: botpy.message.Message): + abm = self.platform._parse_from_qqofficial(message, MessageType.GROUP_MESSAGE) + abm.session_id = abm.sender.user_id if self.platform.unique_session else message.channel_id + self._commit(abm) + + # 收到私聊消息 + async def on_direct_message_create(self, message: botpy.message.DirectMessage): + abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) + abm.session_id = abm.sender.user_id + self._commit(abm) + + # 收到 C2C 消息 + async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) + abm.session_id = abm.sender.user_id + self._commit(abm) + + def _commit(self, abm: AstrBotMessage): + self.platform.commit_event(QQOfficialMessageEvent( + abm.message_str, + abm, + self.platform.meta(), + abm.session_id, + self.platform.client + )) + +class QQOfficialPlatformAdapter(Platform): + + def __init__(self, platform_config: QQOfficialPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: + super().__init__(event_queue) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.config = platform_config + + self.appid = platform_config.appid + self.secret = platform_config.secret + self.unique_session = platform_settings.unique_session + qq_group = platform_config.enable_group_c2c + guild_dm = platform_config.enable_guild_direct_message + + if qq_group: + self.intents = botpy.Intents( + public_messages=True, + public_guild_messages=True, + direct_message=guild_dm + ) + else: + self.intents = botpy.Intents( + public_guild_messages=True, + direct_message=guild_dm + ) + self.client = botClient( + intents=self.intents, + bot_log=False, + timeout=20, + ) + + self.client.set_platform(self) + + self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' + + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") + + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + "qqofficial", + "QQ 机器人官方 API 适配器", + ) + + def _parse_from_qqofficial(self, message: Union[botpy.message.Message, botpy.message.GroupMessage], + message_type: MessageType): + abm = AstrBotMessage() + abm.type = message_type + abm.timestamp = int(time.time()) + abm.raw_message = message + abm.message_id = message.id + abm.tag = "qqofficial" + msg: List[BaseMessageComponent] = [] + + if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage): + if isinstance(message, botpy.message.GroupMessage): + abm.sender = MessageMember( + message.author.member_openid, + "" + ) + else: + abm.sender = MessageMember( + message.author.user_openid, + "" + ) + abm.message_str = message.content.strip() + abm.self_id = "unknown_selfid" + + msg.append(Plain(abm.message_str)) + if message.attachments: + for i in message.attachments: + if i.content_type.startswith("image"): + url = i.url + if not url.startswith("http"): + url = "https://"+url + img = Image.fromURL(url) + msg.append(img) + abm.message = msg + + elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage): + try: + abm.self_id = str(message.mentions[0].id) + except: + abm.self_id = "" + + plain_content = message.content.replace( + "<@!"+str(abm.self_id)+">", "").strip() + msg.append(Plain(plain_content)) + if message.attachments: + for i in message.attachments: + if i.content_type.startswith("image"): + url = i.url + if not url.startswith("http"): + url = "https://"+url + img = Image.fromURL(url) + msg.append(img) + abm.message = msg + abm.message_str = plain_content + abm.sender = MessageMember( + str(message.author.id), + str(message.author.username) + ) + else: + raise ValueError(f"Unknown message type: {message_type}") + return abm + + def run(self): + return self.client.start( + appid=self.appid, + secret=self.secret + ) \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/__init__.py b/packages/astrbot_plugin_openai/__init__.py new file mode 100644 index 000000000..490c10d4b --- /dev/null +++ b/packages/astrbot_plugin_openai/__init__.py @@ -0,0 +1 @@ +PLUGIN_NAME = "astrbot_plugin_openai" \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/commands.py b/packages/astrbot_plugin_openai/commands.py new file mode 100644 index 000000000..31a613eb5 --- /dev/null +++ b/packages/astrbot_plugin_openai/commands.py @@ -0,0 +1,180 @@ +from astrbot.api import Context, AstrMessageEvent, MessageEventResult, MessageChain +from . import PLUGIN_NAME +from astrbot.api import logger, Image, Plain +from astrbot.api import personalities +from astrbot.api import command_parser +from astrbot.api import Provider + + +class OpenAIAdapterCommand: + def __init__(self, context: Context) -> None: + self.provider: Provider = None + self.context = context + context.register_commands(PLUGIN_NAME, "reset", "重置会话", 10, self.reset) + context.register_commands(PLUGIN_NAME, "his", "查看历史记录", 10, self.his) + context.register_commands(PLUGIN_NAME, "status", "查看当前状态", 10, self.status) + context.register_commands(PLUGIN_NAME, "switch", "切换账号", 10, self.switch) + context.register_commands(PLUGIN_NAME, "persona", "设置个性化人格", 10, self.persona) + context.register_commands(PLUGIN_NAME, "draw", "调用 DallE 模型画图", 10, self.draw) + context.register_commands(PLUGIN_NAME, "model", "切换 LLM 模型", 10, self.model) + context.register_commands(PLUGIN_NAME, "画", "调用 DallE 模型画图", 10, self.draw) + + def set_provider(self, provider: Provider): + self.provider = provider + + async def reset(self, message: AstrMessageEvent): + tokens = command_parser.parse(message.message_str) + if tokens.len == 1: + await self.provider.forget(message.session_id, keep_system_prompt=True) + message.set_result(MessageEventResult().message("重置成功")) + elif tokens.get(1) == 'p': + await self.provider.forget(message.session_id) + + async def model(self, message: AstrMessageEvent): + tokens = command_parser.parse(message.message_str) + if tokens.len == 1: + ret = await self._print_models() + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + return + model = tokens.get(1) + if model.isdigit(): + try: + models = await self.provider.get_models() + except BaseException as e: + logger.error(f"获取模型列表失败: {str(e)}。如果出现 404,可能与服务提供商未提供模型列表有关。") + message.set_result(MessageEventResult().message("获取模型列表失败,无法使用编号切换模型。可以尝试直接输入模型名来切换,如 gpt-4o。")) + models = list(models) + if int(model) <= len(models) and int(model) >= 1: + model = models[int(model)-1] + self.provider.set_model(model.id) + message.set_result(MessageEventResult().message(f"模型已设置为 {model.id}")) + else: + self.provider.set_model(model) + message.set_result(MessageEventResult().message(f"模型已设置为 {model} (自定义)")) + + async def _print_models(self): + models = [] + try: + models = await self.provider.get_models() + except BaseException as e: + return "获取模型列表失败: " + str(e) + i = 1 + ret = "下面列出了此服务提供商可用模型:" + for model in models: + ret += f"\n{i}. {model.id}" + i += 1 + ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + logger.debug(ret) + return ret + + def his(self, message: AstrMessageEvent): + tokens = command_parser.parse(message.message_str) + size_per_page = 3 + page = 1 + if tokens.len == 2: + try: + page = int(tokens.get(1)) + except BaseException as e: + message.set_result(MessageEventResult().message("页码格式错误")) + contexts, total_num = self.provider.dump_contexts_page(message.session_id, size_per_page, page=page) + t_pages = total_num // size_per_page + 1 + message.set_result(MessageEventResult().message(f"历史记录:\n\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n\n*输入 /his 2 跳转到第 2 页")) + + def status(self, message: AstrMessageEvent): + keys_data = self.provider.get_keys_data() + ret = "OpenAI Key" + for k in keys_data: + status = "🟢" if keys_data[k] else "🔴" + ret += "\n|- " + k[:8] + " " + status + + ret += "\n当前模型: " + self.provider.get_model() + + if message.session_id in self.provider.session_memory and len(self.provider.session_memory[message.session_id]): + ret += "\n你的会话上下文: " + str(self.provider.session_memory[message.session_id][-1]['usage_tokens']) + " tokens" + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + + async def switch(self, message: AstrMessageEvent): + ''' + 切换账号 + ''' + tokens = command_parser.parse(message.message_str) + if tokens.len == 1: + ret = "" + curr_ = self.provider.get_curr_key() + if curr_ is None: + ret += "当前您未选择账号。输入/switch <账号序号>切换账号。使用 /status 查看账号列表。" + else: + ret += f"当前您选择的账号为:{curr_[:8]}。输入/switch <账号序号>切换账号。使用 /status 查看账号列表。" + message.set_result(MessageEventResult().message(ret)) + elif tokens.len == 2: + try: + key_stat = self.provider.get_keys_data() + index = int(tokens.get(1)) + if index > len(key_stat) or index < 1: + message.set_result(MessageEventResult().message("账号序号错误。")) + else: + try: + new_key = list(key_stat.keys())[index-1] + self.provider.set_key(new_key) + except BaseException as e: + message.set_result(MessageEventResult().message("切换账号未知错误: "+str(e))) + message.set_result(MessageEventResult().message("切换账号成功。") ) + except BaseException as e: + message.set_result(MessageEventResult().message("切换账号错误。")) + else: + message.set_result(MessageEventResult().message("参数过多。")) + + + def persona(self, message: AstrMessageEvent): + l = message.message_str.split(" ") + if len(l) == 1: + message.set_result( + MessageEventResult().message(f"""[Persona] + +- 设置人格: `/persona 人格名`, 如 /persona 编剧 +- 人格列表: `/persona list` +- 人格详细信息: `/persona view 人格名` +- 自定义人格: /persona 人格文本 +- 重置 LLM 会话(清除人格): /reset +- 重置 LLM 会话(保留人格): /reset p + +【当前人格】: {str(self.provider.curr_personality['prompt'])} +""")) + elif l[1] == "list": + msg = "人格列表:\n" + for key in personalities.keys(): + msg += f"- {key}\n" + msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息' + message.set_result(MessageEventResult().message(msg)) + elif l[1] == "view": + if len(l) == 2: + message.set_result(MessageEventResult().message("请输入人格名")) + ps = l[2].strip() + if ps in personalities: + msg = f"人格{ps}的详细信息:\n" + msg += f"{personalities[ps]}\n" + else: + msg = f"人格{ps}不存在" + message.set_result(MessageEventResult().message(msg)) + else: + ps = "".join(l[1:]).strip() + if ps in personalities: + self.provider.curr_personality = { + 'name': ps, + 'prompt': personalities[ps] + } + self.provider.personality_set(self.provider.curr_personality, message.session_id) + message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) + else: + self.provider.curr_personality = { + 'name': '自定义人格', + 'prompt': ps + } + self.provider.personality_set(self.provider.curr_personality, message.session_id) + message.set_result(MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) + + async def draw(self, message: AstrMessageEvent): + prompt = message.message_str.removeprefix("画") + img_url = await self.provider.image_generate(prompt) + message.set_result(MessageEventResult().url_image(img_url)) \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/main.py b/packages/astrbot_plugin_openai/main.py new file mode 100644 index 000000000..026d82a9c --- /dev/null +++ b/packages/astrbot_plugin_openai/main.py @@ -0,0 +1,217 @@ +import json, traceback +from typing import List +from astrbot.api import Context, AstrMessageEvent, MessageEventResult +from .openai_adapter import ProviderOpenAIOfficial +from .commands import OpenAIAdapterCommand +from astrbot.api import logger +from . import PLUGIN_NAME +from astrbot.api import Image, Plain, MessageChain +from openai._exceptions import * +from openai.types.chat.chat_completion_message_tool_call import Function +from astrbot.api import command_parser +from .web_searcher import search_from_bing, fetch_website_content + +class Main: + def __init__(self, context: Context) -> None: + self.context = context + + self.provider_insts: List[ProviderOpenAIOfficial] = [] + self.provider = None + + llms_config = self.context.get_config().llm + loaded = False + for llm in llms_config: + if llm.enable: + if llm.name == "openai": + if not llm.key or not llm.enable: + logger.warning("没有开启 LLM Provider 或 API Key 未填写。") + continue + self.provider_insts.append(ProviderOpenAIOfficial(llm, self.context.get_db())) + loaded = True + logger.info(f"已启用 LLM Provider(OpenAI API): {llm.id}({llm.name})。") + + if loaded: + self.command_handler = OpenAIAdapterCommand(self.context) + self.command_handler.set_provider(self.provider_insts[0]) + self.context.register_listener(PLUGIN_NAME, "openai_adapter_chat", self.chat, "OpenAI Adapter LLM 调用监听器", after_commands=True) + self.provider = self.command_handler.provider + + self.context.register_commands(PLUGIN_NAME, "provider", "查看当前 LLM Provider", 10, self.provider_info) + self.context.register_commands(PLUGIN_NAME, "websearch", "启用/关闭网页搜索", 10, self.web_search) + + if self.context.get_config().llm_settings.web_search: + self.add_web_search_tools() + + def add_web_search_tools(self): + self.context.register_llm_tool("web_search", [{ + "type": "string", + "name": "keyword", + "description": "搜索关键词" + }], + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + search_from_bing + ) + self.context.register_llm_tool("fetch_website_content", [{ + "type": "string", + "name": "url", + "description": "要获取内容的网页链接" + }], + "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", + fetch_website_content + ) + + async def remove_web_search_tools(self): + self.context.unregister_llm_tool("web_search") + self.context.unregister_llm_tool("fetch_website_content") + + async def provider_info(self, event: AstrMessageEvent): + if len(self.provider_insts) == 0: + event.set_result(MessageEventResult().message("未启用任何 LLM Provider。")) + + tokens = command_parser.parse(event.get_message_str()) + + if tokens.len == 1: + ret = "## 当前载入的 LLM 接入源\n" + for idx, llm in enumerate(self.provider_insts): + ret += f"{idx}. {llm.llm_config.id} ({llm.llm_config.model_config.model})" + if self.provider == llm: + ret += " (当前使用)" + ret += "\n" + + ret += "\n使用 /provider <序号> 切换 LLM 接入源。" + event.set_result(MessageEventResult().message(ret)) + return + else: + try: + idx = int(tokens.get(1)) + if idx >= len(self.provider_insts): + event.set_result(MessageEventResult().message("无效的序号。")) + self.provider = self.provider_insts[idx] + self.command_handler.set_provider(self.provider) + event.set_result(MessageEventResult().message(f"已经成功切换到 LLM 接入源 {self.provider.llm_config.id}。")) + return + except BaseException as e: + event.set_result(MessageEventResult().message("provider: 参数错误。")) + return + + async def web_search(self, event: AstrMessageEvent): + websearch = self.context.get_config().llm_settings.web_search + if websearch: + # turn off + self.context.get_config().llm_settings.web_search = False + self.context.get_config().save_config() + self.remove_web_search_tools() + event.set_result(MessageEventResult().message("已关闭网页搜索。")) + return + # turn on + self.context.get_config().llm_settings.web_search = True + self.context.get_config().save_config() + self.add_web_search_tools() + event.set_result(MessageEventResult().message("已开启网页搜索。")) + + async def chat(self, event: AstrMessageEvent): + if not event.is_wake_up(): + return + + image_url = None + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_url = comp.url if comp.url else comp.file + break + llm_result = None + try: + if not self.context.llm_tools.empty(): + # tools-use + tool_use_flag = True + llm_result = await self.provider.text_chat( + prompt=event.message_str, + session_id=event.session_id, + tools=self.context.llm_tools.get_func() + ) + # self.context.metrics_uploader.llm_stats[provider.get_curr_model()] += 1 + + if isinstance(llm_result, Function): + logger.debug(f"function-calling: {llm_result}") + func_obj = None + for i in self.context.llm_tools.func_list: + if i["name"] == llm_result.name: + func_obj = i["func_obj"] + break + if not func_obj: + event.set_result(MessageEventResult().message("AstrBot Function-calling 异常:未找到请求的函数调用。")) + return + try: + args = json.loads(llm_result.arguments) + args['event'] = event + args['provider'] = self.provider + try: + func_result = await func_obj(**args) + except TypeError as e: + args.pop('event') + args.pop('provider') + func_result = await func_obj(**args) + if func_result: + logger.warning(f"function-calling: 工具函数 {llm_result.name} 返回了非空值,该值将被忽略。请使用 event.set_result() 设置返回值。") + return + if event.get_result(): + return + except BaseException as e: + traceback.print_exc() + event.set_result(MessageEventResult().message("AstrBot Function-calling 异常:" + str(e))) + return + else: + event.set_result(MessageEventResult().message(llm_result)) + return + else: + # normal chat + tool_use_flag = False + # add user info to the prompt + if self.context.get_config().llm_settings.identifier: + user_id = event.message_obj.sender.user_id + user_nickname = event.message_obj.sender.nickname + user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n" + event.message_str = user_info + event.message_str + llm_result = await self.provider.text_chat( + prompt=event.message_str, + session_id=event.session_id, + image_url=image_url + ) + # self.context.metrics_uploader.llm_stats[provider.get_curr_model()] += 1 + except BadRequestError as e: + if tool_use_flag: + # seems like the model don't support function-calling + logger.error(f"error: {e}. Using local function-calling implementation") + + try: + # use local function-calling implementation + args = { + 'question': llm_result, + 'func_definition': self.context.llm_tools.func_dump(), + } + _, has_func = await self.context.llm_tools.func_call(**args) + + if not has_func: + # normal chat + llm_result = await self.provider.text_chat( + prompt=event.message_str, + session_id=event.session_id, + image_url=image_url + ) + except BaseException as e: + logger.error(traceback.format_exc()) + event.set_result(MessageEventResult().message("AstrBot Function-calling 异常:" + str(e))) + return + else: + logger.error(traceback.format_exc()) + logger.error(f"LLM 调用失败。") + event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) + return + except BaseException as e: + logger.error(traceback.format_exc()) + logger.error(f"LLM 调用失败。") + event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) + return + + if llm_result: + event.set_result(MessageEventResult().message(llm_result)) + return diff --git a/packages/astrbot_plugin_openai/metadata.yaml b/packages/astrbot_plugin_openai/metadata.yaml new file mode 100644 index 000000000..8af419fb6 --- /dev/null +++ b/packages/astrbot_plugin_openai/metadata.yaml @@ -0,0 +1,6 @@ +name: astrbot_plugin_openai # 插件名称 +desc: 支持 OpenAI API +help: +version: v1.5.0 # 插件版本号。格式:v1.1.1 或者 v1.1 +author: Soulter # 作者 +repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/model/provider/openai_official.py b/packages/astrbot_plugin_openai/openai_adapter.py similarity index 92% rename from model/provider/openai_official.py rename to packages/astrbot_plugin_openai/openai_adapter.py index 19aa256f1..4488acd2b 100644 --- a/model/provider/openai_official.py +++ b/packages/astrbot_plugin_openai/openai_adapter.py @@ -10,19 +10,16 @@ import base64 from openai import AsyncOpenAI from openai.types.chat.chat_completion import ChatCompletion from openai._exceptions import * -from util.io import download_image_by_url +from astrbot.core.utils.io import download_image_by_url -from astrbot.db import BaseDatabase -from model.provider.provider import Provider -from util.cmd_config import LLMConfig -from util.log import LogManager -from logging import Logger +from astrbot.core.db import BaseDatabase +from astrbot.api import Provider +from astrbot.core.config.astrbot_config import LLMConfig +from astrbot import logger from typing import List, Dict from dataclasses import asdict -logger: Logger = LogManager.GetLogger(log_name='astrbot') - class ProviderOpenAIOfficial(Provider): def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase) -> None: super().__init__() @@ -46,7 +43,7 @@ class ProviderOpenAIOfficial(Provider): api_key=self.chosen_api_key, base_url=self.base_url ) - super().set_curr_model(llm_config.model_config.model) + self.set_model(llm_config.model_config.model) if llm_config.image_generation_model_config: self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config) self.session_memory: Dict[str, List] = {} # 会话记忆 @@ -135,14 +132,14 @@ class ProviderOpenAIOfficial(Provider): return context async def get_models(self): + models = [] try: models = await self.client.models.list() except NotFoundError as e: bu = str(self.client.base_url) self.client.base_url = bu + "/v1" models = await self.client.models.list() - finally: - return filter(lambda x: x.id.startswith("gpt"), models.data) + return models async def assemble_context(self, session_id: str, prompt: str, image_url: str = None): ''' @@ -218,9 +215,8 @@ class ProviderOpenAIOfficial(Provider): async def text_chat(self, prompt: str, session_id: str, - image_url: None=None, - tools: None=None, - extra_conf: Dict = None, + image_url=None, + tools=None, **kwargs ) -> str: if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on": @@ -244,14 +240,11 @@ class ProviderOpenAIOfficial(Provider): contexts = await self.retrieve_context(session_id) conf = asdict(self.llm_config.model_config) - if extra_conf: conf.update(extra_conf) # start request retry = 0 rate_limit_retry = 0 while retry < 3 or rate_limit_retry < 5: - logger.debug(conf) - logger.debug(contexts) if tools: completion_coro = self.client.chat.completions.create( messages=contexts, @@ -353,8 +346,9 @@ class ProviderOpenAIOfficial(Provider): retry = 0 conf = self.image_generator_model_configs if not conf: - logger.error("OpenAI 图片生成模型配置不存在。") - raise Exception("OpenAI 图片生成模型配置不存在。") + logger.error("图片生成模型配置不存在。") + raise Exception("图片生成模型配置不存在。") + conf.pop("enable") while retry < 3: try: images_response = await self.client.images.generate( @@ -367,8 +361,8 @@ class ProviderOpenAIOfficial(Provider): retry += 1 if retry >= 3: logger.error(traceback.format_exc()) - raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。") - logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。") + raise Exception(f"图片生成请求失败:{e}。重试次数已达到上限。") + logger.warning(f"图片生成请求失败:{e}。重试第 {retry} 次。") await asyncio.sleep(1) async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool: @@ -389,18 +383,14 @@ class ProviderOpenAIOfficial(Provider): for record in self.session_memory[session_id]: if "user" in record and record['user']: text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content'] - contexts_str += f"User: {text}\n" + contexts_str += f"User: {text}\n\n" if "AI" in record and record['AI']: text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content'] - contexts_str += f"Assistant: {text}\n" + contexts_str += f"Assistant: {text}\n\n" else: contexts_str = "会话 ID 不存在。" return contexts_str, len(self.session_memory[session_id]) - - def set_model(self, model: str): - # TODO: 更新配置文件 - super().set_curr_model(model) def get_configs(self): return asdict(self.llm_config) diff --git a/util/agent/web_searcher.py b/packages/astrbot_plugin_openai/web_searcher.py similarity index 63% rename from util/agent/web_searcher.py rename to packages/astrbot_plugin_openai/web_searcher.py index 5809c99c6..92982d1c8 100644 --- a/util/agent/web_searcher.py +++ b/packages/astrbot_plugin_openai/web_searcher.py @@ -4,21 +4,12 @@ import os from readability import Document from bs4 import BeautifulSoup -from openai.types.chat.chat_completion_message_tool_call import Function from openai._exceptions import * -from util.agent.func_call import FuncCall -from util.websearch.config import HEADERS, USER_AGENTS -from util.websearch.bing import Bing -from util.websearch.sogo import Sogo -from util.websearch.google import Google -from model.provider.provider import Provider -from util.log import LogManager -from logging import Logger -from type.types import Context -from type.message_event import AstrMessageEvent - -logger: Logger = LogManager.GetLogger(log_name='astrbot') - +from .websearch.config import HEADERS, USER_AGENTS +from .websearch.bing import Bing +from .websearch.sogo import Sogo +from .websearch.google import Google +from astrbot.api import logger, AstrMessageEvent, Provider, MessageChain, MessageEventResult bing_search = Bing() sogo_search = Sogo() @@ -31,7 +22,7 @@ def tidy_text(text: str) -> str: ''' return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") -async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str: +async def search_from_bing(keyword: str, event: AstrMessageEvent = None, provider: Provider = None) -> str: ''' tools, 从 bing 搜索引擎搜索 ''' @@ -68,10 +59,10 @@ async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n" idx += 1 - return await summarize(context, ame, ret) + return await summarize(ret, event, provider) -async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str): +async def fetch_website_content(url: str, event: AstrMessageEvent = None, provider: Provider = None) -> str: header = HEADERS header.update({'User-Agent': random.choice(USER_AGENTS)}) async with aiohttp.ClientSession() as session: @@ -81,18 +72,18 @@ async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: st ret = doc.summary(html_partial=True) soup = BeautifulSoup(ret, 'html.parser') ret = tidy_text(soup.get_text()) - return await summarize(context, ame, ret) + return await summarize(ret, event, provider) -async def summarize(context: Context, ame: AstrMessageEvent, text: str): +async def summarize(text: str, event: AstrMessageEvent = None, provider: Provider = None) -> str: summary_prompt = f""" -你是一个专业且高效的助手,你的任务是 -1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结; -2. 简单地发表你对这个问题的看法。 +你是一个专业且高效的助手,你擅长总结给定文本。你的任务是 +1. 回答用户的问题 `{event.message_str}`,用户的问题相关的材料在下方; +2. 简略发表你的看法。 # 例子 -1. 从网上的信息来看,可以知道...我个人认为...你觉得呢? -2. 根据网上的最新信息,可以得知...我觉得...你怎么看? +1. 从网上的信息来看,可以知道...我个人认为... +2. 根据网上的最新信息,可以得知...我觉得... # 限制 1. 限制在 200-300 字; @@ -100,6 +91,5 @@ async def summarize(context: Context, ame: AstrMessageEvent, text: str): # 相关材料 {text}""" - - provider = context.get_current_llm_provider() - return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id) \ No newline at end of file + ret = await provider.text_chat(summary_prompt, session_id=event.session_id) + event.set_result(MessageEventResult().message(ret)) \ No newline at end of file diff --git a/util/websearch/bing.py b/packages/astrbot_plugin_openai/websearch/bing.py similarity index 82% rename from util/websearch/bing.py rename to packages/astrbot_plugin_openai/websearch/bing.py index c0801dce7..3108ab529 100644 --- a/util/websearch/bing.py +++ b/packages/astrbot_plugin_openai/websearch/bing.py @@ -1,11 +1,6 @@ from typing import List - -try: - from util.websearch.engine import SearchEngine, SearchResult - from util.websearch.config import HEADERS, USER_AGENT_BING -except ImportError: - from engine import SearchEngine, SearchResult - from config import HEADERS, USER_AGENT_BING +from .engine import SearchEngine, SearchResult +from .config import HEADERS, USER_AGENT_BING class Bing(SearchEngine): def __init__(self) -> None: diff --git a/util/websearch/config.py b/packages/astrbot_plugin_openai/websearch/config.py similarity index 100% rename from util/websearch/config.py rename to packages/astrbot_plugin_openai/websearch/config.py diff --git a/util/websearch/engine.py b/packages/astrbot_plugin_openai/websearch/engine.py similarity index 94% rename from util/websearch/engine.py rename to packages/astrbot_plugin_openai/websearch/engine.py index 57bf2f784..a72972679 100644 --- a/util/websearch/engine.py +++ b/packages/astrbot_plugin_openai/websearch/engine.py @@ -1,9 +1,5 @@ import random -try: - from util.websearch.config import HEADERS, USER_AGENTS -except ImportError: - from config import HEADERS, USER_AGENTS - +from .config import HEADERS, USER_AGENTS from bs4 import BeautifulSoup from aiohttp import ClientSession from dataclasses import dataclass diff --git a/util/websearch/google.py b/packages/astrbot_plugin_openai/websearch/google.py similarity index 68% rename from util/websearch/google.py rename to packages/astrbot_plugin_openai/websearch/google.py index f1d78470b..4e950975e 100644 --- a/util/websearch/google.py +++ b/packages/astrbot_plugin_openai/websearch/google.py @@ -1,12 +1,8 @@ import os from googlesearch import search -try: - from util.websearch.engine import SearchEngine, SearchResult - from util.websearch.config import HEADERS, USER_AGENTS -except ImportError: - from engine import SearchEngine, SearchResult - from config import HEADERS, USER_AGENTS +from .engine import SearchEngine, SearchResult +from .config import HEADERS, USER_AGENTS from typing import List @@ -18,7 +14,6 @@ class Google(SearchEngine): async def search(self, query: str, num_results: int) -> List[SearchResult]: results = [] try: - print("use proxy:", self.proxy) ls = search(query, advanced=True, num_results=num_results, timeout=3, proxy=self.proxy) for i in ls: results.append(SearchResult(title=i.title, url=i.url, snippet=i.description)) diff --git a/util/websearch/sogo.py b/packages/astrbot_plugin_openai/websearch/sogo.py similarity index 85% rename from util/websearch/sogo.py rename to packages/astrbot_plugin_openai/websearch/sogo.py index 07a7a5b74..c4841662d 100644 --- a/util/websearch/sogo.py +++ b/packages/astrbot_plugin_openai/websearch/sogo.py @@ -1,12 +1,7 @@ import random, re from bs4 import BeautifulSoup - -try: - from util.websearch.engine import SearchEngine, SearchResult - from util.websearch.config import HEADERS, USER_AGENTS -except ImportError: - from engine import SearchEngine, SearchResult - from config import HEADERS, USER_AGENTS +from .engine import SearchEngine, SearchResult +from .config import HEADERS, USER_AGENTS from typing import List diff --git a/tests/mocks/onebot.py b/tests/mocks/onebot.py deleted file mode 100644 index eed5b0193..000000000 --- a/tests/mocks/onebot.py +++ /dev/null @@ -1,18 +0,0 @@ -import copy -from aiocqhttp import Event - -class MockOneBotMessage(): - def __init__(self): - # 这些数据不是敏感的 - self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 'test', 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 'test', 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470}) - self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 'test', 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 'test', 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'}) - - def create_random_group_message(self): - return self.group_event_sample - - def create_random_direct_message(self): - return self.friend_event_sample - - def create_msg(self, text: str): - self.group_event_sample.message = [{'data': {'qq': 'test'}, 'type': 'at'}, {'data': {'text': text}, 'type': 'text'}] - return self.group_event_sample \ No newline at end of file diff --git a/tests/mocks/qq_official.py b/tests/mocks/qq_official.py deleted file mode 100644 index 0978502aa..000000000 --- a/tests/mocks/qq_official.py +++ /dev/null @@ -1,54 +0,0 @@ -import botpy.message - -class MockQQOfficialMessage(): - def __init__(self): - # 这些数据已经经过去敏处理 - self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T19:58:52+08:00'} - self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:06:32+08:00'} - self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:15:24+08:00'} - self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849" - - self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'} - self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'} - self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'} - self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" - - self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'} - self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'} - self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'} - self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" - - def create_random_group_message(self): - mocked = botpy.message.GroupMessage( - api=None, - event_id=self.group_event_id_sample, - data=self.group_plain_text_sample - ) - return mocked - - def create_random_guild_message(self): - mocked = botpy.message.Message( - api=None, - event_id=self.guild_event_id_sample, - data=self.guild_plain_text_sample - ) - return mocked - - def create_random_direct_message(self): - mocked = botpy.message.DirectMessage( - api=None, - event_id=self.direct_event_id_sample, - data=self.direct_plain_text_sample - ) - return mocked - - def create_msg(self, text: str): - sample = self.group_plain_text_sample.copy() - sample['content'] = text - mocked = botpy.message.Message( - api=None, - event_id=self.group_event_id_sample, - data=sample - ) - return mocked - diff --git a/tests/test_message.py b/tests/test_message.py deleted file mode 100644 index e40bd21ed..000000000 --- a/tests/test_message.py +++ /dev/null @@ -1,215 +0,0 @@ -import asyncio, aiohttp -import pytest -import os - -from tests.mocks.qq_official import MockQQOfficialMessage -from tests.mocks.onebot import MockOneBotMessage - -from astrbot.bootstrap import AstrBotBootstrap -from model.platform.qq_official import QQOfficial -from model.platform.qq_aiocqhttp import AIOCQHTTP -from model.provider.openai_official import ProviderOpenAIOfficial -from type.astrbot_message import * -from type.message_event import * -from util.log import LogManager - -from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig - -logger = LogManager.GetLogger(log_name='astrbot') -pytest_plugins = ('pytest_asyncio',) - -os.environ['TEST_MODE'] = 'on' -bootstrap = AstrBotBootstrap() - -llm_config = bootstrap.context.config_helper.llm[0] -llm_config.api_base = os.environ['OPENAI_API_BASE'] -llm_config.key = [os.environ['OPENAI_API_KEY']] -llm_config.model_config.model = os.environ['LLM_MODEL'] -llm_config.model_config.max_tokens = 1000 -asyncio.run(bootstrap.run()) -llm_provider = ProviderOpenAIOfficial(llm_config, bootstrap.db_helper) -bootstrap.message_handler.provider = llm_provider -bootstrap.config_helper.wake_prefix = ["/"] -bootstrap.config_helper.admins_id = ["905617992"] - -for p_config in bootstrap.context.config_helper.platform: - if isinstance(p_config, QQOfficialPlatformConfig): - qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler, p_config) - elif isinstance(p_config, AiocqhttpPlatformConfig): - aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler, p_config) - -class TestBasicMessageHandle(): - @pytest.mark.asyncio - async def test_qqofficial_group_message(self): - group_message = MockQQOfficialMessage().create_random_group_message() - abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE) - ret = await qq_official.handle_msg(abm) - print(ret) - - @pytest.mark.asyncio - async def test_qqofficial_guild_message(self): - guild_message = MockQQOfficialMessage().create_random_guild_message() - abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE) - ret = await qq_official.handle_msg(abm) - print(ret) - - @pytest.mark.asyncio - async def test_aiocqhttp_group_message(self): - event = MockOneBotMessage().create_random_group_message() - abm = aiocqhttp.convert_message(event) - ret = await aiocqhttp.handle_msg(abm) - print(ret) - - @pytest.mark.asyncio - async def test_aiocqhttp_direct_message(self): - event = MockOneBotMessage().create_random_direct_message() - abm = aiocqhttp.convert_message(event) - ret = await aiocqhttp.handle_msg(abm) - print(ret) - -class TestInteralCommandHandle(): - def create(self, text: str): - event = MockOneBotMessage().create_msg(text) - abm = aiocqhttp.convert_message(event) - return abm - - async def fast_test(self, text: str): - abm = self.create(text) - ret = await aiocqhttp.handle_msg(abm) - print(f"Command: {text}, Result: {ret.result_message}") - return ret - - @pytest.mark.asyncio - async def test_config_save(self): - abm = self.create("/websearch on") - ret = await aiocqhttp.handle_msg(abm) - assert bootstrap.context.config_helper.llm_settings.web_search \ - == bootstrap.config_helper.get("llm_settings")['web_search'] - - @pytest.mark.asyncio - async def test_websearch(self): - await self.fast_test("/websearch") - await self.fast_test("/websearch on") - await self.fast_test("/websearch off") - - @pytest.mark.asyncio - async def test_help(self): - await self.fast_test("/help") - - @pytest.mark.asyncio - async def test_myid(self): - await self.fast_test("/myid") - - @pytest.mark.asyncio - async def test_wake(self): - await self.fast_test("/wake") - await self.fast_test("/wake #") - assert "#" in bootstrap.context.config_helper.wake_prefix - assert "#" in bootstrap.context.config_helper.get("wake_prefix") - await self.fast_test("#wake /") - - @pytest.mark.asyncio - async def test_sleep(self): - await self.fast_test("/provider") - - @pytest.mark.asyncio - async def test_update(self): - await self.fast_test("/update") - - @pytest.mark.asyncio - async def test_t2i(self): - if not bootstrap.context.config_helper.t2i: - abm = self.create("/t2i") - await aiocqhttp.handle_msg(abm) - await self.fast_test("/help") - - @pytest.mark.asyncio - async def test_plugin(self): - pname = "astrbot_plugin_bilibili" - url = f"https://github.com/Soulter/{pname}" - await self.fast_test("/plugin") - await self.fast_test(f"/plugin l") - await self.fast_test(f"/plugin i {url}") - await self.fast_test(f"/plugin u {url}") - await self.fast_test(f"/plugin d {pname}") - - @pytest.mark.asyncio - async def test_llm(self): - await self.fast_test("/reset") - os.environ["TEST_LLM"] = "on" - ret = await llm_provider.text_chat("Just reply `ok`", "test") - print(ret) - event = MockOneBotMessage().create_msg("Just reply `ok`") - abm = aiocqhttp.convert_message(event) - ret = await aiocqhttp.handle_msg(abm) - print(ret) - os.environ["TEST_LLM"] = "off" - await self.fast_test("/reset") - await self.fast_test("/status") - await self.fast_test("/his") - await self.fast_test("/switch") - await self.fast_test("/set") - await self.fast_test("/set list") - await self.fast_test("/unset") - -BASE_URL = "http://0.0.0.0:6185/api" -class TestHTTPServer: - - async def get_url(self, url): - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - return await response.json(), response.status - - async def post_url(self, url, data): - async with aiohttp.ClientSession() as session: - async with session.post(url, json=data) as response: - return await response.json(), response.status - - @pytest.mark.asyncio - async def test_config(self): - configs, status = await self.get_url(f"{BASE_URL}/config/get") - assert status == 200 - assert 'data' in configs and 'metadata' in configs['data'] \ - and 'config' in configs['data'] - config = configs['data']['config'] - # test post config - await self.post_url(f"{BASE_URL}/config/astrbot/update", config) - # text post config with invalid data - assert 'rate_limit' in config['platform_settings'] - config['platform_settings']['rate_limit'] = "invalid" - ret, status = await self.post_url(f"{BASE_URL}/config/astrbot/update", config) - assert status == 200 - assert 'status' in ret and ret['status'] == 'error' - - @pytest.mark.asyncio - async def test_update(self): - _, status = await self.get_url(f"{BASE_URL}/update/check") - assert status == 200 - - @pytest.mark.asyncio - async def test_stats(self): - _, status = await self.get_url(f"{BASE_URL}/stat/get") - _, status = await self.get_url(f"{BASE_URL}/stat/version") - _, status = await self.get_url(f"{BASE_URL}/stat/start-time") - assert status == 200 - - @pytest.mark.asyncio - async def test_plugins(self): - pname = "astrbot_plugin_bilibili" - url = f"https://github.com/Soulter/{pname}" - - _, status = await self.get_url(f"{BASE_URL}/plugin/get") - - # test install plugin - _, status = await self.post_url(f"{BASE_URL}/plugin/install", { - "url": url - }) - - # test uninstall plugin - _, status = await self.post_url(f"{BASE_URL}/plugin/uninstall", { - "name": pname - }) - - assert status == 200 - - bootstrap.context.running = False diff --git a/type/command.py b/type/command.py deleted file mode 100644 index ac9ca0d50..000000000 --- a/type/command.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Union, List, Callable -from dataclasses import dataclass -from nakuru.entities.components import Plain, Image - -@dataclass -class CommandItem(): - ''' - 用来描述单个指令 - ''' - - command_name: Union[str, tuple] # 指令名 - callback: Callable # 回调函数 - description: str # 描述 - origin: str # 注册来源 - -class CommandResult(): - ''' - 用于在Command中返回多个值 - ''' - - def __init__(self, - hit: bool = True, - success: bool = True, - message_chain: list = [], - command_name: str = "unknown_command", - use_t2i: bool = None) -> None: - self.hit = hit - self.success = success - self.message_chain = message_chain - self.command_name = command_name - self.is_use_t2i = use_t2i - - def message(self, message: str): - ''' - 快捷回复消息。 - - CommandResult().message("Hello, world!") - ''' - self.message_chain = [Plain(message), ] - return self - - def error(self, message: str): - ''' - 快捷回复消息。 - - CommandResult().error("Hello, world!") - ''' - self.success = False - self.message_chain = [Plain(message), ] - return self - - def url_image(self, url: str): - ''' - 快捷回复图片(网络url的格式)。 - - CommandResult().image("https://example.com/image.jpg") - ''' - self.message_chain = [Image.fromURL(url), ] - return self - - def file_image(self, path: str): - ''' - 快捷回复图片(本地文件路径的格式)。 - - CommandResult().image("image.jpg") - ''' - self.message_chain = [Image.fromFileSystem(path), ] - return self - - def use_t2i(self, use_t2i: bool): - ''' - 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 - ''' - self.is_use_t2i = use_t2i - return self - - def _result_tuple(self): - return (self.success, self.message_chain, self.command_name) diff --git a/type/message_event.py b/type/message_event.py deleted file mode 100644 index 5c106f9e1..000000000 --- a/type/message_event.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import List, Union, Optional -from dataclasses import dataclass -from type.register import RegisteredPlatform -from type.types import Context -from type.astrbot_message import AstrBotMessage, MessageType - -@dataclass -class MessageResult(): - result_message: Union[str, list] - is_command_call: Optional[bool] = False - use_t2i: Optional[bool] = None # None 为跟随用户设置 - callback: Optional[callable] = None - -class AstrMessageEvent(): - - def __init__(self, - message_str: str, - message_obj: AstrBotMessage, - platform: RegisteredPlatform, - role: str, - context: Context, - session_id: str = None, - unified_msg_origin: str = None, - only_command: bool = False): - ''' - AstrBot 消息事件。 - - `message_str`: 纯消息字符串 - `message_obj`: AstrBotMessage 对象 - `platform`: 平台对象 - `role`: 角色,`admin` or `member` - `context`: 全局对象 - `session_id`: 会话id - `unified_msg_origin`: 统一消息来源 - `only_command`: 是否只处理指令,而不使用 LLM 回复 - ''' - self.context = context - self.message_str = message_str - self.message_obj = message_obj - self.platform = platform - self.role = role - self.session_id = session_id - self.unified_msg_origin = unified_msg_origin - self.only_command = only_command - - def from_astrbot_message(message: AstrBotMessage, - context: Context, - platform_name: str, - session_id: str, - - unified_msg_origin: str = None, - only_command: bool = False): - - # 解析 role - sender_id = str(message.sender.user_id) - if sender_id in context.config_helper.admins_id: - role = 'admin' - else: - role = 'member' - - ame = AstrMessageEvent(message.message_str, - message, - context.find_platform(platform_name), - role, - context, - session_id, - unified_msg_origin, - only_command=only_command) - return ame - diff --git a/type/middleware.py b/type/middleware.py deleted file mode 100644 index dc1b1cc8e..000000000 --- a/type/middleware.py +++ /dev/null @@ -1,8 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class Middleware(): - name: str = "" - description: str = "" - origin: str = "" # 注册来源 - func: callable = None \ No newline at end of file diff --git a/type/register.py b/type/register.py deleted file mode 100644 index 1cf679a79..000000000 --- a/type/register.py +++ /dev/null @@ -1,27 +0,0 @@ -from model.provider.provider import Provider as LLMProvider -from model.platform import Platform -from type.plugin import * -from typing import List -from dataclasses import dataclass - -@dataclass -class RegisteredPlatform: - ''' - 注册在 AstrBot 中的平台。平台应当实现 Platform 接口。 - ''' - platform_name: str - platform_instance: Platform - origin: str = None # 注册来源 - - def __str__(self) -> str: - return self.platform_name - - -@dataclass -class RegisteredLLM: - ''' - 注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。 - ''' - llm_name: str - llm_instance: LLMProvider - origin: str = None # 注册来源 diff --git a/type/types.py b/type/types.py deleted file mode 100644 index e14fc36fc..000000000 --- a/type/types.py +++ /dev/null @@ -1,152 +0,0 @@ -import asyncio, os, time -from asyncio import Task -from type.register import * -from typing import List, Awaitable -from logging import Logger -from util.cmd_config import AstrBotConfig -from util.t2i.renderer import TextToImageRenderer -from util.updator.astrbot_updator import AstrBotUpdator -from util.image_uploader import ImageUploader -from util.updator.plugin_updator import PluginUpdator -from type.command import CommandResult -from type.middleware import Middleware -from type.astrbot_message import MessageType -from model.plugin.command import PluginCommandBridge -from model.provider.provider import Provider -from util.log import LogBroker -from util.metrics import MetricUploader - - -class Context: - ''' - 存放一些公用的数据,用于在不同模块(如core与command)之间传递 - ''' - - def __init__(self): - self.running = True - self.logger: Logger = None - self.config_helper: AstrBotConfig = None - self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件 - self.platforms: List[RegisteredPlatform] = [] - self.llms: List[RegisteredLLM] = [] - self.default_personality: dict = None - - self.metrics_uploader: MetricUploader = None - self.updator: AstrBotUpdator = None - self.plugin_updator: PluginUpdator = None - self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins) - self.image_renderer = TextToImageRenderer() - self.image_uploader = ImageUploader() - self.message_handler = None # see astrbot/message/handler.py - self.ext_tasks: List[Task] = [] - self.middlewares: List[Middleware] = [] - - self.command_manager = None - self.running = True - self._loop = asyncio.get_event_loop() - self._start_running = int(time.time()) - - self.log_broker = LogBroker() - - def register_commands(self, - plugin_name: str, - command_name: str, - description: str, - priority: int, - handler: callable, - use_regex: bool = False, - ignore_prefix: bool = False): - ''' - 注册插件指令。 - - @param plugin_name: 插件名,注意需要和你的 metadata 中的一致。 - @param command_name: 指令名,如 "help"。不需要带前缀。 - @param description: 指令描述。 - @param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。 - @param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context - @param use_regex: 是否使用正则表达式匹配指令名。 - @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。 - - .. Example:: - - ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。 - ''' - self.plugin_command_bridge.register_command(plugin_name, - command_name, - description, - priority, - handler, - use_regex, - ignore_prefix) - - def register_task(self, coro: Awaitable, task_name: str): - ''' - 注册任务。适用于需要长时间运行的插件。 - - `coro`: 协程对象 - `task_name`: 任务名,用于标识任务。自定义即可。 - ''' - task = asyncio.create_task(coro, name=task_name) - self.ext_tasks.append(task) - - def register_provider(self, llm_name: str, provider: Provider, origin: str = ''): - ''' - 注册一个提供 LLM 资源的 Provider。 - - `llm_name`: 自定义的用于识别 Provider 的名称。在 AstrBot 配置中,是 `llm` 字段下的 `id` 字段。 - `provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。 - ''' - self.llms.append(RegisteredLLM(llm_name, provider, origin)) - - def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable): - ''' - 为函数调用(function-calling / tools-use)添加工具。 - - @param name: 函数名 - @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] - @param desc: 函数描述 - @param func_obj: 处理函数 - ''' - self.message_handler.llm_tools.add_func(tool_name, params, desc, func) - - def unregister_llm_tool(self, tool_name: str): - ''' - 删除一个函数调用工具。 - ''' - self.message_handler.llm_tools.remove_func(tool_name) - - def register_middleware(self, middleware: Middleware): - ''' - 注册一个中间件。所有的消息事件都会经过中间件处理,然后再进入 LLM 聊天模块。 - - 在 AstrBot 中,会对到来的消息事件首先检查指令,然后再检查中间件。触发指令后将不会进入 LLM 聊天模块,而中间件会。 - ''' - self.middlewares.append(middleware) - - def find_platform(self, platform_name: str) -> RegisteredPlatform: - for platform in self.platforms: - if platform_name == platform.platform_name: - return platform - - if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错 - raise ValueError("couldn't find the platform you specified") - - async def send_message(self, unified_msg_origin: str, message: CommandResult): - ''' - 发送消息。 - - `unified_msg_origin`: 统一消息来源 - `message`: 消息内容 - ''' - l = unified_msg_origin.split(":") - if len(l) != 3: - raise ValueError("Invalid unified_msg_origin") - platform_name, message_type, id = l - platform = self.find_platform(platform_name) - await platform.platform_instance.send_msg_new(MessageType(message_type), id, message) - - def get_current_llm_provider(self) -> Provider: - ''' - 获取当前的 LLM Provider。 - ''' - return self.message_handler.provider \ No newline at end of file diff --git a/util/plugin_dev/api/v1/__init__.py b/util/plugin_dev/api/v1/__init__.py deleted file mode 100644 index 390130e75..000000000 --- a/util/plugin_dev/api/v1/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .bot import * -from .config import * -from .llm import * -from .message import * -from .platform import * -from .register import * -from .types import * \ No newline at end of file diff --git a/util/plugin_dev/api/v1/bot.py b/util/plugin_dev/api/v1/bot.py deleted file mode 100644 index c334872b6..000000000 --- a/util/plugin_dev/api/v1/bot.py +++ /dev/null @@ -1,5 +0,0 @@ -from type.plugin import PluginMetadata, PluginType -from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins -from type.types import Context -from type.message_event import AstrMessageEvent -from type.command import CommandResult \ No newline at end of file diff --git a/util/plugin_dev/api/v1/llm.py b/util/plugin_dev/api/v1/llm.py deleted file mode 100644 index 11d2be559..000000000 --- a/util/plugin_dev/api/v1/llm.py +++ /dev/null @@ -1 +0,0 @@ -from model.provider.provider import Provider as LLMProvider \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py deleted file mode 100644 index 218498c4b..000000000 --- a/util/plugin_dev/api/v1/message.py +++ /dev/null @@ -1,4 +0,0 @@ -from type.message_event import * -from type.astrbot_message import * -from type.command import CommandResult -from astrbot.message.handler import MessageHandler \ No newline at end of file diff --git a/util/plugin_dev/api/v1/platform.py b/util/plugin_dev/api/v1/platform.py deleted file mode 100644 index 55574e0cf..000000000 --- a/util/plugin_dev/api/v1/platform.py +++ /dev/null @@ -1,12 +0,0 @@ -''' -消息平台。 - -Platform类是消息平台的抽象类,定义了消息平台的基本接口。 -消息平台的具体实现类需要继承Platform类,并实现其中的抽象方法。 -''' - -from model.platform import Platform - -from model.platform.qq_nakuru import QQNakuru -from model.platform.qq_official import QQOfficial -from model.platform.qq_aiocqhttp import AIOCQHTTP \ No newline at end of file diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py deleted file mode 100644 index da5881922..000000000 --- a/util/plugin_dev/api/v1/register.py +++ /dev/null @@ -1,68 +0,0 @@ -''' -允许开发者注册某一个类的实例到 LLM 或者 PLATFORM 中,方便其他插件调用。 - -必须分别实现 Platform 和 LLMProvider 中涉及的接口 -''' -from model.provider.provider import Provider as LLMProvider -from model.platform import Platform -from type.types import Context -from type.register import RegisteredPlatform, RegisteredLLM - -def register_platform(platform_name: str, context: Context, platform_instance: Platform = None) -> None: - ''' - 注册一个消息平台。 - - Args: - platform_name: 平台名称。 - platform_instance: 平台实例,可为空。 - context: 上下文对象。 - - Note: - 当插件类被加载时,AstrBot 会传给插件 context 对象。插件可以通过 context 对象注册指令、长任务等。 - ''' - - # check 是否已经注册 - for platform in context.platforms: - if platform.platform_name == platform_name: - raise ValueError(f"Platform {platform_name} has been registered.") - - context.platforms.append(RegisteredPlatform(platform_name, platform_instance)) - -def register_llm(llm_name: str, llm_instance: LLMProvider, context: Context) -> None: - ''' - 注册一个大语言模型。 - - Args: - llm_name: 大语言模型名称。 - llm_instance: 大语言模型实例。 - ''' - # check 是否已经注册 - for llm in context.llms: - if llm.llm_name == llm_name: - raise ValueError(f"LLMProvider {llm_name} has been registered.") - - context.llms.append(RegisteredLLM(llm_name, llm_instance)) - -def unregister_platform(platform_name: str, context: Context) -> None: - ''' - 注销一个消息平台。 - - Args: - platform_name: 平台名称。 - ''' - for i, platform in enumerate(context.platforms): - if platform.platform_name == platform_name: - context.platforms.pop(i) - return - -def unregister_llm(llm_name: str, context: Context) -> None: - ''' - 注销一个大语言模型。 - - Args: - llm_name: 大语言模型名称。 - ''' - for i, llm in enumerate(context.llms): - if llm.llm_name == llm_name: - context.llms.pop(i) - return \ No newline at end of file diff --git a/util/plugin_dev/api/v1/types.py b/util/plugin_dev/api/v1/types.py deleted file mode 100644 index 4665bad72..000000000 --- a/util/plugin_dev/api/v1/types.py +++ /dev/null @@ -1,7 +0,0 @@ -''' -插件类型、消息组件类型 -''' - -from type.plugin import PluginType -from type.middleware import Middleware -from nakuru.entities.components import Image, Plain, At, Node, BaseMessageComponent \ No newline at end of file diff --git a/util/t2i/__init__.py b/util/t2i/__init__.py deleted file mode 100644 index 047b584c3..000000000 --- a/util/t2i/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .renderer import TextToImageRenderer \ No newline at end of file diff --git a/util/t2i/context.py b/util/t2i/context.py deleted file mode 100644 index ac9f837a0..000000000 --- a/util/t2i/context.py +++ /dev/null @@ -1,11 +0,0 @@ -from util.t2i.strategies.base_strategy import RenderStrategy - -class RenderContext: - def __init__(self, strategy: RenderStrategy): - self._strategy = strategy - - def set_strategy(self, strategy: RenderStrategy): - self._strategy = strategy - - async def render(self, text: str, return_url: bool = False): - return await self._strategy.render(text, return_url) diff --git a/util/t2i/strategies/__init__.py b/util/t2i/strategies/__init__.py deleted file mode 100644 index 1fd9d92cc..000000000 --- a/util/t2i/strategies/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base_strategy import RenderStrategy -from .local_strategy import LocalRenderStrategy -from .network_strategy import NetworkRenderStrategy