Compare commits

...

23 Commits

Author SHA1 Message Date
Soulter f2e8303b66 fix: KeyError _mood_imitation_dialogs_processed 2025-02-05 18:52:55 +08:00
Soulter 2a614b545b fix: 修复可能的 KeyError 2025-02-05 17:17:05 +08:00
Soulter 5c0ab21f68 fix: 修复 /model 异常 2025-02-05 17:05:47 +08:00
Soulter 689d109438 typo: myid -> sid 2025-02-05 16:59:21 +08:00
Soulter 2a6934b283 perf: 无对话状态的提示 2025-02-05 16:56:13 +08:00
Soulter 760cb94e9a v3.4.20 2025-02-05 16:06:52 +08:00
Soulter 2a6cff0013 feat: 支持重命名对话 2025-02-05 16:06:18 +08:00
Soulter ce578f0417 feat: 支持使用 LLM 辅助分段回复 #338 2025-02-05 15:40:52 +08:00
Soulter 1745bdb9e2 perf: 优化一些问题 2025-02-05 15:39:59 +08:00
Soulter 3f90b89c3c 添加屏蔽无权限指令回复的功能 #361 2025-02-05 15:06:38 +08:00
Soulter f343e40d15 Merge pull request #370 from Soulter/feat-conversation
feat: 更好的对话管理
2025-02-05 14:56:47 +08:00
Soulter 5cc4be9e65 perf: 优化部分显示问题 2025-02-05 14:51:40 +08:00
Soulter da5aada002 fix: 修复指令组情况下可能造成多指令出触发的问题 2025-02-05 13:52:53 +08:00
Soulter 07f2ee9ad9 fix: 修复 /reset 指令 2025-02-05 13:33:36 +08:00
Soulter 12f4e1146f feat: 更好的对话管理 2025-02-05 13:26:53 +08:00
Soulter 92c57e5476 fix: 修复级联指令组时出现载入错误的问题 2025-02-05 11:11:04 +08:00
Soulter a923baacd8 Update README.md 2025-02-05 01:56:09 +08:00
Soulter 999b094d55 Merge pull request #358 from eltociear/patch-1
chore: update main.py
2025-02-05 01:34:04 +08:00
Soulter d4213f2352 perf: announcement plugin market 2025-02-05 01:19:54 +08:00
Ikko Eltociear Ashimine 3f65c9a066 chore: update main.py
occured -> occurred
2025-02-05 02:18:41 +09:00
Soulter 1d427e2645 perf: 优化插件页面 2025-02-05 01:10:53 +08:00
Soulter 36414c4b00 perf: 优化aiocqhttp适配器对用户非法输入的处理 2025-02-05 00:02:18 +08:00
Soulter 47e253d76c fix: 修复权限过滤算子导致的问题 #350 2025-02-04 23:31:46 +08:00
31 changed files with 659 additions and 400 deletions
+2
View File
@@ -9,6 +9,8 @@
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
+16 -4
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.19"
VERSION = "3.4.20"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -29,8 +29,10 @@ DEFAULT_CONFIG = {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"seg_prompt": "",
"regex": ".*?[。?!~…]+|.+$"
}
},
"no_permission_reply": True,
},
"provider": [],
"provider_settings": {
@@ -109,7 +111,7 @@ CONFIG_METADATA_2 = {
"id": "default",
"type": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_host": "0.0.0.0",
"ws_reverse_port": 6199,
},
"gewechat(微信)": {
@@ -194,6 +196,11 @@ CONFIG_METADATA_2 = {
},
},
},
"no_permission_reply": {
"description": "无权限回复",
"type": "bool",
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
@@ -211,6 +218,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"seg_prompt": {
"description": "分段提示词辅助",
"type": "string",
"hint": "此项为空时表达不启用这个方法。此方法会调用一次LLM请求。让 LLM 在某一句话中插入一个可以用正则表达式分隔的标记,来实现LLM基于情感分段。如: `请基于情感对以下文本进行分段, 并在两段之间添加`<seg>`以便我用正则匹配。` 然后将下面的正则表达式更换为`.+?<seg>`。",
},
"regex": {
"description": "正则表达式",
"type": "string",
@@ -761,7 +773,7 @@ CONFIG_METADATA_2 = {
"description": "管理员 ID",
"type": "list",
"items": {"type": "string"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/sid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {
"description": "HTTP 代理",
+118
View File
@@ -0,0 +1,118 @@
import uuid
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation
class ConversationManager():
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
def __init__(self, db_helper: BaseDatabase):
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
'''新建对话,并将当前会话的对话转移到新对话'''
conversation_id = str(uuid.uuid4())
self.db.new_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
'''切换会话的对话'''
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(
user_id=unified_msg_origin,
cid=conversation_id
)
del self.session_conversations[unified_msg_origin]
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
'''获取会话当前的对话 ID'''
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
'''获取会话的对话'''
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
'''获取会话的所有对话'''
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
'''更新会话的对话'''
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
cid=conversation_id,
history=json.dumps(history)
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
'''更新会话的对话标题'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
user_id=unified_msg_origin,
cid=conversation_id,
title=title
)
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
'''更新会话的对话 Persona ID'''
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
user_id=unified_msg_origin,
cid=conversation_id,
persona_id=persona_id
)
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
contexts = []
temp_contexts = []
for record in history:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
+4 -1
View File
@@ -18,7 +18,7 @@ from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
@@ -43,12 +43,15 @@ class AstrBotCoreLifecycle:
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.conversation_manager = ConversationManager(self.db)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
+20 -10
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -79,25 +79,35 @@ class BaseDatabase(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
'''通过 user_id 和 cid 获取 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
def new_conversation(self, user_id: str, cid: str):
'''新建 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
def get_conversations(self, user_id: str) -> List[Conversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
def delete_conversation(self, user_id: str, cid: str):
'''删除 Conversation'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_title(self, user_id: str, cid: str, title: str):
'''更新 Conversation 标题'''
raise NotImplementedError
@abc.abstractmethod
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
'''更新 Conversation Persona ID'''
raise NotImplementedError
+11 -6
View File
@@ -33,16 +33,16 @@ class Stats():
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
'''LLM 聊天时持久化的信息'''
@dataclass
class LLMHistory():
'''LLM 聊天时持久化的信息'''
provider_type: str
session_id: str
content: str
@dataclass
class ATRIVision():
'''Deprecated'''
id: str
url_or_path: str
caption: str
@@ -53,13 +53,18 @@ class ATRIVision():
sender_nickname: str
timestamp: int = -1
@dataclass
class WebChatConversation():
class Conversation():
'''LLM 对话存储
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
'''
user_id: str
cid: str
history: str = ""
'''字符串格式的列表。'''
created_at: int = 0
updated_at: int = 0
title: str = ""
persona_id: str = ""
+61 -12
View File
@@ -6,7 +6,7 @@ from astrbot.core.db.po import (
Stats,
LLMHistory,
ATRIVision,
WebChatConversation
Conversation
)
from . import BaseDatabase
from typing import Tuple
@@ -25,6 +25,37 @@ class SQLiteDatabase(BaseDatabase):
c = self.conn.cursor()
c.executescript(sql)
self.conn.commit()
# 检查 webchat_conversation 的 title 字段是否存在
c.execute(
'''
PRAGMA table_info(webchat_conversation)
'''
)
res = c.fetchall()
has_title = False
has_persona_id = False
for row in res:
if row[1] == "title":
has_title = True
if row[1] == "persona_id":
has_persona_id = True
if not has_title:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
'''
)
self.conn.commit()
if not has_persona_id:
c.execute(
'''
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
'''
)
self.conn.commit()
c.close()
def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
@@ -202,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -216,9 +247,9 @@ class SQLiteDatabase(BaseDatabase):
res = c.fetchone()
c.close()
return WebChatConversation(*res)
return Conversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
def new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
@@ -228,7 +259,7 @@ class SQLiteDatabase(BaseDatabase):
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
def get_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -236,7 +267,7 @@ class SQLiteDatabase(BaseDatabase):
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
@@ -247,24 +278,42 @@ class SQLiteDatabase(BaseDatabase):
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
title = row[3]
persona_id = row[4]
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
def update_conversation(self, user_id: str, cid: str, history: str):
'''更新对话,并且同时更新时间'''
updated_at = int(time.time())
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
''', (history, updated_at, user_id, cid)
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
''', (title, user_id, cid)
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
''', (persona_id, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
def delete_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
ts = int(time.time())
keywords = ",".join(vision.keywords)
+3 -1
View File
@@ -42,5 +42,7 @@ CREATE TABLE IF NOT EXISTS webchat_conversation(
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
@@ -2,6 +2,7 @@
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
@@ -10,7 +11,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
@@ -24,6 +25,8 @@ class LLMRequestSubStage(Stage):
if self.provider_wake_prefix.startswith(bwp):
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
self.conv_manager = ctx.plugin_manager.context.conversation_manager
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
@@ -46,10 +49,17 @@ class LLMRequestSubStage(Stage):
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
if not conversation_id:
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
req.session_id = conversation_id
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt and not req.image_urls:
return
@@ -62,9 +72,12 @@ class LLMRequestSubStage(Stage):
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
try:
logger.debug(f"提供商请求 Payload: {req.__dict__}")
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
@@ -77,6 +90,9 @@ class LLMRequestSubStage(Stage):
except BaseException:
logger.error(traceback.format_exc())
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
@@ -117,4 +133,24 @@ class LLMRequestSubStage(Stage):
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
return
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {
"role": "user",
"content": req.prompt
}
contexts.append(new_record)
contexts.append({
"role": "assistant",
"content": llm_response.completion_text
})
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.session_id,
history=contexts_to_save
)
+17 -2
View File
@@ -23,6 +23,7 @@ class ResultDecorateStage:
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.seg_prompt = ctx.astrbot_config['platform_settings']['segmented_reply']['seg_prompt']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
@@ -49,12 +50,26 @@ class ResultDecorateStage:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
if self.seg_prompt:
try:
llm_resp = await self.ctx.plugin_manager.context.get_using_provider().text_chat(
prompt=f"{self.seg_prompt}\n{comp.text}",
)
comp.text = llm_resp.completion_text
except BaseException as e:
traceback.print_exc()
logger.error("使用 LLM 分段回复失败: " + str(e))
new_chain.append(comp)
continue
split_response = re.findall(self.regex, comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
new_chain.append(Plain(seg))
if seg:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
+18 -3
View File
@@ -2,11 +2,11 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
@register_stage
class WakingCheckStage(Stage):
@@ -21,6 +21,9 @@ class WakingCheckStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
)
async def process(
self, event: AstrMessageEvent
@@ -77,7 +80,9 @@ class WakingCheckStage(Stage):
# filter 需要满足 AND 的逻辑关系
passed = True
child_command_handler_md = None
permission_not_pass = False
if len(handler.event_filters) == 0:
# 不可能有这种情况, 也不允许有这种情况
continue
@@ -94,6 +99,9 @@ class WakingCheckStage(Stage):
else:
handler = child_command_handler_md # handler 覆盖
break
elif isinstance(filter, PermissionTypeFilter):
if not filter.filter(event, self.ctx.astrbot_config):
permission_not_pass = True
else:
if not filter.filter(event, self.ctx.astrbot_config):
passed = False
@@ -111,6 +119,13 @@ class WakingCheckStage(Stage):
break
if passed:
if permission_not_pass:
if self.no_permission_reply:
await event.send(MessageChain().message(f"ID {event.get_sender_id()} 权限不足"))
event.stop_event()
return
is_wake = True
event.is_wake = True
@@ -102,7 +102,7 @@ class AiocqhttpAdapter(Platform):
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
raise FileNotFoundError(f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot")
m['data'] = {
"file": ret['file'],
@@ -122,7 +122,10 @@ class AiocqhttpAdapter(Platform):
def run(self) -> Awaitable[Any]:
if not self.host or not self.port:
return
logger.warning("aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199")
self.host = "0.0.0.0"
self.port = 6199
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
+7 -1
View File
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from astrbot.core.db.po import Conversation
class ProviderType(enum.Enum):
@@ -38,10 +39,15 @@ class ProviderRequest():
'''上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
'''
system_prompt: str = ""
'''系统提示词'''
conversation: Conversation = None
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
def __str__(self):
return self.__repr__()
@dataclass
class LLMResponse:
+10 -1
View File
@@ -66,7 +66,16 @@ class ProviderManager():
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed=""
)
self.personas.append(self.selected_default_persona)
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
+22 -41
View File
@@ -8,6 +8,8 @@ from typing import TypedDict
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
@@ -15,8 +17,8 @@ class Personality(TypedDict):
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict]
_mood_imitation_dialogs_processed: str
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
@@ -60,25 +62,11 @@ class Provider(AbstractProvider):
) -> None:
super().__init__(provider_config)
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
'''维护了当前的使用的 persona,即人格。可能为 None'''
self.db_helper = db_helper
'''用于持久化的数据库操作对象。'''
if persistant_history:
# 读取历史记录
try:
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError()
@@ -96,22 +84,6 @@ class Provider(AbstractProvider):
'''获得支持的模型列表'''
raise NotImplementedError()
@abc.abstractmethod
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
'''获取人类可读的上下文
page 从 1 开始
Example:
["User: 你好", "Assistant: 你好!"]
Return:
contexts: List[str]: 上下文列表
total_pages: int: 总页数
'''
raise NotImplementedError()
@abc.abstractmethod
async def text_chat(self,
prompt: str,
@@ -125,26 +97,35 @@ class Provider(AbstractProvider):
Args:
prompt: 提示词
session_id: 会话 ID
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
kwargs: 其他参数
Notes:
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
'''
raise NotImplementedError()
@abc.abstractmethod
async def forget(self, session_id: str) -> bool:
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
async def pop_record(self, context: List):
'''
弹出 context 第一条非系统提示词对话记录
'''
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
else:
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
class STTProvider(AbstractProvider):
+8 -86
View File
@@ -73,60 +73,10 @@ class ProviderGoogleGenAI(Provider):
api_base=provider_config.get("api_base", None)
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
return await self.client.models_list()
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
@@ -197,33 +147,32 @@ class ProviderGoogleGenAI(Provider):
llm_response.completion_text = llm_response.completion_text.strip()
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str,
session_id: str = None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
**model_config
}
llm_response = None
try:
@@ -234,7 +183,7 @@ class ProviderGoogleGenAI(Provider):
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
self.pop_record(session_id)
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
@@ -254,34 +203,7 @@ class ProviderGoogleGenAI(Provider):
raise e
if kwargs.get("persist", True) and llm_response:
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
def get_current_key(self) -> str:
return self.client.api_key
@@ -57,20 +57,13 @@ class LLMTunerModelLoader(Provider):
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
contexts: List = [],
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
new_record = {"role": "user", "content": prompt}
if not contexts:
query_context = [
*self.session_memory[session_id],
new_record,
]
system_prompt = self.curr_personality["prompt"]
else:
query_context = [*contexts, new_record]
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
@@ -96,34 +89,8 @@ class LLMTunerModelLoader(Provider):
responses = await self.model.achat(**conf)
llm_response = LLMResponse("assistant", responses[-1].response_text)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
async def forget(self, session_id):
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
async def get_current_key(self):
return "none"
@@ -132,28 +99,4 @@ class LLMTunerModelLoader(Provider):
pass
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return [self.get_model()]
+8 -76
View File
@@ -48,30 +48,6 @@ class ProviderOpenAIOfficial(Provider):
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
try:
@@ -84,22 +60,6 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def pop_record(self, session_id: str):
'''
弹出最早的一个对话
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) < 2:
return
try:
self.session_memory[session_id].pop(0)
self.session_memory[session_id].pop(0)
except IndexError:
pass
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
tool_list = tools.get_func_desc_openai_style()
@@ -141,29 +101,29 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat(
self,
prompt: str,
session_id: str,
session_id: str=None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
model_config = self.provider_config.get("model_config", {})
model_config['model'] = self.get_model()
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
**model_config
}
llm_response = None
try:
@@ -214,9 +174,6 @@ class ProviderOpenAIOfficial(Provider):
raise e
if kwargs.get("persist", True) and llm_response:
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def _remove_image_from_context(self, contexts: List):
@@ -244,32 +201,7 @@ class ProviderOpenAIOfficial(Provider):
context['content'] = new_content
new_contexts.append(context)
return new_contexts
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
return True
def get_current_key(self) -> str:
return self.client.api_key
@@ -22,24 +22,21 @@ class ProviderZhipu(ProviderOpenAIOfficial):
async def text_chat(
self,
prompt: str,
session_id: str,
session_id: str = None,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
contexts=[],
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
model = self.get_model()
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
@@ -62,7 +59,6 @@ class ProviderZhipu(ProviderOpenAIOfficial):
}
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
+3
View File
@@ -16,6 +16,7 @@ from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
from astrbot.core.conversation_mgr import ConversationManager
class Context:
'''
@@ -44,6 +45,7 @@ class Context:
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
@@ -52,6 +54,7 @@ class Context:
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
self.conversation_manager = conversation_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
+5 -1
View File
@@ -46,7 +46,11 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
if not event.is_wake_up():
return False
message_str = event.get_message_str().strip()
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
# 分割为列表(每个参数之间可能会有多个空格)
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
+11 -4
View File
@@ -40,17 +40,24 @@ class CommandGroupFilter(HandlerFilter):
if not event.is_wake_up():
return False, None
message_str = event.get_message_str().strip()
if event.get_extra("parsing_command"):
message_str = event.get_extra("parsing_command").strip()
else:
message_str = event.get_message_str().strip()
ls = re.split(r"\s+", message_str)
if ls[0] != self.group_name:
return False, None
# 改写 message_str
ls = ls[1:]
event.message_str = " ".join(ls)
event.message_str = event.message_str.strip()
# event.message_str = " ".join(ls)
# event.message_str = event.message_str.strip()
parsing_command = " ".join(ls)
parsing_command = parsing_command.strip()
event.set_extra("parsing_command", parsing_command)
if event.message_str == "":
if parsing_command == "":
# 当前还是指令组
tree = self.group_name + "\n" + self.print_cmd_tree(self.sub_command_filters)
raise ValueError(f"指令组 {self.group_name} 未填写完全。这个指令组下有如下指令:\n"+tree)
+3 -2
View File
@@ -19,7 +19,8 @@ class PermissionTypeFilter(HandlerFilter):
'''
if self.permission_type == PermissionType.ADMIN:
if not event.is_admin():
event.stop_event()
raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
# event.stop_event()
# raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限操作管理员指令。")
return False
return True
+1 -1
View File
@@ -64,7 +64,7 @@ def register_command(command_name: str = None, *args):
return decorator
def register_command_group(command_group_name: str = None, desc: str = "", *args):
def register_command_group(command_group_name: str = None, *args):
'''注册一个 CommandGroup
'''
@@ -25,6 +25,10 @@ class ParameterValidationMixin:
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, int):
result[param_name] = int(params[i])
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
+8 -8
View File
@@ -121,7 +121,7 @@ class ChatRoute(Route):
}))
# 持久化
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
@@ -136,7 +136,7 @@ class ChatRoute(Route):
if audio_url:
new_his['audio_url'] = audio_url
history.append(new_his)
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
self.db.update_conversation(username, conversation_id, history=json.dumps(history))
return Response().ok().__dict__
@@ -168,7 +168,7 @@ class ChatRoute(Route):
continue
yield result_text + '\n'
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
conversation = self.db.get_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
@@ -178,7 +178,7 @@ class ChatRoute(Route):
'type': 'bot',
'message': result_text
})
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
self.db.update_conversation(username, cid, history=json.dumps(history))
await asyncio.sleep(0.5)
except BaseException as e:
@@ -204,20 +204,20 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
self.db.delete_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
self.db.new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
conversations = self.db.get_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
@@ -226,7 +226,7 @@ class ChatRoute(Route):
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
self.curr_user_cid[username] = conversation_id
+15
View File
@@ -0,0 +1,15 @@
# What's Changed
> 由于重写了会话记录部分,更新此版本后,将会造成之前的对话记录清空(但没有被删除)。
> 关于更好的对话管理,如果有任何报错或者优化建议,请直接提交 issue~
1. 更好的对话管理,支持 /ls, /del, /new, /switch, /rename 指令来操作对话。
2. 人格情境跟随对话。每个对话支持独立设置人格情境,只需要 /persona 指令切换即可。
3. 支持使用 LLM 辅助分段回复 #338
4. 优化 aiocqhttp 适配器对用户非法输入的处理
5. 优化插件页面
6. 修复权限过滤算子导致的问题 #350
7. 修复级联指令组时出现载入错误的问题 #366
8. 修复代码执行器的一个typo by @eltociear
9. 修复指令组情况下可能造成多指令出触发的问题
10. 添加屏蔽无权限指令回复的功能 #361
@@ -1,7 +1,8 @@
<script setup lang="ts">
const props = defineProps({
title: String,
link: String
link: String,
logo: String
});
const open = (link: string | undefined) => {
@@ -13,6 +14,7 @@ const open = (link: string | undefined) => {
<v-card variant="outlined" elevation="0" class="withbg">
<v-card-item style="padding: 10px 14px">
<div class="d-sm-flex align-center justify-space-between">
<img v-if="logo" :src="logo" alt="logo" style="width: 40px; height: 40px; margin-right: 8px;">
<v-card-title style="font-size: 17px;">{{ props.title }}</v-card-title>
<v-spacer></v-spacer>
<v-btn variant="plain" @click="open(props.link)">仓库</v-btn>
+53 -20
View File
@@ -9,29 +9,35 @@ import axios from 'axios';
<template>
<v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README" title="💡提示"
type="info" variant="tonal">
</v-alert>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
<h3>🧩 已安装的插件</h3>
</div>
</v-col>
<v-col cols="12" md="6" lg="4" v-for="extension in extension_data.data">
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" style="margin-bottom: 4px;">
<v-col cols="12" md="6" lg="3" v-for="extension in extension_data.data">
<ExtensionCard :key="extension.name" :title="extension.name" :link="extension.repo" :logo="extension?.logo"
style="margin-bottom: 4px;">
<p style="min-height: 130px; max-height: 130px; overflow: none;">{{ extension.desc }}</p>
<div class="d-flex align-center gap-2">
<v-icon>mdi-account</v-icon>
<span>{{ extension.author }}</span>
<v-spacer></v-spacer>
<div v-if="!extension.reserved">
<v-btn variant="plain" @click="openExtensionConfig(extension.name)">配置</v-btn>
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="openExtensionConfig(extension.name)">配置</v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="updateExtension(extension.name)">更新</v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="uninstallExtension(extension.name)">卸载</v-btn>
</div>
<!-- <span v-else>保留插件</span> -->
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
<v-btn variant="plain" v-else @click="pluginOn(extension)"></v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-if="extension.activated"
@click="pluginOff(extension)"></v-btn>
<v-btn class="text-none mr-2" size="small" text="Read" variant="flat" border v-else
@click="pluginOn(extension)">启用</v-btn>
</div>
</ExtensionCard>
</v-col>
@@ -39,28 +45,35 @@ import axios from 'axios';
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
<div style="display: flex; align-items: center;">
<h3>🧩 插件市场</h3>
<small style="margin-left: 16px;">如无法显示请打开 <a href="https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json">链接</a> 复制想安装插件对应的 `repo` 链接然后点击右下角 + 号安装或打开链接下载压缩包安装</small>
<small style="margin-left: 16px;">如无法显示请打开 <a
href="https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json">链接</a> 复制想安装插件对应的 `repo`
链接然后点击右下角 + 号安装或打开链接下载压缩包安装</small>
</div>
</div>
</v-col>
<v-col cols="12" md="6" lg="4" v-for="plugin in pluginMarketData">
<v-col cols="12" md="12" v-if="announcement">
<v-banner color="success" lines="one" :text="announcement" :stacked="false" >
</v-banner>
</v-col>
<v-col cols="12" md="6" lg="3" v-for="plugin in pluginMarketData">
<ExtensionCard :key="plugin.name" :title="plugin.name" :link="plugin.repo" style="margin-bottom: 4px;">
<p style="min-height: 130px; max-height: 130px; overflow: hidden;">{{ plugin.desc }}</p>
<div class="d-flex align-center gap-2">
<v-icon>mdi-account</v-icon>
<span>{{ plugin.author }}</span>
<v-spacer></v-spacer>
<v-btn v-if="!plugin.installed" variant="plain"
<v-btn v-if="!plugin.installed" class="text-none mr-2" size="small" text="Read" variant="flat" border
@click="extension_url = plugin.repo; newExtension()">安装</v-btn>
<v-btn v-else variant="plain" disabled>已安装</v-btn>
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border disabled>已安装</v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col style="margin-bottom: 16px;" cols="12" md="12">
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
</v-col>
@@ -75,7 +88,8 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-container>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata"
:iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<p v-else>这个插件没有配置</p>
</v-container>
</v-card-text>
@@ -211,12 +225,19 @@ export default {
title: "加载中...",
statusCode: 0, // 0: loading, 1: success, 2: error,
result: ""
}
},
announcement: ""
}
},
mounted() {
this.getExtensions();
this.fetchPluginCollection();
axios.get('https://api.soulter.top/astrbot-announcement-plugin-market').then((res) => {
let data = res.data.data;
this.announcement = data.text;
});
},
methods: {
toast(message, success) {
@@ -386,7 +407,7 @@ export default {
});
},
updateConfig() {
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
axios.post('/api/config/plugin/update?plugin_name=' + this.curr_namespace, this.extension_config.config).then((res) => {
if (res.data.status === "ok") {
this.toast(res.data.message, "success");
this.$refs.wfr.check();
@@ -422,11 +443,23 @@ export default {
}
for (let i = 0; i < this.pluginMarketData.length; i++) {
for (let j = 0; j < this.extension_data.data.length; j++) {
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo) {
if (this.pluginMarketData[i].repo === this.extension_data.data[j].repo || this.pluginMarketData[i].name === this.extension_data.data[j].name) {
this.pluginMarketData[i].installed = true;
}
}
}
// 将已安装的插件移动到最后面
let installed = [];
let notInstalled = [];
for (let i = 0; i < this.pluginMarketData.length; i++) {
if (this.pluginMarketData[i].installed) {
installed.push(this.pluginMarketData[i]);
} else {
notInstalled.push(this.pluginMarketData[i]);
}
}
this.pluginMarketData = notInstalled.concat(installed);
}
},
}
+171 -40
View File
@@ -1,6 +1,7 @@
import aiohttp
import datetime
import builtins
import json
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
@@ -59,24 +60,29 @@ AstrBot 指令:
/plugin: 查看插件、插件帮助
/t2i: 开关文本转图片
/sid: 获取会话 ID
/op <admin_id>: 授权管理员
/deop <admin_id>: 取消管理员
/wl <sid>: 添加白名单
/dwl <sid>: 删除白名单
/dashboard_update: 更新管理面板
/alter_cmd: 设置指令权限
/op <admin_id>: 授权管理员(op)
/deop <admin_id>: 取消管理员(op)
/wl <sid>: 添加白名单(op)
/dwl <sid>: 删除白名单(op)
/dashboard_update: 更新管理面板(op)
/alter_cmd: 设置指令权限(op)
[大模型]
/provider: 大模型提供商
/model: 模型列表
/key: API Key
/reset: 重置 LLM 会
/history: 对话记录
/persona: 人格情景
/ls: 对话列表
/new: 创建新对
/switch: 切换对话
/rename: 重命名对话
/del: 删除当前会话对话(op)
/reset: 重置 LLM 会话(op)
/history: 当前对话的对话记录
/persona: 人格情景(op)
/tool ls: 函数工具
/key: API Key(op)
[其他]
/set <变量名> <值>: 为会话定义一个变量。适用于 Dify 工作流输入。
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除会话的变量。
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
@@ -273,7 +279,16 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
await self.context.get_using_provider().forget(message.session_id)
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
await self.context.conversation_manager.update_conversation(
message.unified_msg_origin, cid, []
)
ret = "清除会话 LLM 聊天历史成功。"
if self.ltm:
cnt = await self.ltm.remove_session(event=message)
@@ -329,20 +344,29 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
'''查看对话记录'''
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
size_per_page = 3
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
size_per_page = 6
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not session_curr_cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
contexts, total_pages = await self.context.conversation_manager.get_human_readable_context(
message.unified_msg_origin, session_curr_cid, page, size_per_page
)
history = ""
for context in contexts:
if len(context) > 150:
context = context[:150] + "..."
history += f"{context}\n"
ret = f"""历史记录:
ret = f"""当前对话历史记录:
{history}
{page} 页 | 共 {total_pages}
@@ -351,6 +375,88 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1):
'''查看对话列表'''
size_per_page = 6
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
total_pages = len(conversations) // size_per_page
if len(conversations) % size_per_page != 0:
total_pages += 1
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
ret = "对话列表:\n---\n"
global_index = (page - 1) * size_per_page + 1
_titles = {}
for conv in conversations:
persona_id = conv.persona_id
if not persona_id and not persona_id == "[%None]":
persona_id = self.context.provider_manager.selected_default_persona['name']
title = conv.title if conv.title else "新对话"
_titles[conv.cid] = title
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
global_index += 1
ret += "---\n"
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if curr_cid:
ret += f"\n当前对话: {_titles[curr_cid]}({curr_cid[:4]})"
else:
ret += "\n当前对话: 无"
unique_session = self.context.get_config()['platform_settings']['unique_session']
if unique_session:
ret += "\n会话隔离粒度: 个人"
else:
ret += "\n会话隔离粒度: 群聊"
ret += f"\n{page} 页 | 共 {total_pages}"
ret += "\n*输入 /ls 2 跳转到第 2 页"
message.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent):
'''创建新对话'''
cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin)
message.set_result(MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"))
@filter.command("switch")
async def switch_conv(self, message: AstrMessageEvent, index: int):
'''通过 /ls 前面的序号切换对话'''
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
if index > len(conversations) or index < 1:
message.set_result(MessageEventResult().message("对话序号错误,请使用 /ls 查看"))
else:
conversation = conversations[index-1]
title = conversation.title if conversation.title else "新对话"
await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid)
message.set_result(MessageEventResult().message(f"切换到对话: {title}({conversation.cid[:4]})。"))
@filter.command("rename")
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
'''重命名对话'''
await self.context.conversation_manager.update_conversation_title(message.unified_msg_origin, new_name)
message.set_result(MessageEventResult().message("重命名对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent):
'''删除当前对话'''
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
if not session_curr_cid:
message.set_result(MessageEventResult().message("当前未处于对话状态,请 /switch 切换或者 /new 创建。"))
return
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
message.set_result(MessageEventResult().message("删除当前对话成功。"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int=None):
@@ -387,28 +493,32 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
if not self.context.get_using_provider():
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
return
l = message.message_str.split(" ")
curr_persona_name = ""
if self.context.get_using_provider().curr_personality:
curr_persona_name = self.context.get_using_provider().curr_personality['name']
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
curr_cid_title = ""
if cid:
conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid)
if not conversation.persona_id and not conversation.persona_id == "[%None]":
curr_persona_name = self.context.provider_manager.selected_default_persona['name']
else:
curr_persona_name = conversation.persona_id
curr_cid_title = conversation.title if conversation.title else "新对话"
curr_cid_title += f"({cid[:4]})"
if len(l) == 1:
message.set_result(
MessageEventResult().message(f"""[Persona]
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
- 人格情景列表: `/persona list`
- 人格情景详细信息: `/persona view 人格`
- 设置人格情景: `/persona 人格`
- 人格情景详细信息: `/persona view 人格`
- 取消人格: `/persona unset`
当前人格情景: {curr_persona_name}
默认人格情景: {self.context.provider_manager.selected_default_persona['name']}
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
配置人格情景请前往管理面板-配置页
""").use_t2i(False))
@@ -433,7 +543,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
msg = f"人格{ps}不存在"
message.set_result(MessageEventResult().message(msg))
elif l[1] == "unset":
self.context.get_using_provider().curr_personality = None
if not cid:
message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。"))
return
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]")
message.set_result(MessageEventResult().message("取消人格成功。"))
else:
ps = "".join(l[1:]).strip()
@@ -441,7 +554,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
lambda persona: persona['name'] == ps,
self.context.provider_manager.personas
), None):
self.context.get_using_provider().curr_personality = persona
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps)
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
else:
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
@@ -513,7 +626,17 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
return
try:
session_provider_context = provider.session_memory.get(event.session_id)
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin)
if not session_curr_cid:
logger.error("当前未处于对话状态,无法主动回复,请使用 /switch 切换或者 /new 创建。")
return
conv = await self.context.conversation_manager.get_conversation(
event.unified_msg_origin,
session_curr_cid
)
history = json.loads(conv.history)
prompt = self.ltm.ar_prompt
if not prompt:
@@ -523,7 +646,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
prompt=prompt,
func_tool_manager=self.context.get_llm_tool_manager(),
session_id=event.session_id,
contexts=session_provider_context if session_provider_context else []
contexts=history if history else []
)
except BaseException as e:
logger.error(f"主动回复失败: {e}")
@@ -545,14 +668,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
if persona := provider.curr_personality:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
if req.conversation:
persona_id = req.conversation.persona_id
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
persona_id = self.context.provider_manager.selected_default_persona['name']
persona = next(builtins.filter(
lambda persona: persona['name'] == persona_id,
self.context.provider_manager.personas
), None)
if persona:
if prompt := persona['prompt']:
req.system_prompt += prompt
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
if self.ltm:
try:
+2 -2
View File
@@ -358,7 +358,7 @@ class Main(star.Star):
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
break
@@ -393,4 +393,4 @@ class Main(star.Star):
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()
await container.delete()