Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 37076d7920 | |||
| 78347ec91b | |||
| 9ded102a0a | |||
| 59b7d8b8cb | |||
| f5b97f6762 | |||
| d47da241af | |||
| 4611ce15eb | |||
| aa8c56a688 | |||
| ef44d4471a | |||
| 5581eae957 | |||
| ec46dfaac9 | |||
| 6042a047bd | |||
| 6ca9e2a753 | |||
| 618eabfe5c | |||
| bb5db2e9d0 | |||
| 97e4d169b3 | |||
| 50e44b1473 | |||
| 38588dd3fa | |||
| d183388347 | |||
| 1e69d59384 | |||
| 00f008f94d |
@@ -6,6 +6,7 @@ from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
register_permission_type as permission_type,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_on_llm_response as on_llm_response,
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
register_after_message_sent as after_message_sent
|
||||
@@ -31,5 +32,6 @@ __all__ = [
|
||||
'on_llm_request',
|
||||
'llm_tool',
|
||||
'on_decorating_result',
|
||||
'after_message_sent'
|
||||
'after_message_sent',
|
||||
'on_llm_response'
|
||||
]
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.16"
|
||||
VERSION = "3.4.17"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -37,6 +37,7 @@ DEFAULT_CONFIG = {
|
||||
"enable": True,
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
"web_search_link": False,
|
||||
"identifier": False,
|
||||
"datetime_system_prompt": True,
|
||||
"default_personality": "default",
|
||||
@@ -55,6 +56,13 @@ DEFAULT_CONFIG = {
|
||||
"group_message_max_cnt": 300,
|
||||
"image_caption": False,
|
||||
"image_caption_prompt": "Please describe the image using Chinese.",
|
||||
"active_reply": {
|
||||
"enable": False,
|
||||
"method": "possibility_reply",
|
||||
"possibility_reply": 0.1,
|
||||
"prompt": "",
|
||||
},
|
||||
"put_history_to_prompt": True,
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
@@ -78,7 +86,7 @@ DEFAULT_CONFIG = {
|
||||
"persona": [
|
||||
{
|
||||
"name": "default",
|
||||
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"prompt": "",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
@@ -111,14 +119,13 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 6199,
|
||||
},
|
||||
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "localhost",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 11451,
|
||||
},
|
||||
},
|
||||
@@ -237,7 +244,8 @@ CONFIG_METADATA_2 = {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
"obvious_hint": True,
|
||||
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
|
||||
},
|
||||
"id_whitelist_log": {
|
||||
"description": "打印白名单日志",
|
||||
@@ -324,7 +332,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
@@ -420,9 +428,17 @@ CONFIG_METADATA_2 = {
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "tts-1",
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"openai-tts-voice": {
|
||||
"description": "voice",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "OpenAI TTS 的声音。OpenAI 默认支持:'alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer'",
|
||||
},
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
@@ -453,7 +469,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
@@ -539,16 +555,25 @@ CONFIG_METADATA_2 = {
|
||||
"web_search": {
|
||||
"description": "启用网页搜索",
|
||||
"type": "bool",
|
||||
"hint": "能访问 Google 时效果最佳。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
"obvious_hint": True,
|
||||
"hint": "能访问 Google 时效果最佳(国内需要在 `其他配置` 开启 HTTP 代理)。如果 Google 访问失败,程序会依次访问 Bing, Sogo 搜索引擎。",
|
||||
},
|
||||
"web_search_link": {
|
||||
"description": "网页搜索引用链接",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "开启后,将会传入网页搜索结果的链接给模型,并引导模型输出引用链接。",
|
||||
},
|
||||
"identifier": {
|
||||
"description": "启动识别群员",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "在 Prompt 前加上群成员的名字以让模型更好地了解群聊状态。启用将略微增加 token 开销。",
|
||||
},
|
||||
"datetime_system_prompt": {
|
||||
"description": "启用日期时间系统提示",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
|
||||
},
|
||||
"default_personality": {
|
||||
@@ -591,14 +616,14 @@ CONFIG_METADATA_2 = {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色之后按回车",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。对话需要成对(用户和助手),输入完一个角色之后按回车",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
@@ -644,25 +669,61 @@ CONFIG_METADATA_2 = {
|
||||
"group_icl_enable": {
|
||||
"description": "群聊内记录各群员对话",
|
||||
"type": "bool",
|
||||
"obvious-hint": True,
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"group_message_max_cnt": {
|
||||
"description": "群聊消息最大数量",
|
||||
"type": "int",
|
||||
"obvious-hint": True,
|
||||
"obvious_hint": True,
|
||||
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
|
||||
},
|
||||
"image_caption": {
|
||||
"description": "启用图像转述(需要模型支持)",
|
||||
"type": "bool",
|
||||
"obvious-hint": True,
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
|
||||
},
|
||||
"image_caption_prompt": {
|
||||
"description": "图像转述提示词",
|
||||
"type": "string"
|
||||
},
|
||||
"active_reply": {
|
||||
"description": "主动回复",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用主动回复",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "启用后,会根据触发概率主动回复群聊内的对话。",
|
||||
},
|
||||
"method": {
|
||||
"description": "回复方法",
|
||||
"type": "string",
|
||||
"options": ["possibility_reply"],
|
||||
"hint": "回复方法。possibility_reply 为根据概率回复",
|
||||
},
|
||||
"possibility_reply": {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"obvious_hint": True,
|
||||
"hint": "回复概率。当回复方法为 possibility_reply 时有效。当概率 >= 1 时,每条消息都会回复。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "提示词",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "提示词。当提示词为空时,如果触发回复,prompt是触发的消息的内容;否则是提示词。此项可以和定时回复(暂未实现)配合使用。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"put_history_to_prompt": {
|
||||
"description": "将群聊历史记录作为 prompt",
|
||||
"type": "bool",
|
||||
"obvious_hint": True,
|
||||
"hint": "需要先启用 group_icl_enable。此功能会将群聊历史记录放到 prompt 再请求。如果关闭,则是放在 system_prompt。如果开启了主动回复,建议启用,模型能够更好地完成回复任务。",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -21,6 +21,10 @@ class DifyRequestSubStage(Stage):
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
return
|
||||
|
||||
if provider.meta().type != "dify":
|
||||
return
|
||||
|
||||
|
||||
@@ -68,6 +68,15 @@ class LLMRequestSubStage(Stage):
|
||||
if _nested:
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
|
||||
# 执行 LLM 响应后的事件。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
|
||||
@@ -37,7 +37,11 @@ class ProcessStage(Stage):
|
||||
# Handler 的 LLM 请求
|
||||
logger.debug(f"llm request -> {resp.prompt}")
|
||||
event.set_extra("provider_request", resp)
|
||||
_t = False
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
_t = True
|
||||
yield
|
||||
if not _t:
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
@@ -49,6 +53,11 @@ class ProcessStage(Stage):
|
||||
if not event._has_send_oper and event.is_at_or_wake_command:
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if not provider:
|
||||
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
|
||||
return
|
||||
|
||||
match provider.meta().type:
|
||||
case "dify":
|
||||
async for _ in self.dify_request_sub_stage.process(event):
|
||||
|
||||
@@ -296,6 +296,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
def request_llm(
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager = None,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
contexts: List = None,
|
||||
@@ -311,11 +312,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
|
||||
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
|
||||
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||
'''
|
||||
return ProviderRequest(
|
||||
prompt = prompt,
|
||||
session_id = session_id,
|
||||
image_urls = image_urls,
|
||||
func_tool = func_tool_manager,
|
||||
contexts = contexts,
|
||||
system_prompt = system_prompt
|
||||
)
|
||||
@@ -24,7 +24,10 @@ class PlatformManager():
|
||||
case "qq_official":
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
try:
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
except BaseException:
|
||||
logger.warning("当前 astrbot 已不维护 vchat 的接入,如有需要请 pip 安装 vchat 然后重启")
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
@@ -51,4 +52,7 @@ class LLMResponse:
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
'''工具调用名称'''
|
||||
'''工具调用名称'''
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
@@ -108,13 +108,19 @@ class FuncCall:
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
tools.append(
|
||||
{
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
)
|
||||
|
||||
func_declaration = {
|
||||
"name": f.name,
|
||||
"description": f.description
|
||||
}
|
||||
|
||||
# 检查并添加非空的properties参数
|
||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||
if params.get("properties", {}):
|
||||
func_declaration["parameters"] = params
|
||||
|
||||
tools.append(func_declaration)
|
||||
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ class ProviderManager():
|
||||
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
|
||||
self.persona_configs: list = config.get('persona', [])
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
@@ -28,7 +30,7 @@ class ProviderManager():
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
|
||||
continue
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append({
|
||||
@@ -40,9 +42,9 @@ class ProviderManager():
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
|
||||
continue
|
||||
mood_imitation_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
for dialog in mood_imitation_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
@@ -61,6 +63,10 @@ class ProviderManager():
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
|
||||
@@ -74,7 +74,7 @@ class Provider(AbstractProvider):
|
||||
if persistant_history:
|
||||
# 读取历史记录
|
||||
try:
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['type']):
|
||||
for history in db_helper.get_llm_history(provider_type=provider_config['id']):
|
||||
self.session_memory[history.session_id] = json.loads(history.content)
|
||||
except BaseException as e:
|
||||
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
|
||||
|
||||
@@ -260,10 +260,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
|
||||
@@ -118,10 +118,11 @@ class LLMTunerModelLoader(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
|
||||
@@ -72,31 +72,21 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
async def pop_record(self, session_id: str):
|
||||
'''
|
||||
弹出第一条记录
|
||||
弹出最早的一个对话
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
if len(self.session_memory[session_id]) < 2:
|
||||
return
|
||||
|
||||
try:
|
||||
self.session_memory[session_id].pop(0)
|
||||
self.session_memory[session_id].pop(0)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
@@ -104,20 +94,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if tool_list:
|
||||
payloads['tools'] = tool_list
|
||||
|
||||
completion = None
|
||||
try:
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
except BaseException as e:
|
||||
if 'does not support Function Calling' \
|
||||
or 'does not support tools' in e: # ollama
|
||||
# 处理不支持 Function Calling 的模型
|
||||
if 'does not support Function Calling' in str(e) \
|
||||
or 'does not support tools' in str(e) \
|
||||
or 'Function call is not supported' in str(e): # siliconcloud
|
||||
del payloads['tools']
|
||||
logger.debug(f"模型 {self.model_name} 不支持 tools,已自动移除")
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"completion: {completion}")
|
||||
@@ -129,14 +124,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
|
||||
# 适配 deepseek-r1 模型
|
||||
if r'<think>' in completion_text:
|
||||
completion_text = re.sub(r'<think>.*?</think>', '', completion_text, flags=re.DOTALL).strip()
|
||||
# 可能有单标签情况
|
||||
completion_text = completion_text.replace(r'<think>', '').replace(r'</think>', '').strip()
|
||||
|
||||
return LLMResponse("assistant", completion_text)
|
||||
|
||||
return LLMResponse("assistant", completion_text, raw_completion=completion)
|
||||
elif choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
args_ls = []
|
||||
@@ -147,8 +136,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
|
||||
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls, raw_completion=completion)
|
||||
else:
|
||||
logger.error(f"API 返回的 completion 无法解析:{completion}。")
|
||||
raise Exception("Internal Error")
|
||||
|
||||
async def text_chat(
|
||||
@@ -188,19 +178,22 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
await self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
if kwargs.get("persist", True):
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
@@ -219,10 +212,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['id'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
|
||||
@@ -15,7 +15,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
self.voice = provider_config.get("voice", "alloy")
|
||||
self.voice = provider_config.get("openai-tts-voice", "alloy")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
|
||||
@@ -7,6 +7,7 @@ from .star_handler import (
|
||||
register_regex,
|
||||
register_permission_type,
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_llm_tool,
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
'register_regex',
|
||||
'register_permission_type',
|
||||
'register_on_llm_request',
|
||||
'register_on_llm_response',
|
||||
'register_llm_tool',
|
||||
'register_on_decorating_result',
|
||||
'register_after_message_sent'
|
||||
|
||||
@@ -139,6 +139,8 @@ def register_on_llm_request():
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
|
||||
@on_llm_request()
|
||||
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
|
||||
request.system_prompt += "你是一个猫娘..."
|
||||
@@ -152,6 +154,27 @@ def register_on_llm_request():
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_response():
|
||||
'''当有 LLM 请求后的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
from astrbot.api.provider import LLMResponse
|
||||
|
||||
@on_llm_response()
|
||||
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
|
||||
...
|
||||
```
|
||||
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ class EventType(enum.Enum):
|
||||
'''
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
@@ -109,7 +109,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, timeout=120) as resp:
|
||||
async with session.get(url, timeout=1800) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
|
||||
@@ -6,6 +6,10 @@ class StaticFileRoute(Route):
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat']
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
@self.app.errorhandler(404)
|
||||
async def page_not_found(e):
|
||||
return "404 Not found。如果你初次使用打开面板发现 404,请参考文档: https://astrbot.app/deploy/dashboard-404.html"
|
||||
|
||||
async def index(self):
|
||||
return await self.app.send_static_file('index.html')
|
||||
@@ -0,0 +1,11 @@
|
||||
# What's Changed
|
||||
|
||||
- [beta] 支持群聊内基于概率的主动回复
|
||||
- openai tts 更换模型 #300
|
||||
- 增加模型响应后的插件钩子
|
||||
- 修复 相同type的provider共享了记忆
|
||||
- 优化 人格情景在发现格式不对时仍然加载而不是跳过 #282
|
||||
- 修复 Gemini函数调用时,parameters为空对象导致的错误 by @Camreishi
|
||||
- 修复 弹出记录报错的问题 #272
|
||||
- 优化 移除默认人格
|
||||
- 优化 未启用模型提供商时的异常处理
|
||||
+4
-2
@@ -4,6 +4,8 @@ services:
|
||||
astrbot:
|
||||
image: soulter/astrbot:latest
|
||||
container_name: astrbot
|
||||
network_mode: "host"
|
||||
ports:
|
||||
- "6180-6200:6180-6200"
|
||||
- "11451:11451"
|
||||
volumes:
|
||||
- ./data:/AstrBot/data
|
||||
- ./data:/AstrBot/data
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
import random
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.platform import MessageType
|
||||
@@ -8,7 +9,9 @@ from astrbot.api.message_components import Plain, Image
|
||||
from astrbot import logger
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
'''
|
||||
聊天记忆增强
|
||||
'''
|
||||
class LongTermMemory:
|
||||
def __init__(self, config: dict, context: star.Context):
|
||||
self.config = config
|
||||
@@ -23,6 +26,13 @@ class LongTermMemory:
|
||||
self.image_caption = self.config["image_caption"]
|
||||
self.image_caption_prompt = self.config["image_caption_prompt"]
|
||||
|
||||
self.active_reply = self.config["active_reply"]
|
||||
self.ar_method = self.active_reply["method"]
|
||||
self.ar_possibility = self.active_reply["possibility_reply"]
|
||||
self.ar_prompt = self.active_reply.get("prompt", "")
|
||||
|
||||
self.put_history_to_prompt = self.config["put_history_to_prompt"]
|
||||
|
||||
async def remove_session(self, event: AstrMessageEvent) -> int:
|
||||
cnt = 0
|
||||
if event.unified_msg_origin in self.session_chats:
|
||||
@@ -39,8 +49,27 @@ class LongTermMemory:
|
||||
persist=False,
|
||||
)
|
||||
return response.completion_text
|
||||
|
||||
async def need_active_reply(self, event: AstrMessageEvent) -> bool:
|
||||
if not self.active_reply:
|
||||
return False
|
||||
if event.get_message_type() != MessageType.GROUP_MESSAGE:
|
||||
return False
|
||||
|
||||
if event.is_at_or_wake_command:
|
||||
# if the message is a command, let it pass
|
||||
return False
|
||||
|
||||
match self.ar_method:
|
||||
case "possibility_reply":
|
||||
trig = random.random() < self.ar_possibility
|
||||
return trig
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def handle_message(self, event: AstrMessageEvent):
|
||||
'''仅支持群聊'''
|
||||
if event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
@@ -59,23 +88,33 @@ class LongTermMemory:
|
||||
final_message += f" [Image: {caption}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {e}")
|
||||
else:
|
||||
final_message += " [Image]"
|
||||
logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}")
|
||||
self.session_chats[event.unified_msg_origin].append(final_message)
|
||||
if len(self.session_chats[event.unified_msg_origin]) > self.max_cnt:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
'''当触发 LLM 请求前,调用此方法修改 req'''
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
chats_str = '\n---\n'.join(self.session_chats[event.unified_msg_origin])
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
|
||||
if self.put_history_to_prompt:
|
||||
prompt = req.prompt
|
||||
req.prompt = f"You are now in a chatroom. The chat history is as follows:\n{chats_str}"
|
||||
req.prompt += f"\nNow, a new message is coming: `{prompt}`. Please react to it. Only output your response and do not output any other information."
|
||||
req.contexts = [] # 清空上下文,当使用了群聊增强,所有聊天记录都在一个prompt中。
|
||||
else:
|
||||
req.system_prompt += "You are now in a chatroom. The chat history is as follows: \n"
|
||||
req.system_prompt += chats_str
|
||||
if self.image_caption:
|
||||
req.system_prompt += (
|
||||
"The images sent by the members are displayed in text form above."
|
||||
)
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent):
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
@@ -5,7 +5,7 @@ import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.api.provider import Personality, ProviderRequest, LLMResponse
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
@@ -25,7 +25,7 @@ class Main(star.Star):
|
||||
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
|
||||
|
||||
self.ltm = None
|
||||
if self.context.get_config()['provider_ltm_settings']['group_icl_enable']:
|
||||
if self.context.get_config()['provider_ltm_settings']['group_icl_enable'] or self.context.get_config()['provider_ltm_settings']['active_reply']['enable']:
|
||||
try:
|
||||
self.ltm = LongTermMemory(self.context.get_config()['provider_ltm_settings'], self.context)
|
||||
except BaseException as e:
|
||||
@@ -203,6 +203,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
async def provider(self, event: AstrMessageEvent, idx: int = None):
|
||||
'''查看或者切换 LLM Provider'''
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
event.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
if idx is None:
|
||||
ret = "## 当前载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
@@ -227,6 +231,11 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
await self.context.get_using_provider().forget(message.session_id)
|
||||
ret = "清除会话 LLM 聊天历史成功。"
|
||||
if self.ltm:
|
||||
@@ -237,6 +246,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("model")
|
||||
async def model_ls(self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
|
||||
if idx_or_name is None:
|
||||
models = []
|
||||
try:
|
||||
@@ -277,6 +292,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
size_per_page = 3
|
||||
contexts, total_pages = await self.context.get_using_provider().get_human_readable_context(message.session_id, page, size_per_page)
|
||||
|
||||
@@ -296,6 +317,10 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int=None):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
if index is None:
|
||||
keys_data = self.context.get_using_provider().get_keys()
|
||||
@@ -324,6 +349,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
|
||||
if not self.context.get_using_provider():
|
||||
message.set_result(MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"))
|
||||
return
|
||||
|
||||
|
||||
l = message.message_str.split(" ")
|
||||
|
||||
curr_persona_name = "无"
|
||||
@@ -337,6 +368,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
- 设置人格情景: `/persona 人格名`, 如 /persona 编剧
|
||||
- 人格情景列表: `/persona list`
|
||||
- 人格情景详细信息: `/persona view 人格名`
|
||||
- 取消人格: `/persona unset`
|
||||
|
||||
当前人格情景: {curr_persona_name}
|
||||
|
||||
@@ -362,6 +394,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
else:
|
||||
msg = f"人格{ps}不存在"
|
||||
message.set_result(MessageEventResult().message(msg))
|
||||
elif l[1] == "unset":
|
||||
self.context.get_using_provider().curr_personality = None
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if persona := next(builtins.filter(
|
||||
@@ -421,12 +456,39 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
|
||||
@filter.platform_adapter_type(filter.PlatformAdapterType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''长期记忆'''
|
||||
'''群聊记忆增强'''
|
||||
if self.ltm:
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
need_active = await self.ltm.need_active_reply(event)
|
||||
|
||||
group_icl_enable = self.context.get_config()['provider_ltm_settings']['group_icl_enable']
|
||||
if group_icl_enable:
|
||||
'''记录对话'''
|
||||
try:
|
||||
await self.ltm.handle_message(event)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
|
||||
if need_active:
|
||||
'''主动回复'''
|
||||
provider = self.context.get_using_provider()
|
||||
if not provider:
|
||||
logger.error("未找到任何 LLM 提供商。请先配置。无法主动回复")
|
||||
return
|
||||
try:
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
|
||||
prompt = self.ltm.ar_prompt
|
||||
if not prompt:
|
||||
prompt = event.message_str
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
contexts=session_provider_context if session_provider_context else []
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
|
||||
@filter.on_llm_request()
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import List
|
||||
class Google(SearchEngine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proxy = os.environ.get("HTTPS_PROXY")
|
||||
self.proxy = os.environ.get("https_proxy")
|
||||
print(f"Google Search using proxy: {self.proxy}")
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = []
|
||||
|
||||
@@ -22,6 +22,8 @@ class Main(star.Star):
|
||||
self.sogo_search = Sogo()
|
||||
self.google = Google()
|
||||
|
||||
self.websearch_link = self.context.get_config()['provider_settings'].get('web_search_link', False)
|
||||
|
||||
async def initialize(self):
|
||||
websearch = self.context.get_config()['provider_settings']['web_search']
|
||||
if websearch:
|
||||
@@ -109,8 +111,17 @@ class Main(star.Star):
|
||||
except BaseException:
|
||||
site_result = ""
|
||||
site_result = site_result[:700] + "..." if len(site_result) > 700 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
|
||||
header = f"{idx}. {i.title} "
|
||||
|
||||
if self.websearch_link and i.url:
|
||||
header += i.url
|
||||
|
||||
ret += f"{header}\n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
|
||||
if self.websearch_link:
|
||||
ret += "针对问题,请根据上面的结果分点总结,并且在结尾处附上对应内容的参考链接(如有)。"
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user