Compare commits

...

66 Commits

Author SHA1 Message Date
Soulter 17d4bf8f22 perf: 管理面板优化新增列表项的提示 2025-02-06 20:19:53 +08:00
Soulter 836be3b097 update: changelogs 2025-02-06 18:51:47 +08:00
Soulter 310415bea9 feat: 聊天增强图像转述支持自定义 Provider id 2025-02-06 18:49:16 +08:00
Soulter aafc1276a9 v3.4.21 2025-02-06 18:34:43 +08:00
Soulter 2993e794cc perf: hint 2025-02-06 17:45:15 +08:00
Soulter 58cb9cfb2d chore: clean code 2025-02-06 17:43:04 +08:00
Soulter fbdf0901d5 fix: 修复reminder时区问题 2025-02-06 17:41:34 +08:00
Soulter af8c81b621 feat: 支持重载插件 2025-02-06 17:27:53 +08:00
Soulter 06b5275e48 perf: 增加报错显示 2025-02-06 16:43:40 +08:00
Soulter ad95572d5f perf: 更好的 list 可视化 2025-02-06 15:59:45 +08:00
Soulter aebc7850f4 fix: openrouter 报错 no endpoints found that support tool use #371 2025-02-06 15:25:15 +08:00
Soulter 1b7efbc607 支持列表展示插件市场 2025-02-06 15:18:11 +08:00
Soulter 3800e96d14 fix: 修复metadata不生效的问题
feat: 支持查看插件行为
2025-02-06 15:10:24 +08:00
Soulter 461f1bb07c feat: 支持插件handler优先级 2025-02-06 12:35:43 +08:00
Soulter 7d4c07e4f6 feat: 支持设置 timeout 2025-02-06 12:31:39 +08:00
Soulter 31b788f463 fix: 修复不支持图片的模型请求异常 2025-02-06 01:50:53 +08:00
Soulter 96ab761f73 fix: 修复reminder无法删除的问题 2025-02-05 22:45:02 +08:00
Soulter 2b3f05c039 update: 优化部分注释 2025-02-05 19:58:48 +08:00
Soulter f2e8303b66 fix: KeyError _mood_imitation_dialogs_processed 2025-02-05 18:52:55 +08:00
Soulter 2a614b545b fix: 修复可能的 KeyError 2025-02-05 17:17:05 +08:00
Soulter 5c0ab21f68 fix: 修复 /model 异常 2025-02-05 17:05:47 +08:00
Soulter 689d109438 typo: myid -> sid 2025-02-05 16:59:21 +08:00
Soulter 2a6934b283 perf: 无对话状态的提示 2025-02-05 16:56:13 +08:00
Soulter 760cb94e9a v3.4.20 2025-02-05 16:06:52 +08:00
Soulter 2a6cff0013 feat: 支持重命名对话 2025-02-05 16:06:18 +08:00
Soulter ce578f0417 feat: 支持使用 LLM 辅助分段回复 #338 2025-02-05 15:40:52 +08:00
Soulter 1745bdb9e2 perf: 优化一些问题 2025-02-05 15:39:59 +08:00
Soulter 3f90b89c3c 添加屏蔽无权限指令回复的功能 #361 2025-02-05 15:06:38 +08:00
Soulter f343e40d15 Merge pull request #370 from Soulter/feat-conversation
feat: 更好的对话管理
2025-02-05 14:56:47 +08:00
Soulter 5cc4be9e65 perf: 优化部分显示问题 2025-02-05 14:51:40 +08:00
Soulter da5aada002 fix: 修复指令组情况下可能造成多指令出触发的问题 2025-02-05 13:52:53 +08:00
Soulter 07f2ee9ad9 fix: 修复 /reset 指令 2025-02-05 13:33:36 +08:00
Soulter 12f4e1146f feat: 更好的对话管理 2025-02-05 13:26:53 +08:00
Soulter 92c57e5476 fix: 修复级联指令组时出现载入错误的问题 2025-02-05 11:11:04 +08:00
Soulter a923baacd8 Update README.md 2025-02-05 01:56:09 +08:00
Soulter 999b094d55 Merge pull request #358 from eltociear/patch-1
chore: update main.py
2025-02-05 01:34:04 +08:00
Soulter d4213f2352 perf: announcement plugin market 2025-02-05 01:19:54 +08:00
Ikko Eltociear Ashimine 3f65c9a066 chore: update main.py
occured -> occurred
2025-02-05 02:18:41 +09:00
Soulter 1d427e2645 perf: 优化插件页面 2025-02-05 01:10:53 +08:00
Soulter 36414c4b00 perf: 优化aiocqhttp适配器对用户非法输入的处理 2025-02-05 00:02:18 +08:00
Soulter 47e253d76c fix: 修复权限过滤算子导致的问题 #350 2025-02-04 23:31:46 +08:00
Soulter b73cf84df0 v3.4.19 2025-02-04 16:37:15 +08:00
Soulter a5b885a774 fix: schema 中 object hint 不显示 #290
feat: 优化插件市场的访问
2025-02-04 16:36:00 +08:00
Soulter 0c785413da chore: clean code 2025-02-04 15:51:26 +08:00
Soulter 482d7ef5f7 v3.4.19 2025-02-04 15:47:24 +08:00
Soulter 9f9073c0ff feat: 支持设置所有指令的权限
feat: 插件指令支持设置指令描述
feat: plugin 指令支持查看插件的指令
2025-02-04 15:41:45 +08:00
Soulter ef05ff4abd fix: 管理员指令 /reset /persona 2025-02-04 13:50:23 +08:00
Soulter 5848aae435 Update README.md 2025-02-04 13:44:02 +08:00
Soulter fb06f33de0 Update README.md 2025-02-04 12:51:17 +08:00
Soulter 0d7ddb149e fix: 修复请求 gemini 推理模型出现 candidates 错误的问题 #333 2025-02-04 00:30:23 +08:00
Soulter 4f2d7b9c4e feat: 适配 Azure OpenAI #332 2025-02-03 23:59:04 +08:00
Soulter c02ed96f6f perf: gewechat 服务端回调接口默认暴露在所有地址 2025-02-03 18:51:19 +08:00
Soulter 3b2ac891b2 fix: 修复限流器不可用的问题 #263 2025-02-03 18:51:19 +08:00
Soulter ef0108881b Update Dockerfile 2025-02-03 17:48:17 +08:00
Soulter af48975a6b chore: v3.4.18 2025-02-03 16:14:27 +08:00
Soulter 6441b149ab fix: 修复主动概率回复关闭后仍然回复的问题 #317 2025-02-03 14:33:53 +08:00
Soulter f8892881f8 fix: 尝试修复 gewechat 群聊收不到 at 的回复 #294 2025-02-03 14:28:14 +08:00
Soulter 228aec5401 perf: 移除了默认人格 2025-02-03 14:17:45 +08:00
Soulter 68ad48ff55 fix: 修复HTTP代理删除后不生效 #319 2025-02-03 14:11:50 +08:00
Soulter 541ba64032 fix: 调用Gemini API输出多余空行问题 #318 2025-02-03 13:27:56 +08:00
Soulter 2d870b798c feat: 添加硅基流动模版 2025-02-03 13:24:22 +08:00
Soulter 0f1fe1ab63 fix: 硅基流动 not a vlm 和 tool calling not supported 报错 #305 # 291
perf: 安装和更新插件后全量重启避免奇奇怪怪的bug
feat: 支持 /tool off_all 停用所有函数工具
2025-02-03 13:20:49 +08:00
Soulter 73cc86ddb1 perf: 回复时艾特发送者之后添加空格或换行 #312 2025-02-03 12:04:26 +08:00
Soulter 23128f4be2 perf: 主动回复不支持 qq_official 的 hint 2025-02-03 12:00:05 +08:00
Soulter 92200d0e82 fix: docker容器内时区不对 2025-02-03 01:15:09 +08:00
Soulter d6e8655792 fix: 抱错时首先移除 tool 2025-02-02 23:17:59 +08:00
55 changed files with 1714 additions and 663 deletions
+3 -1
View File
@@ -9,6 +9,8 @@
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
@@ -70,7 +72,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| QQ(OneBot) | ✔ | 私聊、群聊 | 文字、图片、语音 |
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| 微信(企业微信) | 🚧 | 计划内 | - |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | | 私聊 | 文字、图片、语音 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| 飞书 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
+68 -17
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.17"
VERSION = "3.4.21"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -29,8 +29,10 @@ DEFAULT_CONFIG = {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"seg_prompt": "",
"regex": ".*?[。?!~…]+|.+$"
}
},
"no_permission_reply": True,
},
"provider": [],
"provider_settings": {
@@ -55,6 +57,7 @@ DEFAULT_CONFIG = {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"active_reply": {
"enable": False,
@@ -70,6 +73,7 @@ DEFAULT_CONFIG = {
},
"admins_id": [],
"t2i": False,
"t2i_word_threshold": 150,
"http_proxy": "",
"dashboard": {
"enable": True,
@@ -83,14 +87,7 @@ DEFAULT_CONFIG = {
"pip_install_arg": "",
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [
{
"name": "default",
"prompt": "",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
],
"persona": [],
}
@@ -116,7 +113,7 @@ CONFIG_METADATA_2 = {
"id": "default",
"type": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_host": "0.0.0.0",
"ws_reverse_port": 6199,
},
"gewechat(微信)": {
@@ -201,6 +198,11 @@ CONFIG_METADATA_2 = {
},
},
},
"no_permission_reply": {
"description": "无权限回复",
"type": "bool",
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
@@ -218,6 +220,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"seg_prompt": {
"description": "分段提示词辅助",
"type": "string",
"hint": "此项为空时表达不启用这个方法。此方法会调用一次LLM请求。让 LLM 在某一句话中插入一个可以用正则表达式分隔的标记,来实现LLM基于情感分段。如: `请基于情感对以下文本进行分段, 并在两段之间添加`<seg>`以便我用正则匹配。` 然后将下面的正则表达式更换为`.+?<seg>`。",
},
"regex": {
"description": "正则表达式",
"type": "string",
@@ -328,11 +335,24 @@ CONFIG_METADATA_2 = {
"type": "list",
"config_template": {
"openai": {
"id": "default",
"id": "openai",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {
"model": "gpt-4o-mini",
},
},
"azure_openai": {
"id": "azure",
"type": "openai_chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"timeout": 120,
"model_config": {
"model": "gpt-4o-mini",
},
@@ -353,6 +373,7 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
},
@@ -363,6 +384,7 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
},
@@ -373,6 +395,7 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {
"model": "deepseek-chat",
},
@@ -382,11 +405,23 @@ CONFIG_METADATA_2 = {
"type": "zhipu_chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
},
"硅基流动": {
"id": "siliconflow",
"type": "openai_chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
},
},
"llmtuner": {
"id": "llmtuner_default",
"type": "llm_tuner",
@@ -433,6 +468,11 @@ CONFIG_METADATA_2 = {
},
},
"items": {
"timeout": {
"description": "超时时间",
"type": "int",
"hint": "超时时间,单位为秒。",
},
"openai-tts-voice": {
"description": "voice",
"type": "string",
@@ -616,14 +656,14 @@ CONFIG_METADATA_2 = {
"description": "预设对话",
"type": "list",
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色之后按回车",
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按回车】。需要偶数个对话",
"obvious_hint": True,
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一。对话需要成对(用户和助手),输入完一个角色之后按回车",
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一。对话需要成对(用户和助手),输入完一个角色的内容之后按回车】。需要偶数个对话",
"obvious_hint": True,
},
},
@@ -684,6 +724,12 @@ CONFIG_METADATA_2 = {
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_provider_id": {
"description": "图像转述提供商 ID",
"type": "string",
"obvious_hint": True,
"hint": "可选。图像转述提供商 ID。如为空将选择聊天使用的提供商。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string"
@@ -696,7 +742,7 @@ CONFIG_METADATA_2 = {
"description": "启用主动回复",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,会根据触发概率主动回复群聊内的对话。",
"hint": "启用后,会根据触发概率主动回复群聊内的对话。QQ官方API(qq_official)不可用",
},
"method": {
"description": "回复方法",
@@ -714,7 +760,7 @@ CONFIG_METADATA_2 = {
"description": "提示词",
"type": "string",
"obvious_hint": True,
"hint": "提示词。当提示词为空时,如果触发回复,prompt是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
"hint": "提示词。当提示词为空时,如果触发回复,则向 LLM 请求的是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
},
},
},
@@ -743,11 +789,16 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,超出一定长度的文本将会通过 AstrBot API 渲染成 Markdown 图片发送。可以缓解审核和消息过长刷屏的问题,并提高 Markdown 文本的可读性。",
},
"t2i_word_threshold": {
"description": "文本转图像字数阈值",
"type": "int",
"hint": "超出此字符长度的文本将会被转换成图片。字数不能低于 50。",
},
"admins_id": {
"description": "管理员 ID",
"type": "list",
"items": {"type": "string"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {
"description": "HTTP 代理",
+118
View File
@@ -0,0 +1,118 @@
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
class ConversationManager():
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
'''新建对话,并将当前会话的对话转移到新对话'''
conversation_id = str(uuid.uuid4())
self.db.new_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
'''切换会话的对话'''
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
del self.session_conversations[unified_msg_origin]
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
'''获取会话当前的对话 ID'''
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
'''获取会话的对话'''
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
'''获取会话的所有对话'''
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
'''更新会话的对话'''
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history)
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
'''更新会话的对话标题'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin,
cid=conversation_id,
title=title
)
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
'''更新会话的对话 Persona ID'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin,
cid=conversation_id,
persona_id=persona_id
)
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
+6 -4
View File
@@ -18,16 +18,15 @@ from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = astrbot_config
self.db = db
if self.astrbot_config['http_proxy']:
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
@@ -44,12 +43,15 @@ class AstrBotCoreLifecycle:
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.conversation_manager = ConversationManager(self.db)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
+20 -10
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -79,25 +79,35 @@ class BaseDatabase(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
'''通过 user_id 和 cid 获取 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
def new_conversation(self, user_id: str, cid: str):
'''新建 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
def delete_conversation(self, user_id: str, cid: str):
'''删除 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
'''更新 Conversation 标题'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
'''更新 Conversation Persona ID'''
raise NotImplementedError
+11 -6
View File
@@ -33,16 +33,16 @@ class Stats():
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
'''LLM 聊天时持久化的信息'''
@dataclass
class LLMHistory():
'''LLM 聊天时持久化的信息'''
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision():
'''Deprecated'''
id: str
url_or_path: str
caption: str
@@ -53,13 +53,18 @@ class ATRIVision():
sender_nickname: str
timestamp: int = -1
@dataclass
class WebChatConversation():
class Conversation():
'''LLM 对话存储
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
'''
user_id: str
cid: str
history: str = ""
'''字符串格式的列表。'''
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
+61 -12
View File
@@ -6,7 +6,7 @@ from astrbot.core.db.po import (
Stats,
LLMHistory,
ATRIVision,
WebChatConversation
Conversation
)
from . import BaseDatabase
from typing import Tuple
@@ -25,6 +25,37 @@ class SQLiteDatabase(BaseDatabase):
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
'''
PRAGMA table_info(webchat_conversation)
'''
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
'''
)
self.conn.commit()
if not has_persona_id:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
'''
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
@@ -202,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -216,9 +247,9 @@ class SQLiteDatabase(BaseDatabase):
res = c.fetchone()
c.close()
return WebChatConversation(*res)
return Conversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
@@ -228,7 +259,7 @@ class SQLiteDatabase(BaseDatabase):
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -236,7 +267,7 @@ class SQLiteDatabase(BaseDatabase):
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
@@ -247,24 +278,42 @@ class SQLiteDatabase(BaseDatabase):
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
title = row[3]
persona_id = row[4]
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新对话,并且同时更新时间'''
updated_at = int(time.time())
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
''', (history, updated_at, user_id, cid)
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
''', (title, user_id, cid)
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
''', (persona_id, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())
keywords = ",".join(vision.keywords)
+3 -1
View File
@@ -42,5 +42,7 @@ CREATE TABLE IF NOT EXISTS webchat_conversation(
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
+3 -1
View File
@@ -2,6 +2,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
@@ -11,7 +12,7 @@ from .respond.stage import RespondStage
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"RateLimitStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
@@ -22,6 +23,7 @@ STAGES_ORDER = [
__all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"RateLimitStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
@@ -2,6 +2,7 @@
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
@@ -10,7 +11,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
@@ -24,6 +25,8 @@ class LLMRequestSubStage(Stage):
if self.provider_wake_prefix.startswith(bwp):
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
@@ -46,10 +49,17 @@ class LLMRequestSubStage(Stage):
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = conversation_id
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt and not req.image_urls:
return
@@ -62,9 +72,12 @@ class LLMRequestSubStage(Stage):
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
try:
logger.debug(f"提供商请求 Payload: {req.__dict__}")
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
@@ -77,12 +90,17 @@ class LLMRequestSubStage(Stage):
except BaseException:
logger.error(traceback.format_exc())
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'err':
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
@@ -115,4 +133,24 @@ class LLMRequestSubStage(Stage):
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
return
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {
"role": "user",
"content": req.prompt
}
contexts.append(new_record)
contexts.append({
"role": "assistant",
"content": llm_response.completion_text
})
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.session_id,
history=contexts_to_save
)
@@ -61,11 +61,12 @@ class RateLimitStage(Stage):
stall_duration = (next_window_time - now).total_seconds()
match self.rl_strategy:
case RateLimitStrategy.STALL:
case RateLimitStrategy.STALL.value:
logger.info(f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。")
await asyncio.sleep(stall_duration)
case RateLimitStrategy.DISCARD:
event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
case RateLimitStrategy.DISCARD.value:
# event.set_result(MessageEventResult().message(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到您的限额于 {stall_duration:.2f} 秒后重置。"))
logger.info(f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。")
return event.stop_event()
self._remove_expired_timestamps(timestamps, now + timedelta(seconds=stall_duration))
+29 -3
View File
@@ -19,10 +19,18 @@ class ResultDecorateStage:
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
self.t2i_word_threshold = ctx.astrbot_config['t2i_word_threshold']
try:
self.t2i_word_threshold = int(self.t2i_word_threshold)
if self.t2i_word_threshold < 50:
self.t2i_word_threshold = 50
except BaseException:
self.t2i_word_threshold = 150
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.seg_prompt = ctx.astrbot_config['platform_settings']['segmented_reply']['seg_prompt']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
@@ -49,12 +57,26 @@ class ResultDecorateStage:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
if self.seg_prompt:
try:
llm_resp = await self.ctx.plugin_manager.context.get_using_provider().text_chat(
prompt=f"{self.seg_prompt}\n{comp.text}",
)
comp.text = llm_resp.completion_text
except BaseException as e:
traceback.print_exc()
logger.warning("使用 LLM 分段回复失败。将不分段回复。: " + str(e))
new_chain.append(comp)
continue
split_response = re.findall(self.regex, comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
new_chain.append(Plain(seg))
if seg:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
@@ -91,7 +113,7 @@ class ResultDecorateStage:
plain_str += "\n\n" + comp.text
else:
break
if plain_str and len(plain_str) > 150:
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
url = await html_renderer.render_t2i(plain_str, return_url=True)
@@ -103,8 +125,12 @@ class ResultDecorateStage:
if url:
result.chain = [Image.fromURL(url)]
# at 回复
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id(), name=event.get_sender_name()))
if len(result.chain) > 1 and isinstance(result.chain[1], Plain):
result.chain[1].text = "\n" + result.chain[1].text
# 引用回复
if self.reply_with_quote:
result.chain.insert(0, Reply(id=event.message_obj.message_id))
+18 -3
View File
@@ -2,11 +2,11 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
@register_stage
class WakingCheckStage(Stage):
@@ -21,6 +21,9 @@ class WakingCheckStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
)
async def process(
self, event: AstrMessageEvent
@@ -77,7 +80,9 @@ class WakingCheckStage(Stage):
# filter 需要满足 AND 的逻辑关系
passed = True
child_command_handler_md = None
permission_not_pass = False
if len(handler.event_filters) == 0:
# 不可能有这种情况, 也不允许有这种情况
continue
@@ -94,6 +99,9 @@ class WakingCheckStage(Stage):
else:
handler = child_command_handler_md # handler 覆盖
break
elif isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -111,6 +119,13 @@ class WakingCheckStage(Stage):
break
if passed:
if permission_not_pass:
if self.no_permission_reply:
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"))
event.stop_event()
return
is_wake = True
event.is_wake = True
+9 -2
View File
@@ -31,12 +31,19 @@ class AstrMessageEvent(abc.ABC):
platform_meta: PlatformMetadata,
session_id: str,):
self.message_str = message_str
'''纯文本的消息'''
self.message_obj = message_obj
'''消息对象,AstrBotMessage。带有完整的消息结构。'''
self.platform_meta = platform_meta
'''消息平台的信息, 其中 name 是平台的类型,如 aiocqhttp'''
self.session_id = session_id
'''用户的会话 ID。可以直接使用下面的 unified_msg_origin'''
self.role = "member"
'''用户是否是管理员。如果是管理员,这里是 admin'''
self.is_wake = False # 是否通过 WakingStage
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True
'''是否唤醒'''
self.is_at_or_wake_command = False
'''是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True,但是不会让这个属性置为 True)'''
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
@@ -44,7 +51,7 @@ class AstrMessageEvent(abc.ABC):
session_id=session_id
)
self.unified_msg_origin = str(self.session)
'''统一的消息来源字符串。格式为 platform_name:message_type:session_id'''
self._result: MessageEventResult = None
'''消息事件的结果'''
@@ -102,7 +102,7 @@ class AiocqhttpAdapter(Platform):
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
m['data'] = {
"file": ret['file'],
@@ -122,7 +122,10 @@ class AiocqhttpAdapter(Platform):
def run(self) -> Awaitable[Any]:
if not self.host or not self.port:
return
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199")
self.host = "0.0.0.0"
self.port = 6199
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
@@ -95,6 +95,8 @@ class SimpleGewechatClient():
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
if '在群聊中@了你' in d.get('PushContent', ''):
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
@@ -192,7 +194,7 @@ class SimpleGewechatClient():
async def start_polling(self):
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host=self.host,
host='0.0.0.0',
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder
)
+7 -1
View File
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from astrbot.core.db.po import Conversation
class ProviderType(enum.Enum):
@@ -38,10 +39,15 @@ class ProviderRequest():
'''上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
'''
system_prompt: str = ""
'''系统提示词'''
conversation: Conversation = None
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
def __str__(self):
return self.__repr__()
@dataclass
class LLMResponse:
+2 -1
View File
@@ -121,7 +121,8 @@ class FuncCall:
tools.append(func_declaration)
declarations["function_declarations"] = tools
if tools:
declarations["function_declarations"] = tools
return declarations
+10 -1
View File
@@ -66,7 +66,16 @@ class ProviderManager():
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed=""
)
self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
+22 -41
View File
@@ -8,6 +8,8 @@ from typing import TypedDict
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
@@ -15,8 +17,8 @@ class Personality(TypedDict):
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict]
_mood_imitation_dialogs_processed: str
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
@@ -60,25 +62,11 @@ class Provider(AbstractProvider):
) -> None:
super().__init__(provider_config)
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
'''维护了当前的使用的 persona,即人格。可能为 None'''
self.db_helper = db_helper
'''用于持久化的数据库操作对象。'''
if persistant_history:
# 读取历史记录
try:
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError()
@@ -96,22 +84,6 @@ class Provider(AbstractProvider):
'''获得支持的模型列表'''
raise NotImplementedError()
@abc.abstractmethod
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
'''获取人类可读的上下文
page 从 1 开始
Example:
["User: 你好", "Assistant: 你好!"]
Return:
contexts: List[str]: 上下文列表
total_pages: int: 总页数
'''
raise NotImplementedError()
@abc.abstractmethod
async def text_chat(self,
prompt: str,
@@ -125,26 +97,35 @@ class Provider(AbstractProvider):
Args:
prompt: 提示词
session_id: 会话 ID
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
kwargs: 其他参数
Notes:
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
'''
raise NotImplementedError()
@abc.abstractmethod
async def forget(self, session_id: str) -> bool:
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
async def pop_record(self, context: List):
'''
弹出 context 第一条非系统提示词对话记录
'''
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
else:
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
class STTProvider(AbstractProvider):
+1 -4
View File
@@ -39,7 +39,7 @@ class ProviderDify(Provider):
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
@@ -64,8 +64,6 @@ class ProviderDify(Provider):
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
@@ -115,7 +113,6 @@ class ProviderDify(Provider):
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
return LLMResponse(role="assistant", completion_text=result)
async def forget(self, session_id):
+37 -94
View File
@@ -1,6 +1,4 @@
import traceback
import base64
import json
import aiohttp
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
@@ -12,17 +10,18 @@ from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
class SimpleGoogleGenAIClient():
def __init__(self, api_key: str, api_base: str):
def __init__(self, api_key: str, api_base: str, timeout: int=120) -> None:
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession(trust_env=True)
self.timeout = timeout
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=10) as resp:
async with self.client.get(request_url, timeout=self.timeout) as resp:
response = await resp.json()
models = []
@@ -48,7 +47,7 @@ class SimpleGoogleGenAIClient():
payload["contents"] = contents
logger.debug(f"payload: {payload}")
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
async with self.client.post(request_url, json=payload, timeout=10) as resp:
async with self.client.post(request_url, json=payload, timeout=self.timeout) as resp:
response = await resp.json()
return response
@@ -67,66 +66,19 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 180)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.client = SimpleGoogleGenAIClient(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None)
api_base=provider_config.get("api_base", None),
timeout=self.timeout
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
return await self.client.models_list()
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
@@ -181,6 +133,9 @@ class ProviderGoogleGenAI(Provider):
)
logger.debug(f"result: {result}")
if "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]['content']['parts']
llm_response = LLMResponse("assistant")
for candidate in candidates:
@@ -190,49 +145,47 @@ class ProviderGoogleGenAI(Provider):
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate['functionCall']['args'])
llm_response.tools_call_name.append(candidate['functionCall']['name'])
llm_response.completion_text = llm_response.completion_text.strip()
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str,
session_id: str = None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
**model_config
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 10
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
self.pop_record(session_id)
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
@@ -240,32 +193,19 @@ class ProviderGoogleGenAI(Provider):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 重置会话")
elif "Function calling is not enabled" in str(e):
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}")
raise e
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
return llm_response
def get_current_key(self) -> str:
return self.client.api_key
@@ -286,6 +226,9 @@ class ProviderGoogleGenAI(Provider):
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
@@ -57,20 +57,13 @@ class LLMTunerModelLoader(Provider):
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
contexts: List = [],
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
new_record = {"role": "user", "content": prompt}
if not contexts:
query_context = [
*self.session_memory[session_id],
new_record,
]
system_prompt = self.curr_personality["prompt"]
else:
query_context = [*contexts, new_record]
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
@@ -96,34 +89,8 @@ class LLMTunerModelLoader(Provider):
responses = await self.model.achat(**conf)
llm_response = LLMResponse("assistant", responses[-1].response_text)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
async def forget(self, session_id):
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
async def get_current_key(self):
return "none"
@@ -132,28 +99,4 @@ class LLMTunerModelLoader(Provider):
pass
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return [self.get_model()]
+102 -109
View File
@@ -1,10 +1,10 @@
import base64
import json
import re
import os
from openai import AsyncOpenAI, NOT_GIVEN
from openai import AsyncOpenAI, AsyncAzureOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import NotFoundError
from openai._exceptions import NotFoundError, UnprocessableEntityError
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
@@ -29,37 +29,27 @@ class ProviderOpenAIOfficial(Provider):
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
# 适配 azure openai #332
if "api_version" in provider_config:
# 使用 azure api
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
base_url=provider_config.get("api_base", None),
timeout=self.timeout
)
else:
# 使用 openai api
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=self.timeout
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
try:
@@ -72,48 +62,17 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def pop_record(self, session_id: str):
'''
弹出最早的一个对话
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) < 2:
return
try:
self.session_memory[session_id].pop(0)
self.session_memory[session_id].pop(0)
except IndexError:
pass
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_openai_style()
if tool_list:
payloads['tools'] = tool_list
completion = None
try:
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
except BaseException as e:
# 处理不支持 Function Calling 的模型
if 'does not support Function Calling' in str(e) \
or 'does not support tools' in str(e) \
or 'Function call is not supported' in str(e): # siliconcloud
del payloads['tools']
logger.debug(f"模型 {self.model_name} 不支持 tools,已自动移除")
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
else:
raise e
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion}")
@@ -144,81 +103,114 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat(
self,
prompt: str,
session_id: str,
session_id: str=None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
**model_config
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
if kwargs.get("persist", True):
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads['messages'] = new_contexts
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
# 重试 10 次
retry_cnt = 10
while retry_cnt > 0:
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
await self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
if kwargs.get("persist", True):
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse("err", "err: 请尝试 /reset 清除会话记录。")
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads['messages'] = new_contexts
llm_response = await self._query(payloads, func_tool)
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
elif 'does not support Function Calling' in str(e) \
or 'does not support tools' in str(e) \
or 'Function call is not supported' in str(e) \
or 'Function calling is not enabled' in str(e) \
or 'Tool calling is not supported' in str(e) \
or 'No endpoints found that support tool use' in str(e): # siliconcloud
logger.info(f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。")
if 'tools' in payloads:
del payloads['tools']
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if 'tool' in str(e).lower() and 'support' in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if 'Connection error.' in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}")
raise e
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
async def _remove_image_from_context(self, contexts: List):
'''
从上下文中删除所有带有 image 的记录
'''
new_contexts = []
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
flag = False
for context in contexts:
if flag:
flag = False # 删除 image 后,下一条(LLM 响应)也要删除
continue
if isinstance(context['content'], list):
flag = True
# continue
new_content = []
for item in context['content']:
if isinstance(item, dict) and 'image_url' in item:
continue
new_content.append(item)
if not new_content:
# 用户只发了图片
new_content = [{"type": "text", "text": "[图片]"}]
context['content'] = new_content
new_contexts.append(context)
return new_contexts
def get_current_key(self) -> str:
return self.client.api_key
@@ -238,9 +230,10 @@ class ProviderOpenAIOfficial(Provider):
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
if image_url.startswith("file:///"):
image_url = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
@@ -22,24 +22,21 @@ class ProviderZhipu(ProviderOpenAIOfficial):
async def text_chat(
self,
prompt: str,
session_id: str,
session_id: str = None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model()
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
@@ -62,7 +59,6 @@ class ProviderZhipu(ProviderOpenAIOfficial):
}
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
+3
View File
@@ -16,6 +16,7 @@ from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class Context:
'''
@@ -44,6 +45,7 @@ class Context:
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
@@ -52,6 +54,7 @@ class Context:
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
self.conversation_manager = conversation_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
+5 -1
View File
@@ -46,7 +46,11 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
if not event.is_wake_up():
return False
message_str = event.get_message_str().strip()
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
# 分割为列表(每个参数之间可能会有多个空格)
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
+11 -4
View File
@@ -40,17 +40,24 @@ class CommandGroupFilter(HandlerFilter):
if not event.is_wake_up():
return False, None
message_str = event.get_message_str().strip()
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
ls = re.split(r"\s+", message_str)
if ls[0] != self.group_name:
return False, None
# 改写 message_str
ls = ls[1:]
event.message_str = " ".join(ls)
event.message_str = event.message_str.strip()
# event.message_str = " ".join(ls)
# event.message_str = event.message_str.strip()
parsing_command = " ".join(ls)
parsing_command = parsing_command.strip()
event.set_extra("parsing_command", parsing_command)
if event.message_str == "":
if parsing_command == "":
# 当前还是指令组
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
+3 -2
View File
@@ -19,7 +19,8 @@ class PermissionTypeFilter(HandlerFilter):
'''
if self.permission_type == PermissionType.ADMIN:
if not event.is_admin():
event.stop_event()
raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限执行此操作。")
# event.stop_event()
# raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
return False
return True
+1
View File
@@ -8,6 +8,7 @@ from astrbot.core.config import AstrBotConfig
class RegexFilter(HandlerFilter):
'''正则表达式过滤器'''
def __init__(self, regex: str):
self.regex_str = regex
self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
+40 -20
View File
@@ -17,7 +17,12 @@ def get_handler_full_name(awaitable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata:
def get_handler_or_create(
handler: Awaitable,
event_type: EventType,
dont_add = False,
**kwargs
) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
@@ -32,12 +37,24 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
handler=handler,
event_filters=[]
)
# 插件handler的附加额外信息
if handler.__doc__:
md.desc = handler.__doc__.strip()
if 'desc' in kwargs:
md.desc = kwargs['desc']
del kwargs['desc']
md.extras_configs = kwargs
if not dont_add:
star_handlers_registry.append(md)
return md
def register_command(command_name: str = None, *args):
'''注册一个 Command'''
def register_command(command_name: str = None, *args, **kwargs):
'''注册一个 Command.
'''
# print("command: ", command_name, args, kwargs)
new_command = None
add_to_event_filters = False
@@ -51,7 +68,7 @@ def register_command(command_name: str = None, *args):
add_to_event_filters = True
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
@@ -61,8 +78,11 @@ def register_command(command_name: str = None, *args):
return decorator
def register_command_group(command_group_name: str = None, *args):
'''注册一个 CommandGroup'''
def register_command_group(command_group_name: str = None, *args, **kwargs):
'''注册一个 CommandGroup
'''
# print("commandgroup: ", command_group_name,args, kwargs)
new_group = None
add_to_event_filters = False
@@ -78,7 +98,7 @@ def register_command_group(command_group_name: str = None, *args):
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
@@ -93,16 +113,16 @@ class RegisteringCommandable():
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
def register_event_message_type(event_message_type: EventMessageType):
def register_event_message_type(event_message_type: EventMessageType, **kwargs):
'''注册一个 EventMessageType'''
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, kwargs)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awaitable
return decorator
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType, **kwargs):
'''注册一个 PlatformAdapterType'''
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
@@ -111,10 +131,10 @@ def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
return decorator
def register_regex(regex: str):
def register_regex(regex: str, **kwargs):
'''注册一个 Regex'''
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent, **kwargs)
handler_md.event_filters.append(RegexFilter(regex))
return awaitable
@@ -134,7 +154,7 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
return decorator
def register_on_llm_request():
def register_on_llm_request(**kwargs):
'''当有 LLM 请求时的事件
Examples:
@@ -149,12 +169,12 @@ def register_on_llm_request():
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent)
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs)
return awaitable
return decorator
def register_on_llm_response():
def register_on_llm_response(**kwargs):
'''当有 LLM 请求后的事件
Examples:
@@ -169,7 +189,7 @@ def register_on_llm_response():
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs)
return awaitable
return decorator
@@ -215,18 +235,18 @@ def register_llm_tool(name: str = None):
return decorator
def register_on_decorating_result():
def register_on_decorating_result(**kwargs):
'''在发送消息前的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent, **kwargs)
return awaitable
return decorator
def register_after_message_sent():
def register_after_message_sent(**kwargs):
'''在消息发送后的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent)
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent, **kwargs)
return awaitable
return decorator
+4 -1
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from dataclasses import dataclass, field
from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = []
@@ -39,6 +39,9 @@ class StarMetadata:
config: AstrBotConfig = None
'''插件配置'''
star_handler_full_names: List[str] = field(default_factory=list)
'''注册的 Handler 的全名列表'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+54 -15
View File
@@ -1,34 +1,41 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
import heapq
from dataclasses import dataclass, field
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T], List[T]):
class StarHandlerRegistry(Generic[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
_handlers = []
def append(self, handler: StarHandlerMetadata):
'''添加一个 Handler'''
super().append(handler)
if 'priority' not in handler.extras_configs:
handler.extras_configs['priority'] = 0
heapq.heappush(self._handlers, (-handler.extras_configs['priority'], handler))
self.star_handlers_map[handler.handler_full_name] = handler
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
def _print_handlers(self):
'''打印所有的 Handler'''
for _, handler in self._handlers:
print(handler.handler_full_name)
def get_handlers_by_event_type(self, event_type: EventType, only_activated=True) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
if only_activated:
return [
handler
for handler in self
if handler.event_type == event_type and
star_map[handler.handler_module_path] and
star_map[handler.handler_module_path].activated
]
else:
return [handler for handler in self if handler.event_type == event_type]
handlers = [
handler
for _, handler in self._handlers
if handler.event_type == event_type and
(not only_activated or (star_map[handler.handler_module_path] and star_map[handler.handler_module_path].activated))
]
return handlers
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
@@ -36,7 +43,32 @@ class StarHandlerRegistry(Generic[T], List[T]):
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_path == module_name]
return [handler for _, handler in self._handlers if handler.handler_module_path == module_name]
def clear(self):
'''清空所有的 Handler'''
self.star_handlers_map.clear()
self._handlers.clear()
def remove(self, handler: StarHandlerMetadata):
'''删除一个 Handler'''
# self._handlers.remove(handler)
for i, h in enumerate(self._handlers):
if h[1] == handler:
self._handlers.pop(i)
break
try:
del self.star_handlers_map[handler.handler_full_name]
except KeyError:
pass
def __iter__(self):
'''使 StarHandlerRegistry 支持迭代'''
return (handler for _, handler in self._handlers)
def __len__(self):
'''返回 Handler 的数量'''
return len(self._handlers)
star_handlers_registry = StarHandlerRegistry()
@@ -76,3 +108,10 @@ class StarHandlerMetadata():
desc: str = ""
'''Handler 的描述信息'''
extras_configs: dict = field(default_factory=dict)
'''插件注册的一些其他的信息, 如 priority 等'''
def __lt__(self, other: StarHandlerMetadata):
'''定义小于运算符以支持优先队列'''
return self.extras_configs.get('priority', 0) < other.extras_configs.get('priority', 0)
+95 -28
View File
@@ -9,7 +9,6 @@ import logging
from types import ModuleType
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import DEFAULT_VALUE_MAP
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
@@ -19,6 +18,8 @@ from .star import star_registry, star_map
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
from .filter.permission import PermissionTypeFilter, PermissionType
class PluginManager:
def __init__(
self,
@@ -39,6 +40,8 @@ class PluginManager:
'''保留插件的路径。在 packages 目录下'''
self.conf_schema_fname = "_conf_schema.json"
'''插件配置 Schema 文件名'''
self.failed_plugin_info = ""
def _get_classes(self, arg: ModuleType):
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
@@ -125,7 +128,7 @@ class PluginManager:
if isinstance(metadata, dict):
if 'name' not in metadata or 'desc' not in metadata or 'version' not in metadata or 'author' not in metadata:
raise Exception("插件元数据信息不完整。")
raise Exception("插件元数据信息不完整。name, desc, version, author 是必须的字段。")
metadata = StarMetadata(
name=metadata['name'],
author=metadata['author'],
@@ -135,30 +138,52 @@ class PluginManager:
)
return metadata
async def reload(self):
'''扫描并加载所有的插件'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
async def reload(self, specified_plugin_name=None):
'''扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件'''
specified_module_path = None
if specified_plugin_name:
for smd in star_registry:
if smd.name == specified_plugin_name:
specified_module_path = smd.module_path
break
# 终止插件
if not specified_module_path:
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
else:
# 只重载指定插件
smd = star_map.get(specified_module_path)
if smd:
await self._unbind_plugin(smd.name, specified_module_path)
try:
del sys.modules[specified_module_path]
except KeyError:
logger.warning(f"模块 {specified_module_path} 未载入")
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
# 导入插件模块,并尝试实例化插件类
for plugin_module in plugin_modules:
try:
@@ -167,11 +192,15 @@ class PluginManager:
root_dir_name = plugin_module['pname'] # 插件的目录名
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
if specified_module_path and path != specified_module_path:
continue
logger.info(f"正在载入插件 {root_dir_name} ...")
# 尝试导入模块
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
try:
module = __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError):
@@ -200,6 +229,18 @@ class PluginManager:
# 通过装饰器的方式注册插件
metadata = star_map[path]
try:
# yaml 文件的元数据优先
metadata_yaml = self._load_plugin_metadata(plugin_path=plugin_dir_path)
if metadata_yaml:
metadata.name = metadata_yaml.name
metadata.author = metadata_yaml.author
metadata.desc = metadata_yaml.desc
metadata.version = metadata_yaml.version
metadata.repo = metadata_yaml.repo
except Exception:
pass
if plugin_config:
metadata.config = plugin_config
try:
@@ -213,12 +254,11 @@ class PluginManager:
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
# 绑定 handler
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
for handler in related_handlers:
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, metadata.star_cls)
# llm_tool
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
@@ -240,8 +280,7 @@ class PluginManager:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata = self._load_plugin_metadata(plugin_path=plugin_dir_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
@@ -251,26 +290,54 @@ class PluginManager:
metadata.module_path = path
star_map[path] = metadata
star_registry.append(metadata)
logger.debug(f"插件 {root_dir_name} 载入成功。")
# 禁用/启用插件
if metadata.module_path in inactivated_plugins:
metadata.activated = False
full_names = []
for handler in star_handlers_registry.get_handlers_by_module_name(metadata.module_path):
full_names.append(handler.handler_full_name)
# 检查并且植入自定义的权限过滤器(alter_cmd)
if metadata.name in alter_cmd and handler.handler_name in alter_cmd[metadata.name]:
cmd_type = alter_cmd[metadata.name][handler.handler_name].get("permission", "member")
found_permission_filter = False
for filter_ in handler.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
filter_.permission_type = PermissionType.ADMIN
else:
filter_.permission_type = PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
handler.event_filters.append(PermissionTypeFilter(PermissionType.ADMIN if cmd_type == "admin" else PermissionType.MEMBER))
logger.debug(f"插入权限过滤器 {cmd_type}{metadata.name}{handler.handler_name} 方法。")
metadata.star_handler_full_names = full_names
# 执行 initialize() 方法
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
traceback.print_exc()
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
logger.error(f"----- 插件 {root_dir_name} 载入失败 -----")
errors = traceback.format_exc()
for line in errors.split('\n'):
logger.error(f"| {line}")
logger.error("----------------------------------")
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {str(e)}\n"
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if not fail_rec:
return True, None
else:
self.failed_plugin_info = fail_rec
return False, fail_rec
async def install_plugin(self, repo_url: str):
+9
View File
@@ -27,9 +27,14 @@ class DifyAPIClient:
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"chat_messages payload: {payload}")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
@@ -55,9 +60,13 @@ class DifyAPIClient:
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"workflow_run payload: {payload}")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
@@ -25,6 +25,10 @@ class ParameterValidationMixin:
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
+8 -8
View File
@@ -121,7 +121,7 @@ class ChatRoute(Route):
}))
# 持久化
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
@@ -136,7 +136,7 @@ class ChatRoute(Route):
if audio_url:
new_his['audio_url'] = audio_url
history.append(new_his)
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
self.db.update_conversation(username, conversation_id, history=json.dumps(history))
return Response().ok().__dict__
@@ -168,7 +168,7 @@ class ChatRoute(Route):
continue
yield result_text + '\n'
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
conversation = self.db.get_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
@@ -178,7 +178,7 @@ class ChatRoute(Route):
'type': 'bot',
'message': result_text
})
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
self.db.update_conversation(username, cid, history=json.dumps(history))
await asyncio.sleep(0.5)
except BaseException as e:
@@ -204,20 +204,20 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
self.db.delete_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
self.db.new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
conversations = self.db.get_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
@@ -226,7 +226,7 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
self.curr_user_cid[username] = conversation_id
+105 -12
View File
@@ -6,6 +6,12 @@ from astrbot.core import logger
from quart import request
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType
class PluginRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
@@ -18,22 +24,58 @@ class PluginRoute(Route):
'/plugin/uninstall': ('POST', self.uninstall_plugin),
'/plugin/market_list': ('GET', self.get_online_plugins),
'/plugin/off': ('POST', self.off_plugin),
'/plugin/on': ('POST', self.on_plugin)
'/plugin/on': ('POST', self.on_plugin),
'/plugin/reload': ('POST', self.reload_plugins),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
self.register_routes()
self.translated_event_type = {
EventType.AdapterMessageEvent: "平台消息下发时",
EventType.OnLLMRequestEvent: "LLM 请求时",
EventType.OnLLMResponseEvent: "LLM 响应后",
EventType.OnDecoratingResultEvent: "回复消息前",
EventType.OnCallingFuncToolEvent: "函数工具",
EventType.OnAfterMessageSentEvent: "发送消息后"
}
async def reload_plugins(self):
data = await request.json
plugin_name = data.get("name", None)
try:
success, message = await self.plugin_manager.reload(plugin_name)
if not success:
return Response().error(message).__dict__
return Response().ok(None, "重载成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/reload: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def get_online_plugins(self):
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
result = await response.json()
return Response().ok(result).__dict__
except Exception as e:
logger.error(f"获取插件列表失败:{e}")
return Response().error(str(e)).__dict__
custom = request.args.get("custom_registry")
if custom:
urls = [custom]
else:
urls = [
"https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json",
"https://api.soulter.top/astrbot/plugins"
]
for url in urls:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
return Response().ok(result).__dict__
else:
logger.error(f"请求 {url} 失败,状态码:{response.status}")
except Exception as e:
logger.error(f"请求 {url} 失败,错误:{e}")
return Response().error("获取插件列表失败").__dict__
async def get_plugins(self):
_plugin_resp = []
@@ -45,10 +87,58 @@ class PluginRoute(Route):
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved,
"activated": plugin.activated
"activated": plugin.activated,
"handlers": await self.get_plugin_handlers_info(plugin.star_handler_full_names),
}
_plugin_resp.append(_t)
return Response().ok(_plugin_resp).__dict__
return Response().ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info).__dict__
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
'''解析插件行为'''
handlers = []
for handler_full_name in handler_full_names:
info = {}
handler = star_handlers_registry.star_handlers_map.get(handler_full_name, None)
if handler is None:
continue
info["event_type"] = handler.event_type.name
info["event_type_h"] = self.translated_event_type.get(handler.event_type, handler.event_type.name)
info["handler_full_name"] = handler.handler_full_name
info["desc"] = handler.desc
info["handler_name"] = handler.handler_name
if handler.event_type == EventType.AdapterMessageEvent:
# 处理平台适配器消息事件
has_admin = False
for filter in handler.event_filters: # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
if isinstance(filter, CommandFilter):
info["type"] = "指令"
info["cmd"] = filter.command_name
elif isinstance(filter, CommandGroupFilter):
info["type"] = "指令组"
info["cmd"] = filter.group_name
info["sub_command"] = filter.print_cmd_tree(filter.sub_command_filters)
elif isinstance(filter, RegexFilter):
info["type"] = "正则匹配"
info["cmd"] = filter.regex_str
elif isinstance(filter, PermissionTypeFilter):
has_admin = True
info["has_admin"] = has_admin
if "cmd" not in info:
info["cmd"] = "未知"
if "type" not in info:
info["type"] = "事件监听器"
else:
info["cmd"] = "自动触发"
info["type"] = ""
if not info["desc"]:
info["desc"] = "无描述"
handlers.append(info)
return handlers
async def install_plugin(self):
post_data = await request.json
@@ -56,6 +146,7 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -70,6 +161,7 @@ class PluginRoute(Route):
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
self.core_lifecycle.restart()
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -94,6 +186,7 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
+12
View File
@@ -0,0 +1,12 @@
# What's Changed
- fix: 修复主动概率回复关闭后仍然回复的问题 #317
- fix: 尝试修复 gewechat 群聊收不到 at 的回复 #294
- perf: 移除了默认人格
- fix: 修复HTTP代理删除后不生效 #319
- fix: 调用Gemini API输出多余空行问题 #318
- feat: 添加硅基流动模版
- fix: 硅基流动 not a vlm 和 tool calling not supported 报错 #305 #291
- perf: 回复时艾特发送者之后添加空格或换行 #312
- fix: docker容器内时区不对导致 reminder 时间错误
- perf: siliconcloud 不支持 tool 的模型
+13
View File
@@ -0,0 +1,13 @@
# What's Changed
1. 支持接入企业微信(测试)
2. 修复速率限制不可用的问题
3. gewechat 回调接口默认暴露在所有 IP
4. 适配 Azure OpenAI
5. 修复请求 gemini 出现 KeyError 'candidates' 的错误
6. 将 /reset /persona 挪入管理员指令 #308
7. 支持通过 /alter_cmd 设置所有指令是否只能管理员操作
8. /plugin 指令支持查看插件注册的指令和指令组
9. 插件注册指令支持传入指令的描述以方便 /plugin 查看。需要写在函数的第一行的 docstring 中。
10. 修复 schema 中 object hint 不显示 #290
11. feat: 优化插件市场的访问速度
+15
View File
@@ -0,0 +1,15 @@
# What's Changed
> 由于重写了会话记录部分,更新此版本后,将会造成之前的对话记录清空(但没有被删除)。
> 关于更好的对话管理,如果有任何报错或者优化建议,请直接提交 issue~
1. 更好的对话管理,支持 /ls, /del, /new, /switch, /rename 指令来操作对话。
2. 人格情境跟随对话。每个对话支持独立设置人格情境,只需要 /persona 指令切换即可。
3. 支持使用 LLM 辅助分段回复 #338
4. 优化 aiocqhttp 适配器对用户非法输入的处理
5. 优化插件页面
6. 修复权限过滤算子导致的问题 #350
7. 修复级联指令组时出现载入错误的问题 #366
8. 修复代码执行器的一个typo by @eltociear
9. 修复指令组情况下可能造成多指令出触发的问题
10. 添加屏蔽无权限指令回复的功能 #361
+19
View File
@@ -0,0 +1,19 @@
# What's Changed
> 由于重写了会话记录部分,更新此版本后,将会造成之前的对话记录清空(但没有被删除)。
> 关于更好的对话管理,如果有任何报错或者优化建议,请直接提交 issue~
1. 修复 reminder 时区问题
2. 面板支持重载单个插件 #297
3. 面板支持列表展示插件市场
4. 文字转图片支持自定义字数阈值(配置->其他配置)
5. 面板更好的列表可视化 #274
6. 面板支持查看插件行为
7. 支持设置 timeout 超时时间参数,防止思考模型太长达到超时时间。(需要重新配置服务提供商或者在服务提供商 config 中配置 timeout 参数) #378
8. openrouter 报错 no endpoints found that support tool use #371
9. 修复插件 metadata 不生效的问题
10. 修复不支持图片的模型请求异常
11. 修复 reminder 无法删除的问题
12. 修复 /model 切换不了模型的问题
13. 插件支持设置优先级
14. 聊天增强图像转述支持自定义 provider id。#274
+1 -1
View File
@@ -32,7 +32,7 @@
"vue-router": "4.2.4",
"vue3-apexcharts": "1.4.4",
"vue3-print-nb": "0.1.4",
"vuetify": "3.3.14",
"vuetify": "3.7.11",
"yup": "1.2.0"
},
"devDependencies": {
@@ -1,22 +1,26 @@
<template>
<h3 style="margin-bottom: 8px;" v-if="iterable && metadata[metadataKey]?.type === 'object'">
{{metadata[metadataKey]?.description }}
{{ metadata[metadataKey]?.description }}
</h3>
<v-card-text>
<div v-for="(index, key) in iterable" :key="key" style="margin-bottom: 0.5px;" v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template">
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object'" style="margin-bottom: 16px"
:text="metadata[metadataKey].items[key]?.hint" :title="'💡 关于' + metadata[metadataKey].items[key]?.description"
type="info" variant="tonal">
<div v-for="(index, key) in iterable" :key="key" style="margin-bottom: 0.5px;"
v-if="metadata[metadataKey]?.type === 'object' || metadata[metadataKey]?.config_template">
<v-alert v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
style="margin-bottom: 16px" :text="metadata[metadataKey].items[key]?.hint"
:title="'💡 关于' + metadata[metadataKey].items[key]?.description" type="info" variant="tonal">
</v-alert>
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
<div style="width: 100%;" v-if="metadata[metadataKey].items[key]">
<v-select v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
variant="outlined" :items="metadata[metadataKey].items[key]?.options"
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense :disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
<v-text-field v-else-if="metadata[metadataKey].items[key]?.type === 'string' && !metadata[metadataKey].items[key]?.invisible"
<v-select
v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
v-model="iterable[key]" variant="outlined" :items="metadata[metadataKey].items[key]?.options"
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense
:disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
<v-text-field
v-else-if="metadata[metadataKey].items[key]?.type === 'string' && !metadata[metadataKey].items[key]?.invisible"
v-model="iterable[key]" :label="metadata[metadataKey].items[key]?.description + '(' + key + ')'"
variant="outlined" dense ></v-text-field>
variant="outlined" dense></v-text-field>
<v-text-field
v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey].items[key]?.invisible"
v-model="iterable[key]" :label="metadata[metadataKey].items[key]?.description + '(' + key + ')'"
@@ -27,17 +31,11 @@
<v-switch v-else-if="metadata[metadataKey].items[key]?.type === 'bool' && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" color="primary"
inset></v-switch>
<v-combobox variant="outlined" v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
v-model="iterable[key]" chips clearable
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" multiple
prepend-icon="mdi-tag-multiple-outline">
<template v-slot:selection="{ attrs, item, select, selected }">
<v-chip v-bind="attrs" :model-value="selected" closable @click="select"
@click:close="remove(item)">
<strong>{{ item }}</strong>
</v-chip>
</template>
</v-combobox>
<ListConfigItem
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
:value="iterable[key]"
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'"/>
<div v-else-if="metadata[metadataKey].items[key]?.type === 'object' && !metadata[metadataKey].items[key]?.invisible"
style="border: 1px solid #e0e0e0; padding: 8px; margin-bottom: 16px; border-radius: 10px;">
<AstrBotConfig :metadata="metadata[metadataKey].items" :iterable="iterable[key]"
@@ -52,51 +50,54 @@
</div>
<div
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object' && !metadata[metadataKey].items[key]?.invisible">
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && !metadata[metadataKey].items[key]?.invisible">
<v-btn icon size="x-small" style="margin-bottom: 22px;">
<v-icon size="x-small">mdi-help</v-icon>
<v-tooltip activator="parent" location="start">{{ metadata[metadataKey].items[key]?.hint
}}</v-tooltip>
</v-btn>
</div>
<div>
<v-chip v-if="!metadata[metadataKey].items[key]?.invisible" color="primary">{{ metadata[metadataKey].items[key]?.type }}</v-chip>
</div>
</div>
</div>
<div v-else>
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && metadata[metadataKey]?.type !== 'object'" style="margin-bottom: 16px"
:text="metadata[metadataKey]?.hint" :title="'💡 关于' + metadata[metadataKey]?.description"
type="info" variant="tonal">
<v-alert v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint"
style="margin-bottom: 16px" :text="metadata[metadataKey]?.hint"
:title="'💡 关于' + metadata[metadataKey]?.description" type="info" variant="tonal">
</v-alert>
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
<div style="width: 100%;">
<v-select v-if="metadata[metadataKey]?.options && !metadata[metadataKey]?.invisible" v-model="iterable[metadataKey]"
variant="outlined" :items="metadata[metadataKey]?.options"
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" dense :disabled="metadata[metadataKey]?.readonly"></v-select>
<v-text-field v-else-if="metadata[metadataKey]?.type === 'string' && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]" :label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'"
variant="outlined" dense ></v-text-field>
<v-select v-if="metadata[metadataKey]?.options && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]" variant="outlined" :items="metadata[metadataKey]?.options"
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" dense
:disabled="metadata[metadataKey]?.readonly"></v-select>
<v-text-field
v-else-if="metadata[metadataKey]?.type === 'string' && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" variant="outlined"
dense></v-text-field>
<v-text-field
v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]" :label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'"
variant="outlined" dense></v-text-field>
<v-textarea v-else-if="metadata[metadataKey]?.type === 'text' && !metadata[metadataKey]?.invisible" v-model="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" variant="outlined"
v-model="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" variant="outlined"
dense></v-text-field>
<v-textarea v-else-if="metadata[metadataKey]?.type === 'text' && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" variant="outlined"
dense></v-textarea>
<v-switch v-else-if="metadata[metadataKey]?.type === 'bool' && !metadata[metadataKey]?.invisible" v-model="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" color="primary"
<v-switch v-else-if="metadata[metadataKey]?.type === 'bool' && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey + ')'" color="primary"
inset></v-switch>
<v-combobox variant="outlined" v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
v-model="iterable[metadataKey]" chips clearable
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'" multiple
prepend-icon="mdi-tag-multiple-outline">
<template v-slot:selection="{ attrs, item, select, selected }">
<v-chip v-bind="attrs" :model-value="selected" closable @click="select"
@click:close="remove(item)">
<strong>{{ item }}</strong>
</v-chip>
</template>
</v-combobox>
<ListConfigItem
v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
:value="iterable[metadataKey]"
:label="metadata[metadataKey]?.description + '(' + metadataKey+ ')'"/>
<div v-else-if="metadata[metadataKey]?.type === 'object' && !metadata[metadataKey]?.invisible"
style="border: 1px solid #e0e0e0; padding: 8px; margin-bottom: 16px; border-radius: 10px;">
<AstrBotConfig :metadata="metadata[metadataKey].items" :iterable="iterable[metadataKey]"
@@ -106,13 +107,17 @@
</div>
<div
v-if="!metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && metadata[metadataKey]?.type !== 'object' && !metadata[metadataKey]?.invisible">
v-if="!metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint && !metadata[metadataKey]?.invisible">
<v-btn icon size="x-small" style="margin-bottom: 22px;">
<v-icon size="x-small">mdi-help</v-icon>
<v-tooltip activator="parent" location="start">{{ metadata[metadataKey]?.hint
}}</v-tooltip>
</v-btn>
</div>
<div>
<v-chip v-if="!metadata[metadataKey]?.invisible" color="primary">{{ metadata[metadataKey]?.type }}</v-chip>
</div>
</div>
</div>
</v-card-text>
@@ -120,8 +125,12 @@
<script>
import { readonly } from 'vue';
import ListConfigItem from './ListConfigItem.vue';
export default {
components: {
ListConfigItem
},
props: {
metadata: Object,
iterable: Object,
@@ -1,7 +1,8 @@
<script setup lang="ts">
const props = defineProps({
title: String,
link: String
link: String,
logo: String
});
const open = (link: string | undefined) => {
@@ -11,15 +12,16 @@ const open = (link: string | undefined) => {
<template>
<v-card variant="outlined" elevation="0" class="withbg">
<v-card-item style="padding: 10px 14px">
<v-card-item style="padding: 10px 12px">
<div class="d-sm-flex align-center justify-space-between">
<v-card-title style="font-size: 17px;">{{ props.title }}</v-card-title>
<img v-if="logo" :src="logo" alt="logo" style="width: 40px; height: 40px; margin-right: 8px;">
<v-card-title style="font-size: 16px;">{{ props.title }}</v-card-title>
<v-spacer></v-spacer>
<v-btn variant="plain" @click="open(props.link)">仓库</v-btn>
<v-btn size="small" text="Read" variant="flat" border @click="open(props.link)">帮助</v-btn>
</div>
</v-card-item>
<v-divider></v-divider>
<v-card-text>
<v-card-text style="padding: 16px;">
<slot />
</v-card-text>
</v-card>
@@ -0,0 +1,93 @@
<template>
<div class="list-config-item">
<h3>{{ label }}</h3>
<v-list dense style="background-color: transparent;max-height: 300px; overflow-y: scroll;" >
<v-list-item v-for="(item, index) in items" :key="index">
<v-list-item-content style="display: flex; justify-content: space-between;">
<v-list-item-title>
<v-chip>{{ item }}</v-chip>
</v-list-item-title>
<v-btn @click="removeItem(index)" variant="plain">
<v-icon>mdi-close</v-icon>
</v-btn>
</v-list-item-content>
</v-list-item>
</v-list>
<v-text-field
v-model="newItem"
label="添加新项,按回车确认添加"
@keyup.enter="addItem"
clearable
dense
hide-details
variant="outlined"
></v-text-field>
</div>
</template>
<script>
export default {
name: 'ListConfigItem',
props: {
value: {
type: Array,
default: () => [],
},
label: {
type: String,
default: '',
},
},
data() {
return {
newItem: '',
items: this.value,
};
},
watch: {
items(newVal) {
this.$emit('input', newVal);
},
},
methods: {
addItem() {
if (this.newItem.trim() !== '') {
this.items.push(this.newItem.trim());
this.newItem = '';
}
},
removeItem(index) {
this.items.splice(index, 1);
},
},
};
</script>
<style scoped>
.list-config-item {
border: 1px solid #e0e0e0;
padding: 16px;
margin-bottom: 16px;
border-radius: 10px;
background-color: #ffffff;
}
.list-config-item h3 {
margin-top: 0;
margin-bottom: 16px;
font-size: 18px;
font-weight: 500;
}
.v-list-item {
padding: 0;
}
.v-list-item-title {
font-size: 14px;
}
.v-btn {
margin-left: 8px;
}
</style>
@@ -9,7 +9,7 @@ const sidebarMenu = shallowRef(sidebarItems);
</script>
<template>
<v-navigation-drawer left v-model="customizer.Sidebar_drawer" elevation="0" rail-width="80" mobile-breakpoint="960"
<v-navigation-drawer left v-model="customizer.Sidebar_drawer" elevation="0" rail-width="80"
app class="leftSidebar" :rail="customizer.mini_sidebar">
<v-list class="pa-4 listitem" style="height: auto">
<template v-for="(item, i) in sidebarMenu" :key="i">
+15 -1
View File
@@ -68,7 +68,7 @@ import config from '@/config';
<v-tabs-window-item v-for="(config_item, index) in config_data[key2]" v-show="config_template_tab === index"
:key="index" :value="index">
<v-container>
<v-btn variant="tonal" rounded="xl" color="error" @click="config_data[key2].splice(index, 1)">
<v-btn variant="tonal" rounded="xl" color="error" @click="deleteItem(key2, index)">
删除这项
</v-btn>
@@ -215,6 +215,20 @@ export default {
// new_tmpl_cfg.id = "new_" + val + "_" + this.config_data[config_item_name].length;
this.config_data[config_item_name].push(new_tmpl_cfg);
this.config_template_tab = this.config_data[config_item_name].length - 1;
},
deleteItem(config_item_name, index) {
console.log(config_item_name, index);
let new_list = [];
for (let i = 0; i < this.config_data[config_item_name].length; i++) {
if (i !== index) {
new_list.push(this.config_data[config_item_name][i]);
}
}
this.config_data[config_item_name] = new_list;
if (this.config_template_tab > 0) {
this.config_template_tab -= 1;
}
}
},
}
+195 -43
View File
@@ -4,59 +4,118 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
import { max } from 'date-fns';
</script>
<template>
<v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README" title="💡提示"
type="info" variant="tonal">
</v-alert>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
<h3>🧩 已安装的插件</h3>
<div style="display: flex; align-items: center;">
<h3>🧩 已安装的插件</h3>
<v-dialog max-width="500px">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" v-if="extension_data.message" icon size="small" color="error"
style="margin-left: auto;" variant="plain">
<v-icon>mdi-alert-circle</v-icon>
</v-btn>
</template>
<v-card>
<v-card-title class="headline">错误信息</v-card-title>
<v-card-text>{{ extension_data.message }}
<br>
<small>详情请检查控制台</small>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="primary" text>关闭</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
</div>
</v-col>
<v-col cols="12" md="6" lg="4" v-for="extension in extension_data.data">
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" style="margin-bottom: 4px;">
<p style="min-height: 130px; max-height: 130px; overflow: none;">{{ extension.desc }}</p>
<div class="d-flex align-center gap-2">
<v-icon>mdi-account</v-icon>
<span>{{ extension.author }}</span>
<v-spacer></v-spacer>
<div v-if="!extension.reserved">
<v-btn variant="plain" @click="openExtensionConfig(extension.name)">配置</v-btn>
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
</div>
<v-col cols="12" md="6" lg="3" v-for="extension in extension_data.data">
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" :logo="extension?.logo"
style="margin-bottom: 4px;">
<div style="min-height: 135px; max-height: 135px; overflow: none;">
<span style="font-weight: bold;">By @{{ extension.author }}</span>
<span> | 插件有 {{ extension.handlers.length }} 个行为</span>
<p style="margin-top: 8px;">{{ extension.desc }}</p>
<a style="font-size: 12px; cursor: pointer; text-decoration: underline; color: #555;" @click="reloadPlugin(extension.name)">重载插件</a>
</div>
<div class="d-flex align-center gap-2 " style="overflow-x: auto;">
<v-btn v-if="!extension.reserved" class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="openExtensionConfig(extension.name)">配置</v-btn>
<v-btn v-if="!extension.reserved" class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="updateExtension(extension.name)">更新</v-btn>
<v-btn v-if="!extension.reserved" class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="uninstallExtension(extension.name)">卸载</v-btn>
<!-- <span v-else>保留插件</span> -->
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
<v-btn variant="plain" v-else @click="pluginOn(extension)"></v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-if="extension.activated"
@click="pluginOff(extension)"></v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-else
@click="pluginOn(extension)">启用</v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="showPluginInfo(extension)">行为</v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
<h3>🧩 插件市场</h3>
<div style="display: flex; align-items: center;">
<h3>🧩 插件市场</h3>
<small style="margin-left: 16px;">如无法显示请打开 <a
href="https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json">链接</a> 复制想安装插件对应的 `repo`
链接然后点击右下角 + 号安装或打开链接下载压缩包安装</small>
<v-btn icon @click="isListView = !isListView" size="small" style="margin-left: auto;" variant="plain">
<v-icon>{{ isListView ? 'mdi-view-grid' : 'mdi-view-list' }}</v-icon>
</v-btn>
</div>
</div>
</v-col>
<v-col cols="12" md="6" lg="4" v-for="plugin in pluginMarketData">
<ExtensionCard :key="plugin.name" :title="plugin.name" :link="plugin.repo" style="margin-bottom: 4px;">
<p style="min-height: 130px; max-height: 130px; overflow: hidden;">{{ plugin.desc }}</p>
<div class="d-flex align-center gap-2">
<v-icon>mdi-account</v-icon>
<span>{{ plugin.author }}</span>
<v-spacer></v-spacer>
<v-btn v-if="!plugin.installed" variant="plain"
@click="extension_url = plugin.repo; newExtension()">安装</v-btn>
<v-btn v-else variant="plain" disabled>已安装</v-btn>
</div>
</ExtensionCard>
<v-col cols="12" md="12" v-if="announcement">
<v-banner color="success" lines="one" :text="announcement" :stacked="false">
</v-banner>
</v-col>
<template v-if="isListView">
<v-col cols="12" md="12">
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name">
<template v-slot:item.actions="{ item }">
<v-btn v-if="!item.installed" class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="extension_url = item.repo; newExtension()">安装</v-btn>
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border disabled>已安装</v-btn>
</template>
</v-data-table>
</v-col>
</template>
<template v-else>
<v-col cols="12" md="6" lg="3" v-for="plugin in pluginMarketData">
<ExtensionCard :key="plugin.name" :title="plugin.name" :link="plugin.repo" style="margin-bottom: 4px;">
<div style="min-height: 130px; max-height: 130px; overflow: hidden;">
<p style="font-weight: bold;">By @{{ plugin.author }}</p>
{{ plugin.desc }}
</div>
<div class="d-flex align-center gap-2">
<v-btn v-if="!plugin.installed" class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="extension_url = plugin.repo; newExtension()">安装</v-btn>
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border disabled>已安装</v-btn>
</div>
</ExtensionCard>
</v-col>
</template>
<v-col style="margin-bottom: 16px;" cols="12" md="12">
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
</v-col>
@@ -71,7 +130,8 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-container>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata"
:iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<p v-else>这个插件没有配置</p>
</v-container>
</v-card-text>
@@ -166,6 +226,44 @@ import axios from 'axios';
</v-card>
</v-dialog>
<v-dialog v-model="showPluginInfoDialog" width="1200">
<template v-slot:activator="{ props }">
</template>
<v-card>
<v-card-title>
<span class="text-h5">{{ selectedPlugin.name }} 插件行为</span>
</v-card-title>
<v-card-text>
<v-data-table style="font-size: 17px;" :headers="plugin_handler_info_headers" :items="selectedPlugin.handlers"
item-key="name">
<template v-slot:header.id="{ column }">
<p style="font-weight: bold;">{{ column.title }}</p>
</template>
<template v-slot:item.event_type="{ item }">
{{ item.event_type }}
</template>
<template v-slot:item.desc="{ item }">
{{ item.desc }}
</template>
<template v-slot:item.type="{ item }">
<v-chip color="success">
{{ item.type }}
</v-chip>
</template>
<template v-slot:item.cmd="{ item }">
<span style="font-weight: bold;">{{ item.cmd }}</span>
</template>
</v-data-table>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="showPluginInfoDialog = false">
关闭
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<v-snackbar :timeout="2000" elevation="24" :color="snack_success" v-model="snack_show">
{{ snack_message }}
</v-snackbar>
@@ -186,7 +284,8 @@ export default {
data() {
return {
extension_data: {
"data": []
"data": [],
"message": ""
},
extension_url: "",
status: "",
@@ -207,12 +306,34 @@ export default {
title: "加载中...",
statusCode: 0, // 0: loading, 1: success, 2: error,
result: ""
}
},
announcement: "",
showPluginInfoDialog: false,
selectedPlugin: {},
plugin_handler_info_headers: [
{ title: '行为类型', key: 'event_type_h' },
{ title: '描述', key: 'desc', maxWidth: '250px' },
{ title: '具体类型', key: 'type' },
{ title: '触发方式', key: 'cmd' },
],
isListView: false,
pluginMarketHeaders: [
{ title: '名称', value: 'name' },
{ title: '描述', value: 'desc' },
{ title: '作者', value: 'author' },
{ title: '操作', value: 'actions', sortable: false }
],
}
},
mounted() {
this.getExtensions();
this.fetchPluginCollection();
axios.get('https://api.soulter.top/astrbot-announcement-plugin-market').then((res) => {
let data = res.data.data;
this.announcement = data.text;
});
},
methods: {
toast(message, success) {
@@ -240,7 +361,8 @@ export default {
},
getExtensions() {
axios.get('/api/plugin/get').then((res) => {
this.extension_data.data = res.data.data;
this.extension_data = res.data;
this.checkAlreadyInstalled();
});
},
@@ -270,7 +392,7 @@ export default {
this.onLoadingDialogResult(2, res.data.message, -1);
return;
}
this.extension_data.data = res.data.data;
this.extension_data = res.data;
this.upload_file = "";
this.onLoadingDialogResult(1, res.data.message);
this.dialog = false;
@@ -291,8 +413,7 @@ export default {
this.onLoadingDialogResult(2, res.data.message, -1);
return;
}
this.extension_data.data = res.data.data;
console.log(this.extension_data);
this.extension_data = res.data;
this.extension_url = "";
this.onLoadingDialogResult(1, res.data.message);
this.dialog = false;
@@ -314,7 +435,7 @@ export default {
this.toast(res.data.message, "error");
return;
}
this.extension_data.data = res.data.data;
this.extension_data = res.data;
this.toast(res.data.message, "success");
this.dialog = false;
this.getExtensions();
@@ -332,7 +453,7 @@ export default {
this.onLoadingDialogResult(2, res.data.message, -1);
return;
}
this.extension_data.data = res.data.data;
this.extension_data = res.data;
console.log(this.extension_data);
this.onLoadingDialogResult(1, res.data.message);
this.dialog = false;
@@ -382,7 +503,7 @@ export default {
});
},
updateConfig() {
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
axios.post('/api/config/plugin/update?plugin_name=' + this.curr_namespace, this.extension_config.config).then((res) => {
if (res.data.status === "ok") {
this.toast(res.data.message, "success");
this.$refs.wfr.check();
@@ -418,11 +539,42 @@ export default {
}
for (let i = 0; i < this.pluginMarketData.length; i++) {
for (let j = 0; j < this.extension_data.data.length; j++) {
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo) {
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo || this.pluginMarketData[i].name === this.extension_data.data[j].name) {
this.pluginMarketData[i].installed = true;
}
}
}
// 将已安装的插件移动到最后面
let installed = [];
let notInstalled = [];
for (let i = 0; i < this.pluginMarketData.length; i++) {
if (this.pluginMarketData[i].installed) {
installed.push(this.pluginMarketData[i]);
} else {
notInstalled.push(this.pluginMarketData[i]);
}
}
this.pluginMarketData = notInstalled.concat(installed);
},
showPluginInfo(plugin) {
this.selectedPlugin = plugin;
this.showPluginInfoDialog = true;
},
reloadPlugin(plugin_name) {
axios.post('/api/plugin/reload',
{
name: plugin_name
}).then((res) => {
if (res.data.status === "error") {
this.onLoadingDialogResult(2, res.data.message, -1);
return;
}
this.toast("重载成功", "success");
this.getExtensions();
}).catch((err) => {
this.toast(err, "error");
});
}
},
}
+10 -2
View File
@@ -25,8 +25,10 @@ class LongTermMemory:
self.max_cnt = 300
self.image_caption = self.config["image_caption"]
self.image_caption_prompt = self.config["image_caption_prompt"]
self.image_caption_provider_id = self.config["image_caption_provider_id"]
self.active_reply = self.config["active_reply"]
self.enable_active_reply = self.active_reply.get("enable", False)
self.ar_method = self.active_reply["method"]
self.ar_possibility = self.active_reply["possibility_reply"]
self.ar_prompt = self.active_reply.get("prompt", "")
@@ -41,7 +43,13 @@ class LongTermMemory:
return cnt
async def get_image_caption(self, image_url: str) -> str:
provider = self.context.get_using_provider()
if not self.image_caption_provider_id:
provider = self.context.get_using_provider()
else:
provider = self.context.get_provider_by_id(self.image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {self.image_caption_provider_id} 的提供商")
response = await provider.text_chat(
prompt=self.image_caption_prompt,
session_id=uuid.uuid4().hex,
@@ -51,7 +59,7 @@ class LongTermMemory:
return response.completion_text
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
if not self.active_reply:
if not self.enable_active_reply:
return False
if event.get_message_type() != MessageType.GROUP_MESSAGE:
return False
+291 -60
View File
@@ -1,15 +1,19 @@
import aiohttp
import datetime
import builtins
import json
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import sp
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
from astrbot.api.platform import MessageType
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata
from astrbot.core.star.star import star_map
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.config.default import VERSION
from collections import defaultdict
from .long_term_memory import LongTermMemory
from astrbot.core import logger
@@ -41,6 +45,7 @@ class Main(star.Star):
@filter.command("help")
async def help(self, event: AstrMessageEvent):
'''查看帮助'''
notice = ""
try:
notice = await self._query_astrbot_notice()
@@ -50,31 +55,37 @@ class Main(star.Star):
dashboard_version = await get_dashboard_version()
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
已注册的 AstrBot 内置指令:
AstrBot 指令:
[System]
/plugin: 查看注册的插件、插件帮助
/t2i: 开启/关闭文本转图片模式
/sid: 获取当前会话 ID
/op <admin_id>: 授权管理员
/deop <admin_id>: 取消管理员
/wl <sid>: 添加会话白名单
/dwl <sid>: 删除会话白名单
/dashboard_update: 更新管理面板
/plugin: 查看插件、插件帮助
/t2i: 开文本转图片
/sid: 获取会话 ID
/op <admin_id>: 授权管理员(op)
/deop <admin_id>: 取消管理员(op)
/wl <sid>: 添加白名单(op)
/dwl <sid>: 删除白名单(op)
/dashboard_update: 更新管理面板(op)
/alter_cmd: 设置指令权限(op)
[大模型]
/provider: 查看、切换大模型提供商
/model: 查看、切换提供商模型列表
/key: 查看、切换 API Key
/reset: 重置 LLM 会
/history: 获取会话历史记录
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
/provider: 大模型提供商
/model: 模型列表
/ls: 对话列表
/new: 创建新对
/switch: 切换对话
/rename: 重命名对话
/del: 删除当前会话对话(op)
/reset: 重置 LLM 会话(op)
/history: 当前对话的对话记录
/persona: 人格情景(op)
/tool ls: 函数工具
/key: API Key(op)
[其他]
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除当前会话的变量。
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除会话的变量。
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@@ -91,7 +102,7 @@ class Main(star.Star):
active = " (启用)" if tool.active else "(停用)"
msg += f"- {tool.name}: {tool.description} {active}\n"
msg += "\n使用 /tool on/off <工具名> 激活或者停用工具。"
msg += "\n使用 /tool on/off <工具名> 激活或者停用函数工具。/tool off_all 停用所有函数工具。"
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@tool.command("on")
@@ -107,6 +118,13 @@ class Main(star.Star):
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
else:
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
@tool.command("off_all")
async def tool_all_off(self, event: AstrMessageEvent):
tm = self.context.get_llm_tool_manager()
for tool in tm.func_list:
self.context.deactivate_llm_tool(tool.name)
event.set_result(MessageEventResult().message(f"停用所有工具成功。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
@@ -117,7 +135,7 @@ class Main(star.Star):
if plugin_list_info.strip() == "":
plugin_list_info = "没有加载任何插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助和加载的指令\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
else:
if oper1 == "off":
@@ -140,10 +158,34 @@ class Main(star.Star):
plugin = self.context.get_registered_star(oper1)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
return
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "帮助信息: 未提供"
help_msg += f"\n\n作者: {plugin.author}\n版本: {plugin.version}"
command_handlers = []
command_names = []
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
if handler.handler_module_path != plugin.module_path:
continue
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
command_handlers.append(handler)
command_names.append(filter_.command_name)
break
elif isinstance(filter_, CommandGroupFilter):
command_handlers.append(handler)
command_names.append(filter_.group_name)
if len(command_handlers) > 0:
help_msg += "\n\n指令列表:\n"
for i in range(len(command_handlers)):
help_msg += f"{command_names[i]}: {command_handlers[i].desc}\n"
help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。"
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
ret += "更多帮助信息请查看插件仓库 README。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
@@ -229,6 +271,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
@@ -236,7 +279,16 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
await self.context.get_using_provider().forget(message.session_id)
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
await self.context.conversation_manager.update_conversation(
message.unified_msg_origin, cid, []
)
ret = "清除会话 LLM 聊天历史成功。"
if self.ltm:
cnt = await self.ltm.remove_session(event=message)
@@ -287,25 +339,34 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
else:
self.context.get_using_provider().set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型成功。 \n模型信息: {idx_or_name}"))
MessageEventResult().message(f"切换模型{self.context.get_using_provider().get_model()}"))
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
'''查看对话记录'''
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
size_per_page = 3
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
size_per_page = 6
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not session_curr_cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
contexts, total_pages = await self.context.conversation_manager.get_human_readable_context(
message.unified_msg_origin, session_curr_cid, page, size_per_page
)
history = ""
for context in contexts:
if len(context) > 150:
context = context[:150] + "..."
history += f"{context}\n"
ret = f"""历史记录:
ret = f"""当前对话历史记录:
{history}
{page} 页 | 共 {total_pages}
@@ -314,6 +375,88 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1):
'''查看对话列表'''
size_per_page = 6
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
total_pages = len(conversations) // size_per_page
if len(conversations) % size_per_page != 0:
total_pages += 1
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
ret = "对话列表:\n---\n"
global_index = (page - 1) * size_per_page + 1
_titles = {}
for conv in conversations:
persona_id = conv.persona_id
if not persona_id and not persona_id == "[%None]":
persona_id = self.context.provider_manager.selected_default_persona['name']
title = conv.title if conv.title else "新对话"
_titles[conv.cid] = title
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
global_index += 1
ret += "---\n"
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if curr_cid:
ret += f"\n当前对话: {_titles[curr_cid]}({curr_cid[:4]})"
else:
ret += "\n当前对话: 无"
unique_session = self.context.get_config()['platform_settings']['unique_session']
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
ret += "\n会话隔离粒度: 群聊"
ret += f"\n{page} 页 | 共 {total_pages}"
ret += "\n*输入 /ls 2 跳转到第 2 页"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent):
'''创建新对话'''
cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin)
message.set_result(MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"))
@filter.command("switch")
async def switch_conv(self, message: AstrMessageEvent, index: int):
'''通过 /ls 前面的序号切换对话'''
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
if index > len(conversations) or index < 1:
message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看"))
else:
conversation = conversations[index-1]
title = conversation.title if conversation.title else "新对话"
await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid)
message.set_result(MessageEventResult().message(f"切换到对话: {title}({conversation.cid[:4]})。"))
@filter.command("rename")
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
'''重命名对话'''
await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name)
message.set_result(MessageEventResult().message("重命名对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent):
'''删除当前对话'''
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not session_curr_cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
message.set_result(MessageEventResult().message("删除当前对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
@@ -347,30 +490,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
MessageEventResult().message("切换 Key 未知错误: "+str(e)))
message.set_result(MessageEventResult().message("切换 Key 成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
l = message.message_str.split(" ")
curr_persona_name = ""
if self.context.get_using_provider().curr_personality:
curr_persona_name = self.context.get_using_provider().curr_personality['name']
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
curr_cid_title = ""
if cid:
conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid)
if not conversation.persona_id and not conversation.persona_id == "[%None]":
curr_persona_name = self.context.provider_manager.selected_default_persona['name']
else:
curr_persona_name = conversation.persona_id
curr_cid_title = conversation.title if conversation.title else "新对话"
curr_cid_title += f"({cid[:4]})"
if len(l) == 1:
message.set_result(
MessageEventResult().message(f"""[Persona]
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
- 人格情景列表: `/persona list`
- 人格情景详细信息: `/persona view 人格`
- 设置人格情景: `/persona 人格`
- 人格情景详细信息: `/persona view 人格`
- 取消人格: `/persona unset`
当前人格情景: {curr_persona_name}
默认人格情景: {self.context.provider_manager.selected_default_persona['name']}
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
配置人格情景请前往管理面板-配置页
""").use_t2i(False))
@@ -395,7 +543,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
elif l[1] == "unset":
self.context.get_using_provider().curr_personality = None
if not cid:
message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。"))
return
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]")
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
ps = "".join(l[1:]).strip()
@@ -403,7 +554,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
lambda persona: persona['name'] == ps,
self.context.provider_manager.personas
), None):
self.context.get_using_provider().curr_personality = persona
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps)
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
else:
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
@@ -475,7 +626,17 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
return
try:
session_provider_context = provider.session_memory.get(event.session_id)
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin)
if not session_curr_cid:
logger.error("当前未处于对话状态,无法主动回复,请使用 /switch 切换或者 /new 创建。")
return
conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid
)
history = json.loads(conv.history)
prompt = self.ltm.ar_prompt
if not prompt:
@@ -485,7 +646,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
prompt=prompt,
func_tool_manager=self.context.get_llm_tool_manager(),
session_id=event.session_id,
contexts=session_provider_context if session_provider_context else []
contexts=history if history else []
)
except BaseException as e:
logger.error(f"主动回复失败: {e}")
@@ -494,7 +655,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
'''在请求 LLM 前注入人格信息、Identifier、时间等 System Prompt'''
provider = self.context.get_using_provider()
if self.prompt_prefix:
req.prompt = self.prompt_prefix + req.prompt
@@ -505,16 +665,27 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
tz_offset = datetime.timedelta(hours=8)
tz = datetime.timezone(tz_offset)
current_time = datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M')
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
if req.conversation:
persona_id = req.conversation.persona_id
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
persona_id = self.context.provider_manager.selected_default_persona['name']
persona = next(builtins.filter(
lambda persona: persona['name'] == persona_id,
self.context.provider_manager.personas
), None)
if persona:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
if self.ltm:
try:
@@ -531,6 +702,66 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
await self.ltm.after_req_llm(event)
except BaseException as e:
logger.error(f"ltm: {e}")
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd")
async def alter_cmd(self, event: AstrMessageEvent):
# token = event.message_str.split(" ")
token = self.parse_commands(event.message_str)
if token.len < 2:
yield event.plain_result("可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令")
return
cmd_name = token.get(1)
cmd_type = token.get(2)
if cmd_type not in ["admin", "member"]:
yield event.plain_result("指令类型错误,可选类型有 admin, member")
return
# 查找指令
found_command = None
for handler in star_handlers_registry:
assert isinstance(handler, StarHandlerMetadata)
for filter_ in handler.event_filters:
if isinstance(filter_, CommandFilter):
if filter_.command_name == cmd_name:
found_command = handler
break
elif isinstance(filter_, CommandGroupFilter):
if cmd_name == filter_.group_name:
found_command = handler
break
if not found_command:
yield event.plain_result("未找到该指令")
return
found_plugin = star_map[found_command.handler_module_path]
alter_cmd_cfg = sp.get("alter_cmd", {})
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
cfg = plugin_.get(found_command.handler_name, {})
cfg["permission"] = cmd_type
plugin_[found_command.handler_name] = cfg
alter_cmd_cfg[found_plugin.name] = plugin_
sp.put("alter_cmd", alter_cmd_cfg)
# 注入权限过滤器
found_permission_filter = False
for filter_ in found_command.event_filters:
if isinstance(filter_, PermissionTypeFilter):
if cmd_type == "admin":
filter_.permission_type = filter.PermissionType.ADMIN
else:
filter_.permission_type = filter.PermissionType.MEMBER
found_permission_filter = True
break
if not found_permission_filter:
found_command.event_filters.insert(0, PermissionTypeFilter(filter.PermissionType.ADMIN if cmd_type == "admin" else filter.PermissionType.MEMBER))
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
# @filter.command_group("kdb")
# def kdb(self):
+2 -2
View File
@@ -358,7 +358,7 @@ class Main(star.Star):
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
break
@@ -393,4 +393,4 @@ class Main(star.Star):
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()
await container.delete()
+45 -13
View File
@@ -1,6 +1,7 @@
import os
import json
import datetime
import uuid
import astrbot.api.star as star
from astrbot.api.event import filter
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -28,11 +29,18 @@ class Main(star.Star):
'''Initialize the scheduler.'''
for group in self.reminder_data:
for reminder in self.reminder_data[group]:
if 'id' not in reminder:
id_ = str(uuid.uuid4())
reminder['id'] = id_
else:
id_ = reminder['id']
if "datetime" in reminder:
if self.check_is_outdated(reminder):
continue
self.scheduler.add_job(
self._reminder_callback,
id=id_,
trigger='date',
args=[group, reminder],
run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"),
@@ -42,6 +50,7 @@ class Main(star.Star):
self.scheduler.add_job(
self._reminder_callback,
trigger='cron',
id=id_,
args=[group, reminder],
misfire_grace_time=60,
**self._parse_cron_expr(reminder["cron"])
@@ -69,11 +78,11 @@ class Main(star.Star):
}
@llm_tool("reminder")
async def reminder_tool(self, event: AstrMessageEvent, text: str, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None):
async def reminder_tool(self, event: AstrMessageEvent, text: str=None, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None):
'''Call this function when user ask for setting a reminder.
Args:
text(string): The content of the reminder.
text(string): Must Required. The content of the reminder.
datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M
cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder.
human_readable_cron(string): Optional. The human readable cron expression of the reminder.
@@ -88,65 +97,88 @@ class Main(star.Star):
if not cron_expression and not datetime_str:
raise ValueError("The cron_expression and datetime_str cannot be both None.")
reminder_time = ""
if not text:
text = "未命名待办事项"
if cron_expression:
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron, "id": str(uuid.uuid4()) }
self.reminder_data[event.unified_msg_origin].append(d)
self.scheduler.add_job(
self._reminder_callback,
'cron',
id=d["id"],
misfire_grace_time=60,
**self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d]
)
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
d = { "text": text, "datetime": datetime_str }
d = { "text": text, "datetime": datetime_str, "id": str(uuid.uuid4()) }
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
self.scheduler.add_job(
self._reminder_callback,
'date',
id=d["id"],
args=[event.unified_msg_origin, d],
run_date=datetime_scheduled,
misfire_grace_time=60
)
reminder_time = datetime_str
await self._save_data()
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。\n使用 /tool off reminder 关闭此功能。")
@filter.command_group("reminder")
def reminder(self):
'''The command group of the reminder.'''
pass
async def get_upcoming_reminders(self, unified_msg_origin: str):
'''Get upcoming reminders.'''
reminders = self.reminder_data.get(unified_msg_origin, [])
if not reminders:
return []
now = datetime.datetime.now()
upcoming_reminders = [
reminder for reminder in reminders
if "datetime" not in reminder or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now
]
return upcoming_reminders
@reminder.command("ls")
async def reminder_ls(self, event: AstrMessageEvent):
'''List all reminders.'''
reminders = self.reminder_data.get(event.unified_msg_origin, [])
'''List upcoming reminders.'''
reminders = await self.get_upcoming_reminders(event.unified_msg_origin)
if not reminders:
yield event.plain_result("没有待办事项。")
yield event.plain_result("没有正在进行的待办事项。")
else:
reminder_str = "待办事项:\n"
reminder_str = "正在进行的待办事项:\n"
for i, reminder in enumerate(reminders):
time_ = reminder.get("datetime", "")
if not time_:
cron_expr = reminder.get("cron", "")
time_ = reminder.get("cron_h", "") + f"(Cron: {cron_expr})"
reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n"
reminder_str += "\n使用 /reminder rm <index> 删除待办事项。"
reminder_str += "\n使用 /reminder rm <id> 删除待办事项。\n"
yield event.plain_result(reminder_str)
@reminder.command("rm")
async def reminder_rm(self, event: AstrMessageEvent, index: int):
'''Remove a reminder by index.'''
reminders = self.reminder_data.get(event.unified_msg_origin, [])
reminders = await self.get_upcoming_reminders(event.unified_msg_origin)
if not reminders:
yield event.plain_result("没有待办事项。")
elif index < 1 or index > len(reminders):
yield event.plain_result("索引越界。")
else:
reminder = reminders.pop(index - 1)
self.scheduler.remove_job(event.unified_msg_origin)
job_id = reminder.get("id")
try:
self.scheduler.remove_job(job_id)
except Exception as e:
logger.error(f"Remove job error: {e}")
await self._save_data()
yield event.plain_result("成功删除待办事项:\n" + reminder["text"])
-1
View File
@@ -9,7 +9,6 @@ class Google(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.proxy = os.environ.get("https_proxy")
print(f"Google Search using proxy: {self.proxy}")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = []
-1
View File
@@ -30,7 +30,6 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = await plugin_manager_pm.reload()
assert success is True
assert err_message is None
assert len(star_handlers_registry) > 0 # package
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):