refactor: im so tired :)
This commit is contained in:
+6
-1
@@ -2,6 +2,7 @@ __pycache__
|
||||
botpy.log
|
||||
.vscode
|
||||
data_v2.db
|
||||
data_v3.db
|
||||
configs/session
|
||||
configs/config.yaml
|
||||
**/.DS_Store
|
||||
@@ -11,4 +12,8 @@ data
|
||||
cookies.json
|
||||
logs/
|
||||
addons/plugins
|
||||
.coverage
|
||||
.coverage
|
||||
|
||||
|
||||
tests/astrbot_plugin_openai
|
||||
chroma
|
||||
+1
-13
@@ -1,16 +1,4 @@
|
||||
|
||||
from astrbot.core.plugin import Context
|
||||
from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, CommandResult
|
||||
from astrbot.core.provider import Provider, Personality
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
|
||||
from astrbot.core.utils.command_parser import CommandParser, CommandTokens
|
||||
from astrbot.core.utils.func_call import FuncCall
|
||||
from astrbot.core import html_renderer
|
||||
|
||||
from astrbot.core.plugin.config import *
|
||||
|
||||
command_parser = CommandParser()
|
||||
from astrbot.core import html_renderer
|
||||
@@ -0,0 +1,40 @@
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
|
||||
# event
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult, MessageChain, CommandResult, EventResultType
|
||||
)
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
# star register
|
||||
from astrbot.core.star.register import (
|
||||
register_command as command,
|
||||
register_command_group as command_group,
|
||||
register_event_message_type as event_message_type,
|
||||
register_regex as regex,
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
)
|
||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
||||
from astrbot.core.star.register import (
|
||||
register_star as register # 注册插件(Star)
|
||||
)
|
||||
from astrbot.core.star import Context, Star
|
||||
from astrbot.core.star.config import *
|
||||
|
||||
|
||||
# provider
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
|
||||
# platform
|
||||
from astrbot.core.platform import (
|
||||
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
)
|
||||
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
|
||||
from .message_components import *
|
||||
@@ -0,0 +1,5 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult, MessageChain, CommandResult, EventResultType
|
||||
)
|
||||
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
@@ -0,0 +1,10 @@
|
||||
from astrbot.core.star.register import (
|
||||
register_command as command,
|
||||
register_command_group as command_group,
|
||||
register_event_message_type as event_message_type,
|
||||
register_regex as regex,
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
)
|
||||
|
||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
||||
@@ -0,0 +1,5 @@
|
||||
from astrbot.core.platform import (
|
||||
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
)
|
||||
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
@@ -0,0 +1 @@
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
@@ -0,0 +1,6 @@
|
||||
from astrbot.core.star.register import (
|
||||
register_star as register # 注册插件(Star)
|
||||
)
|
||||
|
||||
from astrbot.core.star import Context, Star
|
||||
from astrbot.core.star.config import *
|
||||
@@ -1,2 +1,2 @@
|
||||
from .default import DEFAULT_CONFIG_VERSION_2, VERSION, DB_PATH
|
||||
from .astrbot_config import AstrBotConfig
|
||||
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
|
||||
from .astrbot_config import *
|
||||
@@ -1,295 +1,82 @@
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import logging
|
||||
from . import DEFAULT_CONFIG_VERSION_2
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Dict, Optional
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG
|
||||
from typing import List, Dict
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
time: int = 60
|
||||
count: int = 30
|
||||
class RateLimitStrategy(enum.Enum):
|
||||
STALL = "stall"
|
||||
DISCARD = "discard"
|
||||
|
||||
@dataclass
|
||||
class PlatformSettings:
|
||||
unique_session: bool = False
|
||||
rate_limit: RateLimit = field(default_factory=RateLimit)
|
||||
reply_prefix: str = ""
|
||||
forward_threshold: int = 200
|
||||
class AstrBotConfig(dict):
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
|
||||
|
||||
def __post_init__(self):
|
||||
self.rate_limit = RateLimit(**self.rate_limit)
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig():
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
enable: bool = False
|
||||
|
||||
@dataclass
|
||||
class QQOfficialPlatformConfig(PlatformConfig):
|
||||
appid: str = ""
|
||||
secret: str = ""
|
||||
enable_group_c2c: bool = True
|
||||
enable_guild_direct_message: bool = True
|
||||
|
||||
@dataclass
|
||||
class AiocqhttpPlatformConfig(PlatformConfig):
|
||||
ws_reverse_host: str = ""
|
||||
ws_reverse_port: int = 6199
|
||||
qq_id_whitelist: List[str] = field(default_factory=list)
|
||||
qq_group_id_whitelist: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class WechatPlatformConfig(PlatformConfig):
|
||||
wechat_id_whitelist: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
model: str = "gpt-4o"
|
||||
max_tokens: int = 6000
|
||||
temperature: float = 0.9
|
||||
top_p: float = 1
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationModelConfig:
|
||||
enable: bool = True
|
||||
model: str = "dall-e-3"
|
||||
size: str = "1024x1024"
|
||||
style: str = "vivid"
|
||||
quality: str = "standard"
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModel:
|
||||
enable: bool = False
|
||||
model: str = ""
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
id: str = ""
|
||||
name: str = "openai"
|
||||
enable: bool = True
|
||||
key: List[str] = field(default_factory=list)
|
||||
api_base: str = ""
|
||||
prompt_prefix: str = ""
|
||||
default_personality: str = ""
|
||||
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||
image_generation_model_config: Optional[ImageGenerationModelConfig] = field(default_factory=ImageGenerationModelConfig)
|
||||
embedding_model: Optional[EmbeddingModel] = field(default_factory=EmbeddingModel)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.model_config, dict):
|
||||
self.model_config = ModelConfig(**self.model_config)
|
||||
if isinstance(self.image_generation_model_config, dict):
|
||||
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) if self.image_generation_model_config else None
|
||||
if isinstance(self.embedding_model, dict):
|
||||
self.embedding_model = EmbeddingModel(**self.embedding_model) if self.embedding_model else None
|
||||
|
||||
@dataclass
|
||||
class LLMSettings:
|
||||
wake_prefix: str = ""
|
||||
web_search: bool = False
|
||||
identifier: bool = False
|
||||
|
||||
@dataclass
|
||||
class BaiduAIPConfig:
|
||||
enable: bool = False
|
||||
app_id: str = ""
|
||||
api_key: str = ""
|
||||
secret_key: str = ""
|
||||
|
||||
@dataclass
|
||||
class InternalKeywordsConfig:
|
||||
enable: bool = True
|
||||
extra_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ContentSafetyConfig:
|
||||
baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig)
|
||||
internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
self.baidu_aip = BaiduAIPConfig(**self.baidu_aip)
|
||||
self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords)
|
||||
|
||||
@dataclass
|
||||
class DashboardConfig:
|
||||
enable: bool = True
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ATRILongTermMemory:
|
||||
enable: bool = False
|
||||
summary_threshold_cnt: int = 5
|
||||
|
||||
@dataclass
|
||||
class ATRIActiveMessage:
|
||||
enable: bool = False
|
||||
|
||||
@dataclass
|
||||
class ProjectATRI:
|
||||
enable: bool = False
|
||||
long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory)
|
||||
active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage)
|
||||
persona: str = ""
|
||||
split_response: bool = True
|
||||
embedding_provider_id: str = ""
|
||||
summarize_provider_id: str = ""
|
||||
chat_provider_id: str = ""
|
||||
chat_base_model_path: str = ""
|
||||
chat_adapter_model_path: str = ""
|
||||
quantization_bit: int = 4
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.long_term_memory, dict):
|
||||
self.long_term_memory = ATRILongTermMemory(**self.long_term_memory)
|
||||
if isinstance(self.active_message, dict):
|
||||
self.active_message = ATRIActiveMessage(**self.active_message)
|
||||
|
||||
@dataclass
|
||||
class AstrBotConfig():
|
||||
config_version: int = 2
|
||||
platform_settings: PlatformSettings = field(default_factory=PlatformSettings)
|
||||
llm: List[LLMConfig] = field(default_factory=list)
|
||||
llm_settings: LLMSettings = field(default_factory=LLMSettings)
|
||||
content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig)
|
||||
t2i: bool = True
|
||||
admins_id: List[str] = field(default_factory=list)
|
||||
https_proxy: str = ""
|
||||
http_proxy: str = ""
|
||||
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
|
||||
platform: List[PlatformConfig] = field(default_factory=list)
|
||||
wake_prefix: List[str] = field(default_factory=list)
|
||||
log_level: str = "INFO"
|
||||
t2i_endpoint: str = ""
|
||||
pip_install_arg: str = ""
|
||||
plugin_repo_mirror: str = ""
|
||||
project_atri: ProjectATRI = field(default_factory=ProjectATRI)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.init_configs()
|
||||
|
||||
# compability
|
||||
if isinstance(self.wake_prefix, str):
|
||||
self.wake_prefix = [self.wake_prefix]
|
||||
|
||||
if len(self.wake_prefix) == 0:
|
||||
self.wake_prefix.append("/")
|
||||
|
||||
def load_from_dict(self, data: Dict):
|
||||
'''从字典中加载配置到对象。
|
||||
|
||||
@note: 适用于 version 2 配置文件。
|
||||
'''
|
||||
self.config_version=data.get("version", 2)
|
||||
self.platform=[]
|
||||
|
||||
left_platforms = ["qq_official", "aiocqhttp", "wechat"]
|
||||
for p in data.get("platform", []):
|
||||
if 'name' not in p:
|
||||
logger.warning("A platform config missing name, skipping.")
|
||||
continue
|
||||
if p["name"] == "qq_official":
|
||||
self.platform.append(QQOfficialPlatformConfig(**p))
|
||||
left_platforms.remove(p["name"])
|
||||
elif p["name"] == "aiocqhttp":
|
||||
self.platform.append(AiocqhttpPlatformConfig(**p))
|
||||
left_platforms.remove(p["name"])
|
||||
elif p["name"] == "wechat":
|
||||
self.platform.append(WechatPlatformConfig(**p))
|
||||
left_platforms.remove(p["name"])
|
||||
# 注入默认配置
|
||||
for p in left_platforms:
|
||||
if p == "qq_official":
|
||||
self.platform.append(QQOfficialPlatformConfig(id="default", name=p))
|
||||
elif p == "aiocqhttp":
|
||||
self.platform.append(AiocqhttpPlatformConfig(id="default", name=p))
|
||||
elif p == "wechat":
|
||||
self.platform.append(WechatPlatformConfig(id="default", name=p))
|
||||
|
||||
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
|
||||
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
|
||||
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
|
||||
self.content_safety=ContentSafetyConfig(**data.get("content_safety", {}))
|
||||
self.t2i=data.get("t2i", True)
|
||||
self.admins_id=data.get("admins_id", [])
|
||||
self.https_proxy=data.get("https_proxy", "")
|
||||
self.http_proxy=data.get("http_proxy", "")
|
||||
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
|
||||
self.wake_prefix=data.get("wake_prefix", ["/"])
|
||||
self.log_level=data.get("log_level", "INFO")
|
||||
self.t2i_endpoint=data.get("t2i_endpoint", "")
|
||||
self.pip_install_arg=data.get("pip_install_arg", "")
|
||||
self.plugin_repo_mirror=data.get("plugin_repo_mirror", "")
|
||||
self.project_atri=ProjectATRI(**data.get("project_atri", {}))
|
||||
|
||||
def flush_config(self, config: dict = None):
|
||||
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def save_config(self):
|
||||
'''将现存配置写入文件'''
|
||||
self.flush_config(self.to_dict())
|
||||
|
||||
def init_configs(self):
|
||||
'''初始化必需的配置项'''
|
||||
config = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not self.check_exist():
|
||||
self.flush_config()
|
||||
config = DEFAULT_CONFIG_VERSION_2
|
||||
else:
|
||||
config = self.get_all()
|
||||
|
||||
# 加载配置到对象
|
||||
self.load_from_dict(config)
|
||||
# 保存到文件
|
||||
# 这一步操作是为了保证配置文件中的字段的完整性。
|
||||
# 在版本变动新增配置项时,将对象中新增的配置项的默认值写入文件。
|
||||
self.save_config()
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
'''从文件系统中直接获取配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
if key in d:
|
||||
return d[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def get_all(self):
|
||||
'''从文件系统中获取所有配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
if not conf_str:
|
||||
return {}
|
||||
conf = json.loads(conf_str)
|
||||
return conf
|
||||
|
||||
def put(self, key, value):
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
d[key] = value
|
||||
'''不存在时载入默认配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
|
||||
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
|
||||
self.update(conf)
|
||||
if has_new:
|
||||
self.save_config()
|
||||
|
||||
self.update(conf)
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
'''检查配置完整性,如果有新的配置项则返回 True'''
|
||||
has_new = False
|
||||
for key, value in refer_conf.items():
|
||||
if key not in conf:
|
||||
logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}")
|
||||
conf[key] = value
|
||||
has_new = True
|
||||
else:
|
||||
if conf[key] == None:
|
||||
conf[key] = value
|
||||
has_new = True
|
||||
elif isinstance(value, dict):
|
||||
has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key)
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: Dict = None):
|
||||
'''将配置写入文件
|
||||
|
||||
如果传入 replace_config,则将配置替换为 replace_config
|
||||
'''
|
||||
if replace_config:
|
||||
self.update(replace_config)
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(self, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def __delattr__(self, key):
|
||||
try:
|
||||
del self[key]
|
||||
self.save_config()
|
||||
except KeyError:
|
||||
raise AttributeError(f"没有找到 Key: '{key}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def check_exist(self) -> bool:
|
||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
||||
+111
-127
@@ -1,177 +1,160 @@
|
||||
'''
|
||||
这里定义了一些默认配置文件,请不要修改这个文件。如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
'''
|
||||
|
||||
VERSION = '3.4.0'
|
||||
DB_PATH = 'data/data_v3.db'
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"config_version": 2,
|
||||
"platform_settings": {
|
||||
"unique_session": False,
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
"strategy": "stall" # stall, discard
|
||||
},
|
||||
"reply_prefix": "",
|
||||
"forward_threshold": 200,
|
||||
"id_whitelist": []
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"identifier": False,
|
||||
"default_personality": "",
|
||||
"prompt_prefix": ""
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {
|
||||
"enable": True,
|
||||
"extra_keywords": []
|
||||
},
|
||||
"baidu_aip": {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
},
|
||||
"admins_id": [],
|
||||
"t2i": False,
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "astrbot",
|
||||
"password": "77b90590a8945a7d36c963981a307dc9"
|
||||
},
|
||||
"platform": [],
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "",
|
||||
"project_atri": {
|
||||
"enable": False,
|
||||
"long_term_memory": {
|
||||
"enable": False,
|
||||
"summary_threshold_cnt": 5,
|
||||
"embedding_provider_id": "",
|
||||
"summarize_provider_id": ""
|
||||
},
|
||||
"active_message": {
|
||||
"enable": False
|
||||
},
|
||||
"vision": {
|
||||
"enable": False,
|
||||
"provider_id_or_ofa_model_path": "",
|
||||
"reply_meme_prob": 0.4,
|
||||
"reply_meme_similar_threshold": 0.7
|
||||
},
|
||||
"persona": "",
|
||||
"split_response": True,
|
||||
"chat_provider_id": "",
|
||||
"chat_base_model_path": "",
|
||||
"chat_adapter_model_path": "",
|
||||
"quantization_bit": 4
|
||||
}
|
||||
}
|
||||
|
||||
# LLM 提供商配置模板
|
||||
PROVIDER_CONFIG_TEMPLATE = {
|
||||
"openai": {
|
||||
"id": "default",
|
||||
"name": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
},
|
||||
"image_generation_model_config": {
|
||||
"enable": False,
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
},
|
||||
"embedding_model": {
|
||||
"enable": False,
|
||||
"model": "text-embedding-3-small"
|
||||
"model": "gpt-4o-mini",
|
||||
}
|
||||
},
|
||||
"ollama": {
|
||||
"id": "ollama_default",
|
||||
"name": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "llama3.1-8b",
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
}
|
||||
},
|
||||
"gemini": {
|
||||
"id": "gemini_default",
|
||||
"name": "gemini",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
}
|
||||
},
|
||||
"deepseek": {
|
||||
"id": "deepseek_default",
|
||||
"name": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "deepseek-chat",
|
||||
}
|
||||
},
|
||||
"zhipu": {
|
||||
"id": "zhipu_default",
|
||||
"name": "zhipu(glm)",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D
|
||||
DEFAULT_CONFIG_VERSION_2 = {
|
||||
"config_version": 2,
|
||||
"platform": [
|
||||
{
|
||||
"id": "default",
|
||||
"name": "qq_official",
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"enable_group_c2c": True,
|
||||
"enable_guild_direct_message": True,
|
||||
},
|
||||
{
|
||||
"id": "default",
|
||||
"name": "aiocqhttp",
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 6199,
|
||||
"qq_id_whitelist": [],
|
||||
"qq_group_id_whitelist": []
|
||||
},
|
||||
{
|
||||
"id": "default",
|
||||
"name": "wechat",
|
||||
"enable": False,
|
||||
"wechat_id_whitelist": []
|
||||
}
|
||||
],
|
||||
"platform_settings": {
|
||||
"unique_session": False,
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
},
|
||||
"reply_prefix": "",
|
||||
"forward_threshold": 200, # 转发消息的阈值
|
||||
},
|
||||
"llm": [
|
||||
PROVIDER_CONFIG_TEMPLATE["openai"]
|
||||
],
|
||||
"llm_settings": {
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"identifier": False,
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {
|
||||
"enable": True,
|
||||
"extra_keywords": [],
|
||||
}
|
||||
},
|
||||
"wake_prefix": ["/"],
|
||||
"t2i": True,
|
||||
"admins_id": [],
|
||||
"https_proxy": "",
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "astrbot",
|
||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||
},
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "default",
|
||||
"project_atri": {
|
||||
# 平台适配器配置模板
|
||||
ADAPTER_CONFIG_TEMPLATE = {
|
||||
"qq_official": {
|
||||
"id": "default",
|
||||
"name": "qq_official",
|
||||
"enable": False,
|
||||
"long_term_memory": {
|
||||
"enable": False,
|
||||
"summary_threshold_cnt": 6,
|
||||
},
|
||||
"active_message": {
|
||||
"enable": False,
|
||||
},
|
||||
"vision": {
|
||||
"enable": False,
|
||||
"provider_id_or_ofa_model_path": "",
|
||||
},
|
||||
"persona": "",
|
||||
"split_response": True,
|
||||
"embedding_provider_id": "",
|
||||
"summarize_provider_id": "",
|
||||
"chat_provider_id": "",
|
||||
"chat_base_model_path": "",
|
||||
"chat_adapter_model_path": "",
|
||||
"quantization_bit": 4
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"enable_group_c2c": True,
|
||||
"enable_guild_direct_message": True,
|
||||
},
|
||||
"aiocqhtp": {
|
||||
"id": "default",
|
||||
"name": "aiocqhttp",
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 6199
|
||||
},
|
||||
"wechat": {
|
||||
"id": "default",
|
||||
"name": "vchat",
|
||||
"enable": False
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +166,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"items": {
|
||||
"id": {"description": "ID", "type": "string", "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。"},
|
||||
"name": {"description": "适配器类型", "type": "string", "hint": "当前版本下,内置支持 `qq_official`(QQ 官方机器人), `aiocqhttp`(Onebot 适用) 适配器类型。", "options": ["qq_official", "aiocqhttp", "wechat"], "readonly": True},
|
||||
"name": {"description": "适配器类型", "type": "string", "invisible": True},
|
||||
"enable": {"description": "启用", "type": "bool", "hint": "是否启用该适配器。未启用的适配器对应的消息平台将不会接收到消息。"},
|
||||
"appid": {"description": "appid", "type": "string", "hint": "必填项。QQ 官方机器人平台的 appid。如何获取请参考文档。"},
|
||||
"secret": {"description": "secret", "type": "string", "hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。"},
|
||||
@@ -208,18 +191,20 @@ CONFIG_METADATA_2 = {
|
||||
"items": {
|
||||
"time": {"description": "消息速率限制时间", "type": "int"},
|
||||
"count": {"description": "消息速率限制计数", "type": "int"},
|
||||
"strategy": {"description": "速率限制策略", "type": "string", "options": ["stall", "discard"], "hint": "当消息速率超过限制时的处理策略。stall 为等待,discard 为丢弃。"}
|
||||
}
|
||||
},
|
||||
"reply_prefix": {"description": "回复前缀", "type": "string", "hint": "机器人回复消息时带有的前缀。"},
|
||||
"forward_threshold": {"description": "转发消息的字数阈值", "type": "int", "hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。"},
|
||||
"id_whitelist": {"description": "ID 白名单", "type": "list", "items": {"type": "int"}, "hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的 ID。"},
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"provider": {
|
||||
"description": "大语言模型配置",
|
||||
"type": "list",
|
||||
"items": {
|
||||
"id": {"description": "ID", "type": "string", "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。"},
|
||||
"name": {"description": "模型提供商类型", "type": "string", "hint": "如需变更模型提供商,请点击上面的 + 新建一个。如果没有找到你想要接入的提供商,可以前往你的提供商的官网查看是否兼容 OpenAI API,如兼容,可以选择 `openai`。大多数提供商都是兼容的。", "options": list(PROVIDER_CONFIG_TEMPLATE.keys()), "obvious_hint": True, "readonly": True},
|
||||
"type": {"description": "模型提供商类型", "type": "string", "invisible": True},
|
||||
"enable": {"description": "启用", "type": "bool", "hint": "是否启用该模型。未启用的模型将不会被使用。"},
|
||||
"key": {"description": "API Key", "type": "list", "items": {"type": "string"}, "hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。"},
|
||||
"api_base": {"description": "API Base URL", "type": "string", "hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。"},
|
||||
@@ -256,7 +241,7 @@ CONFIG_METADATA_2 = {
|
||||
}
|
||||
}
|
||||
},
|
||||
"llm_settings": {
|
||||
"provider_settings": {
|
||||
"description": "大语言模型设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
@@ -292,7 +277,6 @@ CONFIG_METADATA_2 = {
|
||||
"wake_prefix": {"description": "机器人唤醒前缀", "type": "list", "items": {"type": "string"}, "hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。"},
|
||||
"t2i": {"description": "文本转图像", "type": "bool", "hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。"},
|
||||
"admins_id": {"description": "管理员 ID", "type": "list", "items": {"type": "int"}, "hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。"},
|
||||
"https_proxy": {"description": "HTTPS 代理", "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`"},
|
||||
"http_proxy": {"description": "HTTP 代理", "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`"},
|
||||
"dashboard": {
|
||||
"description": "管理面板配置",
|
||||
@@ -318,6 +302,8 @@ CONFIG_METADATA_2 = {
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"summary_threshold_cnt": {"description": "摘要阈值", "type": "int", "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。"},
|
||||
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
|
||||
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
|
||||
}
|
||||
},
|
||||
"active_message": {
|
||||
@@ -335,10 +321,8 @@ CONFIG_METADATA_2 = {
|
||||
"provider_id_or_ofa_model_path": {"description": "提供商 ID 或 OFA 模型路径", "type": "string", "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。"},
|
||||
}
|
||||
},
|
||||
"split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的事件间隔。默认启用。"},
|
||||
"split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。"},
|
||||
"persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True},
|
||||
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
|
||||
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
|
||||
"chat_provider_id": {"description": "Chat provider ID", "type": "string", "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
|
||||
"chat_base_model_path": {"description": "用于聊天的基座模型路径", "type": "string", "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。", "obvious_hint": True},
|
||||
"chat_adapter_model_path": {"description": "用于聊天的 Lora 模型路径", "type": "string", "hint": "Lora 模型路径。", "obvious_hint": True},
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import asyncio, time, threading
|
||||
import asyncio, time, threading, os
|
||||
from .event_bus import EventBus
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.message.message_event_handler import MessageEventHandler
|
||||
from astrbot.core.plugin import PluginManager
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
@@ -15,21 +18,46 @@ class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
self.astrbot_config = AstrBotConfig()
|
||||
self.db = db
|
||||
|
||||
if self.astrbot_config['http_proxy']:
|
||||
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("AstrBot v"+ VERSION)
|
||||
logger.setLevel(self.astrbot_config.log_level)
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
self.event_queue = Queue()
|
||||
self.event_queue.closed = False
|
||||
self.plugin_manager = PluginManager(self.astrbot_config, self.event_queue, db)
|
||||
self.message_event_handler = MessageEventHandler(self.astrbot_config, self.plugin_manager)
|
||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config.plugin_repo_mirror)
|
||||
self.event_bus = EventBus(self.event_queue, self.message_event_handler)
|
||||
self.stop_flag = False
|
||||
self.start_time = int(time.time())
|
||||
|
||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
|
||||
self.star_context.platform_manager = self.platform_manager
|
||||
self.star_context.provider_manager = self.provider_manager
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
self.plugin_manager.reload()
|
||||
'''扫描、注册插件、实例化插件类'''
|
||||
|
||||
await self.provider_manager.initialize()
|
||||
'''根据配置实例化各个 Provider'''
|
||||
|
||||
await self.platform_manager.initialize()
|
||||
'''根据配置实例化各个平台适配器'''
|
||||
|
||||
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
|
||||
await self.pipeline_scheduler.initialize()
|
||||
'''初始化消息事件流水线调度器'''
|
||||
|
||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror'])
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
self.start_time = int(time.time())
|
||||
self.curr_tasks: List[asyncio.Task] = []
|
||||
|
||||
def _load(self):
|
||||
self.plugin_manager.reload()
|
||||
|
||||
platform_tasks = self.load_platform()
|
||||
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
|
||||
@@ -41,16 +69,26 @@ class AstrBotCoreLifecycle:
|
||||
self._load()
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
def stop(self):
|
||||
self.stop_flag = True
|
||||
|
||||
async def stop(self):
|
||||
self.event_queue.closed = True
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
|
||||
for task in self.curr_tasks:
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||
|
||||
def restart(self):
|
||||
self.event_queue.closed = True
|
||||
threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start()
|
||||
|
||||
def load_platform(self) -> List[asyncio.Task]:
|
||||
tasks = []
|
||||
platform_insts = self.plugin_manager.get_platform_insts()
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name))
|
||||
return tasks
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.core.db.po import Stats, LLMHistory
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
@@ -39,12 +39,12 @@ class BaseDatabase(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_llm_history(self, session_id: str, content: str):
|
||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||
'''更新 LLM 历史记录。当不存在 session_id 时插入'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_llm_history(self, session_id: str = None) -> List[LLMHistory]:
|
||||
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]:
|
||||
'''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有'''
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -62,3 +62,18 @@ class BaseDatabase(abc.ABC):
|
||||
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
||||
'''获取基础统计数据(合并)'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def insert_atri_vision_data(self, vision_data: ATRIVision):
|
||||
'''插入 ATRI 视觉数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data(self) -> List[ATRIVision]:
|
||||
'''获取 ATRI 视觉数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
||||
'''通过 url 或 path 获取 ATRI 视觉数据'''
|
||||
raise NotImplementedError
|
||||
+14
-1
@@ -37,5 +37,18 @@ class Stats():
|
||||
|
||||
@dataclass
|
||||
class LLMHistory():
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
content: str
|
||||
|
||||
@dataclass
|
||||
class ATRIVision():
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
is_meme: bool
|
||||
keywords: List[str]
|
||||
platform_name: str
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
+70
-31
@@ -6,7 +6,8 @@ from astrbot.core.db.po import (
|
||||
Command,
|
||||
Provider,
|
||||
Stats,
|
||||
LLMHistory
|
||||
LLMHistory,
|
||||
ATRIVision
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
@@ -75,28 +76,39 @@ class SQLiteDatabase(BaseDatabase):
|
||||
''', (k, v, int(time.time()))
|
||||
)
|
||||
|
||||
def update_llm_history(self, session_id: str, content: str):
|
||||
res = self.get_llm_history(session_id)
|
||||
def update_llm_history(self, session_id: str, content: str, provider_type: str):
|
||||
res = self.get_llm_history(session_id, provider_type)
|
||||
if res:
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE llm_history SET content = ? WHERE session_id = ?
|
||||
''', (content, session_id)
|
||||
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
|
||||
''', (content, session_id, provider_type)
|
||||
)
|
||||
else:
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO llm_history(session_id, content) VALUES (?, ?)
|
||||
''', (session_id, content)
|
||||
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
|
||||
''', (provider_type, session_id, content)
|
||||
)
|
||||
|
||||
def get_llm_history(self, session_id: str = None) -> Tuple:
|
||||
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
where_clause = "" if session_id is None else f"WHERE session_id = '{session_id}'"
|
||||
|
||||
where_clause = ""
|
||||
if session_id or provider_type:
|
||||
where_clause += " WHERE "
|
||||
has = False
|
||||
if session_id:
|
||||
where_clause += f"session_id = '{session_id}'"
|
||||
has = True
|
||||
if provider_type:
|
||||
if has:
|
||||
where_clause += " AND "
|
||||
where_clause += f"provider_type = '{provider_type}'"
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM llm_history
|
||||
@@ -186,26 +198,53 @@ class SQLiteDatabase(BaseDatabase):
|
||||
for row in c.fetchall():
|
||||
platform.append(Platform(*row))
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT name, SUM(count), timestamp FROM command
|
||||
# ''' + where_clause + " GROUP BY name"
|
||||
# )
|
||||
|
||||
# command = []
|
||||
# for row in c.fetchall():
|
||||
# command.append(Command(*row))
|
||||
|
||||
# c.execute(
|
||||
# '''
|
||||
# SELECT name, SUM(count), timestamp FROM llm
|
||||
# ''' + where_clause + " GROUP BY name"
|
||||
# )
|
||||
|
||||
# llm = []
|
||||
# for row in c.fetchall():
|
||||
# llm.append(Provider(*row))
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
ts = int(time.time())
|
||||
keywords = ",".join(vision.keywords)
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts)
|
||||
)
|
||||
|
||||
def get_atri_vision_data(self) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM atri_vision
|
||||
'''
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
visions = []
|
||||
for row in res:
|
||||
visions.append(ATRIVision(*row))
|
||||
c.close()
|
||||
return visions
|
||||
|
||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
|
||||
''', (url_or_path, id)
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
if res:
|
||||
return ATRIVision(*res)
|
||||
return None
|
||||
@@ -19,6 +19,20 @@ CREATE TABLE IF NOT EXISTS command(
|
||||
timestamp INTEGER
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS llm_history(
|
||||
provider_type VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
content TEXT
|
||||
);
|
||||
|
||||
-- ATRI
|
||||
CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
id TEXT,
|
||||
url_or_path TEXT,
|
||||
caption TEXT,
|
||||
is_meme BOOLEAN,
|
||||
keywords TEXT,
|
||||
platform_name VARCHAR(32),
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
@@ -2,22 +2,22 @@ import asyncio
|
||||
from asyncio import Queue
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from astrbot.core.message.message_event_handler import MessageEventHandler
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
from astrbot.core import logger
|
||||
from .platform import AstrMessageEvent
|
||||
from astrbot.core.message.components import Image, Plain
|
||||
|
||||
class EventBus:
|
||||
def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler):
|
||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||
self.event_queue = event_queue
|
||||
self.message_event_handler = message_event_handler
|
||||
self.pipeline_scheduler = pipeline_scheduler
|
||||
|
||||
async def dispatch(self):
|
||||
logger.info("事件总线已打开。")
|
||||
while True:
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
self._print_event(event)
|
||||
asyncio.create_task(self.message_event_handler.handle(event))
|
||||
asyncio.create_task(self.pipeline_scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent):
|
||||
if event.get_sender_name():
|
||||
|
||||
@@ -271,6 +271,7 @@ class Image(BaseMessageComponent):
|
||||
c: T.Optional[int] = 2
|
||||
# 额外
|
||||
path: T.Optional[str] = ""
|
||||
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
# for k in _.keys():
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
import asyncio, re, time
|
||||
import inspect
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .message_event_result import MessageEventResult, CommandResult, MessageChain
|
||||
from astrbot.core.plugin import PluginManager, Context, CommandMetadata
|
||||
from .components import *
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import html_renderer
|
||||
|
||||
class CommandTokens():
|
||||
def __init__(self) -> None:
|
||||
self.tokens = []
|
||||
self.len = 0
|
||||
|
||||
def get(self, idx: int):
|
||||
if idx >= self.len:
|
||||
return None
|
||||
return self.tokens[idx].strip()
|
||||
|
||||
class CommandParser():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse(self, message: str):
|
||||
cmd_tokens = CommandTokens()
|
||||
cmd_tokens.tokens = message.split(" ")
|
||||
cmd_tokens.len = len(cmd_tokens.tokens)
|
||||
return cmd_tokens
|
||||
|
||||
def regex_match(self, message: str, command: str) -> bool:
|
||||
return re.search(command, message, re.MULTILINE) is not None
|
||||
|
||||
|
||||
class MessageEventHandler():
|
||||
'''
|
||||
处理消息事件。
|
||||
'''
|
||||
def __init__(self, config: AstrBotConfig, plugin_manager: PluginManager):
|
||||
self.config = config
|
||||
self.plugin_manager = plugin_manager
|
||||
self.command_parser = CommandParser()
|
||||
|
||||
async def handle(self, event: AstrMessageEvent):
|
||||
'''
|
||||
处理消息事件。
|
||||
'''
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.config.admins_id:
|
||||
if event.get_sender_id() == admin_id:
|
||||
event.role = "admin"
|
||||
break
|
||||
|
||||
# 检查 wake
|
||||
wake_prefixes = self.config.wake_prefix
|
||||
messages = event.get_messages()
|
||||
is_wake = False
|
||||
for wake_prefix in wake_prefixes:
|
||||
if event.message_str.startswith(wake_prefix):
|
||||
is_wake = True
|
||||
break
|
||||
if not is_wake:
|
||||
# 检查是否有 at 消息
|
||||
for message in messages:
|
||||
if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"):
|
||||
is_wake = True
|
||||
wake_prefix = ""
|
||||
break
|
||||
# 检查是否是私聊
|
||||
if event.is_private_chat():
|
||||
is_wake = True
|
||||
wake_prefix = ""
|
||||
event.is_wake = is_wake
|
||||
|
||||
# 处理事件监听器(在指令扫描之前)
|
||||
listeners = self.plugin_manager.context.registered_listeners
|
||||
listeners_handler = self.plugin_manager.context.listeners_handler
|
||||
for name in listeners:
|
||||
if listeners_handler[name].after_commands:
|
||||
continue
|
||||
ret = await listeners_handler[name].handler(event)
|
||||
if ret:
|
||||
event.set_result(ret)
|
||||
if event.get_result():
|
||||
return await self.post_handle(event)
|
||||
|
||||
# 处理指令,指令带有指定过的前缀
|
||||
commands = self.plugin_manager.context.registered_commands
|
||||
commands_handler = self.plugin_manager.context.commands_handler
|
||||
|
||||
# 扫描指令
|
||||
for command in commands:
|
||||
command = command[1]
|
||||
trig = False
|
||||
pre_ = ""
|
||||
if not commands_handler[command].ignore_prefix:
|
||||
pre_ = wake_prefix
|
||||
|
||||
if commands_handler[command].use_regex:
|
||||
trig = self.command_parser.regex_match(event.message_str, pre_ + command)
|
||||
else:
|
||||
trig = event.message_str.startswith(pre_ + command)
|
||||
if trig:
|
||||
ret = await self.execute_handler(command, commands_handler[command], event)
|
||||
if ret:
|
||||
event.set_result(ret)
|
||||
if event.get_result():
|
||||
return await self.post_handle(event)
|
||||
|
||||
# 处理事件监听器(在指令扫描之后)
|
||||
for name in listeners:
|
||||
if not listeners_handler[name].after_commands:
|
||||
continue
|
||||
ret = await listeners_handler[name].handler(event)
|
||||
if ret:
|
||||
event.set_result(ret)
|
||||
if event.get_result():
|
||||
return await self.post_handle(event)
|
||||
|
||||
async def post_handle(self, event: AstrMessageEvent):
|
||||
result = event.get_result()
|
||||
if result.callback:
|
||||
await result.callback(event)
|
||||
|
||||
# prefix
|
||||
if self.config.platform_settings.reply_prefix:
|
||||
result.chain.insert(0, Plain(self.config.platform_settings.reply_prefix))
|
||||
|
||||
# t2i
|
||||
if (result.use_t2i_ is None and self.config.t2i) or result.use_t2i_:
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain_str += "\n\n" + comp.text
|
||||
else:
|
||||
break
|
||||
if plain_str and len(plain_str) > 150:
|
||||
render_start = time.time()
|
||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
||||
if time.time() - render_start > 3:
|
||||
logger.warning(f"图片转文本耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
|
||||
if url:
|
||||
result.chain = [Image.fromURL(url)]
|
||||
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
await event.send(result)
|
||||
|
||||
async def execute_handler(self,
|
||||
command: str,
|
||||
command_metadata: CommandMetadata,
|
||||
message_event: AstrMessageEvent):
|
||||
logger.info(f"触发 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令。")
|
||||
handler = command_metadata.handler
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
command_result = await handler(message_event)
|
||||
else:
|
||||
command_result = handler(message_event)
|
||||
|
||||
if command_result is not None:
|
||||
message_event.set_result(command_result)
|
||||
except TypeError as e:
|
||||
# 兼容旧版本插件
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
command_result = await handler(message_event, self.plugin_manager.context)
|
||||
else:
|
||||
command_result = handler(message_event, self.plugin_manager.context)
|
||||
|
||||
if command_result is not None:
|
||||
message_event.set_result(command_result)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}"
|
||||
message_event.set_result(MessageEventResult().message(text))
|
||||
@@ -1,43 +1,68 @@
|
||||
from typing import List, Union, Optional
|
||||
import enum, logging
|
||||
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from astrbot.core.message.components import *
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@dataclass
|
||||
class MessageChain():
|
||||
'''MessageChain 描述了一整条消息中带有的所有组件。
|
||||
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
|
||||
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
'''
|
||||
|
||||
chain: List[BaseMessageComponent] = field(default_factory=list)
|
||||
use_t2i_: Optional[bool] = None # None 为跟随用户设置
|
||||
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
def message(self, message: str):
|
||||
'''
|
||||
快捷回复消息。
|
||||
'''添加一条文本消息到消息链 `chain` 中。
|
||||
|
||||
CommandResult().message("Hello, world!")
|
||||
Example:
|
||||
|
||||
CommandResult().message("Hello ").message("world!")
|
||||
# 输出 Hello world!
|
||||
|
||||
'''
|
||||
self.chain.append(Plain(message))
|
||||
return self
|
||||
|
||||
@deprecated("请使用 message 方法代替。")
|
||||
def error(self, message: str):
|
||||
'''
|
||||
快捷回复消息。
|
||||
'''添加一条错误消息到消息链 `chain` 中
|
||||
|
||||
CommandResult().error("Hello, world!")
|
||||
Example:
|
||||
|
||||
CommandResult().error("解析失败")
|
||||
|
||||
'''
|
||||
self.chain.append(Plain(message))
|
||||
return self
|
||||
|
||||
def url_image(self, url: str):
|
||||
'''
|
||||
快捷回复图片(网络url的格式)。
|
||||
'''添加一条图片消息(https 链接)到消息链 `chain` 中。
|
||||
|
||||
CommandResult().image("https://example.com/image.jpg")
|
||||
Note:
|
||||
如果需要发送本地图片,请使用 `file_image` 方法。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().image("https://example.com/image.jpg")
|
||||
|
||||
'''
|
||||
self.chain.append(Image.fromURL(url))
|
||||
return self
|
||||
|
||||
def file_image(self, path: str):
|
||||
'''
|
||||
快捷回复图片(本地文件路径的格式)。
|
||||
'''添加一条图片消息(本地文件路径)到消息链 `chain` 中。
|
||||
|
||||
Note:
|
||||
如果需要发送网络图片,请使用 `url_image` 方法。
|
||||
|
||||
CommandResult().image("image.jpg")
|
||||
'''
|
||||
@@ -45,24 +70,65 @@ class MessageChain():
|
||||
return self
|
||||
|
||||
def use_t2i(self, use_t2i: bool):
|
||||
'''
|
||||
设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
|
||||
'''设置是否使用文本转图片服务。
|
||||
|
||||
Args:
|
||||
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
'''
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def is_split(self, is_split: bool):
|
||||
'''
|
||||
设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
|
||||
具体的效果以各适配器实现为准。
|
||||
Note:
|
||||
具体的效果以各适配器实现为准。
|
||||
|
||||
'''
|
||||
self.is_split_ = is_split
|
||||
return self
|
||||
|
||||
class EventResultType(enum.Enum):
|
||||
'''用于描述事件处理的结果类型。
|
||||
|
||||
Attributes:
|
||||
CONTINUE: 事件将会继续传播
|
||||
STOP: 事件将会终止传播
|
||||
'''
|
||||
CONTINUE = enum.auto()
|
||||
STOP = enum.auto()
|
||||
|
||||
@dataclass
|
||||
class MessageEventResult(MessageChain):
|
||||
is_command_call: Optional[bool] = False
|
||||
callback: Optional[callable] = None
|
||||
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
|
||||
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
|
||||
|
||||
Attributes:
|
||||
`chain` (list): 用于顺序存储各个组件。
|
||||
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
|
||||
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
|
||||
`result_type` (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
|
||||
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
|
||||
|
||||
def stop_event(self) -> 'MessageEventResult':
|
||||
'''终止事件传播。
|
||||
'''
|
||||
self.result_type = EventResultType.STOP
|
||||
return self
|
||||
|
||||
def continue_event(self) -> 'MessageEventResult':
|
||||
'''继续事件传播。
|
||||
'''
|
||||
self.result_type = EventResultType.CONTINUE
|
||||
return self
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
'''
|
||||
是否终止事件传播。
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -0,0 +1,18 @@
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
|
||||
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"RateLimitCheckStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage" # 发送消息
|
||||
]
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .respond.stage import RespondStage
|
||||
@@ -0,0 +1,31 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core import logger
|
||||
from .strategies.strategy import StrategySelector
|
||||
|
||||
@register_stage
|
||||
class ContentSafetyCheckStage(Stage):
|
||||
'''检查内容安全
|
||||
|
||||
当前只会检查文本的。
|
||||
'''
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
config = ctx.astrbot_config['content_safety']
|
||||
self.strategy_selector = StrategySelector(config)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''检查内容安全'''
|
||||
ok, info = self.strategy_selector.check(event.get_message_str())
|
||||
if not ok:
|
||||
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
|
||||
event.stop_event()
|
||||
logger.info(f"内容安全检查不通过,原因:{info}")
|
||||
return
|
||||
event.continue_event()
|
||||
@@ -0,0 +1,8 @@
|
||||
import abc
|
||||
from typing import Tuple
|
||||
|
||||
class ContentSafetyStrategy(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def check(self, content: str) -> Tuple[bool, str]:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,31 @@
|
||||
'''
|
||||
使用此功能应该先 pip install baidu-aip
|
||||
'''
|
||||
from . import ContentSafetyStrategy
|
||||
from aip import AipContentCensor
|
||||
from astrbot.core import logger
|
||||
|
||||
class BaiduAipStrategy(ContentSafetyStrategy):
|
||||
def __init__(self, appid: str, ak: str, sk: str) -> None:
|
||||
self.app_id = appid
|
||||
self.api_key = ak
|
||||
self.secret_key = sk
|
||||
self.client = AipContentCensor(self.app_id,
|
||||
self.api_key,
|
||||
self.secret_key)
|
||||
|
||||
def check(self, content: str):
|
||||
res = self.client.textCensorUserDefined(content)
|
||||
if 'conclusionType' not in res:
|
||||
return False, ""
|
||||
if res['conclusionType'] == 1:
|
||||
return True, ""
|
||||
else:
|
||||
if 'data' not in res:
|
||||
return False, ""
|
||||
count = len(res['data'])
|
||||
info = f"百度审核服务发现 {count} 处违规:\n"
|
||||
for i in res['data']:
|
||||
info += f"{i['msg']};\n"
|
||||
info += "\n判断结果:"+res['conclusion']
|
||||
return False, info
|
||||
@@ -0,0 +1,21 @@
|
||||
import re, os, json, base64
|
||||
from . import ContentSafetyStrategy
|
||||
from astrbot.core import logger
|
||||
|
||||
class KeywordsStrategy(ContentSafetyStrategy):
|
||||
def __init__(self, extra_keywords: list) -> None:
|
||||
self.keywords = []
|
||||
if extra_keywords is None:
|
||||
extra_keywords = []
|
||||
self.keywords.extend(extra_keywords)
|
||||
keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words')
|
||||
# internal keywords
|
||||
if os.path.exists(keywords_path):
|
||||
with open(keywords_path, "r", encoding="utf-8") as f:
|
||||
self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'])
|
||||
|
||||
def check(self, content: str) -> bool:
|
||||
for keyword in self.keywords:
|
||||
if re.search(keyword, content):
|
||||
return False, f"内容安全检查不通过,匹配到敏感词。"
|
||||
return True, ""
|
||||
@@ -0,0 +1,27 @@
|
||||
from . import ContentSafetyStrategy
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class StrategySelector():
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.enabled_strategies: List[ContentSafetyStrategy] = []
|
||||
if config['internal_keywords']['enable']:
|
||||
from .keywords import KeywordsStrategy
|
||||
self.enabled_strategies.append(KeywordsStrategy(
|
||||
config['internal_keywords']['extra_keywords']))
|
||||
if config['baidu_aip']['enable']:
|
||||
try:
|
||||
from .baidu_aip import BaiduAipStrategy
|
||||
except ImportError:
|
||||
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
|
||||
self.enabled_strategies.append(BaiduAipStrategy(config['baidu_aip']['app_id'],
|
||||
config['baidu_aip']['api_key'],
|
||||
config['baidu_aip']['secret_key']
|
||||
))
|
||||
|
||||
def check(self, content: str) -> Tuple[bool, str]:
|
||||
for strategy in self.enabled_strategies:
|
||||
ok, info = strategy.check(content)
|
||||
if not ok:
|
||||
return False, info
|
||||
return True, ""
|
||||
@@ -0,0 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star import PluginManager
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
astrbot_config: AstrBotConfig
|
||||
plugin_manager: PluginManager
|
||||
@@ -0,0 +1,71 @@
|
||||
import asyncio, traceback, json
|
||||
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
|
||||
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if self.prompt_prefix:
|
||||
event.message_str = self.prompt_prefix + event.message_str
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
event.message_str = user_info + event.message_str
|
||||
|
||||
image_urls = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
image_urls.append(image_url)
|
||||
|
||||
tools = self.ctx.plugin_manager.context.get_llm_tools()
|
||||
|
||||
try:
|
||||
llm_response = await self.curr_provider.text_chat(
|
||||
prompt=event.message_str,
|
||||
session_id=event.session_id,
|
||||
image_urls=image_urls,
|
||||
tools=tools
|
||||
)
|
||||
await Metric.upload(llm_tick=1, model_name=self.curr_provider.get_model(), provider_type=self.curr_provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text))
|
||||
elif llm_response.role == 'tool':
|
||||
# function calling
|
||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
||||
func_tool = tools.get_func(func_tool_name)
|
||||
logger.debug(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
try:
|
||||
ret = await func_tool(event=event, *func_tool_args)
|
||||
|
||||
if ret:
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.stop_event()
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
yield
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
|
||||
return
|
||||
@@ -0,0 +1,48 @@
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult, EventResultType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
class StarRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
|
||||
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
|
||||
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params")
|
||||
if not handlers_parsed_params:
|
||||
handlers_parsed_params = {}
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_str not in star_map:
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
|
||||
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
if hasattr(handler.handler, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ret = await handler.handler(event, **params)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ret = await handler.handler(event, self.ctx.plugin_manager.context, **params)
|
||||
else:
|
||||
ret = await handler.handler(star_cls_obj, event, **params)
|
||||
if ret:
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.stop_event()
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
yield
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except Exception as e:
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import List, Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult, EventResultType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.message.components import *
|
||||
from astrbot.core import html_renderer
|
||||
|
||||
@register_stage
|
||||
class ProcessStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
self.llm_request_sub_stage = LLMRequestSubStage()
|
||||
await self.llm_request_sub_stage.initialize(ctx)
|
||||
|
||||
self.star_request_sub_stage = StarRequestSubStage()
|
||||
await self.star_request_sub_stage.initialize(ctx)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''处理事件
|
||||
'''
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
|
||||
if not activated_handlers:
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
else:
|
||||
async for _ in self.star_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||
|
||||
|
||||
@register_stage
|
||||
class RateLimitStage(Stage):
|
||||
"""
|
||||
检查是否需要限制消息发送的限流器。
|
||||
|
||||
使用 Fixed Window 算法。
|
||||
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 存储每个会话的请求时间队列
|
||||
self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque)
|
||||
# 为每个会话设置一个锁,避免并发冲突
|
||||
self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
# 限流参数
|
||||
self.rate_limit_count: int = 0
|
||||
self.rate_limit_time: timedelta = timedelta(0)
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""
|
||||
初始化限流器,根据配置设置限流参数。
|
||||
"""
|
||||
self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count']
|
||||
self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time'])
|
||||
self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""
|
||||
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 当前消息事件。
|
||||
ctx (PipelineContext): 流水线上下文。
|
||||
|
||||
Returns:
|
||||
MessageEventResult: 继续或停止事件处理的结果。
|
||||
"""
|
||||
session_id = event.session_id
|
||||
now = datetime.now()
|
||||
|
||||
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
|
||||
timestamps = self.event_timestamps[session_id]
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now)
|
||||
|
||||
if len(timestamps) >= self.rate_limit_count:
|
||||
# 达到限流阈值,计算下一个窗口的时间
|
||||
next_window_time = timestamps[0] + self.rate_limit_time
|
||||
stall_duration = (next_window_time - now).total_seconds()
|
||||
|
||||
match self.rl_strategy:
|
||||
case RateLimitStrategy.STALL:
|
||||
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
|
||||
await asyncio.sleep(stall_duration)
|
||||
case RateLimitStrategy.DISCARD:
|
||||
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
|
||||
return event.stop_event()
|
||||
|
||||
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
|
||||
|
||||
timestamps.append(now)
|
||||
|
||||
return event.continue_event()
|
||||
|
||||
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
|
||||
"""
|
||||
移除时间窗口外的时间戳。
|
||||
|
||||
Args:
|
||||
timestamps (Deque[datetime]): 当前会话的时间戳队列。
|
||||
now (datetime): 当前时间,用于计算过期时间。
|
||||
"""
|
||||
expiry_threshold: datetime = now - self.rate_limit_time
|
||||
while timestamps and timestamps[0] < expiry_threshold:
|
||||
timestamps.popleft()
|
||||
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||
|
||||
@register_stage
|
||||
class RespondStage:
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event.send(result)
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import asyncio, time
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||
from astrbot.core.message.components import Plain, Image
|
||||
from astrbot.core import html_renderer
|
||||
|
||||
@register_stage
|
||||
class ResultDecorateStage:
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if len(result.chain) > 0:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
result.chain.insert(0, Plain(self.reply_prefix))
|
||||
|
||||
# 文本转图片
|
||||
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
|
||||
plain_str = ""
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain_str += "\n\n" + comp.text
|
||||
else:
|
||||
break
|
||||
if plain_str and len(plain_str) > 150:
|
||||
render_start = time.time()
|
||||
url = await html_renderer.render_t2i(plain_str, return_url=True)
|
||||
if time.time() - render_start > 3:
|
||||
logger.warning(f"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
|
||||
if url:
|
||||
result.chain = [Image.fromURL(url)]
|
||||
@@ -0,0 +1,44 @@
|
||||
from . import STAGES_ORDER
|
||||
from .stage import registered_stages, Stage
|
||||
from .context import PipelineContext
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
|
||||
from astrbot.core import logger
|
||||
|
||||
class PipelineScheduler():
|
||||
def __init__(self, context: PipelineContext):
|
||||
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__ .__name__))
|
||||
self.ctx = context
|
||||
|
||||
async def initialize(self):
|
||||
for stage in registered_stages:
|
||||
logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i]
|
||||
logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
coro = stage.process(event)
|
||||
if isinstance(coro, AsyncGenerator):
|
||||
async for _ in coro:
|
||||
if event.is_stopped():
|
||||
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
|
||||
break
|
||||
await self._process_stages(event, i + 1)
|
||||
else:
|
||||
await coro
|
||||
|
||||
if event.is_stopped():
|
||||
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
|
||||
break
|
||||
|
||||
if event.is_stopped():
|
||||
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
|
||||
break
|
||||
|
||||
async def execute(self, event: AstrMessageEvent):
|
||||
'''执行 pipeline'''
|
||||
await self._process_stages(event)
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import List, Dict, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
|
||||
registered_stages: List[Stage] = []
|
||||
'''维护了所有已注册的 Stage 实现类'''
|
||||
|
||||
def register_stage(cls):
|
||||
'''一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类
|
||||
'''
|
||||
registered_stages.append(cls())
|
||||
return cls
|
||||
|
||||
class Stage(abc.ABC):
|
||||
'''描述一个 Pipeline 的某个阶段
|
||||
'''
|
||||
|
||||
@abc.abstractmethod
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
'''初始化阶段
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''处理事件
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
|
||||
from astrbot.core.message.components import At, Plain
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
'''检查是否需要唤醒。唤醒机器人有如下几点条件:
|
||||
|
||||
1. 机器人被 @ 了
|
||||
2. 机器人的消息被提到了
|
||||
3. 以 wake_prefix 前缀开头
|
||||
4. 插件(Star)的 handler filter 通过
|
||||
'''
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config['admins_id']:
|
||||
if event.get_sender_id() == admin_id:
|
||||
event.role = "admin"
|
||||
break
|
||||
|
||||
# 检查 wake
|
||||
wake_prefixes = self.ctx.astrbot_config['wake_prefix']
|
||||
messages = event.get_messages()
|
||||
is_wake = False
|
||||
for wake_prefix in wake_prefixes:
|
||||
if event.message_str.startswith(wake_prefix):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
event.message_str = event.message_str[len(wake_prefix):].strip()
|
||||
break
|
||||
if not is_wake:
|
||||
# 检查是否有 at 消息
|
||||
for message in messages:
|
||||
if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
wake_prefix = ""
|
||||
break
|
||||
# 检查是否是私聊
|
||||
if event.is_private_chat():
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
wake_prefix = ""
|
||||
|
||||
# 检查插件的 handler filter
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
for handler in star_handlers_registry:
|
||||
# filter 需要满足 AND 的逻辑关系
|
||||
passed = False
|
||||
child_command_handler_md = None
|
||||
for filter in handler.event_filters:
|
||||
try:
|
||||
if isinstance(filter, CommandGroupFilter):
|
||||
'''如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata'''
|
||||
ok, child_command_handler_md = filter.filter(event, self.ctx.astrbot_config)
|
||||
if ok:
|
||||
passed = True
|
||||
handler = child_command_handler_md # handler 覆盖
|
||||
break
|
||||
else:
|
||||
if filter.filter(event, self.ctx.astrbot_config):
|
||||
passed = True
|
||||
break
|
||||
except Exception as e:
|
||||
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
|
||||
# yield
|
||||
await event.send(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
|
||||
event.stop_event()
|
||||
passed = False
|
||||
break
|
||||
|
||||
if passed:
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
|
||||
activated_handlers.append(handler)
|
||||
if 'parsed_params' in event.get_extra():
|
||||
handlers_parsed_params[handler.handler_full_name] = event.get_extra('parsed_params')
|
||||
event.clear_extra()
|
||||
|
||||
event.set_extra('activated_handlers', activated_handlers)
|
||||
event.set_extra('handlers_parsed_params', handlers_parsed_params)
|
||||
|
||||
if not is_wake:
|
||||
event.stop_event()
|
||||
@@ -0,0 +1,19 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import List, Dict, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
|
||||
@register_stage
|
||||
class WhitelistCheckStage(Stage):
|
||||
'''检查是否在群聊/私聊白名单
|
||||
'''
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 检查是否在白名单
|
||||
if event.unified_msg_origin not in self.whitelist:
|
||||
logger.info(f"会话 {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。")
|
||||
event.stop_event()
|
||||
@@ -1,11 +1,11 @@
|
||||
import abc
|
||||
import abc, logging
|
||||
from dataclasses import dataclass
|
||||
from .astrbot_message import AstrBotMessage
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, EventResultType
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from typing import List
|
||||
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
|
||||
from astrbot.core.message.components import *
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
@dataclass
|
||||
@@ -34,8 +34,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.session_id = session_id
|
||||
self.role = "member"
|
||||
self.is_wake = False
|
||||
|
||||
self._result: MessageEventResult = None
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
@@ -43,6 +41,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
session_id=session_id
|
||||
)
|
||||
self.unified_msg_origin = str(self.session)
|
||||
|
||||
self._result: MessageEventResult = None
|
||||
'''消息事件的结果'''
|
||||
|
||||
def get_platform_name(self):
|
||||
return self.platform_meta.name
|
||||
@@ -58,8 +59,19 @@ class AstrMessageEvent(abc.ABC):
|
||||
for i in chain:
|
||||
if isinstance(i, Plain):
|
||||
outline += i.text
|
||||
if isinstance(i, Image):
|
||||
elif isinstance(i, Image):
|
||||
outline += "[图片]"
|
||||
elif isinstance(i, Face):
|
||||
outline += f"[表情:{i.id}]"
|
||||
elif isinstance(i, At):
|
||||
outline += f"[At:{i.qq}]"
|
||||
elif isinstance(i, AtAll):
|
||||
outline += "[At:全体成员]"
|
||||
elif isinstance(i, Forward):
|
||||
# 转发消息
|
||||
outline += f"[转发消息]"
|
||||
else:
|
||||
outline += f"[{i.type}]"
|
||||
return outline
|
||||
|
||||
def get_message_outline(self) -> str:
|
||||
@@ -76,12 +88,24 @@ class AstrMessageEvent(abc.ABC):
|
||||
'''
|
||||
return self.message_obj.message
|
||||
|
||||
def get_message_type(self) -> MessageType:
|
||||
'''
|
||||
获取消息类型。
|
||||
'''
|
||||
return self.message_obj.type
|
||||
|
||||
def get_session_id(self) -> str:
|
||||
'''
|
||||
获取会话id。
|
||||
'''
|
||||
return self.session_id
|
||||
|
||||
def get_group_id(self) -> str:
|
||||
'''
|
||||
获取群组id。如果不是群组消息,返回空字符串。
|
||||
'''
|
||||
return self.message_obj.group_id
|
||||
|
||||
def get_self_id(self) -> str:
|
||||
'''
|
||||
获取机器人自身的id。
|
||||
@@ -101,16 +125,62 @@ class AstrMessageEvent(abc.ABC):
|
||||
return self.message_obj.sender.nickname
|
||||
|
||||
def set_result(self, result: MessageEventResult):
|
||||
'''
|
||||
设置消息事件的结果。当设置了结果后,消息事件将不再继续传递。
|
||||
'''设置消息事件的结果。
|
||||
|
||||
Note:
|
||||
事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。
|
||||
|
||||
如果没有设置 `MessageEventResult` 中的 result_type,默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。
|
||||
|
||||
Example:
|
||||
|
||||
async def ban_handler(self, event: AstrMessageEvent):
|
||||
if event.get_sender_id() in self.blacklist:
|
||||
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
|
||||
return
|
||||
|
||||
async def check_count(self, event: AstrMessageEvent):
|
||||
self.count += 1
|
||||
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
|
||||
return
|
||||
'''
|
||||
self._result = result
|
||||
|
||||
def stop_event(self):
|
||||
'''终止事件传播。
|
||||
'''
|
||||
if self._result is None:
|
||||
self.set_result(MessageEventResult().stop_event())
|
||||
else:
|
||||
self._result.stop_event()
|
||||
|
||||
def continue_event(self):
|
||||
'''继续事件传播。
|
||||
'''
|
||||
if self._result is None:
|
||||
self.set_result(MessageEventResult().continue_event())
|
||||
else:
|
||||
self._result.continue_event()
|
||||
|
||||
def is_stopped(self) -> bool:
|
||||
'''
|
||||
是否终止事件传播。
|
||||
'''
|
||||
if self._result is None:
|
||||
return False # 默认是继续传播
|
||||
return self._result.is_stopped()
|
||||
|
||||
def get_result(self) -> MessageEventResult:
|
||||
'''
|
||||
获取消息事件的结果。
|
||||
'''
|
||||
return self._result
|
||||
|
||||
def clear_result(self):
|
||||
'''
|
||||
清除消息事件的结果。
|
||||
'''
|
||||
self._result = None
|
||||
|
||||
def set_extra(self, key, value):
|
||||
'''
|
||||
@@ -118,6 +188,20 @@ class AstrMessageEvent(abc.ABC):
|
||||
'''
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(self, key = None):
|
||||
'''
|
||||
获取额外的信息。
|
||||
'''
|
||||
if key is None:
|
||||
return self._extras
|
||||
return self._extras.get(key, None)
|
||||
|
||||
def clear_extra(self):
|
||||
'''
|
||||
清除额外的信息。
|
||||
'''
|
||||
self._extras.clear()
|
||||
|
||||
def is_private_chat(self) -> bool:
|
||||
'''
|
||||
是否是私聊。
|
||||
|
||||
@@ -15,8 +15,9 @@ class AstrBotMessage:
|
||||
'''
|
||||
type: MessageType # 消息类型
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group_id: str = "" # 群组id,如果为私聊,则为空
|
||||
sender: MessageMember # 发送者
|
||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .platform import PlatformMetadata, Platform
|
||||
from typing import List
|
||||
from asyncio import Queue
|
||||
from .register import platform_registry, platform_cls_map
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class PlatformManager():
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
self.platform_insts: List[Platform] = []
|
||||
'''加载的 Platform 的实例'''
|
||||
|
||||
self.platforms_config = config['platform']
|
||||
self.settings = config['platform_settings']
|
||||
self.event_queue = event_queue
|
||||
|
||||
for platform in self.platforms_config:
|
||||
if not platform['enable']:
|
||||
continue
|
||||
match platform['name']:
|
||||
case "aiocqhttp":
|
||||
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter
|
||||
case "qqofficial":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialAdapter
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatAdapter
|
||||
|
||||
async def initialize(self):
|
||||
for platform in self.platforms_config:
|
||||
if not platform['enable']:
|
||||
continue
|
||||
if platform['name'] not in platform_cls_map:
|
||||
logger.error(f"未找到适用于 {platform['name']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
cls_type = platform_cls_map[platform['name']]
|
||||
logger.info(f"尝试实例化 {platform['name']}({platform['id']}) 平台适配器 ...")
|
||||
inst = cls_type(platform, self.settings, self.event_queue)
|
||||
self.platform_insts.append(inst)
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import Type
|
||||
@dataclass
|
||||
class PlatformMetadata():
|
||||
name: str # 平台的名称
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import List, Dict, Type
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from astrbot.core import logger
|
||||
|
||||
platform_registry: List[PlatformMetadata] = []
|
||||
'''维护了通过装饰器注册的平台适配器'''
|
||||
platform_cls_map: Dict[str, Type] = {}
|
||||
'''维护了平台适配器名称和适配器类的映射'''
|
||||
|
||||
def register_platform_adapter(adapter_name: str, desc: str):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def decorator(cls):
|
||||
if adapter_name in platform_cls_map:
|
||||
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
|
||||
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
logger.debug(f"平台适配器 {adapter_name} 已注册")
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
import os, traceback, random, asyncio
|
||||
|
||||
from astrbot.api import AstrMessageEvent, MessageChain, logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
+13
-33
@@ -4,24 +4,26 @@ import traceback
|
||||
import logging
|
||||
from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api import Platform
|
||||
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain, MessageEventResult
|
||||
from .aiocqhttp_message_event import *
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
|
||||
class AiocqhttpAdapter(Platform):
|
||||
def __init__(self, platform_config: AiocqhttpPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.unique_session
|
||||
self.host = platform_config.ws_reverse_host
|
||||
self.port = platform_config.ws_reverse_port
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
self.host = platform_config['ws_reverse_host']
|
||||
self.port = platform_config['ws_reverse_port']
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"aiocqhttp",
|
||||
@@ -51,6 +53,7 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if event['message_type'] == 'group':
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
elif event['message_type'] == 'private':
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
@@ -71,35 +74,18 @@ class AiocqhttpAdapter(Platform):
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return
|
||||
logger.debug(f"aiocqhttp: 收到消息: {event.message}")
|
||||
for m in event.message:
|
||||
t = m['type']
|
||||
a = None
|
||||
if t == 'at':
|
||||
a = At(**m['data'])
|
||||
abm.message.append(a)
|
||||
if t == 'text':
|
||||
a = Plain(text=m['data']['text'])
|
||||
message_str += m['data']['text'].strip()
|
||||
abm.message.append(a)
|
||||
if t == 'image':
|
||||
file = m['data']['file'] if 'file' in m['data'] else None
|
||||
url = m['data']['url'] if 'url' in m['data'] else None
|
||||
a = Image(file=file, url=url)
|
||||
abm.message.append(a)
|
||||
a = ComponentTypes[t](**m['data'])
|
||||
abm.message.append(a)
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = event
|
||||
return abm
|
||||
|
||||
def handle_whitelist(self, event: Event) -> bool:
|
||||
match event['message_type']:
|
||||
case "group":
|
||||
if self.config.qq_group_id_whitelist and str(event.group_id) in self.config.qq_group_id_whitelist:
|
||||
return True
|
||||
case "private":
|
||||
if self.config.qq_id_whitelist and str(event.sender['user_id']) in self.config.qq_id_whitelist:
|
||||
return True
|
||||
return False
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
if not self.host or not self.port:
|
||||
@@ -107,18 +93,12 @@ class AiocqhttpAdapter(Platform):
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
if not self.handle_whitelist(event):
|
||||
logger.debug(f"一个群消息({event.group_id})事件由于不在白名单而被过滤。")
|
||||
return
|
||||
abm = self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
if not self.handle_whitelist(event):
|
||||
logger.debug(f"一个私聊消息({event.sender['nickname']}/{event.sender['user_id']})事件由于不在白名单而被过滤。")
|
||||
return
|
||||
abm = self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
+2
-1
@@ -3,7 +3,8 @@ import botpy.message
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
+14
-10
@@ -6,17 +6,17 @@ import botpy.types
|
||||
import botpy.types.message
|
||||
|
||||
from botpy import Client
|
||||
from astrbot.api import Platform
|
||||
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List, Dict
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .qqofficial_message_event import QQOfficialMessageEvent
|
||||
from astrbot.core.config.astrbot_config import PlatformConfig, QQOfficialPlatformConfig, PlatformSettings
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
# QQ 机器人官方框架
|
||||
@register_platform_adapter("qqofficial", "QQ 机器人官方 API 适配器")
|
||||
class botClient(Client):
|
||||
def set_platform(self, platform: 'QQOfficialPlatformAdapter'):
|
||||
self.platform = platform
|
||||
@@ -56,18 +56,18 @@ class botClient(Client):
|
||||
|
||||
class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: QQOfficialPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
self.appid = platform_config.appid
|
||||
self.secret = platform_config.secret
|
||||
self.unique_session = platform_settings.unique_session
|
||||
qq_group = platform_config.enable_group_c2c
|
||||
guild_dm = platform_config.enable_guild_direct_message
|
||||
self.appid = platform_config['appid']
|
||||
self.secret = platform_config['secret']
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
qq_group = platform_config['enable_group_c2c']
|
||||
guild_dm = platform_config['enable_guild_direct_message']
|
||||
|
||||
if qq_group:
|
||||
self.intents = botpy.Intents(
|
||||
@@ -115,6 +115,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
message.author.member_openid,
|
||||
""
|
||||
)
|
||||
abm.group_id = message.group_openid
|
||||
else:
|
||||
abm.sender = MessageMember(
|
||||
message.author.user_openid,
|
||||
@@ -157,6 +158,9 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
str(message.author.id),
|
||||
str(message.author.username)
|
||||
)
|
||||
|
||||
if isinstance(message, botpy.message.Message):
|
||||
abm.group_id = message.channel_id
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {message_type}")
|
||||
return abm
|
||||
+5
-3
@@ -1,10 +1,12 @@
|
||||
import random, asyncio
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from vchat import Core
|
||||
|
||||
class WechatPlatformEvent(AstrMessageEvent):
|
||||
class VChatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
@@ -36,6 +38,6 @@ class WechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await WechatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
|
||||
await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
|
||||
await super().send(message)
|
||||
|
||||
+14
-18
@@ -1,15 +1,13 @@
|
||||
import sys, time, datetime, uuid
|
||||
import sys, time, uuid
|
||||
import asyncio
|
||||
|
||||
from astrbot.api import Platform
|
||||
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from typing import Union, List, Dict
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import *
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .wechat_message_event import WechatPlatformEvent
|
||||
from astrbot.core.config.astrbot_config import PlatformConfig, WechatPlatformConfig, PlatformSettings
|
||||
from astrbot.core.utils.io import save_temp_img, download_image_by_url
|
||||
from .vchat_message_event import VChatPlatformEvent
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
from vchat import Core
|
||||
from vchat import model
|
||||
@@ -18,10 +16,11 @@ if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
class WechatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
|
||||
@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器")
|
||||
class VChatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
@@ -31,13 +30,13 @@ class WechatPlatformAdapter(Platform):
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
from_username = session.session_id.split('$$')[0]
|
||||
await WechatPlatformEvent.send_with_client(self.client, message_chain, from_username)
|
||||
await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"wechat",
|
||||
"vchat",
|
||||
"基于 VChat 的 Wechat 适配器",
|
||||
)
|
||||
|
||||
@@ -53,10 +52,6 @@ class WechatPlatformAdapter(Platform):
|
||||
logger.debug(f"忽略旧消息: {msg}")
|
||||
return
|
||||
logger.debug(f"收到消息: {msg.todict()}")
|
||||
if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist:
|
||||
logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}")
|
||||
return
|
||||
logger.info(f"收到消息: {msg.todict()}")
|
||||
abmsg = self.convert_message(msg)
|
||||
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
|
||||
asyncio.create_task(self.handle_msg(abmsg))
|
||||
@@ -92,12 +87,13 @@ class WechatPlatformAdapter(Platform):
|
||||
amsg.type = MessageType.FRIEND_MESSAGE
|
||||
elif isinstance(msg.from_, model.Chatroom):
|
||||
amsg.type = MessageType.GROUP_MESSAGE
|
||||
amsg.group_id = msg.from_.username
|
||||
else:
|
||||
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
|
||||
|
||||
amsg.raw_message = msg
|
||||
|
||||
if self.settingss.unique_session:
|
||||
if self.settingss['unique_session']:
|
||||
session_id = msg.from_.username + "$$" + msg.to.username
|
||||
if msg.chatroom_sender is not None:
|
||||
session_id += '$$' + msg.chatroom_sender.username
|
||||
@@ -108,7 +104,7 @@ class WechatPlatformAdapter(Platform):
|
||||
return amsg
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = WechatPlatformEvent(
|
||||
message_event = VChatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
@@ -1,4 +0,0 @@
|
||||
from .plugin import Plugin, RegisteredPlugin, PluginMetadata
|
||||
from .plugin_manager import PluginManager
|
||||
from .context import CommandMetadata, Context
|
||||
from astrbot.core.provider import Provider
|
||||
@@ -1,217 +0,0 @@
|
||||
import heapq
|
||||
from asyncio import Queue
|
||||
from . import RegisteredPlugin, PluginMetadata
|
||||
from typing import List, Dict, Awaitable, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.utils.func_call import FuncCall
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
@dataclass
|
||||
class CommandMetadata():
|
||||
'''
|
||||
显式指令
|
||||
'''
|
||||
plugin_name: str
|
||||
plugin_metadata: PluginMetadata
|
||||
handler: Awaitable
|
||||
use_regex: bool = False
|
||||
ignore_prefix: bool = False
|
||||
description: str = ""
|
||||
|
||||
@dataclass
|
||||
class EventListenerMetadata():
|
||||
'''
|
||||
事件监听器
|
||||
'''
|
||||
plugin_name: str
|
||||
plugin_metadata: PluginMetadata
|
||||
handler: Awaitable
|
||||
description: str = ""
|
||||
after_commands: bool = False
|
||||
|
||||
|
||||
class Context:
|
||||
'''
|
||||
暴露给插件的接口上下文,用于注册指令、事件监听器、消息平台、模型提供商等。
|
||||
'''
|
||||
# 事件队列。消息平台通过事件队列传递消息事件。
|
||||
_event_queue: Queue = None
|
||||
|
||||
# AstrBot 配置信息
|
||||
_config: AstrBotConfig = None
|
||||
|
||||
# AstrBot 数据库
|
||||
_db: BaseDatabase = None
|
||||
|
||||
# 维护了注册的插件的信息
|
||||
registered_plugins: List[RegisteredPlugin] = []
|
||||
|
||||
# 维护了插件注册的指令的信息的名字列表,用于优先级排序
|
||||
registered_commands: List[str] = []
|
||||
# 维护了插件注册的指令的信息
|
||||
commands_handler: Dict[str, CommandMetadata] = {}
|
||||
|
||||
# 维护了插件注册的中间件的名字列表,用于优先级排序
|
||||
registered_listeners: List[str] = []
|
||||
# 维护了插件注册的中间件的信息
|
||||
listeners_handler: Dict[str, EventListenerMetadata] = {}
|
||||
|
||||
# 维护了注册的平台的信息
|
||||
registered_platforms: List[Platform] = []
|
||||
|
||||
# 维护了 LLM Tools 信息
|
||||
llm_tools: FuncCall = FuncCall()
|
||||
|
||||
# 维护插件存储的数据
|
||||
plugin_data: Dict[str, Dict[str, any]] = {}
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
|
||||
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
|
||||
for plugin in self.registered_plugins:
|
||||
if plugin.metadata.plugin_name == plugin_name:
|
||||
return plugin
|
||||
return None
|
||||
|
||||
def register_listener(self,
|
||||
plugin_name: str,
|
||||
name: str,
|
||||
handler: Awaitable,
|
||||
description: str = None,
|
||||
after_commands: bool = False):
|
||||
'''
|
||||
注册一个事件监听器。
|
||||
|
||||
after_commands: 是否在指令处理后执行。
|
||||
'''
|
||||
if name in self.registered_listeners:
|
||||
raise ValueError(f"Middleware {name} already exists.")
|
||||
self.registered_listeners.append(name)
|
||||
self.listeners_handler[name] = EventListenerMetadata(
|
||||
plugin_name=plugin_name,
|
||||
plugin_metadata=None,
|
||||
handler=handler,
|
||||
description=description,
|
||||
after_commands=after_commands
|
||||
)
|
||||
|
||||
def register_commands(self,
|
||||
plugin_name: str,
|
||||
command_name: str,
|
||||
description: str,
|
||||
priority: int,
|
||||
handler: Awaitable,
|
||||
use_regex: bool = False,
|
||||
ignore_prefix: bool = False):
|
||||
'''
|
||||
注册插件指令。
|
||||
|
||||
@param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
|
||||
@param command_name: 指令名,如 "help"。不需要带前缀。
|
||||
@param description: 指令描述。
|
||||
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
|
||||
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
|
||||
@param use_regex: 是否使用正则表达式匹配指令名。
|
||||
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
|
||||
|
||||
.. Example::
|
||||
|
||||
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
|
||||
'''
|
||||
for command in self.registered_commands:
|
||||
if command_name in command[1]:
|
||||
raise ValueError(f"Command {command_name} already exists.")
|
||||
if not handler:
|
||||
raise ValueError(f"Handler of {command_name} is None.")
|
||||
|
||||
heapq.heappush(self.registered_commands, (-priority, command_name))
|
||||
self.commands_handler[command_name] = CommandMetadata(
|
||||
plugin_name=plugin_name,
|
||||
plugin_metadata=None,
|
||||
handler=handler,
|
||||
use_regex=use_regex,
|
||||
ignore_prefix=ignore_prefix,
|
||||
description=description
|
||||
)
|
||||
heapq.heapify(self.registered_commands)
|
||||
|
||||
def register_platform(self, platform: Platform):
|
||||
'''
|
||||
注册一个消息平台。
|
||||
'''
|
||||
self.registered_platforms.append(platform)
|
||||
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 异步处理函数。
|
||||
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
self.llm_tools.add_func(name, func_args, desc, func_obj)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
self.llm_tools.remove_func(name)
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''
|
||||
获取 AstrBot 配置信息。
|
||||
'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''
|
||||
获取 AstrBot 数据库。
|
||||
'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
'''
|
||||
获取事件队列。
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
try:
|
||||
session = MessageSesion.from_str(session)
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.registered_platforms:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_data(self, plugin_name: str, key: str, value: any):
|
||||
'''
|
||||
设置插件数据。
|
||||
'''
|
||||
self.plugin_data[plugin_name][key] = value
|
||||
@@ -1,43 +0,0 @@
|
||||
from enum import Enum
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class PluginMetadata:
|
||||
'''
|
||||
插件的元数据。
|
||||
'''
|
||||
# required
|
||||
plugin_name: str
|
||||
author: str # 插件作者
|
||||
desc: str # 插件简介
|
||||
version: str # 插件版本
|
||||
|
||||
# optional
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"PluginMetadata({self.plugin_name}, {self.desc}, {self.version}, {self.repo})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredPlugin:
|
||||
'''
|
||||
注册在 AstrBot 中的插件。
|
||||
'''
|
||||
metadata: PluginMetadata
|
||||
plugin_instance: object
|
||||
module_path: str
|
||||
module: ModuleType
|
||||
root_dir_name: str
|
||||
reserved: bool # 是否是 AstrBot 的保留插件
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
|
||||
|
||||
|
||||
|
||||
class Plugin:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -1 +1 @@
|
||||
from .provider import Provider, Personality
|
||||
from .provider import Provider, Personality, ProviderMetaData
|
||||
@@ -0,0 +1,13 @@
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色'''
|
||||
completion_text: str = None
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = None
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = None
|
||||
'''工具调用名称'''
|
||||
@@ -0,0 +1,49 @@
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from collections import defaultdict
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from .register import provider_cls_map, provider_registry
|
||||
from astrbot.core import logger
|
||||
|
||||
class ProviderManager():
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
self.providers_config: List = config['provider']
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
self.llm_tools: FuncCall = FuncCall()
|
||||
self.curr_provider_inst: Provider = None
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
for provider_cfg in self.providers_config:
|
||||
if not provider_cfg['enable']:
|
||||
continue
|
||||
|
||||
if provider_cfg['id'] in self.loaded_ids:
|
||||
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
|
||||
self.loaded_ids[provider_cfg['id']] = True
|
||||
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial
|
||||
|
||||
async def initialize(self):
|
||||
for provider_config in self.providers_config:
|
||||
if not provider_config['enable']:
|
||||
continue
|
||||
if provider_config['type'] not in provider_cls_map:
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
cls_type = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
|
||||
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
self.provider_insts.append(inst)
|
||||
|
||||
if len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
@@ -1,91 +1,134 @@
|
||||
import abc, json, threading, time
|
||||
import abc, json
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core import logger
|
||||
from typing import TypedDict
|
||||
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
from dataclasses import dataclass
|
||||
class Personality(TypedDict):
|
||||
prompt: str
|
||||
name: str
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta():
|
||||
id: str
|
||||
model: str
|
||||
type: str
|
||||
|
||||
|
||||
class Provider(abc.ABC):
|
||||
def __init__(self, db_helper: BaseDatabase, default_personality: str = None, persistant_history: bool = True) -> None:
|
||||
self.model_name = "unknown"
|
||||
# 维护了 session_id 的上下文,不包含 system 指令
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
persistant_history: bool = True,
|
||||
db_helper: BaseDatabase = None
|
||||
) -> None:
|
||||
self.model_name = ""
|
||||
'''当前使用的模型名称'''
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
self.curr_personality = Personality(prompt=default_personality, name="")
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_config = provider_config
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality = Personality(prompt=provider_settings['default_personality'])
|
||||
'''维护了当前的使用的 persona,即人格。'''
|
||||
|
||||
self.db_helper = db_helper
|
||||
'''用于持久化的数据库操作对象。'''
|
||||
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
try:
|
||||
for history in db_helper.get_llm_history():
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self):
|
||||
def get_model(self) -> str:
|
||||
'''获得当前使用的模型名称'''
|
||||
return self.model_name
|
||||
|
||||
async def get_human_readable_context(self, session_id: str) -> List[str]:
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
'''获得提供商 Key'''
|
||||
return self.provider_config['key']
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_key(self, key: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_models(self) -> List[str]:
|
||||
'''获得支持的模型列表'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
|
||||
'''获取人类可读的上下文
|
||||
|
||||
Example:
|
||||
|
||||
["User: 你好", "Assistant: 你好!"]
|
||||
|
||||
Return:
|
||||
contexts: List[str]: 上下文列表
|
||||
total_pages: int: 总页数
|
||||
'''
|
||||
获取人类可读的上下文
|
||||
|
||||
example:
|
||||
["User: 你好", "Assistant: 你好"]
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
contexts.append(f"Assistant: {record['content']}")
|
||||
|
||||
return contexts
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str] = None,
|
||||
tools = None,
|
||||
contexts=None,
|
||||
**kwargs) -> str:
|
||||
'''
|
||||
prompt: 提示词
|
||||
session_id: 会话id
|
||||
session_id: str=None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts: List=None,
|
||||
**kwargs) -> LLMResponse:
|
||||
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
session_id: 会话 ID
|
||||
image_urls: 图片 URL 列表
|
||||
tools: Function-calling 工具
|
||||
contexts: 上下文
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
|
||||
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
- 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。
|
||||
传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
|
||||
|
||||
[optional]
|
||||
image_url: 图片url(识图)
|
||||
tools: 函数调用工具
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def image_generate(self, prompt: str, session_id: str, **kwargs) -> str:
|
||||
'''
|
||||
prompt: 提示词
|
||||
session_id: 会话id
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
'''
|
||||
获取文本的嵌入
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
'''
|
||||
重置会话
|
||||
'''
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
@@ -0,0 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData():
|
||||
type: str # 提供商适配器名称,如 openai, ollama
|
||||
desc: str = "" # 提供商适配器描述.
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import List, Dict, Type
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from astrbot.core import logger
|
||||
|
||||
provider_registry: List[ProviderMetaData] = []
|
||||
'''维护了通过装饰器注册的 Provider'''
|
||||
provider_cls_map: Dict[str, Type] = {}
|
||||
'''维护了 Provider 类型名称和 Provider 类的映射'''
|
||||
|
||||
def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def decorator(cls):
|
||||
if provider_type_name in provider_cls_map:
|
||||
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
|
||||
|
||||
pm = ProviderMetaData(
|
||||
type=provider_type_name,
|
||||
desc=desc,
|
||||
)
|
||||
provider_registry.append(pm)
|
||||
provider_cls_map[provider_type_name] = cls
|
||||
logger.debug(f"Provider {provider_type_name} 已注册")
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,216 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import *
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from typing import List
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.llm_response import LLMResponse
|
||||
|
||||
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
contexts.append(f"Assistant: {record['content']}")
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models_str = []
|
||||
models = await self.client.models.list()
|
||||
for model in models:
|
||||
models_str.append(model['id'])
|
||||
return models_str
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
弹出第一条记录
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
payloads["tools"] = tools.get_func_desc_openai_style()
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"completion: {completion.usage}")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
return LLMResponse("assistant", completion_text)
|
||||
elif choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
func_name_ls = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
for tool in tools.func_list:
|
||||
if tool['name'] == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
|
||||
else:
|
||||
raise Exception("Internal Error")
|
||||
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
|
||||
context_query = []
|
||||
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
if self.curr_personality["prompt"]:
|
||||
context_query.insert(0, {"role": "system", "content": self.curr_personality["prompt"]})
|
||||
else:
|
||||
context_query = contexts
|
||||
|
||||
logger.debug(f"请求上下文:{context_query}")
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
}
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
return llm_response
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
组装上下文。
|
||||
'''
|
||||
if image_urls:
|
||||
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user","content": text}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
'''
|
||||
将图片转换为 base64
|
||||
'''
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ''
|
||||
@@ -1,7 +1,7 @@
|
||||
from astrbot.core.provider import Provider
|
||||
from typing import Awaitable
|
||||
import json
|
||||
import textwrap
|
||||
from typing import Awaitable, Dict, List
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
@@ -11,7 +11,6 @@ class FuncCallJsonFormatError(Exception):
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
|
||||
class FuncNotFoundError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@@ -19,10 +18,19 @@ class FuncNotFoundError(Exception):
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
class FuncTool(TypedDict):
|
||||
'''
|
||||
用于描述一个函数调用工具。
|
||||
'''
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
func_obj: Awaitable
|
||||
|
||||
|
||||
class FuncCall():
|
||||
def __init__(self) -> None:
|
||||
self.func_list = []
|
||||
self.func_list: List[FuncTool] = []
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -45,12 +53,7 @@ class FuncCall():
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
}
|
||||
_func = {
|
||||
"name": name,
|
||||
"parameters": params,
|
||||
"description": desc,
|
||||
"func_obj": func_obj,
|
||||
}
|
||||
_func = FuncTool(name=name, parameters=params, description=desc, func_obj=func_obj)
|
||||
self.func_list.append(_func)
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
@@ -62,17 +65,16 @@ class FuncCall():
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def func_dump(self) -> str:
|
||||
_l = []
|
||||
def get_func(self, name) -> FuncTool:
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"name": f["name"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
})
|
||||
return json.dumps(_l, ensure_ascii=False)
|
||||
|
||||
def get_func(self) -> list:
|
||||
if f["name"] == name:
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_func_desc_openai_style(self) -> list:
|
||||
'''
|
||||
获得 OpenAI API 风格的工具描述
|
||||
'''
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
@@ -85,7 +87,17 @@ class FuncCall():
|
||||
})
|
||||
return _l
|
||||
|
||||
async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider) -> tuple:
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"name": f["name"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
})
|
||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
|
||||
@@ -111,7 +123,6 @@ class FuncCall():
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
print(res)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
res = json.loads(res)
|
||||
+7
-6
@@ -5,11 +5,13 @@ import os
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from openai._exceptions import *
|
||||
from .websearch.config import HEADERS, USER_AGENTS
|
||||
from .websearch.bing import Bing
|
||||
from .websearch.sogo import Sogo
|
||||
from .websearch.google import Google
|
||||
from astrbot.api import logger, AstrMessageEvent, Provider, MessageChain, MessageEventResult
|
||||
from engines.config import HEADERS, USER_AGENTS
|
||||
from engines.bing import Bing
|
||||
from engines.sogo import Sogo
|
||||
from engines.google import Google
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain, MessageEventResult
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api import logger
|
||||
|
||||
bing_search = Bing()
|
||||
sogo_search = Sogo()
|
||||
@@ -61,7 +63,6 @@ async def search_from_bing(keyword: str, event: AstrMessageEvent = None, provide
|
||||
|
||||
return await summarize(ret, event, provider)
|
||||
|
||||
|
||||
async def fetch_website_content(url: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
@@ -0,0 +1,5 @@
|
||||
# AstrBot Star
|
||||
|
||||
`AstrBot Star` 就是插件。
|
||||
|
||||
在 AstrBot v4.0 版本后,AstrBot 内部将插件命名为 `star`。插件的 handler 称作 `star_handler`。
|
||||
@@ -0,0 +1,4 @@
|
||||
from .star import Star, StarMetadata
|
||||
from .star_manager import PluginManager
|
||||
from .context import Context
|
||||
from astrbot.core.provider import Provider
|
||||
@@ -0,0 +1,174 @@
|
||||
import heapq
|
||||
from asyncio import Queue
|
||||
from . import StarMetadata
|
||||
from typing import List, Dict, TypedDict, Union
|
||||
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from .star import star_registry, star_map, StarMetadata
|
||||
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
|
||||
class StarCommand(TypedDict):
|
||||
full_command_name: str
|
||||
command_name: str
|
||||
|
||||
class Context:
|
||||
'''
|
||||
暴露给插件的接口上下文。
|
||||
'''
|
||||
_event_queue: Queue = None
|
||||
'''事件队列。消息平台通过事件队列传递消息事件。'''
|
||||
|
||||
_config: AstrBotConfig = None
|
||||
'''AstrBot 配置信息'''
|
||||
|
||||
_db: BaseDatabase = None
|
||||
'''AstrBot 数据库'''
|
||||
|
||||
provider_manager: ProviderManager = None
|
||||
|
||||
platform_manager: PlatformManager = None
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
return star_map.get(star_name, None)
|
||||
|
||||
def get_all_stars(self) -> List[StarMetadata]:
|
||||
return star_registry
|
||||
|
||||
def get_llm_tools(self) -> FuncCall:
|
||||
'''
|
||||
获取 LLM Tools。
|
||||
'''
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
|
||||
# def get_star_commands(self, star_name: str) -> List[]:
|
||||
# '''获得一个'''
|
||||
|
||||
# def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
# '''
|
||||
# 为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
# @param name: 函数名
|
||||
# @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
# @param desc: 函数描述
|
||||
# @param func_obj: 异步处理函数。
|
||||
|
||||
# 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
# '''
|
||||
# self.llm_tools.add_func(name, func_args, desc, func_obj)
|
||||
|
||||
# def unregister_llm_tool(self, name: str) -> None:
|
||||
# '''
|
||||
# 删除一个函数调用工具。
|
||||
# '''
|
||||
# self.llm_tools.remove_func(name)
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
注册一个命令。
|
||||
|
||||
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
|
||||
|
||||
@param star_name: 插件(Star)名称。
|
||||
@param command_name: 命令名称。
|
||||
@param desc: 命令描述。
|
||||
@param priority: 优先级。1-10。
|
||||
@param awaitable: 异步处理函数。
|
||||
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_str=awaitable.__module__,
|
||||
handler=awaitable,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
if use_regex:
|
||||
md.event_filters.append(RegexFilter(
|
||||
regex=command_name
|
||||
))
|
||||
else:
|
||||
md.event_filters.append(CommandFilter(
|
||||
command_name=command_name,
|
||||
handler_md=md
|
||||
))
|
||||
star_handlers_registry.append(md)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''
|
||||
获取所有 LLM Provider。
|
||||
'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的 LLM Provider。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''
|
||||
获取 AstrBot 配置信息。
|
||||
'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''
|
||||
获取 AstrBot 数据库。
|
||||
'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
'''
|
||||
获取事件队列。
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
try:
|
||||
session = MessageSesion.from_str(session)
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.registered_platforms:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,10 @@
|
||||
import abc
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
class HandlerFilter(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
'''是否应当被过滤'''
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,67 @@
|
||||
|
||||
import re, inspect
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin
|
||||
from typing import Awaitable
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
# 标准指令受到 wake_prefix 的制约。
|
||||
class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
'''标准指令过滤器'''
|
||||
def __init__(self, command_name: str, handler_md: StarHandlerMetadata = None):
|
||||
self.command_name = command_name
|
||||
if handler_md:
|
||||
self.init_handler_md(handler_md)
|
||||
|
||||
def print_types(self):
|
||||
result = ""
|
||||
print(self.handler_params)
|
||||
for k, v in self.handler_params.items():
|
||||
if isinstance(v, type):
|
||||
result += f"{k}({v.__name__}),"
|
||||
else:
|
||||
result += f"{k}({type(v).__name__})={v},"
|
||||
return result
|
||||
|
||||
def init_handler_md(self, handle_md: StarHandlerMetadata):
|
||||
self.handler_md = handle_md
|
||||
signature = inspect.signature(self.handler_md.handler)
|
||||
self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值
|
||||
idx = 0
|
||||
for k, v in signature.parameters.items():
|
||||
if idx < 2:
|
||||
# 忽略前两个参数,即 self 和 event
|
||||
idx += 1
|
||||
continue
|
||||
if v.default == inspect.Parameter.empty:
|
||||
self.handler_params[k] = v.annotation
|
||||
else:
|
||||
self.handler_params[k] = v.default
|
||||
|
||||
def get_handler_md(self) -> StarHandlerMetadata:
|
||||
return self.handler_md
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
if not event.is_wake_up():
|
||||
return False
|
||||
|
||||
message_str = event.get_message_str().strip()
|
||||
# 分割为列表(每个参数之间可能会有多个空格)
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0]:
|
||||
return False
|
||||
# params_str = message_str[len(self.command_name):].strip()
|
||||
ls = ls[1:]
|
||||
# 去除空字符串
|
||||
ls = [param for param in ls if param]
|
||||
params = {}
|
||||
try:
|
||||
params = self.validate_and_convert_params(ls, self.handler_params)
|
||||
# 解析完成咱也不能丢掉呀,留着给后面的用
|
||||
except ValueError as e:
|
||||
raise e
|
||||
event.set_extra("parsed_params", params)
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Awaitable, List, Union, Tuple
|
||||
from . import HandlerFilter
|
||||
from .command import CommandFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from ..star_handler import StarHandlerMetadata
|
||||
|
||||
# 指令组受到 wake_prefix 的制约。
|
||||
class CommandGroupFilter(HandlerFilter):
|
||||
def __init__(self, group_name: str):
|
||||
self.group_name = group_name
|
||||
self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = []
|
||||
|
||||
def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]):
|
||||
self.sub_command_filters.append(sub_command_filter)
|
||||
|
||||
# 以树的形式打印出来
|
||||
def print_cmd_tree(self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "") -> str:
|
||||
result = ""
|
||||
for sub_filter in sub_command_filters:
|
||||
if isinstance(sub_filter, CommandFilter):
|
||||
cmd_th = sub_filter.print_types()
|
||||
result += f"{prefix}├── {sub_filter.command_name}"
|
||||
if cmd_th:
|
||||
result += f" ({cmd_th})"
|
||||
else:
|
||||
result += f" (无参数指令)"
|
||||
|
||||
result += "\n"
|
||||
elif isinstance(sub_filter, CommandGroupFilter):
|
||||
result += f"{prefix}├── {sub_filter.group_name}"
|
||||
result += "\n"
|
||||
result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"│ ")
|
||||
return result
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
|
||||
if not event.is_wake_up():
|
||||
return False, None
|
||||
|
||||
message_str = event.get_message_str().strip()
|
||||
ls = re.split(r"\s+", message_str)
|
||||
|
||||
if ls[0] != self.group_name:
|
||||
return False, None
|
||||
# 改写 message_str
|
||||
ls = ls[1:]
|
||||
event.message_str = " ".join(ls)
|
||||
event.message_str = event.message_str.strip()
|
||||
|
||||
if event.message_str == "":
|
||||
# 当前还是指令组
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
|
||||
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
|
||||
|
||||
child_command_handler_md = None
|
||||
for sub_filter in self.sub_command_filters:
|
||||
if isinstance(sub_filter, CommandFilter):
|
||||
if sub_filter.filter(event, cfg):
|
||||
child_command_handler_md = sub_filter.get_handler_md()
|
||||
return True, child_command_handler_md
|
||||
elif isinstance(sub_filter, CommandGroupFilter):
|
||||
ok, handler = sub_filter.filter(event, cfg)
|
||||
if ok:
|
||||
child_command_handler_md = handler
|
||||
return True, child_command_handler_md
|
||||
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
|
||||
raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree)
|
||||
@@ -0,0 +1,28 @@
|
||||
import enum
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
class EventMessageType(enum.Flag):
|
||||
GROUP_MESSAGE = enum.auto()
|
||||
PRIVATE_MESSAGE = enum.auto()
|
||||
OTHER_MESSAGE = enum.auto()
|
||||
ALL = GROUP_MESSAGE | PRIVATE_MESSAGE | OTHER_MESSAGE
|
||||
|
||||
MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE = {
|
||||
MessageType.GROUP_MESSAGE: EventMessageType.GROUP_MESSAGE,
|
||||
MessageType.FRIEND_MESSAGE: EventMessageType.PRIVATE_MESSAGE,
|
||||
MessageType.OTHER_MESSAGE: EventMessageType.OTHER_MESSAGE
|
||||
}
|
||||
|
||||
class EventMessageTypeFilter(HandlerFilter):
|
||||
def __init__(self, event_message_type: EventMessageType):
|
||||
self.event_message_type = event_message_type
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
message_type = event.get_message_type()
|
||||
if message_type in MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE:
|
||||
event_message_type = MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE[message_type]
|
||||
return bool(event_message_type & self.event_message_type)
|
||||
return False
|
||||
@@ -0,0 +1,27 @@
|
||||
import enum
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from typing import Union
|
||||
|
||||
class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
VCHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
|
||||
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"vchat": PlatformAdapterType.VCHAT
|
||||
}
|
||||
|
||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
|
||||
self.type_or_str = platform_adapter_type_or_str
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
adapter_name = event.get_platform_name()
|
||||
if adapter_name in ADAPTER_NAME_2_TYPE:
|
||||
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
|
||||
return False
|
||||
@@ -0,0 +1,14 @@
|
||||
|
||||
import re
|
||||
from . import HandlerFilter
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
# 正则表达式过滤器不会受到 wake_prefix 的制约。
|
||||
class RegexFilter(HandlerFilter):
|
||||
'''正则表达式过滤器'''
|
||||
def __init__(self, regex: str):
|
||||
self.regex = re.compile(regex)
|
||||
|
||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||
return bool(self.regex.match(event.get_message_str().strip()))
|
||||
@@ -0,0 +1,8 @@
|
||||
from .star import register_star
|
||||
from .star_handler import (
|
||||
register_command,
|
||||
register_command_group,
|
||||
register_event_message_type,
|
||||
register_platform_adapter_type,
|
||||
register_regex
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
from ..star import star_registry, StarMetadata, star_map
|
||||
|
||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
||||
def decorator(cls):
|
||||
star_metadata = StarMetadata(
|
||||
name=name,
|
||||
author=author,
|
||||
desc=desc,
|
||||
version=version,
|
||||
repo=repo,
|
||||
star_cls_type=cls,
|
||||
module_path=cls.__module__
|
||||
)
|
||||
star_registry.append(star_metadata)
|
||||
star_map[cls.__module__] = star_metadata
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
|
||||
from ..filter.command import CommandFilter
|
||||
from ..filter.command_group import CommandGroupFilter
|
||||
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
|
||||
from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable, List, Dict
|
||||
|
||||
|
||||
def get_handler_full_name(awatable: Awaitable) -> str:
|
||||
'''获取 Handler 的全名'''
|
||||
return f"{awatable.__module__}_{awatable.__name__}"
|
||||
|
||||
def get_handler_or_create(handler: Awaitable) -> StarHandlerMetadata:
|
||||
'''获取 Handler 或者创建一个新的 Handler'''
|
||||
handler_full_name = get_handler_full_name(handler)
|
||||
if handler_full_name in star_handlers_map:
|
||||
return star_handlers_map[handler_full_name]
|
||||
else:
|
||||
md = StarHandlerMetadata(
|
||||
handler_full_name=handler_full_name,
|
||||
handler_name=handler.__name__,
|
||||
handler_module_str=handler.__module__,
|
||||
handler=handler,
|
||||
event_filters=[]
|
||||
)
|
||||
star_handlers_registry.append(md)
|
||||
star_handlers_map[handler_full_name] = md
|
||||
return md
|
||||
|
||||
def register_command(command_name: str = None, *args):
|
||||
'''注册一个 Command'''
|
||||
|
||||
new_command = None
|
||||
add_to_event_filters = False
|
||||
if isinstance(command_name, RegisteringCommandable):
|
||||
# 子指令
|
||||
new_command = CommandFilter(args[0], None)
|
||||
command_name.parent_group.add_sub_command_filter(new_command)
|
||||
else:
|
||||
# 裸指令
|
||||
new_command = CommandFilter(command_name, None)
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable)
|
||||
new_command.init_handler_md(handler_md)
|
||||
if add_to_event_filters:
|
||||
# 裸指令
|
||||
handler_md.event_filters.append(new_command)
|
||||
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_command_group(command_group_name: str = None, *args):
|
||||
'''注册一个 CommandGroup'''
|
||||
|
||||
new_group = None
|
||||
add_to_event_filters = False
|
||||
if isinstance(command_group_name, RegisteringCommandable):
|
||||
# 子指令组
|
||||
new_group = CommandGroupFilter(args[0])
|
||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||
else:
|
||||
# 根指令组
|
||||
new_group = CommandGroupFilter(command_group_name)
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(obj):
|
||||
if add_to_event_filters:
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj)
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
|
||||
return decorator
|
||||
|
||||
class RegisteringCommandable():
|
||||
'''用于指令组级联注册'''
|
||||
group = register_command_group
|
||||
command = register_command
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
|
||||
def register_event_message_type(event_message_type: EventMessageType):
|
||||
'''注册一个 EventMessageType'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
|
||||
return awatable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
|
||||
'''注册一个 PlatformAdapterType'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
|
||||
return awatable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_regex(regex: str):
|
||||
'''注册一个 Regex'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
handler_md.event_filters.append(RegexFilter(regex))
|
||||
return awatable
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
star_map: Dict[str, StarMetadata] = {}
|
||||
'''key 是模块路径,__module__'''
|
||||
|
||||
class Star(CommandParserMixin):
|
||||
'''所有插件(Star)的父类,所有插件都应该继承于这个类'''
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
'''
|
||||
Star 的元数据。
|
||||
'''
|
||||
name: str
|
||||
author: str # 插件作者
|
||||
desc: str # 插件简介
|
||||
version: str # 插件版本
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
star_cls_type: type = None
|
||||
'''Star 的类对象的类型'''
|
||||
module_path: str = None
|
||||
'''Star 的模块路径'''
|
||||
|
||||
star_cls: object = None
|
||||
'''Star 的类对象'''
|
||||
module: ModuleType = None
|
||||
'''Star 的模块对象'''
|
||||
root_dir_name: str = None
|
||||
'''Star 的根目录名'''
|
||||
reserved: bool = False
|
||||
'''是否是 AstrBot 的保留 Star'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, List, Dict
|
||||
from .filter import HandlerFilter
|
||||
|
||||
star_handlers_registry: List[StarHandlerMetadata] = []
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
'''用于快速查找。key 是 handler_full_name'''
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata():
|
||||
'''描述一个 Star 所注册的某一个 Handler。'''
|
||||
|
||||
handler_full_name: str
|
||||
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
|
||||
|
||||
handler_name: str
|
||||
'''Handler 的名字,也就是方法名'''
|
||||
|
||||
handler_module_str: str
|
||||
'''Handler 所在的模块路径。'''
|
||||
|
||||
handler: Awaitable
|
||||
'''Handler 的函数对象,应当是一个异步函数'''
|
||||
|
||||
event_filters: List[HandlerFilter]
|
||||
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
|
||||
|
||||
desc: str = ""
|
||||
'''Handler 的描述信息'''
|
||||
@@ -1,31 +1,35 @@
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
import shutil
|
||||
import yaml
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from types import ModuleType
|
||||
from typing import List, Awaitable
|
||||
from pip import main as pip_main
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger
|
||||
from .context import Context
|
||||
from . import RegisteredPlugin, PluginMetadata
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
|
||||
from .star_handler import star_handlers_registry
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase):
|
||||
self.updator = PluginUpdator(config.plugin_repo_mirror)
|
||||
self.context = Context(event_queue, config, db)
|
||||
def __init__(
|
||||
self,
|
||||
context: Context,
|
||||
config: AstrBotConfig
|
||||
):
|
||||
self.updator = PluginUpdator(config['plugin_repo_mirror'])
|
||||
|
||||
self.context = context
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
|
||||
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
|
||||
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
classes = []
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
@@ -69,6 +73,9 @@ class PluginManager:
|
||||
return plugins
|
||||
|
||||
def _check_plugin_dept_update(self, target_plugin: str = None):
|
||||
'''检查插件的依赖
|
||||
如果 target_plugin 为 None,则检查所有插件的依赖
|
||||
'''
|
||||
plugin_dir = self.plugin_store_path
|
||||
if not os.path.exists(plugin_dir):
|
||||
return False
|
||||
@@ -76,7 +83,7 @@ class PluginManager:
|
||||
if target_plugin:
|
||||
to_update.append(target_plugin)
|
||||
else:
|
||||
for p in self.context.registered_plugins:
|
||||
for p in self.context.get_all_stars():
|
||||
to_update.append(p.root_dir_name)
|
||||
for p in to_update:
|
||||
plugin_path = os.path.join(plugin_dir, p)
|
||||
@@ -89,6 +96,7 @@ class PluginManager:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
def _update_plugin_dept(self, path):
|
||||
'''更新插件的依赖'''
|
||||
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
|
||||
if self.config.pip_install_arg:
|
||||
args.extend(self.config.pip_install_arg)
|
||||
@@ -96,7 +104,11 @@ class PluginManager:
|
||||
if result_code != 0:
|
||||
raise Exception(str(result_code))
|
||||
|
||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata:
|
||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
|
||||
'''v3.4.0 以前的方式载入插件元数据
|
||||
|
||||
先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
|
||||
'''
|
||||
metadata = None
|
||||
|
||||
if not os.path.exists(plugin_path):
|
||||
@@ -112,8 +124,8 @@ class PluginManager:
|
||||
if isinstance(metadata, dict):
|
||||
if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata:
|
||||
raise Exception("插件元数据信息不完整。")
|
||||
metadata = PluginMetadata(
|
||||
plugin_name=metadata['name'],
|
||||
metadata = StarMetadata(
|
||||
name=metadata['name'],
|
||||
author=metadata['author'],
|
||||
desc=metadata['desc'],
|
||||
version=metadata['version'],
|
||||
@@ -123,72 +135,68 @@ class PluginManager:
|
||||
return metadata
|
||||
|
||||
def reload(self):
|
||||
'''
|
||||
加载插件类
|
||||
'''
|
||||
registered_plugins = self.context.registered_plugins
|
||||
plugins = self._get_plugin_modules()
|
||||
if plugins is None:
|
||||
'''扫描并加载所有的 Star'''
|
||||
star_handlers_registry.clear()
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
return False, "未找到任何插件模块"
|
||||
fail_rec = ""
|
||||
|
||||
registered_map = {}
|
||||
for p in registered_plugins:
|
||||
registered_map[p.module_path] = None
|
||||
|
||||
for plugin in plugins:
|
||||
|
||||
# 导入 Star 模块,并尝试实例化 Star 类
|
||||
for plugin_module in plugin_modules:
|
||||
try:
|
||||
p = plugin['module']
|
||||
module_path = plugin['module_path']
|
||||
root_dir_name = plugin['pname']
|
||||
reserved = plugin.get('reserved', False)
|
||||
module_str = plugin_module['module']
|
||||
module_path = plugin_module['module_path']
|
||||
root_dir_name = plugin_module['pname']
|
||||
reserved = plugin_module.get('reserved', False)
|
||||
|
||||
logger.info(f"正在加载插件 {root_dir_name} ...")
|
||||
logger.info(f"正在载入插件 {root_dir_name} ...")
|
||||
|
||||
pre = "data.plugins." if not reserved else "packages."
|
||||
|
||||
# 尝试导入插件模块
|
||||
# 尝试导入模块
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path += root_dir_name + "." + module_str
|
||||
try:
|
||||
module = __import__(pre + root_dir_name + "." + p, fromlist=[p])
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except (ModuleNotFoundError, ImportError) as e:
|
||||
# 尝试安装插件依赖
|
||||
# 尝试安装依赖
|
||||
self._check_plugin_dept_update(target_plugin=root_dir_name)
|
||||
module = __import__(pre + root_dir_name + "." + p, fromlist=[p])
|
||||
module = __import__(path, fromlist=[module_str])
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
|
||||
continue
|
||||
|
||||
cls = self._get_classes(module)
|
||||
|
||||
# 实例化插件类
|
||||
try:
|
||||
obj = getattr(module, cls[0])(context=self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"插件 {root_dir_name} 实例化失败。")
|
||||
raise e
|
||||
|
||||
# 解析插件元数据,加入注册列表
|
||||
metadata = None
|
||||
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
|
||||
if module_path not in registered_map:
|
||||
registered_plugins.append(RegisteredPlugin(
|
||||
metadata=metadata,
|
||||
plugin_instance=obj,
|
||||
module=module,
|
||||
module_path=module_path,
|
||||
root_dir_name=root_dir_name,
|
||||
reserved=reserved
|
||||
))
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
star_metadata = star_map[path]
|
||||
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
|
||||
star_metadata.module = module
|
||||
star_metadata.root_dir_name = root_dir_name
|
||||
star_metadata.reserved = reserved
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
|
||||
classes = self._get_classes(module)
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"插件 {root_dir_name} 实例化失败。")
|
||||
raise e
|
||||
|
||||
metadata = None
|
||||
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
|
||||
metadata.star_cls = obj
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
metadata.star_cls_type = obj.__class__
|
||||
metadata.module_path = path
|
||||
star_map[path] = metadata
|
||||
|
||||
for command in self.context.commands_handler:
|
||||
if self.context.commands_handler[command].plugin_name == metadata.plugin_name:
|
||||
self.context.commands_handler[command].plugin_metadata = metadata
|
||||
for listener in self.context.listeners_handler:
|
||||
if self.context.listeners_handler[listener].plugin_name == metadata.plugin_name:
|
||||
self.context.listeners_handler[listener].plugin_metadata = metadata
|
||||
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
|
||||
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
@@ -200,26 +208,26 @@ class PluginManager:
|
||||
return False, fail_rec
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
plugin_path = await self.updator.update(repo_url)
|
||||
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
|
||||
f.write(repo_url)
|
||||
plugin_path = await self.updator.install(repo_url)
|
||||
self._check_plugin_dept_update()
|
||||
return plugin_path
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_plugin(plugin_name)
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
self.context.registered_plugins.remove(plugin)
|
||||
|
||||
del star_map[plugin.module_path]
|
||||
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_plugin(plugin_name)
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
if plugin.reserved:
|
||||
@@ -228,41 +236,13 @@ class PluginManager:
|
||||
await self.updator.update(plugin)
|
||||
|
||||
def install_plugin_from_file(self, zip_file_path: str):
|
||||
# try to unzip
|
||||
temp_dir = os.path.join(os.path.dirname(zip_file_path), str(uuid.uuid4()))
|
||||
self.updator.unzip_file(zip_file_path, temp_dir)
|
||||
# check if the plugin has metadata.yaml
|
||||
if not os.path.exists(os.path.join(temp_dir, "metadata.yaml")):
|
||||
remove_dir(temp_dir)
|
||||
raise Exception("插件缺少 metadata.yaml 文件。")
|
||||
|
||||
metadata = self._load_plugin_metadata(temp_dir)
|
||||
plugin_name = metadata.plugin_name
|
||||
if not plugin_name:
|
||||
remove_dir(temp_dir)
|
||||
raise Exception("插件 metadata.yaml 文件中 name 字段为空。")
|
||||
plugin_name = self.updator.format_name(plugin_name)
|
||||
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
|
||||
self.updator.unzip_file(zip_file_path, desti_dir)
|
||||
|
||||
ppath = self.plugin_store_path
|
||||
plugin_path = os.path.join(ppath, plugin_name)
|
||||
if os.path.exists(plugin_path):
|
||||
remove_dir(plugin_path)
|
||||
|
||||
# move to the target path
|
||||
shutil.move(temp_dir, plugin_path)
|
||||
|
||||
if metadata.repo:
|
||||
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
|
||||
f.write(metadata.repo)
|
||||
|
||||
# remove the temp dir
|
||||
remove_dir(temp_dir)
|
||||
# remove the zip
|
||||
try:
|
||||
os.remove(zip_file_path)
|
||||
except BaseException as e:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
|
||||
self._check_plugin_dept_update()
|
||||
|
||||
def get_platform_insts(self):
|
||||
return self.context.registered_platforms
|
||||
|
||||
def get_loaded_plugins(self):
|
||||
return self.context.registered_plugins
|
||||
|
||||
@@ -2,7 +2,7 @@ import os, zipfile, shutil
|
||||
|
||||
from ..updator import RepoZipUpdator
|
||||
from astrbot.core.utils.io import remove_dir, on_error
|
||||
from ..plugin import RegisteredPlugin
|
||||
from ..star.star import StarMetadata
|
||||
from typing import Union
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -13,20 +13,22 @@ class PluginUpdator(RepoZipUpdator):
|
||||
|
||||
def get_plugin_store_path(self) -> str:
|
||||
return self.plugin_store_path
|
||||
|
||||
async def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
|
||||
repo_url = None
|
||||
|
||||
async def install(self, repo_url: str) -> str:
|
||||
repo_name = self.format_repo_name(repo_url)
|
||||
plugin_path = os.path.join(self.plugin_store_path, repo_name)
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
self.unzip_file(plugin_path + ".zip", plugin_path)
|
||||
|
||||
if not isinstance(plugin, str):
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
if not os.path.exists(os.path.join(plugin_path, "REPO")):
|
||||
raise Exception("插件更新信息文件 `REPO` 不存在,请手动升级,或者先卸载然后重新安装该插件。")
|
||||
|
||||
with open(os.path.join(plugin_path, "REPO"), "r", encoding='utf-8') as f:
|
||||
repo_url = f.read()
|
||||
else:
|
||||
repo_url = plugin
|
||||
plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url))
|
||||
return plugin_path
|
||||
|
||||
async def update(self, plugin: StarMetadata) -> str:
|
||||
repo_url = plugin.repo
|
||||
|
||||
if not repo_url:
|
||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
@@ -34,7 +36,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
try:
|
||||
remove_dir(plugin_path)
|
||||
except BaseException as e:
|
||||
logger.error(f"删除旧版本插件 {plugin.metadata.plugin_name} 文件夹失败: {str(e)},使用覆盖安装。")
|
||||
logger.error(f"删除旧版本插件 {plugin_path} 文件夹失败: {str(e)},使用覆盖安装。")
|
||||
|
||||
self.unzip_file(plugin_path + ".zip", plugin_path)
|
||||
|
||||
@@ -48,13 +50,10 @@ class PluginUpdator(RepoZipUpdator):
|
||||
update_dir = z.namelist()[0]
|
||||
z.extractall(target_dir)
|
||||
|
||||
avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir]
|
||||
|
||||
files = os.listdir(os.path.join(target_dir, update_dir))
|
||||
for f in files:
|
||||
logger.info(f"移动更新文件/目录: {f}")
|
||||
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
|
||||
if f in avoid_dirs: continue
|
||||
if os.path.exists(os.path.join(target_dir, f)):
|
||||
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
|
||||
else:
|
||||
@@ -68,3 +67,4 @@ class PluginUpdator(RepoZipUpdator):
|
||||
os.remove(zip_path)
|
||||
except:
|
||||
logger.warning(f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
|
||||
@@ -10,13 +10,10 @@ class CommandTokens():
|
||||
return None
|
||||
return self.tokens[idx].strip()
|
||||
|
||||
class CommandParser():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def parse(self, message: str):
|
||||
class CommandParserMixin():
|
||||
def parse_commands(self, message: str):
|
||||
cmd_tokens = CommandTokens()
|
||||
cmd_tokens.tokens = message.split(" ")
|
||||
cmd_tokens.tokens = re.split(r"\s+", message)
|
||||
cmd_tokens.len = len(cmd_tokens.tokens)
|
||||
return cmd_tokens
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import aiohttp
|
||||
import uuid
|
||||
|
||||
class ImageUploader():
|
||||
def __init__(self) -> None:
|
||||
self.S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
|
||||
|
||||
async def upload_image(self, image_path: str) -> str:
|
||||
'''
|
||||
上传图像文件到S3
|
||||
'''
|
||||
with open(image_path, "rb") as f:
|
||||
image = f.read()
|
||||
|
||||
image_url = f"{self.S3_URL}/{uuid.uuid4().hex}.jpg"
|
||||
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
|
||||
async with session.put(image_url, data=image) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to upload image: {resp.status}")
|
||||
return image_url
|
||||
@@ -0,0 +1,31 @@
|
||||
import inspect
|
||||
from typing import Awaitable, List, Union, Dict, Any, Type
|
||||
|
||||
class ParameterValidationMixin:
|
||||
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
|
||||
'''将参数列表 params 根据 param_type 转换为参数字典。
|
||||
'''
|
||||
result = {}
|
||||
print(params, param_type)
|
||||
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
|
||||
if i >= len(params):
|
||||
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
|
||||
# 是类型
|
||||
raise ValueError(f"参数 {param_name} 缺失")
|
||||
else:
|
||||
# 是默认值
|
||||
result[param_name] = param_type_or_default_val
|
||||
else:
|
||||
# 尝试强制转换
|
||||
try:
|
||||
if param_type_or_default_val == None:
|
||||
if params[i].isdigit():
|
||||
result[param_name] = int(params[i])
|
||||
else:
|
||||
result[param_name] = params[i]
|
||||
else:
|
||||
result[param_name] = type(param_type_or_default_val)(params[i])
|
||||
except ValueError:
|
||||
raise ValueError(f"参数 {param_name} 类型错误")
|
||||
print(result)
|
||||
return result
|
||||
@@ -21,4 +21,4 @@ class AstrBotDashBoardLifecycle:
|
||||
await task
|
||||
except asyncio.CancelledError as e:
|
||||
logger.info("🌈 正在关闭 AstrBot...")
|
||||
core_lifecycle.stop()
|
||||
await core_lifecycle.stop()
|
||||
@@ -41,7 +41,7 @@ class AuthRoute(Route):
|
||||
if new_username:
|
||||
self.config.dashboard.username = new_username
|
||||
|
||||
self.config.flush_config()
|
||||
self.config.save_config()
|
||||
|
||||
return Response().ok(None, "修改成功").__dict__
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os, json
|
||||
import os, json, traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import Quart, request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE, ADAPTER_CONFIG_TEMPLATE
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.plugin.config import update_config
|
||||
from astrbot.core.star.config import update_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from dataclasses import asdict
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
@@ -55,10 +54,6 @@ def validate_config(data, config: AstrBotConfig):
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
validate(data)
|
||||
|
||||
# hardcode warning
|
||||
data['config_version'] = config.config_version
|
||||
data['dashboard'] = asdict(config.dashboard)
|
||||
|
||||
return errors
|
||||
|
||||
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
|
||||
@@ -66,7 +61,7 @@ def save_astrbot_config(post_config: dict, config: AstrBotConfig):
|
||||
errors = validate_config(post_config, config)
|
||||
if errors:
|
||||
raise ValueError(f"格式校验未通过: {errors}")
|
||||
config.flush_config(post_config)
|
||||
config.save_config(post_config)
|
||||
|
||||
def save_extension_config(post_config: dict):
|
||||
if 'namespace' not in post_config:
|
||||
@@ -112,6 +107,7 @@ class ConfigRoute(Route):
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def post_extension_configs(self):
|
||||
@@ -123,14 +119,15 @@ class ConfigRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config.to_dict()
|
||||
config = self.config
|
||||
for key in self.config_key_dont_show:
|
||||
if key in config:
|
||||
del config[key]
|
||||
return {
|
||||
"metadata": CONFIG_METADATA_2,
|
||||
"config": config,
|
||||
"provider_config_tmpl": PROVIDER_CONFIG_TEMPLATE
|
||||
"provider_config_tmpl": PROVIDER_CONFIG_TEMPLATE,
|
||||
"adapter_config_tmpl": ADAPTER_CONFIG_TEMPLATE
|
||||
}
|
||||
|
||||
async def _get_extension_config(self, namespace: str):
|
||||
|
||||
@@ -2,7 +2,7 @@ import threading, traceback, uuid
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import Quart, request
|
||||
from astrbot.core.plugin.plugin_manager import PluginManager
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
class PluginRoute(Route):
|
||||
@@ -21,14 +21,14 @@ class PluginRoute(Route):
|
||||
|
||||
async def get_plugins(self):
|
||||
_plugin_resp = []
|
||||
for plugin in self.plugin_manager.context.registered_plugins:
|
||||
_p = plugin.metadata
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
_t = {
|
||||
"name": _p.plugin_name,
|
||||
"repo": '' if _p.repo is None else _p.repo,
|
||||
"author": _p.author,
|
||||
"desc": _p.desc,
|
||||
"version": _p.version
|
||||
"name": plugin.name,
|
||||
"repo": '' if plugin.repo is None else plugin.repo,
|
||||
"author": plugin.author,
|
||||
"desc": plugin.desc,
|
||||
"version": plugin.version,
|
||||
"reserved": plugin.reserved
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
return Response().ok(_plugin_resp).__dict__
|
||||
|
||||
@@ -72,8 +72,8 @@ class StatRoute(Route):
|
||||
stat_dict.update({
|
||||
"platform": self.db_helper.get_grouped_base_stats(offset_sec).platform,
|
||||
"message_count": self.db_helper.get_total_message_count() or 0,
|
||||
"platform_count": len(self.core_lifecycle.plugin_manager.get_platform_insts()),
|
||||
"plugin_count": len(self.core_lifecycle.plugin_manager.get_loaded_plugins()),
|
||||
"platform_count": len(self.core_lifecycle.platform_manager.get_insts()),
|
||||
"plugin_count": len(self.core_lifecycle.star_context.get_all_stars()),
|
||||
"message_time_series": message_time_based_stats,
|
||||
"running": self.format_sec(int(time.time()) - self.core_lifecycle.start_time),
|
||||
"memory": {
|
||||
|
||||
@@ -39,10 +39,10 @@ class AstrBotDashboard():
|
||||
return
|
||||
# claim jwt
|
||||
token = request.headers.get("Authorization")
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
if not token:
|
||||
return Response().error("未授权").__dict__
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
try:
|
||||
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError:
|
||||
|
||||
@@ -96,7 +96,8 @@ if __name__ == "__main__":
|
||||
# print logo
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
dashboard_lifecycle = AstrBotDashBoardLifecycle(db)
|
||||
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
|
||||
asyncio.run(core_lifecycle.initialize())
|
||||
|
||||
dashboard_lifecycle = AstrBotDashBoardLifecycle(db)
|
||||
asyncio.run(dashboard_lifecycle.start(core_lifecycle))
|
||||
+236
-67
@@ -1,59 +1,21 @@
|
||||
import aiohttp, base64, os, json, re, time
|
||||
import aiohttp
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from typing import Dict
|
||||
from astrbot.api import Context, AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import logger, command_parser
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api import logger
|
||||
from astrbot.api import personalities
|
||||
from astrbot.api.provider import Personality
|
||||
|
||||
class Main:
|
||||
def __init__(self, context: Context) -> None:
|
||||
from typing import Union
|
||||
|
||||
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
context.register_commands("astrbot", "help", "查看 AstrBot 帮助", 10, self.help)
|
||||
context.register_commands("astrbot", "plugin", "AstrBot 插件管理", 10, self.plugin)
|
||||
context.register_commands("astrbot", "t2i", "关闭/启动文本转图片", 10, self.t2i)
|
||||
context.register_commands("astrbot", "myid", "查看自己在该平台上的 ID", 10, self.myid)
|
||||
|
||||
context.register_listener("astrbot", "keywords_ban_rate_limit", self.keywords_ban, "关键词屏蔽和发言频率监听器")
|
||||
# keywords
|
||||
with open(os.path.join(os.path.dirname(__file__), "unfit_words"), "r", encoding="utf-8") as f:
|
||||
self.keywords: list = json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords']
|
||||
internal_keywords_cfg = context.get_config().content_safety.internal_keywords
|
||||
if internal_keywords_cfg.enable:
|
||||
self.keywords.extend(internal_keywords_cfg.extra_keywords)
|
||||
|
||||
# rate limit
|
||||
self.user_rate_limit: Dict[int, int] = {}
|
||||
rl_cfg = context.get_config().platform_settings.rate_limit
|
||||
self.rate_limit_time: int = rl_cfg.time
|
||||
self.rate_limit_count: int = rl_cfg.count
|
||||
self.user_frequency = {}
|
||||
|
||||
async def keywords_ban(self, event: AstrMessageEvent):
|
||||
if not event.is_wake_up():
|
||||
return
|
||||
|
||||
# keywords 检测
|
||||
for i in self.keywords:
|
||||
matches = re.match(i, event.get_message_str().strip(), re.I | re.M)
|
||||
if matches:
|
||||
event.set_result(MessageEventResult().message("你的消息中包含不适当的关键词,已被屏蔽。"))
|
||||
return
|
||||
|
||||
# rate limit 检测
|
||||
ts = int(time.time())
|
||||
if event.session_id in self.user_frequency:
|
||||
if ts-self.user_frequency[event.session_id]['time'] > self.rate_limit_time:
|
||||
# reset
|
||||
self.user_frequency[event.session_id]['time'] = ts
|
||||
self.user_frequency[event.session_id]['count'] = 1
|
||||
return
|
||||
if self.user_frequency[event.session_id]['count'] >= self.rate_limit_count:
|
||||
event.set_result(MessageEventResult().message("你发送消息的频率过快,请稍后再试。"))
|
||||
return
|
||||
self.user_frequency[event.session_id]['count'] += 1
|
||||
else:
|
||||
t = {'time': ts, 'count': 1}
|
||||
self.user_frequency[event.session_id] = t
|
||||
|
||||
|
||||
@filter.command("help")
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
notice = ""
|
||||
try:
|
||||
@@ -63,26 +25,41 @@ class Main:
|
||||
except BaseException as e:
|
||||
pass
|
||||
|
||||
msg = "# AstrBot 帮助\n## 已注册的指令\n"
|
||||
for key, value in self.context.commands_handler.items():
|
||||
if value.plugin_metadata:
|
||||
msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n"
|
||||
else: msg += f"- `{key}`: {value.description}\n"
|
||||
msg = "已注册的 AstrBot 内置指令:"
|
||||
msg += f"""[System]
|
||||
/plugin: 插件管理
|
||||
/t2i: 开启/关闭文本转图片模式
|
||||
/sid: 获取当前会话的 ID
|
||||
/op <admin_id>: 授权管理员
|
||||
/deop <admin_id>: 取消管理员
|
||||
/wl <sid>: 添加会话白名单
|
||||
/dwl <sid>: 删除会话白名单
|
||||
|
||||
msg += "\n> 提示:使用 /plugin 查看已加载的插件\n"
|
||||
msg += notice
|
||||
[大模型]
|
||||
/provider: 查看、切换大模型提供商
|
||||
/model: 查看、切换提供商模型列表
|
||||
/key: 查看、切换 API Key
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 获取会话历史记录
|
||||
/persona: 情境人格设置
|
||||
|
||||
event.set_result(MessageEventResult().message(msg))
|
||||
|
||||
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
|
||||
@filter.command("plugin")
|
||||
async def plugin(self, event: AstrMessageEvent):
|
||||
plugin_list_info = "已加载的插件:\n"
|
||||
for plugin in self.context.registered_plugins:
|
||||
plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {plugin.metadata.author}: {plugin.metadata.desc}\n"
|
||||
plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {
|
||||
plugin.metadata.author}: {plugin.metadata.desc}\n"
|
||||
if plugin_list_info.strip() == "":
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
|
||||
|
||||
event.set_result(MessageEventResult().message(f"{plugin_list_info}"))
|
||||
|
||||
|
||||
@filter.command("t2i")
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
config = self.context.get_config()
|
||||
if config.t2i:
|
||||
@@ -93,7 +70,199 @@ class Main:
|
||||
config.t2i = True
|
||||
config.save_config()
|
||||
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
|
||||
|
||||
async def myid(self, event: AstrMessageEvent):
|
||||
|
||||
@filter.command("sid")
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
sid = event.unified_msg_origin
|
||||
user_id = str(event.get_sender_id())
|
||||
event.set_result(MessageEventResult().message(f"你的 ID 是 {user_id}。此 ID 可用于设置 AstrBot 管理员。"))
|
||||
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
|
||||
UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deop <UID> 取消管理员。"""
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
|
||||
@filter.command("op")
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str):
|
||||
self.context.get_config()['admins_id'].append(admin_id)
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
|
||||
@filter.command("deop")
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str):
|
||||
try:
|
||||
self.context.get_config()['admins_id'].remove(admin_id)
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("取消授权成功。"))
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此用户 ID 不在管理员名单内。"))
|
||||
|
||||
@filter.command("wl")
|
||||
async def wl(self, event: AstrMessageEvent, sid: str):
|
||||
self.context.get_config()['platform_settings']['id_whitelist'].append(sid)
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
|
||||
@filter.command("dwl")
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str):
|
||||
try:
|
||||
self.context.get_config()['platform_settings']['id_whitelist'].remove(sid)
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("删除白名单成功。"))
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
@filter.command("provider")
|
||||
async def provider(self, event: AstrMessageEvent, idx: int = None):
|
||||
'''查看或者切换 LLM Provider'''
|
||||
|
||||
if idx is None:
|
||||
ret = "## 当前载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})"
|
||||
if self.provider == llm:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
ret += "\n使用 /provider <序号> 切换提供商。"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
else:
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1]
|
||||
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}。"))
|
||||
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
await self.context.get_using_provider().forget(message.session_id)
|
||||
message.set_result(MessageEventResult().message("重置成功"))
|
||||
|
||||
@filter.command("model")
|
||||
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
|
||||
if idx_or_name is None:
|
||||
models = []
|
||||
try:
|
||||
models = await self.context.get_using_provider().get_models()
|
||||
except BaseException as e:
|
||||
message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)))
|
||||
return
|
||||
i = 1
|
||||
ret = "下面列出了此服务提供商可用模型:"
|
||||
for model in models:
|
||||
ret += f"\n{i}. {model}"
|
||||
i += 1
|
||||
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
if isinstance(idx_or_name, int):
|
||||
models = []
|
||||
try:
|
||||
models = await self.context.get_using_provider().get_models()
|
||||
except BaseException as e:
|
||||
message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)))
|
||||
return
|
||||
if idx_or_name > len(models) or idx_or_name < 1:
|
||||
message.set_result(MessageEventResult().message("模型序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_model = models[idx_or_name-1]
|
||||
self.context.get_using_provider().set_model(new_model)
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message("切换模型未知错误: "+str(e)))
|
||||
message.set_result(MessageEventResult().message("切换模型成功。"))
|
||||
else:
|
||||
self.context.get_using_provider().set_model(idx_or_name)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"切换模型成功。 \n模型信息: {idx_or_name}"))
|
||||
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
size_per_page = 3
|
||||
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
|
||||
|
||||
history = ""
|
||||
for context in contexts:
|
||||
history += f"{context}\n"
|
||||
|
||||
ret = f"""历史记录:
|
||||
{history}
|
||||
第 {page} 页 | 共 {total_pages} 页
|
||||
|
||||
*输入 /history 2 跳转到第 2 页
|
||||
"""
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int=None):
|
||||
|
||||
if index == None:
|
||||
keys_data = self.context.get_using_provider().get_keys()
|
||||
curr_key = self.context.get_using_provider().get_current_key()
|
||||
ret = "Key:"
|
||||
for i, k in enumerate(keys_data):
|
||||
ret += f"\n{i+1}. {k[:8]}"
|
||||
|
||||
ret += f"\n当前 Key: {curr_key[:8]}"
|
||||
ret += "\n当前模型: " + self.context.get_using_provider().get_model()
|
||||
ret += "\n使用 /key <idx> 切换 Key。"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
else:
|
||||
keys_data = self.context.get_using_provider().get_keys()
|
||||
if index > len(keys_data) or index < 1:
|
||||
message.set_result(MessageEventResult().message("Key 序号错误。"))
|
||||
else:
|
||||
try:
|
||||
new_key = keys_data[index-1]
|
||||
self.context.get_using_provider().set_key(new_key)
|
||||
except BaseException as e:
|
||||
message.set_result(
|
||||
MessageEventResult().message("切换 Key 未知错误: "+str(e)))
|
||||
message.set_result(MessageEventResult().message("切换 Key 成功。"))
|
||||
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
l = message.message_str.split(" ")
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"""[Persona]
|
||||
|
||||
- 设置人格: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格列表: `/persona list`
|
||||
- 人格详细信息: `/persona view 人格名`
|
||||
- 自定义人格: /persona 人格文本
|
||||
- 重置 LLM 会话(清除人格): /reset
|
||||
- 重置 LLM 会话(保留人格): /reset p
|
||||
|
||||
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
|
||||
"""))
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
msg += f"- {key}\n"
|
||||
msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息'
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "view":
|
||||
if len(l) == 2:
|
||||
message.set_result(MessageEventResult().message("请输入人格名"))
|
||||
ps = l[2].strip()
|
||||
if ps in personalities:
|
||||
msg = f"人格{ps}的详细信息:\n"
|
||||
msg += f"{personalities[ps]}\n"
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if ps in personalities:
|
||||
self.context.get_using_provider().curr_personality = Personality(
|
||||
name=ps, prompt=personalities[ps])
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
else:
|
||||
self.context.get_using_provider().curr_personality = Personality(
|
||||
name="自定义人格", prompt=ps)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
name: astrbot # 插件名称
|
||||
desc: AstrBot 内置指令集
|
||||
help:
|
||||
version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1
|
||||
author: AstrBot # 作者
|
||||
repo: https://github.com/Soulter/AstrBot
|
||||
@@ -1,13 +0,0 @@
|
||||
from astrbot.api import Context
|
||||
from .aiocqhttp_platform_adapter import AiocqhttpAdapter
|
||||
from astrbot.api import logger
|
||||
|
||||
class Main:
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
platforms_config = context.get_config().platform
|
||||
settings = context.get_config().platform_settings
|
||||
for platform in platforms_config:
|
||||
if platform.name == "aiocqhttp" and platform.enable:
|
||||
self.context.register_platform(AiocqhttpAdapter(platform, settings, context.get_event_queue()))
|
||||
logger.info(f"已注册 aiocqhttp({platform.id}) 消息适配器。")
|
||||
@@ -1,6 +0,0 @@
|
||||
name: astrbot_adapter_aiocqhttp # 插件名称
|
||||
desc: 支持 OneBot 协议的消息平台适配器(反向 Websockets)
|
||||
help:
|
||||
version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1
|
||||
author: Soulter # 作者
|
||||
repo: https://github.com/Soulter/AstrBot
|
||||
@@ -1,18 +0,0 @@
|
||||
import botpy, logging
|
||||
# delete qqbotpy's logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
from astrbot.api import Context
|
||||
from .qqofficial_platform_adapter import QQOfficialPlatformAdapter
|
||||
from astrbot.api import logger
|
||||
|
||||
class Main:
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
platforms_config = context.get_config().platform
|
||||
settings = context.get_config().platform_settings
|
||||
for platform in platforms_config:
|
||||
if platform.name == "qq_official" and platform.enable:
|
||||
self.context.register_platform(QQOfficialPlatformAdapter(platform, settings, context.get_event_queue()))
|
||||
logger.info(f"已注册 qq_official({platform.id}) 消息适配器。")
|
||||
@@ -1,6 +0,0 @@
|
||||
name: astrbot_adapter_qqofficial # 插件名称
|
||||
desc: 支持 QQ 官方机器人平台的消息平台适配器
|
||||
help:
|
||||
version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1
|
||||
author: Soulter # 作者
|
||||
repo: https://github.com/Soulter/AstrBot
|
||||
@@ -1,18 +0,0 @@
|
||||
from astrbot.api import Context, AstrMessageEvent, MessageEventResult
|
||||
from .wechat_platform_adapter import WechatPlatformAdapter
|
||||
from astrbot.api import logger
|
||||
|
||||
class Main:
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.context = context
|
||||
platforms_config = context.get_config().platform
|
||||
settings = context.get_config().platform_settings
|
||||
for platform in platforms_config:
|
||||
if platform.name == "wechat" and platform.enable:
|
||||
self.context.register_platform(WechatPlatformAdapter(platform, settings, context.get_event_queue()))
|
||||
logger.info(f"已注册 wechat({platform.id}) 消息适配器。")
|
||||
|
||||
self.context.register_commands("astrbot_adapter_wechat", "wechatid", "查看微信ID", 1, self.get_wechat_id)
|
||||
|
||||
async def get_wechat_id(self, event: AstrMessageEvent):
|
||||
event.set_result(MessageEventResult().message("这个会话的微信ID是" + event.message_obj.raw_message.from_.username))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user