Files
AstrBot/packages/astrbot/main.py
T
2024-12-25 12:50:29 +08:00

340 lines
15 KiB
Python

import aiohttp
import datetime
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import personalities
from astrbot.api.provider import Personality, ProviderRequest
from typing import Union
@star.register(name="astrbot", desc="AstrBot 基础指令集合", author="Soulter", version="4.0.0")
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
cfg = context.get_config()
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
async def _query_astrbot_notice(self):
try:
async with aiohttp.ClientSession() as session:
async with session.get("https://astrbot.soulter.top/notice.json", timeout=2) as resp:
return (await resp.json())["notice"]
except BaseException:
return ""
@filter.command("help")
async def help(self, event: AstrMessageEvent):
notice = ""
try:
notice = await self._query_astrbot_notice()
except BaseException:
pass
msg = "已注册的 AstrBot 内置指令:\n"
msg += f"""[System]
/plugin: 查看注册的插件、插件帮助
/t2i: 开启/关闭文本转图片模式
/sid: 获取当前会话的 ID
/op <admin_id>: 授权管理员
/deop <admin_id>: 取消管理员
/wl <sid>: 添加会话白名单
/dwl <sid>: 删除会话白名单
[大模型]
/provider: 查看、切换大模型提供商
/model: 查看、切换提供商模型列表
/key: 查看、切换 API Key
/reset: 重置 LLM 会话
/history: 获取会话历史记录
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@filter.command_group("tool")
def tool(self):
pass
@tool.command("ls")
async def tool_ls(self, event: AstrMessageEvent):
tm = self.context.get_llm_tool_manager()
msg = "函数工具:\n"
for tool in tm.func_list:
active = " (启用)" if tool.active else "(停用)"
msg += f"- {tool.name}: {tool.description} {active}\n"
msg += "\n使用 /tool on/off <工具名> 激活或者停用工具。"
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@tool.command("on")
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
if self.context.activate_llm_tool(tool_name):
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 成功。"))
else:
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 失败,未找到此工具。"))
@tool.command("off")
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
if self.context.deactivate_llm_tool(tool_name):
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
else:
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper: str = None):
if oper is None:
plugin_list_info = "已加载的插件:\n"
for plugin in self.context.get_all_stars():
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n"
if plugin_list_info.strip() == "":
plugin_list_info = "没有加载任何插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。"
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
else:
plugin = self.context.get_registered_star(oper)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
config = self.context.get_config()
if config['t2i']:
config['t2i'] = False
config.save_config()
event.set_result(MessageEventResult().message("已关闭文本转图片模式。"))
return
config['t2i'] = True
config.save_config()
event.set_result(MessageEventResult().message("已开启文本转图片模式。"))
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):
sid = event.unified_msg_origin
user_id = str(event.get_sender_id())
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deop <UID> 取消管理员。"""
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str):
self.context.get_config()['admins_id'].append(admin_id)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str):
try:
self.context.get_config()['admins_id'].remove(admin_id)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("取消授权成功。"))
except ValueError:
event.set_result(MessageEventResult().message("此用户 ID 不在管理员名单内。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str):
self.context.get_config()['platform_settings']['id_whitelist'].append(sid)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str):
try:
self.context.get_config()['platform_settings']['id_whitelist'].remove(sid)
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("删除白名单成功。"))
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
@filter.command("provider")
async def provider(self, event: AstrMessageEvent, idx: int = None):
'''查看或者切换 LLM Provider'''
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})"
if self.provider == llm:
ret += " (当前使用)"
ret += "\n"
ret += "\n使用 /provider <序号> 切换提供商。"
event.set_result(MessageEventResult().message(ret))
else:
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1]
event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}"))
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
await self.context.get_using_provider().forget(message.session_id)
message.set_result(MessageEventResult().message("重置成功"))
@filter.command("model")
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
if idx_or_name is None:
models = []
try:
models = await self.context.get_using_provider().get_models()
except BaseException as e:
message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)).use_t2i(False))
return
i = 1
ret = "下面列出了此服务提供商可用模型:"
for model in models:
ret += f"\n{i}. {model}"
i += 1
ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
if isinstance(idx_or_name, int):
models = []
try:
models = await self.context.get_using_provider().get_models()
except BaseException as e:
message.set_result(MessageEventResult().message("获取模型列表失败: " + str(e)))
return
if idx_or_name > len(models) or idx_or_name < 1:
message.set_result(MessageEventResult().message("模型序号错误。"))
else:
try:
new_model = models[idx_or_name-1]
self.context.get_using_provider().set_model(new_model)
except BaseException as e:
message.set_result(
MessageEventResult().message("切换模型未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换模型成功。"))
else:
self.context.get_using_provider().set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型成功。 \n模型信息: {idx_or_name}"))
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
size_per_page = 3
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
history = ""
for context in contexts:
history += f"{context}\n"
ret = f"""历史记录:
{history}
{page} 页 | 共 {total_pages}
*输入 /history 2 跳转到第 2 页
"""
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
if index is None:
keys_data = self.context.get_using_provider().get_keys()
curr_key = self.context.get_using_provider().get_current_key()
ret = "Key:"
for i, k in enumerate(keys_data):
ret += f"\n{i+1}. {k[:8]}"
ret += f"\n当前 Key: {curr_key[:8]}"
ret += "\n当前模型: " + self.context.get_using_provider().get_model()
ret += "\n使用 /key <idx> 切换 Key。"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
else:
keys_data = self.context.get_using_provider().get_keys()
if index > len(keys_data) or index < 1:
message.set_result(MessageEventResult().message("Key 序号错误。"))
else:
try:
new_key = keys_data[index-1]
self.context.get_using_provider().set_key(new_key)
except BaseException as e:
message.set_result(
MessageEventResult().message("切换 Key 未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换 Key 成功。"))
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
l = message.message_str.split(" ")
if len(l) == 1:
message.set_result(
MessageEventResult().message(f"""[Persona]
- 设置人格: `/persona 人格名`, 如 /persona 编剧
- 人格列表: `/persona list`
- 人格详细信息: `/persona view 人格名`
- 自定义人格: /persona 人格文本
- 重置 LLM 会话(清除人格): /reset
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
""").use_t2i(False))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
msg += f"- {key}\n"
msg += '\n\n*输入 `/persona view 人格名` 查看人格详细信息'
message.set_result(MessageEventResult().message(msg))
elif l[1] == "view":
if len(l) == 2:
message.set_result(MessageEventResult().message("请输入人格名"))
ps = l[2].strip()
if ps in personalities:
msg = f"人格{ps}的详细信息:\n"
msg += f"{personalities[ps]}\n"
else:
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
else:
ps = "".join(l[1:]).strip()
if ps in personalities:
self.context.get_using_provider().curr_personality = Personality(
name=ps, prompt=personalities[ps])
message.set_result(
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
else:
self.context.get_using_provider().curr_personality = Personality(
name="自定义人格", prompt=ps)
message.set_result(
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if provider.curr_personality['prompt']:
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()