fix: 修复 stage 在不同 pipeline 中被重复使用的问题和 persona 相关问题
This commit is contained in:
@@ -74,7 +74,10 @@ class AstrBotConfigManager:
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
umo = str(MessageSession.from_str(umo)) # validate
|
||||
try:
|
||||
umo = str(MessageSession.from_str(umo)) # validate
|
||||
except Exception:
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
for uuid_, meta in abconf_data.items():
|
||||
for pattern in meta["umop"]:
|
||||
@@ -107,8 +110,10 @@ class AstrBotConfigManager:
|
||||
}
|
||||
self.sp.put("abconf_mapping", abconf_data)
|
||||
|
||||
def get_conf(self, umo: str | MessageSession) -> AstrBotConfig:
|
||||
def get_conf(self, umo: str | MessageSession | None) -> AstrBotConfig:
|
||||
"""获取指定 umo 的配置文件。如果不存在,则 fallback 到默认配置文件。"""
|
||||
if not umo:
|
||||
return self.confs["default"]
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = f"{umo.platform_id}:{umo.message_type}:{umo.session_id}"
|
||||
|
||||
@@ -182,7 +187,9 @@ class AstrBotConfigManager:
|
||||
return False
|
||||
|
||||
# 获取配置文件路径
|
||||
conf_path = os.path.join(get_astrbot_config_path(), abconf_data[conf_id]["path"])
|
||||
conf_path = os.path.join(
|
||||
get_astrbot_config_path(), abconf_data[conf_id]["path"]
|
||||
)
|
||||
|
||||
# 删除配置文件
|
||||
try:
|
||||
@@ -204,7 +211,9 @@ class AstrBotConfigManager:
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(self, conf_id: str, name: str = None, umo_parts: list[str] = None) -> bool:
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str = None, umo_parts: list[str] = None
|
||||
) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
Args:
|
||||
@@ -234,7 +243,9 @@ class AstrBotConfigManager:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError("umo_parts must be a list of strings or MessageSession instances")
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
|
||||
@@ -91,7 +91,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(
|
||||
self.astrbot_config, self.db, self.persona_mgr
|
||||
self.astrbot_config_mgr, self.db, self.persona_mgr
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Persona, Personality
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot import logger
|
||||
|
||||
DEFAULT_PERSONALITY = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
tools=None,
|
||||
_begin_dialogs_processed=[],
|
||||
)
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
|
||||
self.db = db_helper
|
||||
self.acm = acm
|
||||
default_ps = acm.default_conf.get("provider_settings", {})
|
||||
self.default_persona: str = default_ps.get("default_personality", "default")
|
||||
self.personas: list[Persona] = []
|
||||
@@ -28,6 +37,21 @@ class PersonaManager:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
return persona
|
||||
|
||||
async def get_default_persona_v3(
|
||||
self, umo: str | MessageSession | None = None
|
||||
) -> Personality:
|
||||
"""获取默认 persona"""
|
||||
cfg = self.acm.get_conf(umo)
|
||||
default_persona_id = cfg.get("provider_settings", {}).get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
if not default_persona_id or default_persona_id == "default":
|
||||
return DEFAULT_PERSONALITY
|
||||
try:
|
||||
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
|
||||
except ValueError:
|
||||
return DEFAULT_PERSONALITY
|
||||
|
||||
async def delete_persona(self, persona_id: str):
|
||||
"""删除指定 persona"""
|
||||
if not await self.db.get_persona_by_id(persona_id):
|
||||
@@ -140,12 +164,7 @@ class PersonaManager:
|
||||
selected_default_persona = personas_v3[0]
|
||||
|
||||
if not selected_default_persona:
|
||||
selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
tools=None,
|
||||
_begin_dialogs_processed=[],
|
||||
)
|
||||
selected_default_persona = DEFAULT_PERSONALITY
|
||||
personas_v3.append(selected_default_persona)
|
||||
|
||||
self.personas_v3 = personas_v3
|
||||
|
||||
@@ -64,9 +64,10 @@ class ResultDecorateStage(Stage):
|
||||
]
|
||||
self.content_safe_check_stage = None
|
||||
if self.content_safe_check_reply:
|
||||
for stage in registered_stages:
|
||||
if stage.__class__.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage
|
||||
for stage_cls in registered_stages:
|
||||
if stage_cls.__name__ == "ContentSafetyCheckStage":
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
|
||||
@@ -11,16 +11,17 @@ class PipelineScheduler:
|
||||
|
||||
def __init__(self, context: PipelineContext):
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||
key=lambda x: STAGES_ORDER.index(x.__name__)
|
||||
) # 按照顺序排序
|
||||
self.ctx = context # 上下文对象
|
||||
self.stages = [] # 存储阶段实例
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管道调度器时, 初始化所有阶段"""
|
||||
for stage in registered_stages:
|
||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
for stage_cls in registered_stages:
|
||||
stage_instance = stage_cls() # 创建实例
|
||||
await stage_instance.initialize(self.ctx)
|
||||
self.stages.append(stage_instance)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
"""依次执行各个阶段
|
||||
@@ -29,9 +30,9 @@ class PipelineScheduler:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||
"""
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
for i in range(from_stage, len(self.stages)):
|
||||
stage = self.stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__.__name__}")
|
||||
coroutine = stage.process(
|
||||
event
|
||||
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import List, AsyncGenerator, Union
|
||||
from typing import List, AsyncGenerator, Union, Type
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型
|
||||
|
||||
|
||||
def register_stage(cls):
|
||||
"""一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类"""
|
||||
registered_stages.append(cls())
|
||||
registered_stages.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import traceback
|
||||
from typing import List
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
@@ -15,12 +15,13 @@ from ..persona_mgr import PersonaManager
|
||||
class ProviderManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: AstrBotConfig,
|
||||
acm: AstrBotConfigManager,
|
||||
db_helper: BaseDatabase,
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.persona_mgr = persona_mgr
|
||||
self.astrbot_config = config
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: List = config["provider"]
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
@@ -42,11 +43,11 @@ class ProviderManager:
|
||||
self.llm_tools = llm_tools
|
||||
|
||||
self.curr_provider_inst: Provider | None = None
|
||||
"""默认的 Provider 实例"""
|
||||
"""默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.curr_stt_provider_inst: STTProvider | None = None
|
||||
"""默认的 Speech To Text Provider 实例"""
|
||||
"""默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.curr_tts_provider_inst: TTSProvider | None = None
|
||||
"""默认的 Text To Speech Provider 实例"""
|
||||
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
|
||||
self.db_helper = db_helper
|
||||
|
||||
@property
|
||||
@@ -61,7 +62,7 @@ class ProviderManager:
|
||||
|
||||
@property
|
||||
def selected_default_persona(self):
|
||||
"""动态获取最新的默认选中 persona"""
|
||||
"""动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()"""
|
||||
return self.persona_mgr.selected_default_persona_v3
|
||||
|
||||
async def set_provider(
|
||||
@@ -72,7 +73,9 @@ class ProviderManager:
|
||||
Args:
|
||||
provider_id (str): 提供商 ID。
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。当用户启用了提供商会话隔离时此参数才生效。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||||
|
||||
Version 4.0.0: 这个版本下已经默认隔离提供商
|
||||
"""
|
||||
if provider_id not in self.inst_map:
|
||||
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
||||
@@ -92,6 +95,57 @@ class ProviderManager:
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
||||
sp.put("curr_provider", provider_id)
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
provider_type (ProviderType): 提供商类型。
|
||||
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
||||
|
||||
Returns:
|
||||
Provider: 正在使用的提供商实例。
|
||||
"""
|
||||
provider = None
|
||||
if umo:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
session_perf = perf.get(umo, {})
|
||||
provider_id = session_perf.get(provider_type.value)
|
||||
if provider_id:
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
# default setting
|
||||
config = self.acm.get_conf(umo)
|
||||
if provider_type == ProviderType.CHAT_COMPLETION:
|
||||
provider_id = config["provider_settings"].get("default_provider_id")
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = self.provider_insts[0] if self.provider_insts else None
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
provider_id = config["provider_stt_settings"].get("provider_id")
|
||||
if not provider_id:
|
||||
return None
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = (
|
||||
self.stt_provider_insts[0] if self.stt_provider_insts else None
|
||||
)
|
||||
elif provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
provider_id = config["provider_tts_settings"].get("provider_id")
|
||||
if not provider_id:
|
||||
return None
|
||||
provider = self.inst_map.get(provider_id)
|
||||
if not provider:
|
||||
provider = (
|
||||
self.tts_provider_insts[0] if self.tts_provider_insts else None
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
return provider
|
||||
|
||||
async def initialize(self):
|
||||
# 逐个初始化提供商
|
||||
for provider_config in self.providers_config:
|
||||
|
||||
@@ -141,12 +141,10 @@ class Context:
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.CHAT_COMPLETION.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
|
||||
"""
|
||||
@@ -155,12 +153,10 @@ class Context:
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.TEXT_TO_SPEECH.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_tts_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
|
||||
"""
|
||||
@@ -169,12 +165,10 @@ class Context:
|
||||
Args:
|
||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||
"""
|
||||
if umo:
|
||||
perf = sp.get("session_provider_perf", {})
|
||||
prov_id = perf.get(umo, {}).get(ProviderType.SPEECH_TO_TEXT.value, None)
|
||||
if inst := self.provider_manager.inst_map.get(prov_id, None):
|
||||
return inst
|
||||
return self.provider_manager.curr_stt_provider_inst
|
||||
return self.provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||
umo=umo,
|
||||
)
|
||||
|
||||
def get_config(self, umo: str = None) -> AstrBotConfig:
|
||||
"""获取 AstrBot 的配置。"""
|
||||
|
||||
@@ -24,7 +24,6 @@ class SessionManagementRoute(Route):
|
||||
"/session/list": ("GET", self.list_sessions),
|
||||
"/session/update_persona": ("POST", self.update_session_persona),
|
||||
"/session/update_provider": ("POST", self.update_session_provider),
|
||||
"/session/get_session_info": ("POST", self.get_session_info),
|
||||
"/session/plugins": ("GET", self.get_session_plugins),
|
||||
"/session/update_plugin": ("POST", self.update_session_plugin),
|
||||
"/session/update_llm": ("POST", self.update_session_llm),
|
||||
@@ -39,17 +38,10 @@ class SessionManagementRoute(Route):
|
||||
async def list_sessions(self):
|
||||
"""获取所有会话的列表,包括 persona 和 provider 信息"""
|
||||
try:
|
||||
# 获取会话对话映射
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
|
||||
# 获取会话提供商偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
|
||||
# 获取可用的 personas
|
||||
personas = self.core_lifecycle.star_context.provider_manager.personas
|
||||
|
||||
# 获取可用的 providers
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
persona_mgr = self.core_lifecycle.persona_mgr
|
||||
personas = persona_mgr.personas_v3
|
||||
|
||||
sessions = []
|
||||
|
||||
@@ -95,6 +87,7 @@ class SessionManagementRoute(Route):
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
@@ -105,59 +98,34 @@ class SessionManagementRoute(Route):
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
default_persona = persona_mgr.selected_default_persona_v3
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = (
|
||||
default_stt_provider.meta().id
|
||||
)
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = (
|
||||
default_tts_provider.meta().id
|
||||
)
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.provider_manager
|
||||
chat_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.CHAT_COMPLETION, umo=session_id
|
||||
)
|
||||
tts_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.TEXT_TO_SPEECH, umo=session_id
|
||||
)
|
||||
stt_provider = provider_manager.get_using_provider(
|
||||
provider_type=ProviderType.SPEECH_TO_TEXT, umo=session_id
|
||||
)
|
||||
if chat_provider:
|
||||
meta = chat_provider.meta()
|
||||
session_info["chat_provider_id"] = meta.id
|
||||
session_info["chat_provider_name"] = meta.id
|
||||
if tts_provider:
|
||||
meta = tts_provider.meta()
|
||||
session_info["tts_provider_id"] = meta.id
|
||||
session_info["tts_provider_name"] = meta.id
|
||||
if stt_provider:
|
||||
meta = stt_provider.meta()
|
||||
session_info["stt_provider_id"] = meta.id
|
||||
session_info["stt_provider_name"] = meta.id
|
||||
|
||||
sessions.append(session_info)
|
||||
|
||||
@@ -311,133 +279,6 @@ class SessionManagementRoute(Route):
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话提供商失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_info(self):
|
||||
"""获取指定会话的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
# 获取会话对话信息
|
||||
session_conversations = sp.get("session_conversation", {}) or {}
|
||||
conversation_id = session_conversations.get(session_id)
|
||||
|
||||
if not conversation_id:
|
||||
return Response().error(f"会话 {session_id} 未找到对话").__dict__
|
||||
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"conversation_id": conversation_id,
|
||||
"persona_id": None,
|
||||
"persona_name": None,
|
||||
"chat_provider_id": None,
|
||||
"chat_provider_name": None,
|
||||
"stt_provider_id": None,
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"llm_enabled": SessionServiceManager.is_llm_enabled_for_session(
|
||||
session_id
|
||||
),
|
||||
"tts_enabled": None, # 将在下面设置
|
||||
"platform": session_id.split(":")[0]
|
||||
if ":" in session_id
|
||||
else "unknown",
|
||||
"message_type": session_id.split(":")[1]
|
||||
if session_id.count(":") >= 1
|
||||
else "unknown",
|
||||
"session_name": session_id.split(":")[2]
|
||||
if session_id.count(":") >= 2
|
||||
else session_id,
|
||||
}
|
||||
|
||||
# 获取TTS状态
|
||||
session_info["tts_enabled"] = (
|
||||
SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
)
|
||||
|
||||
# 获取对话信息
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=session_id, conversation_id=conversation_id
|
||||
)
|
||||
if conversation:
|
||||
session_info["persona_id"] = conversation.persona_id
|
||||
|
||||
# 查找 persona 名称
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
personas = provider_manager.personas
|
||||
|
||||
if conversation.persona_id and conversation.persona_id != "[%None]":
|
||||
for persona in personas:
|
||||
if persona["name"] == conversation.persona_id:
|
||||
session_info["persona_name"] = persona["name"]
|
||||
break
|
||||
elif conversation.persona_id == "[%None]":
|
||||
session_info["persona_name"] = "无人格"
|
||||
else:
|
||||
# 使用默认人格
|
||||
default_persona = provider_manager.selected_default_persona
|
||||
if default_persona:
|
||||
session_info["persona_id"] = default_persona["name"]
|
||||
session_info["persona_name"] = default_persona["name"]
|
||||
|
||||
# 获取会话的 provider 偏好设置
|
||||
session_provider_perf = sp.get("session_provider_perf", {}) or {}
|
||||
session_perf = session_provider_perf.get(session_id, {})
|
||||
|
||||
# 获取 provider 信息
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
# Chat completion provider
|
||||
chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value)
|
||||
if chat_provider_id:
|
||||
chat_provider = provider_manager.inst_map.get(chat_provider_id)
|
||||
if chat_provider:
|
||||
session_info["chat_provider_id"] = chat_provider_id
|
||||
session_info["chat_provider_name"] = chat_provider.meta().id
|
||||
else:
|
||||
# 使用默认 provider
|
||||
default_provider = provider_manager.curr_provider_inst
|
||||
if default_provider:
|
||||
session_info["chat_provider_id"] = default_provider.meta().id
|
||||
session_info["chat_provider_name"] = default_provider.meta().id
|
||||
|
||||
# STT provider
|
||||
stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value)
|
||||
if stt_provider_id:
|
||||
stt_provider = provider_manager.inst_map.get(stt_provider_id)
|
||||
if stt_provider:
|
||||
session_info["stt_provider_id"] = stt_provider_id
|
||||
session_info["stt_provider_name"] = stt_provider.meta().id
|
||||
else:
|
||||
# 使用默认 STT provider
|
||||
default_stt_provider = provider_manager.curr_stt_provider_inst
|
||||
if default_stt_provider:
|
||||
session_info["stt_provider_id"] = default_stt_provider.meta().id
|
||||
session_info["stt_provider_name"] = default_stt_provider.meta().id
|
||||
|
||||
# TTS provider
|
||||
tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value)
|
||||
if tts_provider_id:
|
||||
tts_provider = provider_manager.inst_map.get(tts_provider_id)
|
||||
if tts_provider:
|
||||
session_info["tts_provider_id"] = tts_provider_id
|
||||
session_info["tts_provider_name"] = tts_provider.meta().id
|
||||
else:
|
||||
# 使用默认 TTS provider
|
||||
default_tts_provider = provider_manager.curr_tts_provider_inst
|
||||
if default_tts_provider:
|
||||
session_info["tts_provider_id"] = default_tts_provider.meta().id
|
||||
session_info["tts_provider_name"] = default_tts_provider.meta().id
|
||||
|
||||
return Response().ok(session_info).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"获取会话信息失败: {str(e)}").__dict__
|
||||
|
||||
async def get_session_plugins(self):
|
||||
"""获取指定会话的插件配置信息"""
|
||||
try:
|
||||
|
||||
+11
-15
@@ -701,11 +701,6 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
"""生成所有对话的标题字典"""
|
||||
_titles = {}
|
||||
for conv in conversations_all:
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id or persona_id == "[%None]":
|
||||
persona_id = self.context.provider_manager.selected_default_persona[
|
||||
"name"
|
||||
]
|
||||
title = conv.title if conv.title else "新对话"
|
||||
_titles[conv.cid] = title
|
||||
|
||||
@@ -713,9 +708,10 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
for conv in conversations_paged:
|
||||
persona_id = conv.persona_id
|
||||
if not persona_id or persona_id == "[%None]":
|
||||
persona_id = self.context.provider_manager.selected_default_persona[
|
||||
"name"
|
||||
]
|
||||
persona = await self.context.persona_manager.get_default_persona_v3(
|
||||
umo=message.unified_msg_origin
|
||||
)
|
||||
persona_id = persona["name"]
|
||||
title = _titles.get(conv.cid, "新对话")
|
||||
ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
global_index += 1
|
||||
@@ -981,22 +977,22 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
l = message.message_str.split(" ") # noqa: E741
|
||||
umo = message.unified_msg_origin
|
||||
|
||||
curr_persona_name = "无"
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
message.unified_msg_origin
|
||||
cid = await self.context.conversation_manager.get_curr_conversation_id(umo)
|
||||
default_persona = await self.context.persona_manager.get_default_persona_v3(
|
||||
umo=umo
|
||||
)
|
||||
curr_cid_title = "无"
|
||||
if cid:
|
||||
conversation = await self.context.conversation_manager.get_conversation(
|
||||
unified_msg_origin=message.unified_msg_origin,
|
||||
unified_msg_origin=umo,
|
||||
conversation_id=cid,
|
||||
create_if_not_exists=True,
|
||||
)
|
||||
if not conversation.persona_id and not conversation.persona_id == "[%None]":
|
||||
curr_persona_name = (
|
||||
self.context.provider_manager.selected_default_persona["name"]
|
||||
)
|
||||
curr_persona_name = default_persona["name"]
|
||||
else:
|
||||
curr_persona_name = conversation.persona_id
|
||||
|
||||
@@ -1014,7 +1010,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
- 人格情景详细信息: `/persona view 人格`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
默认人格情景: {self.context.provider_manager.selected_default_persona["name"]}
|
||||
默认人格情景: {default_persona["name"]}
|
||||
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
|
||||
Reference in New Issue
Block a user