Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 37a1f144ab | |||
| 9a7a654596 | |||
| 9abccd63cf | |||
| 93fea77182 | |||
| 19797243f6 | |||
| c9c733d925 | |||
| a7d7678c78 | |||
| c0911921c7 | |||
| 4a4241d57a | |||
| c9426bb6eb | |||
| db4abd169a | |||
| 80b6958599 | |||
| 80058c781a | |||
| 44bd2e36f3 |
@@ -36,7 +36,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
> [!TIP]
|
||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM,因此无法在聊天页使用大模型。
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
|
||||
|
||||
## ✨ 使用方式
|
||||
|
||||
@@ -78,6 +78,18 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
|
||||
# 🦌 接下来的路线图
|
||||
|
||||
> [!TIP]
|
||||
> 欢迎在 Issue 提出更多建议 <3
|
||||
|
||||
- [ ] 完善并保证目前所有平台适配器的功能一致性
|
||||
- [ ] 优化插件接口
|
||||
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
|
||||
- [ ] 完善“聊天增强”部分,支持持久化记忆
|
||||
- [ ] 规划 i18n
|
||||
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
@@ -2,4 +2,5 @@ from astrbot.core.platform import (
|
||||
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
)
|
||||
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
from astrbot.core.platform.register import register_platform_adapter
|
||||
from astrbot.core.message.components import *
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import json
|
||||
import logging
|
||||
import enum
|
||||
from .default import DEFAULT_CONFIG
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
from typing import Dict
|
||||
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
@@ -13,29 +13,72 @@ class RateLimitStrategy(enum.Enum):
|
||||
DISCARD = "discard"
|
||||
|
||||
class AstrBotConfig(dict):
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
|
||||
'''从配置文件中加载的配置,支持直接通过点号操作符访问根配置项。
|
||||
|
||||
def __init__(self):
|
||||
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
|
||||
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
default_config: dict = DEFAULT_CONFIG,
|
||||
schema: dict = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
|
||||
object.__setattr__(self, 'config_path', config_path)
|
||||
object.__setattr__(self, 'default_config', default_config)
|
||||
object.__setattr__(self, 'schema', schema)
|
||||
|
||||
if schema:
|
||||
default_config = self._config_schema_to_default_config(schema)
|
||||
|
||||
if not self.check_exist():
|
||||
'''不存在时载入默认配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
|
||||
with open(config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(default_config, f, indent=4, ensure_ascii=False)
|
||||
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
with open(config_path, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
conf = json.loads(conf_str)
|
||||
|
||||
# 检查配置完整性,并插入
|
||||
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
|
||||
has_new = self.check_config_integrity(default_config, conf)
|
||||
self.update(conf)
|
||||
if has_new:
|
||||
self.save_config()
|
||||
|
||||
self.update(conf)
|
||||
|
||||
|
||||
def _config_schema_to_default_config(self, schema: dict) -> dict:
|
||||
'''将 Schema 转换成 Config'''
|
||||
conf = {}
|
||||
|
||||
def _parse_schema(schema: dict, conf: dict):
|
||||
for k, v in schema.items():
|
||||
if v['type'] not in DEFAULT_VALUE_MAP:
|
||||
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
|
||||
if 'default' in v:
|
||||
default = v['default']
|
||||
else:
|
||||
default = DEFAULT_VALUE_MAP[v['type']]
|
||||
|
||||
if v['type'] == 'object':
|
||||
conf[k] = {}
|
||||
_parse_schema(v['items'], conf[k])
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
_parse_schema(schema, conf)
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
|
||||
'''检查配置完整性,如果有新的配置项则返回 True'''
|
||||
has_new = False
|
||||
@@ -61,7 +104,7 @@ class AstrBotConfig(dict):
|
||||
'''
|
||||
if replace_config:
|
||||
self.update(replace_config)
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
with open(self.config_path, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(self, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def __getattr__(self, item):
|
||||
@@ -81,4 +124,4 @@ class AstrBotConfig(dict):
|
||||
self[key] = value
|
||||
|
||||
def check_exist(self) -> bool:
|
||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
||||
return os.path.exists(self.config_path)
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.14"
|
||||
VERSION = "3.4.15"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -236,8 +236,8 @@ CONFIG_METADATA_2 = {
|
||||
"id_whitelist": {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "int"},
|
||||
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
"items": {"type": "string"},
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
"description": "打印白名单日志",
|
||||
@@ -265,6 +265,7 @@ CONFIG_METADATA_2 = {
|
||||
"path_mapping": {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
}
|
||||
@@ -589,14 +590,14 @@ CONFIG_METADATA_2 = {
|
||||
"begin_dialogs": {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {},
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {},
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
@@ -684,7 +685,7 @@ CONFIG_METADATA_2 = {
|
||||
"admins_id": {
|
||||
"description": "管理员 ID",
|
||||
"type": "list",
|
||||
"items": {"type": "int"},
|
||||
"items": {"type": "string"},
|
||||
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
|
||||
},
|
||||
"http_proxy": {
|
||||
|
||||
@@ -15,6 +15,7 @@ class RespondStage(Stage):
|
||||
|
||||
# 分段回复
|
||||
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
|
||||
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
|
||||
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
|
||||
interval_str_ls = interval_str.replace(" ", "").split(",")
|
||||
try:
|
||||
@@ -22,6 +23,7 @@ class RespondStage(Stage):
|
||||
except BaseException as e:
|
||||
logger.error(f'解析分段回复的间隔时间失败。{e}')
|
||||
self.interval = [1.5, 3.5]
|
||||
logger.info(f"分段回复间隔时间:{self.interval}")
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -31,11 +33,12 @@ class RespondStage(Stage):
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event._pre_send()
|
||||
if self.enable_seg:
|
||||
|
||||
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
|
||||
# 分段回复
|
||||
for comp in result.chain:
|
||||
await event.send(MessageChain([comp]))
|
||||
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
|
||||
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
|
||||
else:
|
||||
await event.send(result)
|
||||
await event._post_send()
|
||||
|
||||
@@ -18,6 +18,11 @@ class WhitelistCheckStage(Stage):
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if not self.enable_whitelist_check:
|
||||
# 白名单检查未启用
|
||||
return
|
||||
|
||||
if len(self.whitelist) == 0:
|
||||
# 白名单为空,不检查
|
||||
return
|
||||
|
||||
if event.get_platform_name() == 'webchat':
|
||||
|
||||
@@ -13,6 +13,7 @@ class ProviderManager():
|
||||
self.providers_config: List = config['provider']
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
|
||||
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
|
||||
self.persona_configs: list = config.get('persona', [])
|
||||
|
||||
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
|
||||
@@ -64,7 +65,7 @@ class ProviderManager():
|
||||
'''加载的 Provider 的实例'''
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
'''加载的 Speech To Text Provider 的实例'''
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
self.tts_provider_insts: Lieist[TTSProvider] = []
|
||||
'''加载的 Text To Speech Provider 的实例'''
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
@@ -123,7 +124,7 @@ class ProviderManager():
|
||||
selected_tts_provider_id = self.provider_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
tts_enabled = self.provider_settings.get("enable", False)
|
||||
tts_enabled = self.provider_tts_settings.get("enable", False)
|
||||
|
||||
for provider_config in self.providers_config:
|
||||
if not provider_config['enable']:
|
||||
|
||||
@@ -132,7 +132,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
# 适配 deepseek-r1 模型
|
||||
if r'<think>' in completion_text:
|
||||
completion_text = re.sub(r'<think>.*?<think/>', '', completion_text).strip()
|
||||
completion_text = re.sub(r'<think>.*?</think>', '', completion_text, flags=re.DOTALL).strip()
|
||||
# 可能有单标签情况
|
||||
completion_text = completion_text.replace(r'<think>', '').replace(r'</think>', '').strip()
|
||||
|
||||
return LLMResponse("assistant", completion_text)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
'''
|
||||
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
|
||||
'''
|
||||
|
||||
from typing import Union
|
||||
import os
|
||||
import json
|
||||
|
||||
+104
-102
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from .star import star_registry, StarMetadata
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
@@ -54,46 +54,19 @@ class Context:
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
'''根据插件名获取插件的 Metadata'''
|
||||
for star in star_registry:
|
||||
if star.name == star_name:
|
||||
return star
|
||||
|
||||
def get_all_stars(self) -> List[StarMetadata]:
|
||||
'''获取当前载入的所有插件 Metadata 的列表'''
|
||||
return star_registry
|
||||
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
'''
|
||||
获取 LLM Tool Manager
|
||||
'''
|
||||
'''获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools'''
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 异步处理函数。
|
||||
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_path=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
star_handlers_registry.append(md)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
|
||||
|
||||
@@ -102,6 +75,11 @@ class Context:
|
||||
'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
@@ -129,6 +107,101 @@ class Context:
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''获取 AstrBot 的配置。'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''获取 AstrBot 数据库。'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
'''
|
||||
获取事件队列。
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
try:
|
||||
session = MessageSesion.from_str(session)
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
'''
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
'''
|
||||
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 异步处理函数。
|
||||
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_path=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
star_handlers_registry.append(md)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
注册一个命令。
|
||||
@@ -162,77 +235,6 @@ class Context:
|
||||
))
|
||||
star_handlers_registry.append(md)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''
|
||||
通过 ID 获取 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''
|
||||
获取所有 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的 LLM Provider(Chat_Completion 类型)。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
return self.provider_manager.curr_provider_inst
|
||||
|
||||
def get_config(self) -> AstrBotConfig:
|
||||
'''
|
||||
获取 AstrBot 配置信息。
|
||||
'''
|
||||
return self._config
|
||||
|
||||
def get_db(self) -> BaseDatabase:
|
||||
'''
|
||||
获取 AstrBot 数据库。
|
||||
'''
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
'''
|
||||
获取事件队列。
|
||||
'''
|
||||
return self._event_queue
|
||||
|
||||
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
|
||||
'''
|
||||
根据 session(unified_msg_origin) 发送消息。
|
||||
|
||||
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
|
||||
@param message_chain: 消息链。
|
||||
|
||||
@return: 是否找到匹配的平台。
|
||||
|
||||
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
|
||||
'''
|
||||
|
||||
if isinstance(session, str):
|
||||
try:
|
||||
session = MessageSesion.from_str(session)
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_task(self, task: Awaitable, desc: str):
|
||||
'''
|
||||
注册一个异步任务。
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from types import ModuleType
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
star_registry: List[StarMetadata] = []
|
||||
star_map: Dict[str, StarMetadata] = {}
|
||||
@@ -11,7 +12,7 @@ star_map: Dict[str, StarMetadata] = {}
|
||||
@dataclass
|
||||
class StarMetadata:
|
||||
'''
|
||||
Star 的元数据。
|
||||
插件的元数据。
|
||||
'''
|
||||
name: str
|
||||
author: str # 插件作者
|
||||
@@ -20,21 +21,24 @@ class StarMetadata:
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
star_cls_type: type = None
|
||||
'''Star 的类对象的类型'''
|
||||
'''插件的类对象的类型'''
|
||||
module_path: str = None
|
||||
'''Star 的模块路径'''
|
||||
'''插件的模块路径'''
|
||||
|
||||
star_cls: object = None
|
||||
'''Star 的类对象'''
|
||||
'''插件的类对象'''
|
||||
module: ModuleType = None
|
||||
'''Star 的模块对象'''
|
||||
'''插件的模块对象'''
|
||||
root_dir_name: str = None
|
||||
'''Star 的根目录名'''
|
||||
'''插件的目录名称'''
|
||||
reserved: bool = False
|
||||
'''是否是 AstrBot 的保留 Star'''
|
||||
'''是否是 AstrBot 的保留插件'''
|
||||
|
||||
activated: bool = True
|
||||
'''是否被激活'''
|
||||
|
||||
config: AstrBotConfig = None
|
||||
'''插件配置'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
@@ -2,12 +2,14 @@ import inspect
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import DEFAULT_VALUE_MAP
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
@@ -26,13 +28,20 @@ class PluginManager:
|
||||
self.updator = PluginUpdator(config['plugin_repo_mirror'])
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self # 就这样吧,不想改了
|
||||
self.context._star_manager = self
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
|
||||
'''存储插件的路径。即 data/plugins'''
|
||||
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
|
||||
'''存储插件配置的路径。data/config'''
|
||||
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
|
||||
'''保留插件的路径。在 packages 目录下'''
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
'''插件配置 Schema 文件名'''
|
||||
|
||||
def _get_classes(self, arg: ModuleType):
|
||||
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
|
||||
classes = []
|
||||
clsmembers = inspect.getmembers(arg, inspect.isclass)
|
||||
for (name, _) in clsmembers:
|
||||
@@ -128,7 +137,7 @@ class PluginManager:
|
||||
return metadata
|
||||
|
||||
async def reload(self):
|
||||
'''扫描并加载所有的 Star'''
|
||||
'''扫描并加载所有的插件'''
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
@@ -150,13 +159,13 @@ class PluginManager:
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
# 导入 Star 模块,并尝试实例化 Star 类
|
||||
# 导入插件模块,并尝试实例化插件类
|
||||
for plugin_module in plugin_modules:
|
||||
try:
|
||||
module_str = plugin_module['module']
|
||||
# module_path = plugin_module['module_path']
|
||||
root_dir_name = plugin_module['pname']
|
||||
reserved = plugin_module.get('reserved', False)
|
||||
root_dir_name = plugin_module['pname'] # 插件的目录名
|
||||
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
|
||||
logger.info(f"正在载入插件 {root_dir_name} ...")
|
||||
|
||||
@@ -173,11 +182,33 @@ class PluginManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
|
||||
continue
|
||||
|
||||
# 检查 _conf_schema.json
|
||||
plugin_config = None
|
||||
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
|
||||
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
|
||||
if os.path.exists(plugin_schema_path):
|
||||
# 加载插件配置
|
||||
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
|
||||
plugin_config = AstrBotConfig(
|
||||
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
|
||||
schema=json.loads(f.read())
|
||||
)
|
||||
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
metadata = star_map[path]
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
|
||||
if plugin_config:
|
||||
metadata.config = plugin_config
|
||||
try:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
|
||||
except TypeError as _:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
else:
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
@@ -199,16 +230,20 @@ class PluginManager:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
|
||||
classes = self._get_classes(module)
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context)
|
||||
except BaseException as e:
|
||||
logger.error(f"插件 {root_dir_name} 实例化失败。")
|
||||
raise e
|
||||
|
||||
if plugin_config:
|
||||
try:
|
||||
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
|
||||
except TypeError as _:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
else:
|
||||
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
|
||||
|
||||
metadata = None
|
||||
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
|
||||
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
|
||||
metadata.star_cls = obj
|
||||
metadata.config = plugin_config
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
@@ -221,7 +256,7 @@ class PluginManager:
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
# 执行 initialize 函数
|
||||
# 执行 initialize() 方法
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
@@ -292,13 +327,14 @@ class PluginManager:
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import os
|
||||
import json
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.config import update_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core import logger
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
@@ -19,9 +17,9 @@ def try_cast(value: str, type_: str):
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
|
||||
def validate_config(data, config: AstrBotConfig):
|
||||
def validate_config(data, schema: dict, is_core: bool):
|
||||
errors = []
|
||||
def validate(data, metadata=CONFIG_METADATA_2, path=""):
|
||||
def validate(data, metadata=schema, path=""):
|
||||
for key, meta in metadata.items():
|
||||
if key not in data:
|
||||
continue
|
||||
@@ -56,35 +54,33 @@ def validate_config(data, config: AstrBotConfig):
|
||||
elif meta["type"] == "object" and not isinstance(value, dict):
|
||||
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
validate(data)
|
||||
|
||||
if is_core:
|
||||
for key, group in schema.items():
|
||||
group_meta = group.get("metadata")
|
||||
if not group_meta:
|
||||
continue
|
||||
logger.info(f"验证配置: 组 {key} ...")
|
||||
validate(data, group_meta, path=f"{key}.")
|
||||
else:
|
||||
validate(data, schema)
|
||||
|
||||
return errors
|
||||
|
||||
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
|
||||
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
||||
'''验证并保存配置'''
|
||||
errors = validate_config(post_config, config)
|
||||
errors = None
|
||||
try:
|
||||
if is_core:
|
||||
errors = validate_config(post_config, CONFIG_METADATA_2, is_core)
|
||||
else:
|
||||
errors = validate_config(post_config, config.schema, is_core)
|
||||
except BaseException as e:
|
||||
logger.warning(f"验证配置时出现异常: {e}")
|
||||
if errors:
|
||||
raise ValueError(f"格式校验未通过: {errors}")
|
||||
config.save_config(post_config)
|
||||
|
||||
def save_extension_config(post_config: dict):
|
||||
if 'namespace' not in post_config:
|
||||
raise ValueError("Missing key: namespace")
|
||||
if 'config' not in post_config:
|
||||
raise ValueError("Missing key: config")
|
||||
|
||||
namespace = post_config['namespace']
|
||||
config: list = post_config['config'][0]['body']
|
||||
for item in config:
|
||||
key = item['path']
|
||||
value = item['value']
|
||||
typ = item['val_type']
|
||||
if typ == 'int':
|
||||
if not value.isdigit():
|
||||
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
|
||||
value = int(value)
|
||||
update_config(namespace, key, value)
|
||||
|
||||
class ConfigRoute(Route):
|
||||
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
@@ -92,17 +88,17 @@ class ConfigRoute(Route):
|
||||
self.routes = {
|
||||
'/config/get': ('GET', self.get_configs),
|
||||
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
|
||||
'/config/plugin/update': ('POST', self.post_extension_configs),
|
||||
'/config/plugin/update': ('POST', self.post_plugin_configs),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def get_configs(self):
|
||||
# namespace 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 namespace 的插件配置
|
||||
namespace = "" if "namespace" not in request.args else request.args["namespace"]
|
||||
if not namespace:
|
||||
# plugin_name 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 plugin_name 的插件配置
|
||||
plugin_name = request.args.get("plugin_name", None)
|
||||
if not plugin_name:
|
||||
return Response().ok(await self._get_astrbot_config()).__dict__
|
||||
return Response().ok(await self._get_extension_config(namespace)).__dict__
|
||||
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
|
||||
|
||||
async def post_astrbot_configs(self):
|
||||
post_configs = await request.json
|
||||
@@ -110,14 +106,15 @@ class ConfigRoute(Route):
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(e)
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def post_extension_configs(self):
|
||||
async def post_plugin_configs(self):
|
||||
post_configs = await request.json
|
||||
plugin_name = request.args.get("plugin_name", "unknown")
|
||||
try:
|
||||
await self._save_extension_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
await self._save_plugin_configs(post_configs, plugin_name)
|
||||
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -141,28 +138,48 @@ class ConfigRoute(Route):
|
||||
"config": config
|
||||
}
|
||||
|
||||
async def _get_extension_config(self, namespace: str):
|
||||
path = f"data/config/{namespace}.json"
|
||||
if not os.path.exists(path):
|
||||
return []
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
return [{
|
||||
"config_type": "group",
|
||||
"name": namespace + " 插件配置",
|
||||
"description": "",
|
||||
"body": list(json.load(f).values())
|
||||
},]
|
||||
|
||||
async def _get_plugin_config(self, plugin_name: str):
|
||||
ret = {
|
||||
"metadata": None,
|
||||
"config": None
|
||||
}
|
||||
|
||||
for plugin_md in star_registry:
|
||||
if plugin_md.name == plugin_name:
|
||||
if not plugin_md.config:
|
||||
break
|
||||
ret['config'] = plugin_md.config # 这是自定义的 Dict 类(AstrBotConfig)
|
||||
ret['metadata'] = {
|
||||
plugin_name: {
|
||||
"description": f"{plugin_name} 配置",
|
||||
"type": "object",
|
||||
"items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
return ret
|
||||
|
||||
async def _save_astrbot_configs(self, post_configs: dict):
|
||||
try:
|
||||
save_astrbot_config(post_configs, self.config)
|
||||
save_config(post_configs, self.config, is_core=True)
|
||||
self.core_lifecycle.restart()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def _save_extension_configs(self, post_configs: dict):
|
||||
|
||||
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
|
||||
md = None
|
||||
for plugin_md in star_registry:
|
||||
if plugin_md.name == plugin_name:
|
||||
md = plugin_md
|
||||
|
||||
if not md:
|
||||
raise ValueError(f"插件 {plugin_name} 不存在")
|
||||
if not md.config:
|
||||
raise ValueError(f"插件 {plugin_name} 没有注册配置")
|
||||
|
||||
try:
|
||||
save_extension_config(post_configs)
|
||||
save_config(post_configs, md.config)
|
||||
self.core_lifecycle.restart()
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
import aiohttp
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
# What's Changed
|
||||
|
||||
- 修复: 配置 Validator 不起效的问题
|
||||
- 修复: DeepSeek-R1 思考标签问题
|
||||
- 修复: 分段回复间隔时间不生效
|
||||
- 修复: 修复白名单为空时依然终止事件 #259
|
||||
- 修复: 群聊增强某些参数的类型转换问题
|
||||
- 新增: 插件支持注册配置,详见 [注册插件配置](https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta)
|
||||
- 优化: 插件的禁用/启用逻辑以及函数工具的禁用/启用逻辑
|
||||
@@ -1,41 +0,0 @@
|
||||
<script setup>
|
||||
import UiParentCard from '@/components/shared/UiParentCard.vue';
|
||||
|
||||
const props = defineProps({
|
||||
config: Array
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<a v-show="config.length === 0">该插件没有配置</a>
|
||||
<UiParentCard v-for="group in config" :key="group.name" :title="group.name" style="margin-bottom: 16px;">
|
||||
<template v-for="item in group.body">
|
||||
<template v-if="item.config_type === 'item'">
|
||||
<template v-if="item.val_type === 'bool'">
|
||||
<v-switch v-model="item.value" :label="item.name" :hint="item.description" color="primary" inset></v-switch>
|
||||
</template>
|
||||
<template v-else-if="item.val_type === 'str'">
|
||||
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
|
||||
variant="outlined"></v-text-field>
|
||||
</template>
|
||||
<template v-else-if="item.val_type === 'int'">
|
||||
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
|
||||
variant="outlined"></v-text-field>
|
||||
</template>
|
||||
<template v-else-if="item.val_type === 'list'">
|
||||
<span>{{ item.name }}</span>
|
||||
<v-combobox v-model="item.value" chips clearable label="请添加" multiple prepend-icon="mdi-tag-multiple-outline">
|
||||
<template v-slot:selection="{ attrs, item, select, selected }">
|
||||
<v-chip v-bind="attrs" :model-value="selected" closable @click="select" @click:close="remove(item)">
|
||||
<strong>{{ item }}</strong>
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-combobox>
|
||||
</template>
|
||||
</template>
|
||||
<template v-else-if="item.config_type === 'divider'">
|
||||
<v-divider style="margin-top: 8px; margin-bottom: 8px;"></v-divider>
|
||||
</template>
|
||||
</template>
|
||||
</UiParentCard>
|
||||
</template>
|
||||
@@ -1,7 +1,7 @@
|
||||
<script setup>
|
||||
import ExtensionCard from '@/components/shared/ExtensionCard.vue';
|
||||
import ConfigDetailCard from '@/components/shared/ConfigDetailCard.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import axios from 'axios';
|
||||
|
||||
@@ -52,11 +52,17 @@ import axios from 'axios';
|
||||
<v-btn v-else variant="plain" disabled>已安装</v-btn>
|
||||
</div>
|
||||
</ExtensionCard>
|
||||
|
||||
</v-col>
|
||||
|
||||
<v-col style="margin-bottom: 16px;" cols="12" md="12">
|
||||
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
|
||||
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
|
||||
</v-col>
|
||||
|
||||
</v-row>
|
||||
|
||||
<v-dialog v-model="configDialog" width="750">
|
||||
<v-dialog v-model="configDialog" width="1000">
|
||||
<template v-slot:activator="{ props }">
|
||||
</template>
|
||||
<v-card>
|
||||
@@ -65,7 +71,8 @@ import axios from 'axios';
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
<ConfigDetailCard :config="extension_config"></ConfigDetailCard>
|
||||
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
|
||||
<p v-else>这个插件没有配置</p>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
@@ -172,9 +179,9 @@ export default {
|
||||
name: 'ExtensionPage',
|
||||
components: {
|
||||
ExtensionCard,
|
||||
ConfigDetailCard,
|
||||
WaitingForRestart,
|
||||
ConsoleDisplayer
|
||||
ConsoleDisplayer,
|
||||
AstrBotConfig
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
@@ -189,7 +196,10 @@ export default {
|
||||
snack_success: "success",
|
||||
loading_: false,
|
||||
configDialog: false,
|
||||
extension_config: {},
|
||||
extension_config: {
|
||||
"metadata": {},
|
||||
"config": {}
|
||||
},
|
||||
upload_file: null,
|
||||
pluginMarketData: {},
|
||||
loadingDialog: {
|
||||
@@ -364,7 +374,7 @@ export default {
|
||||
openExtensionConfig(extension_name) {
|
||||
this.curr_namespace = extension_name;
|
||||
this.configDialog = true;
|
||||
axios.get('/api/config/get?namespace=' + extension_name).then((res) => {
|
||||
axios.get('/api/config/get?plugin_name=' + extension_name).then((res) => {
|
||||
this.extension_config = res.data.data;
|
||||
console.log(this.extension_config);
|
||||
}).catch((err) => {
|
||||
@@ -372,10 +382,7 @@ export default {
|
||||
});
|
||||
},
|
||||
updateConfig() {
|
||||
axios.post('/api/config/plugin/update', {
|
||||
"config": this.extension_config,
|
||||
"namespace": this.curr_namespace
|
||||
}).then((res) => {
|
||||
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
|
||||
if (res.data.status === "ok") {
|
||||
this.toast(res.data.message, "success");
|
||||
this.$refs.wfr.check();
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
model_id = "openai/whisper-large-v3"
|
||||
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
batch_size=16, # batch size for inference - set based on your device
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
result = pipe(sample)
|
||||
print(result["text"])
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
|
||||
Reference in New Issue
Block a user