refactor: im so tired :)

This commit is contained in:
Soulter
2024-12-09 22:38:42 +08:00
parent 7abe90f2ac
commit bdfc77d349
109 changed files with 2843 additions and 2104 deletions
+6 -1
View File
@@ -2,6 +2,7 @@ __pycache__
botpy.log
.vscode
data_v2.db
data_v3.db
configs/session
configs/config.yaml
**/.DS_Store
@@ -11,4 +12,8 @@ data
cookies.json
logs/
addons/plugins
.coverage
.coverage
tests/astrbot_plugin_openai
chroma
+1 -13
View File
@@ -1,16 +1,4 @@
from astrbot.core.plugin import Context
from astrbot.core.platform import AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, CommandResult
from astrbot.core.provider import Provider, Personality
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core.utils.command_parser import CommandParser, CommandTokens
from astrbot.core.utils.func_call import FuncCall
from astrbot.core import html_renderer
from astrbot.core.plugin.config import *
command_parser = CommandParser()
from astrbot.core import html_renderer
+40
View File
@@ -0,0 +1,40 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
# event
from astrbot.core.message.message_event_result import (
MessageEventResult, MessageChain, CommandResult, EventResultType
)
from astrbot.core.platform import AstrMessageEvent
# star register
from astrbot.core.star.register import (
register_command as command,
register_command_group as command_group,
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from astrbot.core.star.register import (
register_star as register # 注册插件(Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *
# provider
from astrbot.core.provider import Provider, Personality, ProviderMetaData
# platform
from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from .message_components import *
+5
View File
@@ -0,0 +1,5 @@
from astrbot.core.message.message_event_result import (
MessageEventResult, MessageChain, CommandResult, EventResultType
)
from astrbot.core.platform import AstrMessageEvent
+10
View File
@@ -0,0 +1,10 @@
from astrbot.core.star.register import (
register_command as command,
register_command_group as command_group,
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
from astrbot.core.star.filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
+5
View File
@@ -0,0 +1,5 @@
from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
+1
View File
@@ -0,0 +1 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
+6
View File
@@ -0,0 +1,6 @@
from astrbot.core.star.register import (
register_star as register # 注册插件(Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star.config import *
+2 -2
View File
@@ -1,2 +1,2 @@
from .default import DEFAULT_CONFIG_VERSION_2, VERSION, DB_PATH
from .astrbot_config import AstrBotConfig
from .default import DEFAULT_CONFIG, VERSION, DB_PATH
from .astrbot_config import *
+68 -281
View File
@@ -1,295 +1,82 @@
import os
import json
import shutil
import logging
from . import DEFAULT_CONFIG_VERSION_2
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional
import enum
from .default import DEFAULT_CONFIG
from typing import List, Dict
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
logger = logging.getLogger("astrbot")
@dataclass
class RateLimit:
time: int = 60
count: int = 30
class RateLimitStrategy(enum.Enum):
STALL = "stall"
DISCARD = "discard"
@dataclass
class PlatformSettings:
unique_session: bool = False
rate_limit: RateLimit = field(default_factory=RateLimit)
reply_prefix: str = ""
forward_threshold: int = 200
class AstrBotConfig(dict):
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
def __post_init__(self):
self.rate_limit = RateLimit(**self.rate_limit)
@dataclass
class PlatformConfig():
id: str = ""
name: str = ""
enable: bool = False
@dataclass
class QQOfficialPlatformConfig(PlatformConfig):
appid: str = ""
secret: str = ""
enable_group_c2c: bool = True
enable_guild_direct_message: bool = True
@dataclass
class AiocqhttpPlatformConfig(PlatformConfig):
ws_reverse_host: str = ""
ws_reverse_port: int = 6199
qq_id_whitelist: List[str] = field(default_factory=list)
qq_group_id_whitelist: List[str] = field(default_factory=list)
@dataclass
class WechatPlatformConfig(PlatformConfig):
wechat_id_whitelist: List[str] = field(default_factory=list)
@dataclass
class ModelConfig:
model: str = "gpt-4o"
max_tokens: int = 6000
temperature: float = 0.9
top_p: float = 1
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@dataclass
class ImageGenerationModelConfig:
enable: bool = True
model: str = "dall-e-3"
size: str = "1024x1024"
style: str = "vivid"
quality: str = "standard"
@dataclass
class EmbeddingModel:
enable: bool = False
model: str = ""
@dataclass
class LLMConfig:
id: str = ""
name: str = "openai"
enable: bool = True
key: List[str] = field(default_factory=list)
api_base: str = ""
prompt_prefix: str = ""
default_personality: str = ""
model_config: ModelConfig = field(default_factory=ModelConfig)
image_generation_model_config: Optional[ImageGenerationModelConfig] = field(default_factory=ImageGenerationModelConfig)
embedding_model: Optional[EmbeddingModel] = field(default_factory=EmbeddingModel)
def __post_init__(self):
if isinstance(self.model_config, dict):
self.model_config = ModelConfig(**self.model_config)
if isinstance(self.image_generation_model_config, dict):
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) if self.image_generation_model_config else None
if isinstance(self.embedding_model, dict):
self.embedding_model = EmbeddingModel(**self.embedding_model) if self.embedding_model else None
@dataclass
class LLMSettings:
wake_prefix: str = ""
web_search: bool = False
identifier: bool = False
@dataclass
class BaiduAIPConfig:
enable: bool = False
app_id: str = ""
api_key: str = ""
secret_key: str = ""
@dataclass
class InternalKeywordsConfig:
enable: bool = True
extra_keywords: List[str] = field(default_factory=list)
@dataclass
class ContentSafetyConfig:
baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig)
internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig)
def __post_init__(self):
self.baidu_aip = BaiduAIPConfig(**self.baidu_aip)
self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords)
@dataclass
class DashboardConfig:
enable: bool = True
username: str = ""
password: str = ""
@dataclass
class ATRILongTermMemory:
enable: bool = False
summary_threshold_cnt: int = 5
@dataclass
class ATRIActiveMessage:
enable: bool = False
@dataclass
class ProjectATRI:
enable: bool = False
long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory)
active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage)
persona: str = ""
split_response: bool = True
embedding_provider_id: str = ""
summarize_provider_id: str = ""
chat_provider_id: str = ""
chat_base_model_path: str = ""
chat_adapter_model_path: str = ""
quantization_bit: int = 4
def __post_init__(self):
if isinstance(self.long_term_memory, dict):
self.long_term_memory = ATRILongTermMemory(**self.long_term_memory)
if isinstance(self.active_message, dict):
self.active_message = ATRIActiveMessage(**self.active_message)
@dataclass
class AstrBotConfig():
config_version: int = 2
platform_settings: PlatformSettings = field(default_factory=PlatformSettings)
llm: List[LLMConfig] = field(default_factory=list)
llm_settings: LLMSettings = field(default_factory=LLMSettings)
content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig)
t2i: bool = True
admins_id: List[str] = field(default_factory=list)
https_proxy: str = ""
http_proxy: str = ""
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
platform: List[PlatformConfig] = field(default_factory=list)
wake_prefix: List[str] = field(default_factory=list)
log_level: str = "INFO"
t2i_endpoint: str = ""
pip_install_arg: str = ""
plugin_repo_mirror: str = ""
project_atri: ProjectATRI = field(default_factory=ProjectATRI)
def __init__(self) -> None:
self.init_configs()
# compability
if isinstance(self.wake_prefix, str):
self.wake_prefix = [self.wake_prefix]
if len(self.wake_prefix) == 0:
self.wake_prefix.append("/")
def load_from_dict(self, data: Dict):
'''从字典中加载配置到对象。
@note: 适用于 version 2 配置文件。
'''
self.config_version=data.get("version", 2)
self.platform=[]
left_platforms = ["qq_official", "aiocqhttp", "wechat"]
for p in data.get("platform", []):
if 'name' not in p:
logger.warning("A platform config missing name, skipping.")
continue
if p["name"] == "qq_official":
self.platform.append(QQOfficialPlatformConfig(**p))
left_platforms.remove(p["name"])
elif p["name"] == "aiocqhttp":
self.platform.append(AiocqhttpPlatformConfig(**p))
left_platforms.remove(p["name"])
elif p["name"] == "wechat":
self.platform.append(WechatPlatformConfig(**p))
left_platforms.remove(p["name"])
# 注入默认配置
for p in left_platforms:
if p == "qq_official":
self.platform.append(QQOfficialPlatformConfig(id="default", name=p))
elif p == "aiocqhttp":
self.platform.append(AiocqhttpPlatformConfig(id="default", name=p))
elif p == "wechat":
self.platform.append(WechatPlatformConfig(id="default", name=p))
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
self.content_safety=ContentSafetyConfig(**data.get("content_safety", {}))
self.t2i=data.get("t2i", True)
self.admins_id=data.get("admins_id", [])
self.https_proxy=data.get("https_proxy", "")
self.http_proxy=data.get("http_proxy", "")
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
self.wake_prefix=data.get("wake_prefix", ["/"])
self.log_level=data.get("log_level", "INFO")
self.t2i_endpoint=data.get("t2i_endpoint", "")
self.pip_install_arg=data.get("pip_install_arg", "")
self.plugin_repo_mirror=data.get("plugin_repo_mirror", "")
self.project_atri=ProjectATRI(**data.get("project_atri", {}))
def flush_config(self, config: dict = None):
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False)
f.flush()
def save_config(self):
'''将现存配置写入文件'''
self.flush_config(self.to_dict())
def init_configs(self):
'''初始化必需的配置项'''
config = None
def __init__(self):
super().__init__()
if not self.check_exist():
self.flush_config()
config = DEFAULT_CONFIG_VERSION_2
else:
config = self.get_all()
# 加载配置到对象
self.load_from_dict(config)
# 保存到文件
# 这一步操作是为了保证配置文件中的字段的完整性。
# 在版本变动新增配置项时,将对象中新增的配置项的默认值写入文件。
self.save_config()
def get(self, key: str, default=None):
'''从文件系统中直接获取配置'''
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
if key in d:
return d[key]
else:
return default
def get_all(self):
'''从文件系统中获取所有配置'''
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
if not conf_str:
return {}
conf = json.loads(conf_str)
return conf
def put(self, key, value):
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
d[key] = value
'''不存在时载入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
def to_dict(self) -> Dict:
return asdict(self)
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
for key, value in refer_conf.items():
if key not in conf:
logger.info(f"检查到配置项 {path + "." + key if path else key} 不存在,已插入默认值 {value}")
conf[key] = value
has_new = True
else:
if conf[key] == None:
conf[key] = value
has_new = True
elif isinstance(value, dict):
has_new |= self.check_config_integrity(value, conf[key], path + "." + key if path else key)
return has_new
def save_config(self, replace_config: Dict = None):
'''将配置写入文件
如果传入 replace_config,则将配置替换为 replace_config
'''
if replace_config:
self.update(replace_config)
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
return None
def __delattr__(self, key):
try:
del self[key]
self.save_config()
except KeyError:
raise AttributeError(f"没有找到 Key: '{key}'")
def __setattr__(self, key, value):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(ASTRBOT_CONFIG_PATH)
+111 -127
View File
@@ -1,177 +1,160 @@
'''
这里定义了一些默认配置文件,请不要修改这个文件。如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
'''
VERSION = '3.4.0'
DB_PATH = 'data/data_v3.db'
# 默认配置
DEFAULT_CONFIG = {
"config_version": 2,
"platform_settings": {
"unique_session": False,
"rate_limit": {
"time": 60,
"count": 30,
"strategy": "stall" # stall, discard
},
"reply_prefix": "",
"forward_threshold": 200,
"id_whitelist": []
},
"provider": [],
"provider_settings": {
"wake_prefix": "",
"web_search": False,
"identifier": False,
"default_personality": "",
"prompt_prefix": ""
},
"content_safety": {
"internal_keywords": {
"enable": True,
"extra_keywords": []
},
"baidu_aip": {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
}
},
"admins_id": [],
"t2i": False,
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9"
},
"platform": [],
"wake_prefix": ["/"],
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": "",
"project_atri": {
"enable": False,
"long_term_memory": {
"enable": False,
"summary_threshold_cnt": 5,
"embedding_provider_id": "",
"summarize_provider_id": ""
},
"active_message": {
"enable": False
},
"vision": {
"enable": False,
"provider_id_or_ofa_model_path": "",
"reply_meme_prob": 0.4,
"reply_meme_similar_threshold": 0.7
},
"persona": "",
"split_response": True,
"chat_provider_id": "",
"chat_base_model_path": "",
"chat_adapter_model_path": "",
"quantization_bit": 4
}
}
# LLM 提供商配置模板
PROVIDER_CONFIG_TEMPLATE = {
"openai": {
"id": "default",
"name": "openai",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
},
"image_generation_model_config": {
"enable": False,
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
},
"embedding_model": {
"enable": False,
"model": "text-embedding-3-small"
"model": "gpt-4o-mini",
}
},
"ollama": {
"id": "ollama_default",
"name": "ollama",
"type": "openai_chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "llama3.1-8b",
"temperature": 0.9,
"top_p": 1,
}
},
"gemini": {
"id": "gemini_default",
"name": "gemini",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "gemini-1.5-flash",
}
},
"deepseek": {
"id": "deepseek_default",
"name": "deepseek",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.deepseek.com/v1",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "deepseek-chat",
}
},
"zhipu": {
"id": "zhipu_default",
"name": "zhipu(glm)",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "glm-4-flash",
}
},
}
# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D
DEFAULT_CONFIG_VERSION_2 = {
"config_version": 2,
"platform": [
{
"id": "default",
"name": "qq_official",
"enable": False,
"appid": "",
"secret": "",
"enable_group_c2c": True,
"enable_guild_direct_message": True,
},
{
"id": "default",
"name": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 6199,
"qq_id_whitelist": [],
"qq_group_id_whitelist": []
},
{
"id": "default",
"name": "wechat",
"enable": False,
"wechat_id_whitelist": []
}
],
"platform_settings": {
"unique_session": False,
"rate_limit": {
"time": 60,
"count": 30,
},
"reply_prefix": "",
"forward_threshold": 200, # 转发消息的阈值
},
"llm": [
PROVIDER_CONFIG_TEMPLATE["openai"]
],
"llm_settings": {
"wake_prefix": "",
"web_search": False,
"identifier": False,
},
"content_safety": {
"internal_keywords": {
"enable": True,
"extra_keywords": [],
}
},
"wake_prefix": ["/"],
"t2i": True,
"admins_id": [],
"https_proxy": "",
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9",
},
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": "default",
"project_atri": {
# 平台适配器配置模板
ADAPTER_CONFIG_TEMPLATE = {
"qq_official": {
"id": "default",
"name": "qq_official",
"enable": False,
"long_term_memory": {
"enable": False,
"summary_threshold_cnt": 6,
},
"active_message": {
"enable": False,
},
"vision": {
"enable": False,
"provider_id_or_ofa_model_path": "",
},
"persona": "",
"split_response": True,
"embedding_provider_id": "",
"summarize_provider_id": "",
"chat_provider_id": "",
"chat_base_model_path": "",
"chat_adapter_model_path": "",
"quantization_bit": 4
"appid": "",
"secret": "",
"enable_group_c2c": True,
"enable_guild_direct_message": True,
},
"aiocqhtp": {
"id": "default",
"name": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 6199
},
"wechat": {
"id": "default",
"name": "vchat",
"enable": False
}
}
@@ -183,7 +166,7 @@ CONFIG_METADATA_2 = {
"type": "list",
"items": {
"id": {"description": "ID", "type": "string", "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。"},
"name": {"description": "适配器类型", "type": "string", "hint": "当前版本下,内置支持 `qq_official`QQ 官方机器人), `aiocqhttp`(Onebot 适用) 适配器类型。", "options": ["qq_official", "aiocqhttp", "wechat"], "readonly": True},
"name": {"description": "适配器类型", "type": "string", "invisible": True},
"enable": {"description": "启用", "type": "bool", "hint": "是否启用该适配器。未启用的适配器对应的消息平台将不会接收到消息。"},
"appid": {"description": "appid", "type": "string", "hint": "必填项。QQ 官方机器人平台的 appid。如何获取请参考文档。"},
"secret": {"description": "secret", "type": "string", "hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。"},
@@ -208,18 +191,20 @@ CONFIG_METADATA_2 = {
"items": {
"time": {"description": "消息速率限制时间", "type": "int"},
"count": {"description": "消息速率限制计数", "type": "int"},
"strategy": {"description": "速率限制策略", "type": "string", "options": ["stall", "discard"], "hint": "当消息速率超过限制时的处理策略。stall 为等待,discard 为丢弃。"}
}
},
"reply_prefix": {"description": "回复前缀", "type": "string", "hint": "机器人回复消息时带有的前缀。"},
"forward_threshold": {"description": "转发消息的字数阈值", "type": "int", "hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。"},
"id_whitelist": {"description": "ID 白名单", "type": "list", "items": {"type": "int"}, "hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的 ID。"},
}
},
"llm": {
"provider": {
"description": "大语言模型配置",
"type": "list",
"items": {
"id": {"description": "ID", "type": "string", "hint": "提供商 ID 名,用于在多实例下方便管理和识别。自定义,ID 不能重复。"},
"name": {"description": "模型提供商类型", "type": "string", "hint": "如需变更模型提供商,请点击上面的 + 新建一个。如果没有找到你想要接入的提供商,可以前往你的提供商的官网查看是否兼容 OpenAI API,如兼容,可以选择 `openai`。大多数提供商都是兼容的。", "options": list(PROVIDER_CONFIG_TEMPLATE.keys()), "obvious_hint": True, "readonly": True},
"type": {"description": "模型提供商类型", "type": "string", "invisible": True},
"enable": {"description": "启用", "type": "bool", "hint": "是否启用该模型。未启用的模型将不会被使用。"},
"key": {"description": "API Key", "type": "list", "items": {"type": "string"}, "hint": "API Key 列表。填写好后输入回车即可添加 API Key。支持多个 API Key。"},
"api_base": {"description": "API Base URL", "type": "string", "hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。"},
@@ -256,7 +241,7 @@ CONFIG_METADATA_2 = {
}
}
},
"llm_settings": {
"provider_settings": {
"description": "大语言模型设置",
"type": "object",
"items": {
@@ -292,7 +277,6 @@ CONFIG_METADATA_2 = {
"wake_prefix": {"description": "机器人唤醒前缀", "type": "list", "items": {"type": "string"}, "hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。"},
"t2i": {"description": "文本转图像", "type": "bool", "hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。"},
"admins_id": {"description": "管理员 ID", "type": "list", "items": {"type": "int"}, "hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。"},
"https_proxy": {"description": "HTTPS 代理", "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`"},
"http_proxy": {"description": "HTTP 代理", "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`"},
"dashboard": {
"description": "管理面板配置",
@@ -318,6 +302,8 @@ CONFIG_METADATA_2 = {
"items": {
"enable": {"description": "启用", "type": "bool"},
"summary_threshold_cnt": {"description": "摘要阈值", "type": "int", "hint": "当一个会话的对话记录数量超过该阈值时,会自动进行摘要。"},
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
}
},
"active_message": {
@@ -335,10 +321,8 @@ CONFIG_METADATA_2 = {
"provider_id_or_ofa_model_path": {"description": "提供商 ID 或 OFA 模型路径", "type": "string", "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。"},
}
},
"split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的事件间隔。默认启用。"},
"split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的时间间隔。默认启用。"},
"persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True},
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
"chat_provider_id": {"description": "Chat provider ID", "type": "string", "hint": "将会使用指定的 provider 来进行文本聊天,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
"chat_base_model_path": {"description": "用于聊天的基座模型路径", "type": "string", "hint": "用于聊天的基座模型路径。当填写此项和 Lora 路径后,将会忽略上面设置的 Chat provider ID。", "obvious_hint": True},
"chat_adapter_model_path": {"description": "用于聊天的 Lora 模型路径", "type": "string", "hint": "Lora 模型路径。", "obvious_hint": True},
+53 -15
View File
@@ -1,10 +1,13 @@
import asyncio, time, threading
import asyncio, time, threading, os
from .event_bus import EventBus
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.message.message_event_handler import MessageEventHandler
from astrbot.core.plugin import PluginManager
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
@@ -15,21 +18,46 @@ class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = AstrBotConfig()
self.db = db
if self.astrbot_config['http_proxy']:
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
logger.setLevel(self.astrbot_config.log_level)
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
self.plugin_manager = PluginManager(self.astrbot_config, self.event_queue, db)
self.message_event_handler = MessageEventHandler(self.astrbot_config, self.plugin_manager)
self.astrbot_updator = AstrBotUpdator(self.astrbot_config.plugin_repo_mirror)
self.event_bus = EventBus(self.event_queue, self.message_event_handler)
self.stop_flag = False
self.start_time = int(time.time())
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
self.star_context.platform_manager = self.platform_manager
self.star_context.provider_manager = self.provider_manager
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
'''根据配置实例化各个 Provider'''
await self.platform_manager.initialize()
'''根据配置实例化各个平台适配器'''
self.pipeline_scheduler = PipelineScheduler(PipelineContext(self.astrbot_config, self.plugin_manager))
await self.pipeline_scheduler.initialize()
'''初始化消息事件流水线调度器'''
self.astrbot_updator = AstrBotUpdator(self.astrbot_config['plugin_repo_mirror'])
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
self.start_time = int(time.time())
self.curr_tasks: List[asyncio.Task] = []
def _load(self):
self.plugin_manager.reload()
platform_tasks = self.load_platform()
event_bus_task = asyncio.create_task(self.event_bus.dispatch(), name="event_bus")
@@ -41,16 +69,26 @@ class AstrBotCoreLifecycle:
self._load()
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
def stop(self):
self.stop_flag = True
async def stop(self):
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
for task in self.curr_tasks:
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
def restart(self):
self.event_queue.closed = True
threading.Thread(target=self.astrbot_updator._reboot, name="restart", daemon=True).start()
def load_platform(self) -> List[asyncio.Task]:
tasks = []
platform_insts = self.plugin_manager.get_platform_insts()
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name))
return tasks
+18 -3
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
@dataclass
class BaseDatabase(abc.ABC):
@@ -39,12 +39,12 @@ class BaseDatabase(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def update_llm_history(self, session_id: str, content: str):
def update_llm_history(self, session_id: str, content: str, provider_type: str):
'''更新 LLM 历史记录。当不存在 session_id 时插入'''
raise NotImplementedError
@abc.abstractmethod
def get_llm_history(self, session_id: str = None) -> List[LLMHistory]:
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> List[LLMHistory]:
'''获取 LLM 历史记录, 如果 session_id 为 None, 返回所有'''
raise NotImplementedError
@@ -62,3 +62,18 @@ class BaseDatabase(abc.ABC):
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
'''获取基础统计数据(合并)'''
raise NotImplementedError
@abc.abstractmethod
def insert_atri_vision_data(self, vision_data: ATRIVision):
'''插入 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_atri_vision_data(self) -> List[ATRIVision]:
'''获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
+14 -1
View File
@@ -37,5 +37,18 @@ class Stats():
@dataclass
class LLMHistory():
provider_type: str
session_id: str
content: str
content: str
@dataclass
class ATRIVision():
id: str
url_or_path: str
caption: str
is_meme: bool
keywords: List[str]
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
+70 -31
View File
@@ -6,7 +6,8 @@ from astrbot.core.db.po import (
Command,
Provider,
Stats,
LLMHistory
LLMHistory,
ATRIVision
)
from . import BaseDatabase
from typing import Tuple
@@ -75,28 +76,39 @@ class SQLiteDatabase(BaseDatabase):
''', (k, v, int(time.time()))
)
def update_llm_history(self, session_id: str, content: str):
res = self.get_llm_history(session_id)
def update_llm_history(self, session_id: str, content: str, provider_type: str):
res = self.get_llm_history(session_id, provider_type)
if res:
self._exec_sql(
'''
UPDATE llm_history SET content = ? WHERE session_id = ?
''', (content, session_id)
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
''', (content, session_id, provider_type)
)
else:
self._exec_sql(
'''
INSERT INTO llm_history(session_id, content) VALUES (?, ?)
''', (session_id, content)
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
''', (provider_type, session_id, content)
)
def get_llm_history(self, session_id: str = None) -> Tuple:
def get_llm_history(self, session_id: str = None, provider_type: str = None) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
where_clause = "" if session_id is None else f"WHERE session_id = '{session_id}'"
where_clause = ""
if session_id or provider_type:
where_clause += " WHERE "
has = False
if session_id:
where_clause += f"session_id = '{session_id}'"
has = True
if provider_type:
if has:
where_clause += " AND "
where_clause += f"provider_type = '{provider_type}'"
c.execute(
'''
SELECT * FROM llm_history
@@ -186,26 +198,53 @@ class SQLiteDatabase(BaseDatabase):
for row in c.fetchall():
platform.append(Platform(*row))
# c.execute(
# '''
# SELECT name, SUM(count), timestamp FROM command
# ''' + where_clause + " GROUP BY name"
# )
# command = []
# for row in c.fetchall():
# command.append(Command(*row))
# c.execute(
# '''
# SELECT name, SUM(count), timestamp FROM llm
# ''' + where_clause + " GROUP BY name"
# )
# llm = []
# for row in c.fetchall():
# llm.append(Provider(*row))
c.close()
return Stats(platform, [], [])
return Stats(platform, [], [])
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())
keywords = ",".join(vision.keywords)
self._exec_sql(
'''
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (vision.id, vision.url_or_path, vision.caption, vision.is_meme, keywords, vision.platform_name, vision.session_id, vision.sender_nickname, ts)
)
def get_atri_vision_data(self) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM atri_vision
'''
)
res = c.fetchall()
visions = []
for row in res:
visions.append(ATRIVision(*row))
c.close()
return visions
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
''', (url_or_path, id)
)
res = c.fetchone()
c.close()
if res:
return ATRIVision(*res)
return None
+14
View File
@@ -19,6 +19,20 @@ CREATE TABLE IF NOT EXISTS command(
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS llm_history(
provider_type VARCHAR(32),
session_id VARCHAR(32),
content TEXT
);
-- ATRI
CREATE TABLE IF NOT EXISTS atri_vision(
id TEXT,
url_or_path TEXT,
caption TEXT,
is_meme BOOLEAN,
keywords TEXT,
platform_name VARCHAR(32),
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
+4 -4
View File
@@ -2,22 +2,22 @@ import asyncio
from asyncio import Queue
from collections import defaultdict
from typing import List
from astrbot.core.message.message_event_handler import MessageEventHandler
from astrbot.core.pipeline.scheduler import PipelineScheduler
from astrbot.core import logger
from .platform import AstrMessageEvent
from astrbot.core.message.components import Image, Plain
class EventBus:
def __init__(self, event_queue: Queue, message_event_handler: MessageEventHandler):
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue
self.message_event_handler = message_event_handler
self.pipeline_scheduler = pipeline_scheduler
async def dispatch(self):
logger.info("事件总线已打开。")
while True:
event: AstrMessageEvent = await self.event_queue.get()
self._print_event(event)
asyncio.create_task(self.message_event_handler.handle(event))
asyncio.create_task(self.pipeline_scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent):
if event.get_sender_name():
+1
View File
@@ -271,6 +271,7 @@ class Image(BaseMessageComponent):
c: T.Optional[int] = 2
# 额外
path: T.Optional[str] = ""
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
@@ -1,177 +0,0 @@
import asyncio, re, time
import inspect
import traceback
from typing import List, Union
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.config.astrbot_config import AstrBotConfig
from .message_event_result import MessageEventResult, CommandResult, MessageChain
from astrbot.core.plugin import PluginManager, Context, CommandMetadata
from .components import *
from astrbot.core import logger
from astrbot.core import html_renderer
class CommandTokens():
def __init__(self) -> None:
self.tokens = []
self.len = 0
def get(self, idx: int):
if idx >= self.len:
return None
return self.tokens[idx].strip()
class CommandParser():
def __init__(self):
pass
def parse(self, message: str):
cmd_tokens = CommandTokens()
cmd_tokens.tokens = message.split(" ")
cmd_tokens.len = len(cmd_tokens.tokens)
return cmd_tokens
def regex_match(self, message: str, command: str) -> bool:
return re.search(command, message, re.MULTILINE) is not None
class MessageEventHandler():
'''
处理消息事件。
'''
def __init__(self, config: AstrBotConfig, plugin_manager: PluginManager):
self.config = config
self.plugin_manager = plugin_manager
self.command_parser = CommandParser()
async def handle(self, event: AstrMessageEvent):
'''
处理消息事件。
'''
event.message_str = event.message_str.strip()
for admin_id in self.config.admins_id:
if event.get_sender_id() == admin_id:
event.role = "admin"
break
# 检查 wake
wake_prefixes = self.config.wake_prefix
messages = event.get_messages()
is_wake = False
for wake_prefix in wake_prefixes:
if event.message_str.startswith(wake_prefix):
is_wake = True
break
if not is_wake:
# 检查是否有 at 消息
for message in messages:
if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"):
is_wake = True
wake_prefix = ""
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
wake_prefix = ""
event.is_wake = is_wake
# 处理事件监听器(在指令扫描之前)
listeners = self.plugin_manager.context.registered_listeners
listeners_handler = self.plugin_manager.context.listeners_handler
for name in listeners:
if listeners_handler[name].after_commands:
continue
ret = await listeners_handler[name].handler(event)
if ret:
event.set_result(ret)
if event.get_result():
return await self.post_handle(event)
# 处理指令,指令带有指定过的前缀
commands = self.plugin_manager.context.registered_commands
commands_handler = self.plugin_manager.context.commands_handler
# 扫描指令
for command in commands:
command = command[1]
trig = False
pre_ = ""
if not commands_handler[command].ignore_prefix:
pre_ = wake_prefix
if commands_handler[command].use_regex:
trig = self.command_parser.regex_match(event.message_str, pre_ + command)
else:
trig = event.message_str.startswith(pre_ + command)
if trig:
ret = await self.execute_handler(command, commands_handler[command], event)
if ret:
event.set_result(ret)
if event.get_result():
return await self.post_handle(event)
# 处理事件监听器(在指令扫描之后)
for name in listeners:
if not listeners_handler[name].after_commands:
continue
ret = await listeners_handler[name].handler(event)
if ret:
event.set_result(ret)
if event.get_result():
return await self.post_handle(event)
async def post_handle(self, event: AstrMessageEvent):
result = event.get_result()
if result.callback:
await result.callback(event)
# prefix
if self.config.platform_settings.reply_prefix:
result.chain.insert(0, Plain(self.config.platform_settings.reply_prefix))
# t2i
if (result.use_t2i_ is None and self.config.t2i) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
plain_str += "\n\n" + comp.text
else:
break
if plain_str and len(plain_str) > 150:
render_start = time.time()
url = await html_renderer.render_t2i(plain_str, return_url=True)
if time.time() - render_start > 3:
logger.warning(f"图片转文本耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
if url:
result.chain = [Image.fromURL(url)]
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
await event.send(result)
async def execute_handler(self,
command: str,
command_metadata: CommandMetadata,
message_event: AstrMessageEvent):
logger.info(f"触发 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令。")
handler = command_metadata.handler
try:
if inspect.iscoroutinefunction(handler):
command_result = await handler(message_event)
else:
command_result = handler(message_event)
if command_result is not None:
message_event.set_result(command_result)
except TypeError as e:
# 兼容旧版本插件
if inspect.iscoroutinefunction(handler):
command_result = await handler(message_event, self.plugin_manager.context)
else:
command_result = handler(message_event, self.plugin_manager.context)
if command_result is not None:
message_event.set_result(command_result)
except BaseException as e:
logger.error(traceback.format_exc())
text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}"
message_event.set_result(MessageEventResult().message(text))
+85 -19
View File
@@ -1,43 +1,68 @@
from typing import List, Union, Optional
import enum, logging
from typing import List, Optional
from dataclasses import dataclass, field
from astrbot.core.message.components import *
from typing_extensions import deprecated
@dataclass
class MessageChain():
'''MessageChain 描述了一整条消息中带有的所有组件。
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''
快捷回复消息。
'''添加一条文本消息到消息链 `chain` 中。
CommandResult().message("Hello, world!")
Example:
CommandResult().message("Hello ").message("world!")
# 输出 Hello world!
'''
self.chain.append(Plain(message))
return self
@deprecated("请使用 message 方法代替。")
def error(self, message: str):
'''
快捷回复消息。
'''添加一条错误消息到消息链 `chain` 中
CommandResult().error("Hello, world!")
Example:
CommandResult().error("解析失败")
'''
self.chain.append(Plain(message))
return self
def url_image(self, url: str):
'''
快捷回复图片(网络url的格式)。
'''添加一条图片消息(https 链接)到消息链 `chain` 中。
CommandResult().image("https://example.com/image.jpg")
Note:
如果需要发送本地图片,请使用 `file_image` 方法。
Example:
CommandResult().image("https://example.com/image.jpg")
'''
self.chain.append(Image.fromURL(url))
return self
def file_image(self, path: str):
'''
快捷回复图片(本地文件路径的格式)。
'''添加一条图片消息(本地文件路径)到消息链 `chain` 中。
Note:
如果需要发送网络图片,请使用 `url_image` 方法。
CommandResult().image("image.jpg")
'''
@@ -45,24 +70,65 @@ class MessageChain():
return self
def use_t2i(self, use_t2i: bool):
'''
设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
'''设置是否使用文本转图片服务。
Args:
use_t2i (bool): 是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''
设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
具体的效果以各适配器实现为准。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
Attributes:
CONTINUE: 事件将会继续传播
STOP: 事件将会终止传播
'''
CONTINUE = enum.auto()
STOP = enum.auto()
@dataclass
class MessageEventResult(MessageChain):
is_command_call: Optional[bool] = False
callback: Optional[callable] = None
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
现代消息平台的一条富文本消息中可能由多个组件构成,如文本、图片、At 等,并且保留了顺序。
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
def stop_event(self) -> 'MessageEventResult':
'''终止事件传播。
'''
self.result_type = EventResultType.STOP
return self
def continue_event(self) -> 'MessageEventResult':
'''继续事件传播。
'''
self.result_type = EventResultType.CONTINUE
return self
def is_stopped(self) -> bool:
'''
是否终止事件传播。
'''
return self.result_type == EventResultType.STOP
CommandResult = MessageEventResult
+18
View File
@@ -0,0 +1,18 @@
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
]
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
@@ -0,0 +1,31 @@
import asyncio
from datetime import datetime, timedelta
from collections import defaultdict, deque
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from .strategies.strategy import StrategySelector
@register_stage
class ContentSafetyCheckStage(Stage):
'''检查内容安全
当前只会检查文本的。
'''
async def initialize(self, ctx: PipelineContext):
config = ctx.astrbot_config['content_safety']
self.strategy_selector = StrategySelector(config)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''检查内容安全'''
ok, info = self.strategy_selector.check(event.get_message_str())
if not ok:
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
event.continue_event()
@@ -0,0 +1,8 @@
import abc
from typing import Tuple
class ContentSafetyStrategy(abc.ABC):
@abc.abstractmethod
def check(self, content: str) -> Tuple[bool, str]:
raise NotImplementedError
@@ -0,0 +1,31 @@
'''
使用此功能应该先 pip install baidu-aip
'''
from . import ContentSafetyStrategy
from aip import AipContentCensor
from astrbot.core import logger
class BaiduAipStrategy(ContentSafetyStrategy):
def __init__(self, appid: str, ak: str, sk: str) -> None:
self.app_id = appid
self.api_key = ak
self.secret_key = sk
self.client = AipContentCensor(self.app_id,
self.api_key,
self.secret_key)
def check(self, content: str):
res = self.client.textCensorUserDefined(content)
if 'conclusionType' not in res:
return False, ""
if res['conclusionType'] == 1:
return True, ""
else:
if 'data' not in res:
return False, ""
count = len(res['data'])
info = f"百度审核服务发现 {count} 处违规:\n"
for i in res['data']:
info += f"{i['msg']}\n"
info += "\n判断结果:"+res['conclusion']
return False, info
@@ -0,0 +1,21 @@
import re, os, json, base64
from . import ContentSafetyStrategy
from astrbot.core import logger
class KeywordsStrategy(ContentSafetyStrategy):
def __init__(self, extra_keywords: list) -> None:
self.keywords = []
if extra_keywords is None:
extra_keywords = []
self.keywords.extend(extra_keywords)
keywords_path = os.path.join(os.path.dirname(__file__), 'unfit_words')
# internal keywords
if os.path.exists(keywords_path):
with open(keywords_path, "r", encoding="utf-8") as f:
self.keywords.extend(json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords'])
def check(self, content: str) -> bool:
for keyword in self.keywords:
if re.search(keyword, content):
return False, f"内容安全检查不通过,匹配到敏感词。"
return True, ""
@@ -0,0 +1,27 @@
from . import ContentSafetyStrategy
from typing import List, Tuple
class StrategySelector():
def __init__(self, config: dict) -> None:
self.enabled_strategies: List[ContentSafetyStrategy] = []
if config['internal_keywords']['enable']:
from .keywords import KeywordsStrategy
self.enabled_strategies.append(KeywordsStrategy(
config['internal_keywords']['extra_keywords']))
if config['baidu_aip']['enable']:
try:
from .baidu_aip import BaiduAipStrategy
except ImportError:
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
self.enabled_strategies.append(BaiduAipStrategy(config['baidu_aip']['app_id'],
config['baidu_aip']['api_key'],
config['baidu_aip']['secret_key']
))
def check(self, content: str) -> Tuple[bool, str]:
for strategy in self.enabled_strategies:
ok, info = strategy.check(content)
if not ok:
return False, info
return True, ""
+8
View File
@@ -0,0 +1,8 @@
from dataclasses import dataclass
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star import PluginManager
@dataclass
class PipelineContext:
astrbot_config: AstrBotConfig
plugin_manager: PluginManager
@@ -0,0 +1,71 @@
import asyncio, traceback, json
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.llm_response import LLMResponse
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if self.prompt_prefix:
event.message_str = self.prompt_prefix + event.message_str
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
event.message_str = user_info + event.message_str
image_urls = []
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
image_urls.append(image_url)
tools = self.ctx.plugin_manager.context.get_llm_tools()
try:
llm_response = await self.curr_provider.text_chat(
prompt=event.message_str,
session_id=event.session_id,
image_urls=image_urls,
tools=tools
)
await Metric.upload(llm_tick=1, model_name=self.curr_provider.get_model(), provider_type=self.curr_provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text))
elif llm_response.role == 'tool':
# function calling
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = tools.get_func(func_tool_name)
logger.debug(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
ret = await func_tool(event=event, *func_tool_args)
if ret:
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
event.stop_event()
event.set_result(ret)
# 执行后续步骤来发送消息
yield
except BaseException as e:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
return
@@ -0,0 +1,48 @@
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult, EventResultType
from astrbot.core import logger
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra("handlers_parsed_params")
if not handlers_parsed_params:
handlers_parsed_params = {}
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_str not in star_map:
# 孤立无援的 star handler
continue
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
if hasattr(handler.handler, '__self__'):
# 猜测没有通过装饰器去注册
try:
ret = await handler.handler(event, **params)
except TypeError:
# 向下兼容
ret = await handler.handler(event, self.ctx.plugin_manager.context, **params)
else:
ret = await handler.handler(star_cls_obj, event, **params)
if ret:
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
event.stop_event()
event.set_result(ret)
# 执行后续步骤来发送消息
yield
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
@@ -0,0 +1,36 @@
from typing import List, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult, EventResultType
from astrbot.core import logger
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.message.components import *
from astrbot.core import html_renderer
@register_stage
class ProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
await self.llm_request_sub_stage.initialize(ctx)
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
if not activated_handlers:
async for _ in self.llm_request_sub_stage.process(event):
yield
else:
async for _ in self.star_request_sub_stage.process(event):
yield
@@ -0,0 +1,87 @@
import asyncio
from datetime import datetime, timedelta
from collections import defaultdict, deque
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
@register_stage
class RateLimitStage(Stage):
"""
检查是否需要限制消息发送的限流器。
使用 Fixed Window 算法。
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
"""
def __init__(self):
# 存储每个会话的请求时间队列
self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque)
# 为每个会话设置一个锁,避免并发冲突
self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
# 限流参数
self.rate_limit_count: int = 0
self.rate_limit_time: timedelta = timedelta(0)
async def initialize(self, ctx: PipelineContext) -> None:
"""
初始化限流器,根据配置设置限流参数。
"""
self.rate_limit_count = ctx.astrbot_config['platform_settings']['rate_limit']['count']
self.rate_limit_time = timedelta(seconds=ctx.astrbot_config['platform_settings']['rate_limit']['time'])
self.rl_strategy = ctx.astrbot_config['platform_settings']['rate_limit']['strategy'] # stall or discard
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
"""
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
Args:
event (AstrMessageEvent): 当前消息事件。
ctx (PipelineContext): 流水线上下文。
Returns:
MessageEventResult: 继续或停止事件处理的结果。
"""
session_id = event.session_id
now = datetime.now()
async with self.locks[session_id]: # 确保同一会话不会并发修改队列
timestamps = self.event_timestamps[session_id]
self._remove_expired_timestamps(timestamps, now)
if len(timestamps) >= self.rate_limit_count:
# 达到限流阈值,计算下一个窗口的时间
next_window_time = timestamps[0] + self.rate_limit_time
stall_duration = (next_window_time - now).total_seconds()
match self.rl_strategy:
case RateLimitStrategy.STALL:
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
await asyncio.sleep(stall_duration)
case RateLimitStrategy.DISCARD:
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
return event.stop_event()
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
timestamps.append(now)
return event.continue_event()
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
"""
移除时间窗口外的时间戳。
Args:
timestamps (Deque[datetime]): 当前会话的时间戳队列。
now (datetime): 当前时间,用于计算过期时间。
"""
expiry_threshold: datetime = now - self.rate_limit_time
while timestamps and timestamps[0] < expiry_threshold:
timestamps.popleft()
+25
View File
@@ -0,0 +1,25 @@
import asyncio
from datetime import datetime, timedelta
from collections import defaultdict, deque
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
@register_stage
class RespondStage:
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
await event.send(result)
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
@@ -0,0 +1,45 @@
import asyncio, time
from datetime import datetime, timedelta
from collections import defaultdict, deque
from typing import DefaultDict, Deque, List, Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.config.astrbot_config import RateLimitStrategy
from astrbot.core.message.components import Plain, Image
from astrbot.core import html_renderer
@register_stage
class ResultDecorateStage:
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.t2i = ctx.astrbot_config['t2i']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
result.chain.insert(0, Plain(self.reply_prefix))
# 文本转图片
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
plain_str += "\n\n" + comp.text
else:
break
if plain_str and len(plain_str) > 150:
render_start = time.time()
url = await html_renderer.render_t2i(plain_str, return_url=True)
if time.time() - render_start > 3:
logger.warning(f"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
if url:
result.chain = [Image.fromURL(url)]
+44
View File
@@ -0,0 +1,44 @@
from . import STAGES_ORDER
from .stage import registered_stages, Stage
from .context import PipelineContext
from typing import AsyncGenerator
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
from astrbot.core import logger
class PipelineScheduler():
def __init__(self, context: PipelineContext):
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__ .__name__))
self.ctx = context
async def initialize(self):
for stage in registered_stages:
logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i]
logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coro = stage.process(event)
if isinstance(coro, AsyncGenerator):
async for _ in coro:
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
break
await self._process_stages(event, i + 1)
else:
await coro
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
break
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__ .__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent):
'''执行 pipeline'''
await self._process_stages(event)
+32
View File
@@ -0,0 +1,32 @@
from __future__ import annotations
import abc
from typing import List, Dict, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
registered_stages: List[Stage] = []
'''维护了所有已注册的 Stage 实现类'''
def register_stage(cls):
'''一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类
'''
registered_stages.append(cls())
return cls
class Stage(abc.ABC):
'''描述一个 Pipeline 的某个阶段
'''
@abc.abstractmethod
async def initialize(self, ctx: PipelineContext) -> None:
'''初始化阶段
'''
raise NotImplementedError
@abc.abstractmethod
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
raise NotImplementedError
@@ -0,0 +1,96 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, EventResultType
from astrbot.core.message.components import At, Plain
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.filter.command_group import CommandGroupFilter
@register_stage
class WakingCheckStage(Stage):
'''检查是否需要唤醒。唤醒机器人有如下几点条件:
1. 机器人被 @ 了
2. 机器人的消息被提到了
3. 以 wake_prefix 前缀开头
4. 插件(Star)的 handler filter 通过
'''
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
# 设置 sender 身份
event.message_str = event.message_str.strip()
for admin_id in self.ctx.astrbot_config['admins_id']:
if event.get_sender_id() == admin_id:
event.role = "admin"
break
# 检查 wake
wake_prefixes = self.ctx.astrbot_config['wake_prefix']
messages = event.get_messages()
is_wake = False
for wake_prefix in wake_prefixes:
if event.message_str.startswith(wake_prefix):
is_wake = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix):].strip()
break
if not is_wake:
# 检查是否有 at 消息
for message in messages:
if isinstance(message, At) and (str(message.qq) == str(event.get_self_id()) or str(message.qq) == "all"):
is_wake = True
event.is_wake = True
wake_prefix = ""
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
wake_prefix = ""
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry:
# filter 需要满足 AND 的逻辑关系
passed = False
child_command_handler_md = None
for filter in handler.event_filters:
try:
if isinstance(filter, CommandGroupFilter):
'''如果指令组过滤成功, 会返回叶子指令的 StarHandlerMetadata'''
ok, child_command_handler_md = filter.filter(event, self.ctx.astrbot_config)
if ok:
passed = True
handler = child_command_handler_md # handler 覆盖
break
else:
if filter.filter(event, self.ctx.astrbot_config):
passed = True
break
except Exception as e:
# event.set_result(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
# yield
await event.send(MessageEventResult().message(f"插件 {handler.handler_full_name} 报错:{e}"))
event.stop_event()
passed = False
break
if passed:
is_wake = True
event.is_wake = True
activated_handlers.append(handler)
if 'parsed_params' in event.get_extra():
handlers_parsed_params[handler.handler_full_name] = event.get_extra('parsed_params')
event.clear_extra()
event.set_extra('activated_handlers', activated_handlers)
event.set_extra('handlers_parsed_params', handlers_parsed_params)
if not is_wake:
event.stop_event()
@@ -0,0 +1,19 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import List, Dict, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
@register_stage
class WhitelistCheckStage(Stage):
'''检查是否在群聊/私聊白名单
'''
async def initialize(self, ctx: PipelineContext) -> None:
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
# 检查是否在白名单
if event.unified_msg_origin not in self.whitelist:
logger.info(f"会话 {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。")
event.stop_event()
+92 -8
View File
@@ -1,11 +1,11 @@
import abc
import abc, logging
from dataclasses import dataclass
from .astrbot_message import AstrBotMessage
from .platform_metadata import PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain, EventResultType
from astrbot.core.platform.message_type import MessageType
from typing import List
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
from astrbot.core.message.components import *
from astrbot.core.utils.metrics import Metric
@dataclass
@@ -34,8 +34,6 @@ class AstrMessageEvent(abc.ABC):
self.session_id = session_id
self.role = "member"
self.is_wake = False
self._result: MessageEventResult = None
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
@@ -43,6 +41,9 @@ class AstrMessageEvent(abc.ABC):
session_id=session_id
)
self.unified_msg_origin = str(self.session)
self._result: MessageEventResult = None
'''消息事件的结果'''
def get_platform_name(self):
return self.platform_meta.name
@@ -58,8 +59,19 @@ class AstrMessageEvent(abc.ABC):
for i in chain:
if isinstance(i, Plain):
outline += i.text
if isinstance(i, Image):
elif isinstance(i, Image):
outline += "[图片]"
elif isinstance(i, Face):
outline += f"[表情:{i.id}]"
elif isinstance(i, At):
outline += f"[At:{i.qq}]"
elif isinstance(i, AtAll):
outline += "[At:全体成员]"
elif isinstance(i, Forward):
# 转发消息
outline += f"[转发消息]"
else:
outline += f"[{i.type}]"
return outline
def get_message_outline(self) -> str:
@@ -76,12 +88,24 @@ class AstrMessageEvent(abc.ABC):
'''
return self.message_obj.message
def get_message_type(self) -> MessageType:
'''
获取消息类型。
'''
return self.message_obj.type
def get_session_id(self) -> str:
'''
获取会话id。
'''
return self.session_id
def get_group_id(self) -> str:
'''
获取群组id。如果不是群组消息,返回空字符串。
'''
return self.message_obj.group_id
def get_self_id(self) -> str:
'''
获取机器人自身的id。
@@ -101,16 +125,62 @@ class AstrMessageEvent(abc.ABC):
return self.message_obj.sender.nickname
def set_result(self, result: MessageEventResult):
'''
设置消息事件的结果。当设置了结果后,消息事件将不再继续传递。
'''设置消息事件的结果。
Note:
事件处理器可以通过设置结果来控制事件是否继续传播,并向消息适配器发送消息。
如果没有设置 `MessageEventResult` 中的 result_type,默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。
Example:
async def ban_handler(self, event: AstrMessageEvent):
if event.get_sender_id() in self.blacklist:
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
return
async def check_count(self, event: AstrMessageEvent):
self.count += 1
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
return
'''
self._result = result
def stop_event(self):
'''终止事件传播。
'''
if self._result is None:
self.set_result(MessageEventResult().stop_event())
else:
self._result.stop_event()
def continue_event(self):
'''继续事件传播。
'''
if self._result is None:
self.set_result(MessageEventResult().continue_event())
else:
self._result.continue_event()
def is_stopped(self) -> bool:
'''
是否终止事件传播。
'''
if self._result is None:
return False # 默认是继续传播
return self._result.is_stopped()
def get_result(self) -> MessageEventResult:
'''
获取消息事件的结果。
'''
return self._result
def clear_result(self):
'''
清除消息事件的结果。
'''
self._result = None
def set_extra(self, key, value):
'''
@@ -118,6 +188,20 @@ class AstrMessageEvent(abc.ABC):
'''
self._extras[key] = value
def get_extra(self, key = None):
'''
获取额外的信息。
'''
if key is None:
return self._extras
return self._extras.get(key, None)
def clear_extra(self):
'''
清除额外的信息。
'''
self._extras.clear()
def is_private_chat(self) -> bool:
'''
是否是私聊。
+2 -1
View File
@@ -15,8 +15,9 @@ class AstrBotMessage:
'''
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group_id: str = "" # 群组id,如果为私聊,则为空
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
+42
View File
@@ -0,0 +1,42 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from .platform import PlatformMetadata, Platform
from typing import List
from asyncio import Queue
from .register import platform_registry, platform_cls_map
from astrbot.core import logger
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
self.platform_insts: List[Platform] = []
'''加载的 Platform 的实例'''
self.platforms_config = config['platform']
self.settings = config['platform_settings']
self.event_queue = event_queue
for platform in self.platforms_config:
if not platform['enable']:
continue
match platform['name']:
case "aiocqhttp":
from .sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter
case "qqofficial":
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialAdapter
case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatAdapter
async def initialize(self):
for platform in self.platforms_config:
if not platform['enable']:
continue
if platform['name'] not in platform_cls_map:
logger.error(f"未找到适用于 {platform['name']}({platform['id']}) 平台适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = platform_cls_map[platform['name']]
logger.info(f"尝试实例化 {platform['name']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
def get_insts(self):
return self.platform_insts
+1 -1
View File
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Type
@dataclass
class PlatformMetadata():
name: str # 平台的名称
+25
View File
@@ -0,0 +1,25 @@
from typing import List, Dict, Type
from .platform_metadata import PlatformMetadata
from astrbot.core import logger
platform_registry: List[PlatformMetadata] = []
'''维护了通过装饰器注册的平台适配器'''
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if adapter_name in platform_cls_map:
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
pm = PlatformMetadata(
name=adapter_name,
description=desc,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
logger.debug(f"平台适配器 {adapter_name} 已注册")
return cls
return decorator
@@ -1,6 +1,6 @@
import os, traceback, random, asyncio
from astrbot.api import AstrMessageEvent, MessageChain, logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -4,24 +4,26 @@ import traceback
import logging
from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain, MessageEventResult
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.config.astrbot_config import PlatformConfig, AiocqhttpPlatformConfig, PlatformSettings
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
def __init__(self, platform_config: AiocqhttpPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings.unique_session
self.host = platform_config.ws_reverse_host
self.port = platform_config.ws_reverse_port
self.unique_session = platform_settings['unique_session']
self.host = platform_config['ws_reverse_host']
self.port = platform_config['ws_reverse_port']
self.metadata = PlatformMetadata(
"aiocqhttp",
@@ -51,6 +53,7 @@ class AiocqhttpAdapter(Platform):
if event['message_type'] == 'group':
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
elif event['message_type'] == 'private':
abm.type = MessageType.FRIEND_MESSAGE
@@ -71,35 +74,18 @@ class AiocqhttpAdapter(Platform):
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return
logger.debug(f"aiocqhttp: 收到消息: {event.message}")
for m in event.message:
t = m['type']
a = None
if t == 'at':
a = At(**m['data'])
abm.message.append(a)
if t == 'text':
a = Plain(text=m['data']['text'])
message_str += m['data']['text'].strip()
abm.message.append(a)
if t == 'image':
file = m['data']['file'] if 'file' in m['data'] else None
url = m['data']['url'] if 'url' in m['data'] else None
a = Image(file=file, url=url)
abm.message.append(a)
a = ComponentTypes[t](**m['data'])
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = event
return abm
def handle_whitelist(self, event: Event) -> bool:
match event['message_type']:
case "group":
if self.config.qq_group_id_whitelist and str(event.group_id) in self.config.qq_group_id_whitelist:
return True
case "private":
if self.config.qq_id_whitelist and str(event.sender['user_id']) in self.config.qq_id_whitelist:
return True
return False
def run(self) -> Awaitable[Any]:
if not self.host or not self.port:
@@ -107,18 +93,12 @@ class AiocqhttpAdapter(Platform):
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
if not self.handle_whitelist(event):
logger.debug(f"一个群消息({event.group_id})事件由于不在白名单而被过滤。")
return
abm = self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
if not self.handle_whitelist(event):
logger.debug(f"一个私聊消息({event.sender['nickname']}/{event.sender['user_id']})事件由于不在白名单而被过滤。")
return
abm = self.convert_message(event)
if abm:
await self.handle_msg(abm)
@@ -3,7 +3,8 @@ import botpy.message
import botpy.types
import botpy.types.message
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from botpy import Client
from botpy.http import Route
@@ -6,17 +6,17 @@ import botpy.types
import botpy.types.message
from botpy import Client
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List, Dict
from astrbot.api.message_components import *
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from .qqofficial_message_event import QQOfficialMessageEvent
from astrbot.core.config.astrbot_config import PlatformConfig, QQOfficialPlatformConfig, PlatformSettings
from astrbot.core.utils.io import save_temp_img, download_image_by_url
from ...register import register_platform_adapter
# QQ 机器人官方框架
@register_platform_adapter("qqofficial", "QQ 机器人官方 API 适配器")
class botClient(Client):
def set_platform(self, platform: 'QQOfficialPlatformAdapter'):
self.platform = platform
@@ -56,18 +56,18 @@ class botClient(Client):
class QQOfficialPlatformAdapter(Platform):
def __init__(self, platform_config: QQOfficialPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.config = platform_config
self.appid = platform_config.appid
self.secret = platform_config.secret
self.unique_session = platform_settings.unique_session
qq_group = platform_config.enable_group_c2c
guild_dm = platform_config.enable_guild_direct_message
self.appid = platform_config['appid']
self.secret = platform_config['secret']
self.unique_session = platform_settings['unique_session']
qq_group = platform_config['enable_group_c2c']
guild_dm = platform_config['enable_guild_direct_message']
if qq_group:
self.intents = botpy.Intents(
@@ -115,6 +115,7 @@ class QQOfficialPlatformAdapter(Platform):
message.author.member_openid,
""
)
abm.group_id = message.group_openid
else:
abm.sender = MessageMember(
message.author.user_openid,
@@ -157,6 +158,9 @@ class QQOfficialPlatformAdapter(Platform):
str(message.author.id),
str(message.author.username)
)
if isinstance(message, botpy.message.Message):
abm.group_id = message.channel_id
else:
raise ValueError(f"Unknown message type: {message_type}")
return abm
@@ -1,10 +1,12 @@
import random, asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from vchat import Core
class WechatPlatformEvent(AstrMessageEvent):
class VChatPlatformEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@@ -36,6 +38,6 @@ class WechatPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
await WechatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
await VChatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
await super().send(message)
@@ -1,15 +1,13 @@
import sys, time, datetime, uuid
import sys, time, uuid
import asyncio
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from typing import Union, List, Dict
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import *
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from .wechat_message_event import WechatPlatformEvent
from astrbot.core.config.astrbot_config import PlatformConfig, WechatPlatformConfig, PlatformSettings
from astrbot.core.utils.io import save_temp_img, download_image_by_url
from .vchat_message_event import VChatPlatformEvent
from ...register import register_platform_adapter
from vchat import Core
from vchat import model
@@ -18,10 +16,11 @@ if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class WechatPlatformAdapter(Platform):
def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
@register_platform_adapter("vchat", "基于 VChat 的 Wechat 适配器")
class VChatPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
@@ -31,13 +30,13 @@ class WechatPlatformAdapter(Platform):
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
from_username = session.session_id.split('$$')[0]
await WechatPlatformEvent.send_with_client(self.client, message_chain, from_username)
await VChatPlatformEvent.send_with_client(self.client, message_chain, from_username)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"wechat",
"vchat",
"基于 VChat 的 Wechat 适配器",
)
@@ -53,10 +52,6 @@ class WechatPlatformAdapter(Platform):
logger.debug(f"忽略旧消息: {msg}")
return
logger.debug(f"收到消息: {msg.todict()}")
if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist:
logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}")
return
logger.info(f"收到消息: {msg.todict()}")
abmsg = self.convert_message(msg)
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
asyncio.create_task(self.handle_msg(abmsg))
@@ -92,12 +87,13 @@ class WechatPlatformAdapter(Platform):
amsg.type = MessageType.FRIEND_MESSAGE
elif isinstance(msg.from_, model.Chatroom):
amsg.type = MessageType.GROUP_MESSAGE
amsg.group_id = msg.from_.username
else:
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
amsg.raw_message = msg
if self.settingss.unique_session:
if self.settingss['unique_session']:
session_id = msg.from_.username + "$$" + msg.to.username
if msg.chatroom_sender is not None:
session_id += '$$' + msg.chatroom_sender.username
@@ -108,7 +104,7 @@ class WechatPlatformAdapter(Platform):
return amsg
async def handle_msg(self, message: AstrBotMessage):
message_event = WechatPlatformEvent(
message_event = VChatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
-4
View File
@@ -1,4 +0,0 @@
from .plugin import Plugin, RegisteredPlugin, PluginMetadata
from .plugin_manager import PluginManager
from .context import CommandMetadata, Context
from astrbot.core.provider import Provider
-217
View File
@@ -1,217 +0,0 @@
import heapq
from asyncio import Queue
from . import RegisteredPlugin, PluginMetadata
from typing import List, Dict, Awaitable, Union
from dataclasses import dataclass
from astrbot.core.platform import Platform
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.utils.func_call import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
@dataclass
class CommandMetadata():
'''
显式指令
'''
plugin_name: str
plugin_metadata: PluginMetadata
handler: Awaitable
use_regex: bool = False
ignore_prefix: bool = False
description: str = ""
@dataclass
class EventListenerMetadata():
'''
事件监听器
'''
plugin_name: str
plugin_metadata: PluginMetadata
handler: Awaitable
description: str = ""
after_commands: bool = False
class Context:
'''
暴露给插件的接口上下文,用于注册指令、事件监听器、消息平台、模型提供商等。
'''
# 事件队列。消息平台通过事件队列传递消息事件。
_event_queue: Queue = None
# AstrBot 配置信息
_config: AstrBotConfig = None
# AstrBot 数据库
_db: BaseDatabase = None
# 维护了注册的插件的信息
registered_plugins: List[RegisteredPlugin] = []
# 维护了插件注册的指令的信息的名字列表,用于优先级排序
registered_commands: List[str] = []
# 维护了插件注册的指令的信息
commands_handler: Dict[str, CommandMetadata] = {}
# 维护了插件注册的中间件的名字列表,用于优先级排序
registered_listeners: List[str] = []
# 维护了插件注册的中间件的信息
listeners_handler: Dict[str, EventListenerMetadata] = {}
# 维护了注册的平台的信息
registered_platforms: List[Platform] = []
# 维护了 LLM Tools 信息
llm_tools: FuncCall = FuncCall()
# 维护插件存储的数据
plugin_data: Dict[str, Dict[str, any]] = {}
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
self._event_queue = event_queue
self._config = config
self._db = db
def get_registered_plugin(self, plugin_name: str) -> RegisteredPlugin:
for plugin in self.registered_plugins:
if plugin.metadata.plugin_name == plugin_name:
return plugin
return None
def register_listener(self,
plugin_name: str,
name: str,
handler: Awaitable,
description: str = None,
after_commands: bool = False):
'''
注册一个事件监听器。
after_commands: 是否在指令处理后执行。
'''
if name in self.registered_listeners:
raise ValueError(f"Middleware {name} already exists.")
self.registered_listeners.append(name)
self.listeners_handler[name] = EventListenerMetadata(
plugin_name=plugin_name,
plugin_metadata=None,
handler=handler,
description=description,
after_commands=after_commands
)
def register_commands(self,
plugin_name: str,
command_name: str,
description: str,
priority: int,
handler: Awaitable,
use_regex: bool = False,
ignore_prefix: bool = False):
'''
注册插件指令。
@param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
@param command_name: 指令名,如 "help"。不需要带前缀。
@param description: 指令描述。
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。
@param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
.. Example::
ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
'''
for command in self.registered_commands:
if command_name in command[1]:
raise ValueError(f"Command {command_name} already exists.")
if not handler:
raise ValueError(f"Handler of {command_name} is None.")
heapq.heappush(self.registered_commands, (-priority, command_name))
self.commands_handler[command_name] = CommandMetadata(
plugin_name=plugin_name,
plugin_metadata=None,
handler=handler,
use_regex=use_regex,
ignore_prefix=ignore_prefix,
description=description
)
heapq.heapify(self.registered_commands)
def register_platform(self, platform: Platform):
'''
注册一个消息平台。
'''
self.registered_platforms.append(platform)
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
self.llm_tools.add_func(name, func_args, desc, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''
删除一个函数调用工具。
'''
self.llm_tools.remove_func(name)
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.registered_platforms:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
def set_data(self, plugin_name: str, key: str, value: any):
'''
设置插件数据。
'''
self.plugin_data[plugin_name][key] = value
-43
View File
@@ -1,43 +0,0 @@
from enum import Enum
from types import ModuleType
from typing import List
from dataclasses import dataclass
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.desc}, {self.version}, {self.repo})"
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
reserved: bool # 是否是 AstrBot 的保留插件
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
class Plugin:
def __init__(self):
pass
+1 -1
View File
@@ -1 +1 @@
from .provider import Provider, Personality
from .provider import Provider, Personality, ProviderMetaData
+13
View File
@@ -0,0 +1,13 @@
from typing import Dict, List
from dataclasses import dataclass
@dataclass
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
'''工具调用参数'''
tools_call_name: List[str] = None
'''工具调用名称'''
+49
View File
@@ -0,0 +1,49 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from astrbot.core.provider.tool import FuncCall
from .register import provider_cls_map, provider_registry
from astrbot.core import logger
class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.llm_tools: FuncCall = FuncCall()
self.curr_provider_inst: Provider = None
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
if provider_cfg['id'] in self.loaded_ids:
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
self.loaded_ids[provider_cfg['id']] = True
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial
async def initialize(self):
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
if len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
def get_insts(self):
return self.provider_insts
+100 -57
View File
@@ -1,91 +1,134 @@
import abc, json, threading, time
import abc, json
from collections import defaultdict
from typing import List
from astrbot.core.db import BaseDatabase
from astrbot.core import logger
from typing import TypedDict
from .provider_metadata import ProviderMetaData
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.llm_response import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str
name: str
prompt: str = ""
name: str = ""
@dataclass
class ProviderMeta():
id: str
model: str
type: str
class Provider(abc.ABC):
def __init__(self, db_helper: BaseDatabase, default_personality: str = None, persistant_history: bool = True) -> None:
self.model_name = "unknown"
# 维护了 session_id 的上下文,不包含 system 指令
def __init__(
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None
) -> None:
self.model_name = ""
'''当前使用的模型名称'''
self.session_memory = defaultdict(list)
self.curr_personality = Personality(prompt=default_personality, name="")
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality = Personality(prompt=provider_settings['default_personality'])
'''维护了当前的使用的 persona,即人格。'''
self.db_helper = db_helper
'''用于持久化的数据库操作对象。'''
if persistant_history:
# 读取历史记录
try:
for history in db_helper.get_llm_history():
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self):
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
async def get_human_readable_context(self, session_id: str) -> List[str]:
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError()
def get_keys(self) -> List[str]:
'''获得提供商 Key'''
return self.provider_config['key']
@abc.abstractmethod
def set_key(self, key: str):
raise NotImplementedError()
@abc.abstractmethod
def get_models(self) -> List[str]:
'''获得支持的模型列表'''
raise NotImplementedError()
@abc.abstractmethod
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
'''获取人类可读的上下文
Example:
["User: 你好", "Assistant: 你好!"]
Return:
contexts: List[str]: 上下文列表
total_pages: int: 总页数
'''
获取人类可读的上下文
example:
["User: 你好", "Assistant: 你好"]
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
contexts.append(f"Assistant: {record['content']}")
return contexts
raise NotImplementedError()
@abc.abstractmethod
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str] = None,
tools = None,
contexts=None,
**kwargs) -> str:
'''
prompt: 提示词
session_id: 会话id
session_id: str=None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts: List=None,
**kwargs) -> LLMResponse:
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
prompt: 提示词
session_id: 会话 ID
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
kwargs: 其他参数
Notes:
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
- 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。
传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
[optional]
image_url: 图片url(识图)
tools: 函数调用工具
'''
raise NotImplementedError()
@abc.abstractmethod
async def image_generate(self, prompt: str, session_id: str, **kwargs) -> str:
'''
prompt: 提示词
session_id: 会话id
'''
raise NotImplementedError()
@abc.abstractmethod
async def get_embedding(self, text: str) -> List[float]:
'''
获取文本的嵌入
'''
raise NotImplementedError()
@abc.abstractmethod
async def forget(self, session_id: str) -> bool:
'''
重置会话
'''
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
@@ -0,0 +1,6 @@
from dataclasses import dataclass
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.
+25
View File
@@ -0,0 +1,25 @@
from typing import List, Dict, Type
from .provider_metadata import ProviderMetaData
from astrbot.core import logger
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
provider_cls_map: Dict[str, Type] = {}
'''维护了 Provider 类型名称和 Provider 类的映射'''
def register_provider_adapter(provider_type_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = cls
logger.debug(f"Provider {provider_type_name} 已注册")
return cls
return decorator
@@ -0,0 +1,216 @@
import asyncio
import traceback
import base64
import json
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.tool import FuncCall
from typing import List
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from ..register import register_provider_adapter
from astrbot.core.provider.llm_response import LLMResponse
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
class ProviderOpenAIOfficial(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
contexts.append(f"Assistant: {record['content']}")
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
try:
models_str = []
models = await self.client.models.list()
for model in models:
models_str.append(model['id'])
return models_str
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
payloads["tools"] = tools.get_func_desc_openai_style()
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
return LLMResponse("assistant", completion_text)
elif choice.message.tool_calls:
# tools call (function calling)
args_ls = []
func_name_ls = []
for tool_call in choice.message.tool_calls:
for tool in tools.func_list:
if tool['name'] == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
else:
raise Exception("Internal Error")
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
if self.curr_personality["prompt"]:
context_query.insert(0, {"role": "system", "content": self.curr_personality["prompt"]})
else:
context_query = contexts
logger.debug(f"请求上下文:{context_query}")
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
if llm_response.role == "assistant":
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
return llm_response
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
def get_current_key(self) -> str:
return self.client.api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
if image_urls:
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:
return {"role": "user","content": text}
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''
@@ -1,7 +1,7 @@
from astrbot.core.provider import Provider
from typing import Awaitable
import json
import textwrap
from typing import Awaitable, Dict, List
from typing_extensions import TypedDict
class FuncCallJsonFormatError(Exception):
@@ -11,7 +11,6 @@ class FuncCallJsonFormatError(Exception):
def __str__(self):
return self.msg
class FuncNotFoundError(Exception):
def __init__(self, msg):
self.msg = msg
@@ -19,10 +18,19 @@ class FuncNotFoundError(Exception):
def __str__(self):
return self.msg
class FuncTool(TypedDict):
'''
用于描述一个函数调用工具
'''
name: str
parameters: Dict
description: str
func_obj: Awaitable
class FuncCall():
def __init__(self) -> None:
self.func_list = []
self.func_list: List[FuncTool] = []
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -45,12 +53,7 @@ class FuncCall():
"type": param['type'],
"description": param['description']
}
_func = {
"name": name,
"parameters": params,
"description": desc,
"func_obj": func_obj,
}
_func = FuncTool(name=name, parameters=params, description=desc, func_obj=func_obj)
self.func_list.append(_func)
def remove_func(self, name: str) -> None:
@@ -62,17 +65,16 @@ class FuncCall():
self.func_list.pop(i)
break
def func_dump(self) -> str:
_l = []
def get_func(self, name) -> FuncTool:
for f in self.func_list:
_l.append({
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
})
return json.dumps(_l, ensure_ascii=False)
def get_func(self) -> list:
if f["name"] == name:
return f
return None
def get_func_desc_openai_style(self) -> list:
'''
获得 OpenAI API 风格的工具描述
'''
_l = []
for f in self.func_list:
_l.append({
@@ -85,7 +87,17 @@ class FuncCall():
})
return _l
async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider) -> tuple:
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
_l.append({
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
})
func_definition = json.dumps(_l, ensure_ascii=False)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用
@@ -111,7 +123,6 @@ class FuncCall():
while _c < 3:
try:
res = await provider.text_chat(prompt, session_id)
print(res)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
res = json.loads(res)
@@ -5,11 +5,13 @@ import os
from readability import Document
from bs4 import BeautifulSoup
from openai._exceptions import *
from .websearch.config import HEADERS, USER_AGENTS
from .websearch.bing import Bing
from .websearch.sogo import Sogo
from .websearch.google import Google
from astrbot.api import logger, AstrMessageEvent, Provider, MessageChain, MessageEventResult
from engines.config import HEADERS, USER_AGENTS
from engines.bing import Bing
from engines.sogo import Sogo
from engines.google import Google
from astrbot.api.event import AstrMessageEvent, MessageChain, MessageEventResult
from astrbot.api.provider import Provider
from astrbot.api import logger
bing_search = Bing()
sogo_search = Sogo()
@@ -61,7 +63,6 @@ async def search_from_bing(keyword: str, event: AstrMessageEvent = None, provide
return await summarize(ret, event, provider)
async def fetch_website_content(url: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
+5
View File
@@ -0,0 +1,5 @@
# AstrBot Star
`AstrBot Star` 就是插件。
在 AstrBot v4.0 版本后,AstrBot 内部将插件命名为 `star`。插件的 handler 称作 `star_handler`
+4
View File
@@ -0,0 +1,4 @@
from .star import Star, StarMetadata
from .star_manager import PluginManager
from .context import Context
from astrbot.core.provider import Provider
+174
View File
@@ -0,0 +1,174 @@
import heapq
from asyncio import Queue
from . import StarMetadata
from typing import List, Dict, TypedDict, Union
from astrbot.core.platform import Platform
from astrbot.core.provider import Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.tool import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, star_map, StarMetadata
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
class StarCommand(TypedDict):
full_command_name: str
command_name: str
class Context:
'''
暴露给插件的接口上下文。
'''
_event_queue: Queue = None
'''事件队列。消息平台通过事件队列传递消息事件。'''
_config: AstrBotConfig = None
'''AstrBot 配置信息'''
_db: BaseDatabase = None
'''AstrBot 数据库'''
provider_manager: ProviderManager = None
platform_manager: PlatformManager = None
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
self._event_queue = event_queue
self._config = config
self._db = db
def get_registered_star(self, star_name: str) -> StarMetadata:
return star_map.get(star_name, None)
def get_all_stars(self) -> List[StarMetadata]:
return star_registry
def get_llm_tools(self) -> FuncCall:
'''
获取 LLM Tools。
'''
return self.provider_manager.llm_tools
# def get_star_commands(self, star_name: str) -> List[]:
# '''获得一个'''
# def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
# '''
# 为函数调用(function-calling / tools-use)添加工具。
# @param name: 函数名
# @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
# @param desc: 函数描述
# @param func_obj: 异步处理函数。
# 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
# '''
# self.llm_tools.add_func(name, func_args, desc, func_obj)
# def unregister_llm_tool(self, name: str) -> None:
# '''
# 删除一个函数调用工具。
# '''
# self.llm_tools.remove_func(name)
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
[Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。
@param star_name: 插件(Star)名称。
@param command_name: 命令名称。
@param desc: 命令描述。
@param priority: 优先级。1-10。
@param awaitable: 异步处理函数。
'''
md = StarHandlerMetadata(
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
)
if use_regex:
md.event_filters.append(RegexFilter(
regex=command_name
))
else:
md.event_filters.append(CommandFilter(
command_name=command_name,
handler_md=md
))
star_handlers_registry.append(md)
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider。
'''
self.provider_manager.provider_insts.append(provider)
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider。
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.registered_platforms:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
+10
View File
@@ -0,0 +1,10 @@
import abc
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
class HandlerFilter(abc.ABC):
@abc.abstractmethod
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
'''是否应当被过滤'''
raise NotImplementedError
+67
View File
@@ -0,0 +1,67 @@
import re, inspect
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from astrbot.core.utils.param_validation_mixin import ParameterValidationMixin
from typing import Awaitable
from ..star_handler import StarHandlerMetadata
# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter, ParameterValidationMixin):
'''标准指令过滤器'''
def __init__(self, command_name: str, handler_md: StarHandlerMetadata = None):
self.command_name = command_name
if handler_md:
self.init_handler_md(handler_md)
def print_types(self):
result = ""
print(self.handler_params)
for k, v in self.handler_params.items():
if isinstance(v, type):
result += f"{k}({v.__name__}),"
else:
result += f"{k}({type(v).__name__})={v},"
return result
def init_handler_md(self, handle_md: StarHandlerMetadata):
self.handler_md = handle_md
signature = inspect.signature(self.handler_md.handler)
self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值
idx = 0
for k, v in signature.parameters.items():
if idx < 2:
# 忽略前两个参数,即 self 和 event
idx += 1
continue
if v.default == inspect.Parameter.empty:
self.handler_params[k] = v.annotation
else:
self.handler_params[k] = v.default
def get_handler_md(self) -> StarHandlerMetadata:
return self.handler_md
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
if not event.is_wake_up():
return False
message_str = event.get_message_str().strip()
# 分割为列表(每个参数之间可能会有多个空格)
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 去除空字符串
ls = [param for param in ls if param]
params = {}
try:
params = self.validate_and_convert_params(ls, self.handler_params)
# 解析完成咱也不能丢掉呀,留着给后面的用
except ValueError as e:
raise e
event.set_extra("parsed_params", params)
return True
+70
View File
@@ -0,0 +1,70 @@
from __future__ import annotations
import re
from typing import Awaitable, List, Union, Tuple
from . import HandlerFilter
from .command import CommandFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from ..star_handler import StarHandlerMetadata
# 指令组受到 wake_prefix 的制约。
class CommandGroupFilter(HandlerFilter):
def __init__(self, group_name: str):
self.group_name = group_name
self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = []
def add_sub_command_filter(self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]):
self.sub_command_filters.append(sub_command_filter)
# 以树的形式打印出来
def print_cmd_tree(self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "") -> str:
result = ""
for sub_filter in sub_command_filters:
if isinstance(sub_filter, CommandFilter):
cmd_th = sub_filter.print_types()
result += f"{prefix}├── {sub_filter.command_name}"
if cmd_th:
result += f" ({cmd_th})"
else:
result += f" (无参数指令)"
result += "\n"
elif isinstance(sub_filter, CommandGroupFilter):
result += f"{prefix}├── {sub_filter.group_name}"
result += "\n"
result += sub_filter.print_cmd_tree(sub_filter.sub_command_filters, prefix+"")
return result
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> Tuple[bool, StarHandlerMetadata]:
if not event.is_wake_up():
return False, None
message_str = event.get_message_str().strip()
ls = re.split(r"\s+", message_str)
if ls[0] != self.group_name:
return False, None
# 改写 message_str
ls = ls[1:]
event.message_str = " ".join(ls)
event.message_str = event.message_str.strip()
if event.message_str == "":
# 当前还是指令组
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
child_command_handler_md = None
for sub_filter in self.sub_command_filters:
if isinstance(sub_filter, CommandFilter):
if sub_filter.filter(event, cfg):
child_command_handler_md = sub_filter.get_handler_md()
return True, child_command_handler_md
elif isinstance(sub_filter, CommandGroupFilter):
ok, handler = sub_filter.filter(event, cfg)
if ok:
child_command_handler_md = handler
return True, child_command_handler_md
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
raise ValueError(f"指令组 {self.group_name} 下没有找到对应的指令。这个指令组下有如下指令:\n"+tree)
@@ -0,0 +1,28 @@
import enum
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.message_type import MessageType
class EventMessageType(enum.Flag):
GROUP_MESSAGE = enum.auto()
PRIVATE_MESSAGE = enum.auto()
OTHER_MESSAGE = enum.auto()
ALL = GROUP_MESSAGE | PRIVATE_MESSAGE | OTHER_MESSAGE
MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE = {
MessageType.GROUP_MESSAGE: EventMessageType.GROUP_MESSAGE,
MessageType.FRIEND_MESSAGE: EventMessageType.PRIVATE_MESSAGE,
MessageType.OTHER_MESSAGE: EventMessageType.OTHER_MESSAGE
}
class EventMessageTypeFilter(HandlerFilter):
def __init__(self, event_message_type: EventMessageType):
self.event_message_type = event_message_type
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
message_type = event.get_message_type()
if message_type in MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE:
event_message_type = MESSAGE_TYPE_2_EVENT_MESSAGE_TYPE[message_type]
return bool(event_message_type & self.event_message_type)
return False
@@ -0,0 +1,27 @@
import enum
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
from typing import Union
class PlatformAdapterType(enum.Flag):
AIOCQHTTP = enum.auto()
QQOFFICIAL = enum.auto()
VCHAT = enum.auto()
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
ADAPTER_NAME_2_TYPE = {
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
"qq_official": PlatformAdapterType.QQOFFICIAL,
"vchat": PlatformAdapterType.VCHAT
}
class PlatformAdapterTypeFilter(HandlerFilter):
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
self.type_or_str = platform_adapter_type_or_str
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
adapter_name = event.get_platform_name()
if adapter_name in ADAPTER_NAME_2_TYPE:
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
return False
+14
View File
@@ -0,0 +1,14 @@
import re
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.config import AstrBotConfig
# 正则表达式过滤器不会受到 wake_prefix 的制约。
class RegexFilter(HandlerFilter):
'''正则表达式过滤器'''
def __init__(self, regex: str):
self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return bool(self.regex.match(event.get_message_str().strip()))
+8
View File
@@ -0,0 +1,8 @@
from .star import register_star
from .star_handler import (
register_command,
register_command_group,
register_event_message_type,
register_platform_adapter_type,
register_regex
)
+18
View File
@@ -0,0 +1,18 @@
from ..star import star_registry, StarMetadata, star_map
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
def decorator(cls):
star_metadata = StarMetadata(
name=name,
author=author,
desc=desc,
version=version,
repo=repo,
star_cls_type=cls,
module_path=cls.__module__
)
star_registry.append(star_metadata)
star_map[cls.__module__] = star_metadata
return cls
return decorator
+115
View File
@@ -0,0 +1,115 @@
from __future__ import annotations
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAdapterType
from ..filter.regex import RegexFilter
from typing import Awaitable, List, Dict
def get_handler_full_name(awatable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awatable.__module__}_{awatable.__name__}"
def get_handler_or_create(handler: Awaitable) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
if handler_full_name in star_handlers_map:
return star_handlers_map[handler_full_name]
else:
md = StarHandlerMetadata(
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
handler=handler,
event_filters=[]
)
star_handlers_registry.append(md)
star_handlers_map[handler_full_name] = md
return md
def register_command(command_name: str = None, *args):
'''注册一个 Command'''
new_command = None
add_to_event_filters = False
if isinstance(command_name, RegisteringCommandable):
# 子指令
new_command = CommandFilter(args[0], None)
command_name.parent_group.add_sub_command_filter(new_command)
else:
# 裸指令
new_command = CommandFilter(command_name, None)
add_to_event_filters = True
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
handler_md.event_filters.append(new_command)
return awaitable
return decorator
def register_command_group(command_group_name: str = None, *args):
'''注册一个 CommandGroup'''
new_group = None
add_to_event_filters = False
if isinstance(command_group_name, RegisteringCommandable):
# 子指令组
new_group = CommandGroupFilter(args[0])
command_group_name.parent_group.add_sub_command_filter(new_group)
else:
# 根指令组
new_group = CommandGroupFilter(command_group_name)
add_to_event_filters = True
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
return decorator
class RegisteringCommandable():
'''用于指令组级联注册'''
group = register_command_group
command = register_command
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
def register_event_message_type(event_message_type: EventMessageType):
'''注册一个 EventMessageType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awatable
return decorator
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
'''注册一个 PlatformAdapterType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
return awatable
return decorator
def register_regex(regex: str):
'''注册一个 Regex'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
handler_md.event_filters.append(RegexFilter(regex))
return awatable
return decorator
+43
View File
@@ -0,0 +1,43 @@
from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from astrbot.core.utils.command_parser import CommandParserMixin
star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
'''key 是模块路径,__module__'''
class Star(CommandParserMixin):
'''所有插件(Star)的父类,所有插件都应该继承于这个类'''
def __init__(self):
pass
@dataclass
class StarMetadata:
'''
Star 的元数据。
'''
name: str
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
repo: str = None # 插件仓库地址
star_cls_type: type = None
'''Star 的类对象的类型'''
module_path: str = None
'''Star 的模块路径'''
star_cls: object = None
'''Star 的类对象'''
module: ModuleType = None
'''Star 的模块对象'''
root_dir_name: str = None
'''Star 的根目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+31
View File
@@ -0,0 +1,31 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Awaitable, List, Dict
from .filter import HandlerFilter
star_handlers_registry: List[StarHandlerMetadata] = []
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
@dataclass
class StarHandlerMetadata():
'''描述一个 Star 所注册的某一个 Handler。'''
handler_full_name: str
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
handler_name: str
'''Handler 的名字,也就是方法名'''
handler_module_str: str
'''Handler 所在的模块路径。'''
handler: Awaitable
'''Handler 的函数对象,应当是一个异步函数'''
event_filters: List[HandlerFilter]
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
desc: str = ""
'''Handler 的描述信息'''
@@ -1,31 +1,35 @@
import inspect
import os
import sys
import traceback
import uuid
import shutil
import yaml
import logging
from asyncio import Queue
from types import ModuleType
from typing import List, Awaitable
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger
from .context import Context
from . import RegisteredPlugin, PluginMetadata
from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from .star_handler import star_handlers_registry
class PluginManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue, db: BaseDatabase):
self.updator = PluginUpdator(config.plugin_repo_mirror)
self.context = Context(event_queue, config, db)
def __init__(
self,
context: Context,
config: AstrBotConfig
):
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
def _get_classes(self, arg: ModuleType):
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
@@ -69,6 +73,9 @@ class PluginManager:
return plugins
def _check_plugin_dept_update(self, target_plugin: str = None):
'''检查插件的依赖
如果 target_plugin None则检查所有插件的依赖
'''
plugin_dir = self.plugin_store_path
if not os.path.exists(plugin_dir):
return False
@@ -76,7 +83,7 @@ class PluginManager:
if target_plugin:
to_update.append(target_plugin)
else:
for p in self.context.registered_plugins:
for p in self.context.get_all_stars():
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
@@ -89,6 +96,7 @@ class PluginManager:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend(self.config.pip_install_arg)
@@ -96,7 +104,11 @@ class PluginManager:
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata:
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据
先寻找 metadata.yaml 文件如果不存在则使用插件对象的 info() 函数获取元数据
'''
metadata = None
if not os.path.exists(plugin_path):
@@ -112,8 +124,8 @@ class PluginManager:
if isinstance(metadata, dict):
if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata:
raise Exception("插件元数据信息不完整。")
metadata = PluginMetadata(
plugin_name=metadata['name'],
metadata = StarMetadata(
name=metadata['name'],
author=metadata['author'],
desc=metadata['desc'],
version=metadata['version'],
@@ -123,72 +135,68 @@ class PluginManager:
return metadata
def reload(self):
'''
加载插件类
'''
registered_plugins = self.context.registered_plugins
plugins = self._get_plugin_modules()
if plugins is None:
'''扫描并加载所有的 Star'''
star_handlers_registry.clear()
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
registered_map = {}
for p in registered_plugins:
registered_map[p.module_path] = None
for plugin in plugins:
# 导入 Star 模块,并尝试实例化 Star 类
for plugin_module in plugin_modules:
try:
p = plugin['module']
module_path = plugin['module_path']
root_dir_name = plugin['pname']
reserved = plugin.get('reserved', False)
module_str = plugin_module['module']
module_path = plugin_module['module_path']
root_dir_name = plugin_module['pname']
reserved = plugin_module.get('reserved', False)
logger.info(f"正在载插件 {root_dir_name} ...")
logger.info(f"正在载插件 {root_dir_name} ...")
pre = "data.plugins." if not reserved else "packages."
# 尝试导入插件模块
# 尝试导入模块
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
try:
module = __import__(pre + root_dir_name + "." + p, fromlist=[p])
module = __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError) as e:
# 尝试安装插件依赖
# 尝试安装依赖
self._check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__(pre + root_dir_name + "." + p, fromlist=[p])
module = __import__(path, fromlist=[module_str])
except Exception as e:
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
cls = self._get_classes(module)
# 实例化插件类
try:
obj = getattr(module, cls[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
# 解析插件元数据,加入注册列表
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
if module_path not in registered_map:
registered_plugins.append(RegisteredPlugin(
metadata=metadata,
plugin_instance=obj,
module=module,
module_path=module_path,
root_dir_name=root_dir_name,
reserved=reserved
))
if path in star_map:
# 通过装饰器的方式注册插件
star_metadata = star_map[path]
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
star_metadata.module = module
star_metadata.root_dir_name = root_dir_name
star_metadata.reserved = reserved
else:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
classes = self._get_classes(module)
try:
obj = getattr(module, classes[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
metadata.star_cls_type = obj.__class__
metadata.module_path = path
star_map[path] = metadata
for command in self.context.commands_handler:
if self.context.commands_handler[command].plugin_name == metadata.plugin_name:
self.context.commands_handler[command].plugin_metadata = metadata
for listener in self.context.listeners_handler:
if self.context.listeners_handler[listener].plugin_name == metadata.plugin_name:
self.context.listeners_handler[listener].plugin_metadata = metadata
except BaseException as e:
traceback.print_exc()
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
fail_rec += f"加载 {path} 插件出现问题,原因 {str(e)}\n"
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
@@ -200,26 +208,26 @@ class PluginManager:
return False, fail_rec
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.update(repo_url)
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
f.write(repo_url)
plugin_path = await self.updator.install(repo_url)
self._check_plugin_dept_update()
return plugin_path
def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_plugin(plugin_name)
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
if plugin.reserved:
raise Exception("该插件是 AstrBot 保留插件,无法卸载。")
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
self.context.registered_plugins.remove(plugin)
del star_map[plugin.module_path]
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
async def update_plugin(self, plugin_name: str):
plugin = self.context.get_registered_plugin(plugin_name)
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
if plugin.reserved:
@@ -228,41 +236,13 @@ class PluginManager:
await self.updator.update(plugin)
def install_plugin_from_file(self, zip_file_path: str):
# try to unzip
temp_dir = os.path.join(os.path.dirname(zip_file_path), str(uuid.uuid4()))
self.updator.unzip_file(zip_file_path, temp_dir)
# check if the plugin has metadata.yaml
if not os.path.exists(os.path.join(temp_dir, "metadata.yaml")):
remove_dir(temp_dir)
raise Exception("插件缺少 metadata.yaml 文件。")
metadata = self._load_plugin_metadata(temp_dir)
plugin_name = metadata.plugin_name
if not plugin_name:
remove_dir(temp_dir)
raise Exception("插件 metadata.yaml 文件中 name 字段为空。")
plugin_name = self.updator.format_name(plugin_name)
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
self.updator.unzip_file(zip_file_path, desti_dir)
ppath = self.plugin_store_path
plugin_path = os.path.join(ppath, plugin_name)
if os.path.exists(plugin_path):
remove_dir(plugin_path)
# move to the target path
shutil.move(temp_dir, plugin_path)
if metadata.repo:
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
f.write(metadata.repo)
# remove the temp dir
remove_dir(temp_dir)
# remove the zip
try:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
def get_platform_insts(self):
return self.context.registered_platforms
def get_loaded_plugins(self):
return self.context.registered_plugins
@@ -2,7 +2,7 @@ import os, zipfile, shutil
from ..updator import RepoZipUpdator
from astrbot.core.utils.io import remove_dir, on_error
from ..plugin import RegisteredPlugin
from ..star.star import StarMetadata
from typing import Union
from astrbot.core import logger
@@ -13,20 +13,22 @@ class PluginUpdator(RepoZipUpdator):
def get_plugin_store_path(self) -> str:
return self.plugin_store_path
async def update(self, plugin: Union[RegisteredPlugin, str]) -> str:
repo_url = None
async def install(self, repo_url: str) -> str:
repo_name = self.format_repo_name(repo_url)
plugin_path = os.path.join(self.plugin_store_path, repo_name)
await self.download_from_repo_url(plugin_path, repo_url)
self.unzip_file(plugin_path + ".zip", plugin_path)
if not isinstance(plugin, str):
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
if not os.path.exists(os.path.join(plugin_path, "REPO")):
raise Exception("插件更新信息文件 `REPO` 不存在,请手动升级,或者先卸载然后重新安装该插件。")
with open(os.path.join(plugin_path, "REPO"), "r", encoding='utf-8') as f:
repo_url = f.read()
else:
repo_url = plugin
plugin_path = os.path.join(self.plugin_store_path, self.format_repo_name(repo_url))
return plugin_path
async def update(self, plugin: StarMetadata) -> str:
repo_url = plugin.repo
if not repo_url:
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
await self.download_from_repo_url(plugin_path, repo_url)
@@ -34,7 +36,7 @@ class PluginUpdator(RepoZipUpdator):
try:
remove_dir(plugin_path)
except BaseException as e:
logger.error(f"删除旧版本插件 {plugin.metadata.plugin_name} 文件夹失败: {str(e)},使用覆盖安装。")
logger.error(f"删除旧版本插件 {plugin_path} 文件夹失败: {str(e)},使用覆盖安装。")
self.unzip_file(plugin_path + ".zip", plugin_path)
@@ -48,13 +50,10 @@ class PluginUpdator(RepoZipUpdator):
update_dir = z.namelist()[0]
z.extractall(target_dir)
avoid_dirs = ["logs", "data", "configs", "temp_plugins", update_dir]
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if f in avoid_dirs: continue
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
else:
@@ -68,3 +67,4 @@ class PluginUpdator(RepoZipUpdator):
os.remove(zip_path)
except:
logger.warning(f"删除更新文件失败,可以手动删除 {zip_path}{os.path.join(target_dir, update_dir)}")
+3 -6
View File
@@ -10,13 +10,10 @@ class CommandTokens():
return None
return self.tokens[idx].strip()
class CommandParser():
def __init__(self):
pass
def parse(self, message: str):
class CommandParserMixin():
def parse_commands(self, message: str):
cmd_tokens = CommandTokens()
cmd_tokens.tokens = message.split(" ")
cmd_tokens.tokens = re.split(r"\s+", message)
cmd_tokens.len = len(cmd_tokens.tokens)
return cmd_tokens
-21
View File
@@ -1,21 +0,0 @@
import aiohttp
import uuid
class ImageUploader():
def __init__(self) -> None:
self.S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
async def upload_image(self, image_path: str) -> str:
'''
上传图像文件到S3
'''
with open(image_path, "rb") as f:
image = f.read()
image_url = f"{self.S3_URL}/{uuid.uuid4().hex}.jpg"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with session.put(image_url, data=image) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return image_url
@@ -0,0 +1,31 @@
import inspect
from typing import Awaitable, List, Union, Dict, Any, Type
class ParameterValidationMixin:
def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, Type]) -> Dict[str, Any]:
'''将参数列表 params 根据 param_type 转换为参数字典。
'''
result = {}
print(params, param_type)
for i, (param_name, param_type_or_default_val) in enumerate(param_type.items()):
if i >= len(params):
if isinstance(param_type_or_default_val, Type) or param_type_or_default_val is inspect.Parameter.empty:
# 是类型
raise ValueError(f"参数 {param_name} 缺失")
else:
# 是默认值
result[param_name] = param_type_or_default_val
else:
# 尝试强制转换
try:
if param_type_or_default_val == None:
if params[i].isdigit():
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
else:
result[param_name] = type(param_type_or_default_val)(params[i])
except ValueError:
raise ValueError(f"参数 {param_name} 类型错误")
print(result)
return result
+1 -1
View File
@@ -21,4 +21,4 @@ class AstrBotDashBoardLifecycle:
await task
except asyncio.CancelledError as e:
logger.info("🌈 正在关闭 AstrBot...")
core_lifecycle.stop()
await core_lifecycle.stop()
+1 -1
View File
@@ -41,7 +41,7 @@ class AuthRoute(Route):
if new_username:
self.config.dashboard.username = new_username
self.config.flush_config()
self.config.save_config()
return Response().ok(None, "修改成功").__dict__
+8 -11
View File
@@ -1,11 +1,10 @@
import os, json
import os, json, traceback
from .route import Route, Response, RouteContext
from quart import Quart, request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE, ADAPTER_CONFIG_TEMPLATE
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.plugin.config import update_config
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from dataclasses import asdict
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -55,10 +54,6 @@ def validate_config(data, config: AstrBotConfig):
validate(value, meta["items"], path=f"{path}{key}.")
validate(data)
# hardcode warning
data['config_version'] = config.config_version
data['dashboard'] = asdict(config.dashboard)
return errors
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
@@ -66,7 +61,7 @@ def save_astrbot_config(post_config: dict, config: AstrBotConfig):
errors = validate_config(post_config, config)
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.flush_config(post_config)
config.save_config(post_config)
def save_extension_config(post_config: dict):
if 'namespace' not in post_config:
@@ -112,6 +107,7 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
traceback.print_exc()
return Response().error(str(e)).__dict__
async def post_extension_configs(self):
@@ -123,14 +119,15 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
async def _get_astrbot_config(self):
config = self.config.to_dict()
config = self.config
for key in self.config_key_dont_show:
if key in config:
del config[key]
return {
"metadata": CONFIG_METADATA_2,
"config": config,
"provider_config_tmpl": PROVIDER_CONFIG_TEMPLATE
"provider_config_tmpl": PROVIDER_CONFIG_TEMPLATE,
"adapter_config_tmpl": ADAPTER_CONFIG_TEMPLATE
}
async def _get_extension_config(self, namespace: str):
+8 -8
View File
@@ -2,7 +2,7 @@ import threading, traceback, uuid
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import Quart, request
from astrbot.core.plugin.plugin_manager import PluginManager
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class PluginRoute(Route):
@@ -21,14 +21,14 @@ class PluginRoute(Route):
async def get_plugins(self):
_plugin_resp = []
for plugin in self.plugin_manager.context.registered_plugins:
_p = plugin.metadata
for plugin in self.plugin_manager.context.get_all_stars():
_t = {
"name": _p.plugin_name,
"repo": '' if _p.repo is None else _p.repo,
"author": _p.author,
"desc": _p.desc,
"version": _p.version
"name": plugin.name,
"repo": '' if plugin.repo is None else plugin.repo,
"author": plugin.author,
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved
}
_plugin_resp.append(_t)
return Response().ok(_plugin_resp).__dict__
+2 -2
View File
@@ -72,8 +72,8 @@ class StatRoute(Route):
stat_dict.update({
"platform": self.db_helper.get_grouped_base_stats(offset_sec).platform,
"message_count": self.db_helper.get_total_message_count() or 0,
"platform_count": len(self.core_lifecycle.plugin_manager.get_platform_insts()),
"plugin_count": len(self.core_lifecycle.plugin_manager.get_loaded_plugins()),
"platform_count": len(self.core_lifecycle.platform_manager.get_insts()),
"plugin_count": len(self.core_lifecycle.star_context.get_all_stars()),
"message_time_series": message_time_based_stats,
"running": self.format_sec(int(time.time()) - self.core_lifecycle.start_time),
"memory": {
+2 -2
View File
@@ -39,10 +39,10 @@ class AstrBotDashboard():
return
# claim jwt
token = request.headers.get("Authorization")
if token.startswith("Bearer "):
token = token[7:]
if not token:
return Response().error("未授权").__dict__
if token.startswith("Bearer "):
token = token[7:]
try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
+2 -1
View File
@@ -96,7 +96,8 @@ if __name__ == "__main__":
# print logo
logger.info(logo_tmpl)
dashboard_lifecycle = AstrBotDashBoardLifecycle(db)
core_lifecycle = AstrBotCoreLifecycle(log_broker, db)
asyncio.run(core_lifecycle.initialize())
dashboard_lifecycle = AstrBotDashBoardLifecycle(db)
asyncio.run(dashboard_lifecycle.start(core_lifecycle))
+236 -67
View File
@@ -1,59 +1,21 @@
import aiohttp, base64, os, json, re, time
import aiohttp
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from typing import Dict
from astrbot.api import Context, AstrMessageEvent, MessageEventResult
from astrbot.api import logger, command_parser
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api.platform import MessageType
from astrbot.api import logger
from astrbot.api import personalities
from astrbot.api.provider import Personality
class Main:
def __init__(self, context: Context) -> None:
from typing import Union
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
context.register_commands("astrbot", "help", "查看 AstrBot 帮助", 10, self.help)
context.register_commands("astrbot", "plugin", "AstrBot 插件管理", 10, self.plugin)
context.register_commands("astrbot", "t2i", "关闭/启动文本转图片", 10, self.t2i)
context.register_commands("astrbot", "myid", "查看自己在该平台上的 ID", 10, self.myid)
context.register_listener("astrbot", "keywords_ban_rate_limit", self.keywords_ban, "关键词屏蔽和发言频率监听器")
# keywords
with open(os.path.join(os.path.dirname(__file__), "unfit_words"), "r", encoding="utf-8") as f:
self.keywords: list = json.loads(base64.b64decode(f.read()).decode("utf-8"))['keywords']
internal_keywords_cfg = context.get_config().content_safety.internal_keywords
if internal_keywords_cfg.enable:
self.keywords.extend(internal_keywords_cfg.extra_keywords)
# rate limit
self.user_rate_limit: Dict[int, int] = {}
rl_cfg = context.get_config().platform_settings.rate_limit
self.rate_limit_time: int = rl_cfg.time
self.rate_limit_count: int = rl_cfg.count
self.user_frequency = {}
async def keywords_ban(self, event: AstrMessageEvent):
if not event.is_wake_up():
return
# keywords 检测
for i in self.keywords:
matches = re.match(i, event.get_message_str().strip(), re.I | re.M)
if matches:
event.set_result(MessageEventResult().message("你的消息中包含不适当的关键词,已被屏蔽。"))
return
# rate limit 检测
ts = int(time.time())
if event.session_id in self.user_frequency:
if ts-self.user_frequency[event.session_id]['time'] > self.rate_limit_time:
# reset
self.user_frequency[event.session_id]['time'] = ts
self.user_frequency[event.session_id]['count'] = 1
return
if self.user_frequency[event.session_id]['count'] >= self.rate_limit_count:
event.set_result(MessageEventResult().message("你发送消息的频率过快,请稍后再试。"))
return
self.user_frequency[event.session_id]['count'] += 1
else:
t = {'time': ts, 'count': 1}
self.user_frequency[event.session_id] = t
@filter.command("help")
async def help(self, event: AstrMessageEvent):
notice = ""
try:
@@ -63,26 +25,41 @@ class Main:
except BaseException as e:
pass
msg = "# AstrBot 帮助\n## 已注册的指令\n"
for key, value in self.context.commands_handler.items():
if value.plugin_metadata:
msg += f"- `{key}` ({value.plugin_metadata.plugin_name}): {value.description}\n"
else: msg += f"- `{key}`: {value.description}\n"
msg = "已注册的 AstrBot 内置指令:"
msg += f"""[System]
/plugin: 插件管理
/t2i: 开启/关闭文本转图片模式
/sid: 获取当前会话的 ID
/op <admin_id>: 授权管理员
/deop <admin_id>: 取消管理员
/wl <sid>: 添加会话白名单
/dwl <sid>: 删除会话白名单
msg += "\n> 提示:使用 /plugin 查看已加载的插件\n"
msg += notice
[大模型]
/provider: 查看、切换大模型提供商
/model: 查看、切换提供商模型列表
/key: 查看、切换 API Key
/reset: 重置 LLM 会话
/history: 获取会话历史记录
/persona: 情境人格设置
event.set_result(MessageEventResult().message(msg))
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent):
plugin_list_info = "已加载的插件:\n"
for plugin in self.context.registered_plugins:
plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {plugin.metadata.author}: {plugin.metadata.desc}\n"
plugin_list_info += f"- `{plugin.metadata.plugin_name}` By {
plugin.metadata.author}: {plugin.metadata.desc}\n"
if plugin_list_info.strip() == "":
plugin_list_info = "没有加载任何插件。"
event.set_result(MessageEventResult().message(f"{plugin_list_info}"))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
config = self.context.get_config()
if config.t2i:
@@ -93,7 +70,199 @@ class Main:
config.t2i = True
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
async def myid(self, event: AstrMessageEvent):
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):
sid = event.unified_msg_origin
user_id = str(event.get_sender_id())
event.set_result(MessageEventResult().message(f"你的 ID 是 {user_id}此 ID 可用于设置 AstrBot 管理员。"))
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deop <UID> 取消管理员。"""
event.set_result(MessageEventResult().message(ret))
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str):
self.context.get_config()['admins_id'].append(admin_id)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str):
try:
self.context.get_config()['admins_id'].remove(admin_id)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("取消授权成功。"))
except ValueError:
event.set_result(MessageEventResult().message("此用户 ID 不在管理员名单内。"))
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str):
self.context.get_config()['platform_settings']['id_whitelist'].append(sid)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str):
try:
self.context.get_config()['platform_settings']['id_whitelist'].remove(sid)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("删除白名单成功。"))
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
@filter.command("provider")
async def provider(self, event: AstrMessageEvent, idx: int = None):
'''查看或者切换 LLM Provider'''
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})"
if self.provider == llm:
ret += " (当前使用)"
ret += "\n"
ret += "\n使用 /provider <序号> 切换提供商。"
event.set_result(MessageEventResult().message(ret))
else:
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1]
event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}"))
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
await self.context.get_using_provider().forget(message.session_id)
message.set_result(MessageEventResult().message("重置成功"))
@filter.command("model")
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
if idx_or_name is None:
models = []
try:
models = await self.context.get_using_provider().get_models()
except BaseException as e:
message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)))
return
i = 1
ret = "下面列出了此服务提供商可用模型:"
for model in models:
ret += f"\n{i}. {model}"
i += 1
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
if isinstance(idx_or_name, int):
models = []
try:
models = await self.context.get_using_provider().get_models()
except BaseException as e:
message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)))
return
if idx_or_name > len(models) or idx_or_name < 1:
message.set_result(MessageEventResult().message("模型序号错误。"))
else:
try:
new_model = models[idx_or_name-1]
self.context.get_using_provider().set_model(new_model)
except BaseException as e:
message.set_result(
MessageEventResult().message("切换模型未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换模型成功。"))
else:
self.context.get_using_provider().set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型成功。 \n模型信息: {idx_or_name}"))
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
size_per_page = 3
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
history = ""
for context in contexts:
history += f"{context}\n"
ret = f"""历史记录:
{history}
{page} 页 | 共 {total_pages}
*输入 /history 2 跳转到第 2 页
"""
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
if index == None:
keys_data = self.context.get_using_provider().get_keys()
curr_key = self.context.get_using_provider().get_current_key()
ret = "Key:"
for i, k in enumerate(keys_data):
ret += f"\n{i+1}. {k[:8]}"
ret += f"\n当前 Key: {curr_key[:8]}"
ret += "\n当前模型: " + self.context.get_using_provider().get_model()
ret += "\n使用 /key <idx> 切换 Key。"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
keys_data = self.context.get_using_provider().get_keys()
if index > len(keys_data) or index < 1:
message.set_result(MessageEventResult().message("Key 序号错误。"))
else:
try:
new_key = keys_data[index-1]
self.context.get_using_provider().set_key(new_key)
except BaseException as e:
message.set_result(
MessageEventResult().message("切换 Key 未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换 Key 成功。"))
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
l = message.message_str.split(" ")
if len(l) == 1:
message.set_result(
MessageEventResult().message(f"""[Persona]
- 设置人格: `/persona 人格名`, 如 /persona 编剧
- 人格列表: `/persona list`
- 人格详细信息: `/persona view 人格名`
- 自定义人格: /persona 人格文本
- 重置 LLM 会话(清除人格): /reset
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
"""))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
msg += f"- {key}\n"
msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息'
message.set_result(MessageEventResult().message(msg))
elif l[1] == "view":
if len(l) == 2:
message.set_result(MessageEventResult().message("请输入人格名"))
ps = l[2].strip()
if ps in personalities:
msg = f"人格{ps}的详细信息:\n"
msg += f"{personalities[ps]}\n"
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
else:
ps = "".join(l[1:]).strip()
if ps in personalities:
self.context.get_using_provider().curr_personality = Personality(
name=ps, prompt=personalities[ps])
message.set_result(
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
else:
self.context.get_using_provider().curr_personality = Personality(
name="自定义人格", prompt=ps)
message.set_result(
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
-6
View File
@@ -1,6 +0,0 @@
name: astrbot # 插件名称
desc: AstrBot 内置指令集
help:
version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1
author: AstrBot # 作者
repo: https://github.com/Soulter/AstrBot
@@ -1,13 +0,0 @@
from astrbot.api import Context
from .aiocqhttp_platform_adapter import AiocqhttpAdapter
from astrbot.api import logger
class Main:
def __init__(self, context: Context) -> None:
self.context = context
platforms_config = context.get_config().platform
settings = context.get_config().platform_settings
for platform in platforms_config:
if platform.name == "aiocqhttp" and platform.enable:
self.context.register_platform(AiocqhttpAdapter(platform, settings, context.get_event_queue()))
logger.info(f"已注册 aiocqhttp({platform.id}) 消息适配器。")
@@ -1,6 +0,0 @@
name: astrbot_adapter_aiocqhttp # 插件名称
desc: 支持 OneBot 协议的消息平台适配器(反向 Websockets)
help:
version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1
author: Soulter # 作者
repo: https://github.com/Soulter/AstrBot
@@ -1,18 +0,0 @@
import botpy, logging
# delete qqbotpy's logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
from astrbot.api import Context
from .qqofficial_platform_adapter import QQOfficialPlatformAdapter
from astrbot.api import logger
class Main:
def __init__(self, context: Context) -> None:
self.context = context
platforms_config = context.get_config().platform
settings = context.get_config().platform_settings
for platform in platforms_config:
if platform.name == "qq_official" and platform.enable:
self.context.register_platform(QQOfficialPlatformAdapter(platform, settings, context.get_event_queue()))
logger.info(f"已注册 qq_official({platform.id}) 消息适配器。")
@@ -1,6 +0,0 @@
name: astrbot_adapter_qqofficial # 插件名称
desc: 支持 QQ 官方机器人平台的消息平台适配器
help:
version: v1.3.0 # 插件版本号。格式:v1.1.1 或者 v1.1
author: Soulter # 作者
repo: https://github.com/Soulter/AstrBot
-18
View File
@@ -1,18 +0,0 @@
from astrbot.api import Context, AstrMessageEvent, MessageEventResult
from .wechat_platform_adapter import WechatPlatformAdapter
from astrbot.api import logger
class Main:
def __init__(self, context: Context) -> None:
self.context = context
platforms_config = context.get_config().platform
settings = context.get_config().platform_settings
for platform in platforms_config:
if platform.name == "wechat" and platform.enable:
self.context.register_platform(WechatPlatformAdapter(platform, settings, context.get_event_queue()))
logger.info(f"已注册 wechat({platform.id}) 消息适配器。")
self.context.register_commands("astrbot_adapter_wechat", "wechatid", "查看微信ID", 1, self.get_wechat_id)
async def get_wechat_id(self, event: AstrMessageEvent):
event.set_result(MessageEventResult().message("这个会话的微信ID是" + event.message_obj.raw_message.from_.username))

Some files were not shown because too many files have changed in this diff Show More