feat: 更好的对话管理
This commit is contained in:
@@ -0,0 +1,118 @@
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from astrbot.core import sp
|
||||
from typing import Dict, List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
class ConversationManager():
|
||||
'''负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。'''
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
asyncio.create_task(self._periodic_save())
|
||||
|
||||
async def _periodic_save(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
'''新建对话,并将当前会话的对话转移到新对话'''
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id
|
||||
)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
return conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
'''切换会话的对话'''
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def delete_conversation(self, unified_msg_origin: str, conversation_id: str=None):
|
||||
'''删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id
|
||||
)
|
||||
del self.session_conversations[unified_msg_origin]
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
'''获取会话当前的对话 ID'''
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
|
||||
async def get_conversation(self, unified_msg_origin: str, conversation_id: str) -> Conversation:
|
||||
'''获取会话的对话'''
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
'''获取会话的所有对话'''
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
|
||||
async def update_conversation(self, unified_msg_origin: str, conversation_id: str, history: List[Dict]):
|
||||
'''更新会话的对话'''
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
history=json.dumps(history)
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
'''更新会话的对话标题'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
title=title
|
||||
)
|
||||
|
||||
async def update_conversation_persona_id(self, unified_msg_origin: str, persona_id: str):
|
||||
'''更新会话的对话 Persona ID'''
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
user_id=unified_msg_origin,
|
||||
cid=conversation_id,
|
||||
persona_id=persona_id
|
||||
)
|
||||
|
||||
async def get_human_readable_context(self, unified_msg_origin, conversation_id, page=1, page_size=10):
|
||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||
history = json.loads(conversation.history)
|
||||
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in history:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
@@ -18,7 +18,7 @@ from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
@@ -43,12 +43,15 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager
|
||||
)
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
+20
-10
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
@@ -79,25 +79,35 @@ class BaseDatabase(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
'''通过 user_id 和 cid 获取 WebChatConversation'''
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
'''通过 user_id 和 cid 获取 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 WebChatConversation'''
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
|
||||
def get_conversations(self, user_id: str) -> List[Conversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 WebChatConversation'''
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
'''删除 WebChatConversation'''
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
'''删除 Conversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
'''更新 Conversation 标题'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
'''更新 Conversation Persona ID'''
|
||||
raise NotImplementedError
|
||||
+11
-6
@@ -33,16 +33,16 @@ class Stats():
|
||||
command: List[Command] = field(default_factory=list)
|
||||
llm: List[Provider] = field(default_factory=list)
|
||||
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
|
||||
@dataclass
|
||||
class LLMHistory():
|
||||
'''LLM 聊天时持久化的信息'''
|
||||
provider_type: str
|
||||
session_id: str
|
||||
content: str
|
||||
|
||||
@dataclass
|
||||
class ATRIVision():
|
||||
'''Deprecated'''
|
||||
id: str
|
||||
url_or_path: str
|
||||
caption: str
|
||||
@@ -53,13 +53,18 @@ class ATRIVision():
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebChatConversation():
|
||||
class Conversation():
|
||||
'''LLM 对话存储
|
||||
|
||||
对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。
|
||||
对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。
|
||||
'''
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
'''字符串格式的列表。'''
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
title: str = ""
|
||||
persona_id: str = ""
|
||||
+61
-12
@@ -6,7 +6,7 @@ from astrbot.core.db.po import (
|
||||
Stats,
|
||||
LLMHistory,
|
||||
ATRIVision,
|
||||
WebChatConversation
|
||||
Conversation
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
@@ -25,6 +25,37 @@ class SQLiteDatabase(BaseDatabase):
|
||||
c = self.conn.cursor()
|
||||
c.executescript(sql)
|
||||
self.conn.commit()
|
||||
|
||||
# 检查 webchat_conversation 的 title 字段是否存在
|
||||
c.execute(
|
||||
'''
|
||||
PRAGMA table_info(webchat_conversation)
|
||||
'''
|
||||
)
|
||||
res = c.fetchall()
|
||||
has_title = False
|
||||
has_persona_id = False
|
||||
for row in res:
|
||||
if row[1] == "title":
|
||||
has_title = True
|
||||
if row[1] == "persona_id":
|
||||
has_persona_id = True
|
||||
if not has_title:
|
||||
c.execute(
|
||||
'''
|
||||
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
||||
'''
|
||||
)
|
||||
self.conn.commit()
|
||||
if not has_persona_id:
|
||||
c.execute(
|
||||
'''
|
||||
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
||||
'''
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
c.close()
|
||||
|
||||
def _get_conn(self, db_path: str) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -202,7 +233,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -216,9 +247,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return WebChatConversation(*res)
|
||||
return Conversation(*res)
|
||||
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
@@ -228,7 +259,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
''', (user_id, cid, history, updated_at, created_at)
|
||||
)
|
||||
|
||||
def get_webchat_conversations(self, user_id: str) -> Tuple:
|
||||
def get_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -236,7 +267,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
''', (user_id,)
|
||||
)
|
||||
|
||||
@@ -247,24 +278,42 @@ class SQLiteDatabase(BaseDatabase):
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
|
||||
title = row[3]
|
||||
persona_id = row[4]
|
||||
conversations.append(Conversation("", cid, '[]', created_at, updated_at, title, persona_id))
|
||||
return conversations
|
||||
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新对话,并且同时更新时间'''
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, user_id, cid)
|
||||
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, updated_at, user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
''', (title, user_id, cid)
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
''', (persona_id, user_id, cid)
|
||||
)
|
||||
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
ts = int(time.time())
|
||||
keywords = ",".join(vision.keywords)
|
||||
|
||||
@@ -42,5 +42,7 @@ CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
cid TEXT,
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
updated_at INTEGER,
|
||||
title TEXT,
|
||||
persona_id TEXT
|
||||
);
|
||||
@@ -2,6 +2,7 @@
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
import json
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
@@ -10,7 +11,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Result
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -24,6 +25,8 @@ class LLMRequestSubStage(Stage):
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
@@ -46,10 +49,17 @@ class LLMRequestSubStage(Stage):
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
req.session_id = event.session_id
|
||||
|
||||
# 获取对话上下文
|
||||
conversation_id = await self.conv_manager.get_curr_conversation_id(event.unified_msg_origin)
|
||||
if not conversation_id:
|
||||
conversation_id = await self.conv_manager.new_conversation(event.unified_msg_origin)
|
||||
req.session_id = conversation_id
|
||||
conversation = await self.conv_manager.get_conversation(event.unified_msg_origin, conversation_id)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
@@ -62,6 +72,9 @@ class LLMRequestSubStage(Stage):
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
try:
|
||||
logger.debug(f"提供商请求 Payload: {req.__dict__}")
|
||||
@@ -77,6 +90,9 @@ class LLMRequestSubStage(Stage):
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, llm_response)
|
||||
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
@@ -117,4 +133,24 @@ class LLMRequestSubStage(Stage):
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
|
||||
return
|
||||
return
|
||||
|
||||
async def _save_to_history(self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts
|
||||
new_record = {
|
||||
"role": "user",
|
||||
"content": req.prompt
|
||||
}
|
||||
contexts.append(new_record)
|
||||
contexts.append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.session_id,
|
||||
history=contexts_to_save
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from astrbot.core.db.po import Conversation
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
@@ -38,9 +39,9 @@ class ProviderRequest():
|
||||
'''上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
'''
|
||||
|
||||
system_prompt: str = ""
|
||||
'''系统提示词'''
|
||||
conversation: Conversation = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -66,7 +66,14 @@ class ProviderManager():
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
|
||||
if not self.selected_default_persona:
|
||||
self.selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
)
|
||||
self.personas.append(self.selected_default_persona)
|
||||
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import TypedDict
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
@@ -60,25 +62,11 @@ class Provider(AbstractProvider):
|
||||
) -> None:
|
||||
super().__init__(provider_config)
|
||||
|
||||
self.session_memory = defaultdict(list)
|
||||
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
|
||||
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
self.curr_personality: Personality = default_persona
|
||||
'''维护了当前的使用的 persona,即人格。可能为 None'''
|
||||
|
||||
self.db_helper = db_helper
|
||||
'''用于持久化的数据库操作对象。'''
|
||||
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
try:
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
@@ -96,22 +84,6 @@ class Provider(AbstractProvider):
|
||||
'''获得支持的模型列表'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_human_readable_context(self, session_id: str, page: int, page_size: int):
|
||||
'''获取人类可读的上下文
|
||||
|
||||
page 从 1 开始
|
||||
|
||||
Example:
|
||||
|
||||
["User: 你好", "Assistant: 你好!"]
|
||||
|
||||
Return:
|
||||
contexts: List[str]: 上下文列表
|
||||
total_pages: int: 总页数
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
@@ -125,26 +97,35 @@ class Provider(AbstractProvider):
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
session_id: 会话 ID
|
||||
session_id: 会话 ID(此属性已经被废弃)
|
||||
image_urls: 图片 URL 列表
|
||||
tools: Function-calling 工具
|
||||
contexts: 上下文
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
|
||||
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
|
||||
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
async def pop_record(self, context: List):
|
||||
'''
|
||||
弹出 context 第一条非系统提示词对话记录
|
||||
'''
|
||||
poped = 0
|
||||
indexs_to_pop = []
|
||||
for idx, record in enumerate(context):
|
||||
if record["role"] == "system":
|
||||
continue
|
||||
else:
|
||||
indexs_to_pop.append(idx)
|
||||
poped += 1
|
||||
if poped == 2:
|
||||
break
|
||||
|
||||
for idx in reversed(indexs_to_pop):
|
||||
context.pop(idx)
|
||||
|
||||
|
||||
class STTProvider(AbstractProvider):
|
||||
|
||||
@@ -73,60 +73,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
api_base=provider_config.get("api_base", None)
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
弹出第一条记录
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
if tools:
|
||||
@@ -197,11 +147,10 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.completion_text = llm_response.completion_text.strip()
|
||||
return llm_response
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
@@ -210,10 +159,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
@@ -234,7 +180,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -254,34 +200,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
raise e
|
||||
|
||||
if kwargs.get("persist", True) and llm_response:
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
@@ -63,14 +63,7 @@ class LLMTunerModelLoader(Provider):
|
||||
) -> LLMResponse:
|
||||
system_prompt = ""
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
if not contexts:
|
||||
query_context = [
|
||||
*self.session_memory[session_id],
|
||||
new_record,
|
||||
]
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
else:
|
||||
query_context = [*contexts, new_record]
|
||||
query_context = [*contexts, new_record]
|
||||
|
||||
# 提取出系统提示
|
||||
system_idxs = []
|
||||
@@ -96,34 +89,8 @@ class LLMTunerModelLoader(Provider):
|
||||
responses = await self.model.achat(**conf)
|
||||
|
||||
llm_response = LLMResponse("assistant", responses[-1].response_text)
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
@@ -132,28 +99,4 @@ class LLMTunerModelLoader(Provider):
|
||||
pass
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record["role"] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record["role"] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
return [self.get_model()]
|
||||
@@ -48,30 +48,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
@@ -84,22 +60,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str):
|
||||
'''
|
||||
弹出最早的一个对话
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) < 2:
|
||||
return
|
||||
|
||||
try:
|
||||
self.session_memory[session_id].pop(0)
|
||||
self.session_memory[session_id].pop(0)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
@@ -141,7 +101,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
session_id: str=None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
@@ -149,11 +109,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
@@ -214,9 +170,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
raise e
|
||||
|
||||
if kwargs.get("persist", True) and llm_response:
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
@@ -244,32 +197,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
context['content'] = new_content
|
||||
new_contexts.append(context)
|
||||
return new_contexts
|
||||
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
|
||||
self.session_memory[session_id] = [*contexts_to_save, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
@@ -32,10 +32,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
# glm-4v-flash 只支持一张图片
|
||||
@@ -62,7 +59,6 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
}
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
|
||||
@@ -16,6 +16,7 @@ from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
|
||||
class Context:
|
||||
'''
|
||||
@@ -44,6 +45,7 @@ class Context:
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
@@ -52,6 +54,7 @@ class Context:
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
'''根据插件名获取插件的 Metadata'''
|
||||
|
||||
@@ -25,6 +25,10 @@ class ParameterValidationMixin:
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, int):
|
||||
result[param_name] = int(params[i])
|
||||
elif isinstance(param_type_or_default_val, float):
|
||||
result[param_name] = float(params[i])
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
|
||||
@@ -121,7 +121,7 @@ class ChatRoute(Route):
|
||||
}))
|
||||
|
||||
# 持久化
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
@@ -136,7 +136,7 @@ class ChatRoute(Route):
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
self.db.update_conversation(username, conversation_id, history=json.dumps(history))
|
||||
|
||||
return Response().ok().__dict__
|
||||
|
||||
@@ -168,7 +168,7 @@ class ChatRoute(Route):
|
||||
continue
|
||||
yield result_text + '\n'
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
|
||||
conversation = self.db.get_conversation_by_user_id(username, cid)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
@@ -178,7 +178,7 @@ class ChatRoute(Route):
|
||||
'type': 'bot',
|
||||
'message': result_text
|
||||
})
|
||||
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
|
||||
self.db.update_conversation(username, cid, history=json.dumps(history))
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except BaseException as e:
|
||||
@@ -204,20 +204,20 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
self.db.delete_webchat_conversation(username, conversation_id)
|
||||
self.db.delete_conversation(username, conversation_id)
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def new_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.webchat_new_conversation(username, conversation_id)
|
||||
self.db.new_conversation(username, conversation_id)
|
||||
return Response().ok(data={
|
||||
'conversation_id': conversation_id
|
||||
}).__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversations = self.db.get_webchat_conversations(username)
|
||||
conversations = self.db.get_conversations(username)
|
||||
return Response().ok(data=conversations).__dict__
|
||||
|
||||
async def get_conversation(self):
|
||||
@@ -226,7 +226,7 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
conversation = self.db.get_conversation_by_user_id(username, conversation_id)
|
||||
|
||||
self.curr_user_cid[username] = conversation_id
|
||||
|
||||
|
||||
+124
-40
@@ -1,6 +1,7 @@
|
||||
import aiohttp
|
||||
import datetime
|
||||
import builtins
|
||||
import json
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
@@ -59,24 +60,28 @@ AstrBot 指令:
|
||||
/plugin: 查看插件、插件帮助
|
||||
/t2i: 开关文本转图片
|
||||
/sid: 获取会话 ID
|
||||
/op <admin_id>: 授权管理员
|
||||
/deop <admin_id>: 取消管理员
|
||||
/wl <sid>: 添加白名单
|
||||
/dwl <sid>: 删除白名单
|
||||
/dashboard_update: 更新管理面板
|
||||
/alter_cmd: 设置指令权限
|
||||
/op <admin_id>: 授权管理员(op)
|
||||
/deop <admin_id>: 取消管理员(op)
|
||||
/wl <sid>: 添加白名单(op)
|
||||
/dwl <sid>: 删除白名单(op)
|
||||
/dashboard_update: 更新管理面板(op)
|
||||
/alter_cmd: 设置指令权限(op)
|
||||
|
||||
[大模型]
|
||||
/provider: 大模型提供商
|
||||
/model: 模型列表
|
||||
/key: API Key
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 对话记录
|
||||
/persona: 人格情景
|
||||
/ls: 对话列表
|
||||
/new: 创建新对话
|
||||
/switch: 切换对话
|
||||
/del: 删除当前会话对话(op)
|
||||
/reset: 重置 LLM 会话(op)
|
||||
/history: 当前对话的对话记录
|
||||
/persona: 人格情景(op)
|
||||
/tool ls: 函数工具
|
||||
/key: API Key(op)
|
||||
|
||||
[其他]
|
||||
/set <变量名> <值>: 为会话定义一个变量。适用于 Dify 工作流输入。
|
||||
/set <变量名> <值>: 为会话定义变量。适用于 Dify 工作流输入。
|
||||
/unset <变量名>: 删除会话的变量。
|
||||
|
||||
提示:如要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
@@ -273,7 +278,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
await self.context.get_using_provider().forget(message.session_id)
|
||||
await self.context.conversation_manager.update_conversation(
|
||||
message.unified_msg_origin, message.session_id, []
|
||||
)
|
||||
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm:
|
||||
cnt = await self.ltm.remove_session(event=message)
|
||||
@@ -329,20 +337,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
|
||||
|
||||
'''查看对话记录'''
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
size_per_page = 3
|
||||
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
|
||||
size_per_page = 6
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
contexts, total_pages = await self.context.conversation_manager.get_human_readable_context(
|
||||
message.unified_msg_origin, session_curr_cid, page, size_per_page
|
||||
)
|
||||
|
||||
history = ""
|
||||
for context in contexts:
|
||||
history += f"{context}\n"
|
||||
|
||||
ret = f"""历史记录:
|
||||
ret = f"""当前对话历史记录:
|
||||
{history}
|
||||
第 {page} 页 | 共 {total_pages} 页
|
||||
|
||||
@@ -351,6 +361,64 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("ls")
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
'''查看对话列表'''
|
||||
size_per_page = 6
|
||||
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
total_pages = len(conversations) // size_per_page
|
||||
if len(conversations) % size_per_page != 0:
|
||||
total_pages += 1
|
||||
conversations = conversations[(page-1)*size_per_page:page*size_per_page]
|
||||
|
||||
ret = "\n对话列表:\n"
|
||||
global_index = (page - 1) * size_per_page + 1
|
||||
|
||||
for conv in conversations:
|
||||
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id and not persona_id == "[%None]":
|
||||
persona_id = self.context.provider_manager.selected_default_persona['name']
|
||||
|
||||
ret += f"{global_index}. 新对话{conv.cid[:4]}\n 人格情景: {persona_id}\n上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
global_index += 1
|
||||
|
||||
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
if curr_cid:
|
||||
ret += f"\n当前对话: {curr_cid[:4]}"
|
||||
else:
|
||||
ret += "\n当前对话: 无"
|
||||
ret += f"\n第 {page} 页 | 共 {total_pages} 页"
|
||||
ret += "\n*输入 /ls 2 跳转到第 2 页"
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("new")
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
'''创建新对话'''
|
||||
cid = await self.context.conversation_manager.new_conversation(message.unified_msg_origin)
|
||||
message.set_result(MessageEventResult().message(f"切换到新对话: {cid[:4]}。"))
|
||||
|
||||
@filter.command("switch")
|
||||
async def switch_conv(self, message: AstrMessageEvent, index: int):
|
||||
'''切换对话'''
|
||||
conversations = await self.context.conversation_manager.get_conversations(message.unified_msg_origin)
|
||||
if index > len(conversations) or index < 1:
|
||||
message.set_result(MessageEventResult().message("对话序号错误。"))
|
||||
else:
|
||||
conversation = conversations[index-1]
|
||||
await self.context.conversation_manager.switch_conversation(message.unified_msg_origin, conversation.cid)
|
||||
message.set_result(MessageEventResult().message(f"切换到对话: {conversation.cid[:4]}。"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("del")
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
'''删除当前对话'''
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
await self.context.conversation_manager.delete_conversation(message.unified_msg_origin, session_curr_cid)
|
||||
message.set_result(MessageEventResult().message("删除当前对话成功。"))
|
||||
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int=None):
|
||||
@@ -387,28 +455,28 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
|
||||
l = message.message_str.split(" ")
|
||||
|
||||
curr_persona_name = "无"
|
||||
if self.context.get_using_provider().curr_personality:
|
||||
curr_persona_name = self.context.get_using_provider().curr_personality['name']
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(message.unified_msg_origin)
|
||||
if cid:
|
||||
conversation = await self.context.conversation_manager.get_conversation(message.unified_msg_origin, cid)
|
||||
if not conversation.persona_id and not conversation.persona_id == "[%None]":
|
||||
curr_persona_name = self.context.provider_manager.selected_default_persona['name']
|
||||
else:
|
||||
curr_persona_name = conversation.persona_id
|
||||
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"""[Persona]
|
||||
|
||||
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格情景列表: `/persona list`
|
||||
- 人格情景详细信息: `/persona view 人格名`
|
||||
- 设置人格情景: `/persona 人格`
|
||||
- 人格情景详细信息: `/persona view 人格`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
当前人格情景: {curr_persona_name}
|
||||
默认人格情景: {self.context.provider_manager.selected_default_persona['name']}
|
||||
当前对话 {cid[:4]} 的人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
""").use_t2i(False))
|
||||
@@ -433,7 +501,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "unset":
|
||||
self.context.get_using_provider().curr_personality = None
|
||||
if not cid:
|
||||
message.set_result(MessageEventResult().message("当前没有对话,无法取消人格。"))
|
||||
return
|
||||
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, "[%None]")
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
@@ -441,7 +512,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
lambda persona: persona['name'] == ps,
|
||||
self.context.provider_manager.personas
|
||||
), None):
|
||||
self.context.get_using_provider().curr_personality = persona
|
||||
await self.context.conversation_manager.update_conversation_persona_id(message.unified_msg_origin, ps)
|
||||
message.set_result(MessageEventResult().message("设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。"))
|
||||
else:
|
||||
message.set_result(MessageEventResult().message("不存在该人格情景。使用 /persona list 查看所有。"))
|
||||
@@ -513,7 +584,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id(event.unified_msg_origin)
|
||||
conv = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin,
|
||||
session_curr_cid
|
||||
)
|
||||
history = json.loads(conv.history)
|
||||
|
||||
prompt = self.ltm.ar_prompt
|
||||
if not prompt:
|
||||
@@ -523,7 +599,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
contexts=session_provider_context if session_provider_context else []
|
||||
contexts=history if history else []
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
@@ -545,14 +621,22 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
|
||||
|
||||
if persona := provider.curr_personality:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
if req.conversation:
|
||||
persona_id = req.conversation.persona_id
|
||||
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
|
||||
persona_id = self.context.provider_manager.selected_default_persona['name']
|
||||
persona = next(builtins.filter(
|
||||
lambda persona: persona['name'] == persona_id,
|
||||
self.context.provider_manager.personas
|
||||
), None)
|
||||
if persona:
|
||||
if prompt := persona['prompt']:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona['_mood_imitation_dialogs_processed']:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
if self.ltm:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user