Refactor cron job handling and enhance proactive agent capabilities
- Updated FunctionToolExecutor to improve background task handling and integrate new system prompts for proactive agents. - Enhanced MainAgentBuildConfig with additional configuration options for tool management and context handling. - Introduced new system prompts for proactive agents triggered by cron jobs and background tasks to improve user interaction. - Refactored cron job management to utilize ProviderRequest for better context management and tool integration. - Renamed cron job tools for clarity, changing "create_cron_job" to "create_future_task" and similar adjustments for consistency. - Improved error handling and logging for cron job execution and agent responses. - Added support for image captioning and persona management in agent requests.
This commit is contained in:
@@ -7,7 +7,6 @@ from astrbot.api.provider import LLMResponse, ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
from .long_term_memory import LongTermMemory
|
||||
from .process_llm_request import ProcessLLMRequest
|
||||
|
||||
|
||||
class Main(star.Star):
|
||||
@@ -19,8 +18,6 @@ class Main(star.Star):
|
||||
except BaseException as e:
|
||||
logger.error(f"聊天增强 err: {e}")
|
||||
|
||||
self.proc_llm_req = ProcessLLMRequest(self.context)
|
||||
|
||||
def ltm_enabled(self, event: AstrMessageEvent):
|
||||
ltmse = self.context.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_ltm_settings"
|
||||
@@ -91,8 +88,6 @@ class Main(star.Star):
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
await self.proc_llm_req.process_llm_request(event, req)
|
||||
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.on_req_llm(event, req)
|
||||
|
||||
@@ -1,401 +0,0 @@
|
||||
import builtins
|
||||
import copy
|
||||
import datetime
|
||||
import zoneinfo
|
||||
|
||||
from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
)
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
|
||||
|
||||
|
||||
class ProcessLLMRequest:
|
||||
def __init__(self, context: star.Context):
|
||||
self.ctx = context
|
||||
cfg = context.get_config()
|
||||
self.timezone = cfg.get("timezone")
|
||||
if not self.timezone:
|
||||
# 系统默认时区
|
||||
self.timezone = None
|
||||
else:
|
||||
logger.info(f"Timezone set to: {self.timezone}")
|
||||
|
||||
self.skill_manager = SkillManager()
|
||||
|
||||
def _apply_local_env_tools(self, req: ProviderRequest) -> None:
|
||||
"""Add local environment tools to the provider request."""
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(LOCAL_PYTHON_TOOL)
|
||||
|
||||
async def _ensure_persona(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
cfg: dict,
|
||||
umo: str,
|
||||
platform_type: str,
|
||||
event: AstrMessageEvent,
|
||||
):
|
||||
"""确保用户人格已加载"""
|
||||
if not req.conversation:
|
||||
return
|
||||
# persona inject
|
||||
|
||||
# custom rule is preferred
|
||||
persona_id = (
|
||||
await sp.get_async(
|
||||
scope="umo", scope_id=umo, key="session_service_config", default={}
|
||||
)
|
||||
).get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
||||
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
|
||||
default_persona = self.ctx.persona_manager.selected_default_persona_v3
|
||||
if default_persona:
|
||||
persona_id = default_persona["name"]
|
||||
|
||||
# ChatUI special default persona
|
||||
if platform_type == "webchat":
|
||||
# non-existent persona_id to let following codes not working
|
||||
persona_id = "_chatui_default_"
|
||||
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
|
||||
|
||||
persona = next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == persona_id,
|
||||
self.ctx.persona_manager.personas_v3,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if persona:
|
||||
if prompt := persona["prompt"]:
|
||||
req.system_prompt += prompt
|
||||
if begin_dialogs := copy.deepcopy(persona["_begin_dialogs_processed"]):
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# skills select and prompt
|
||||
runtime = self.skills_cfg.get("runtime", "local")
|
||||
skills = self.skill_manager.list_skills(active_only=True, runtime=runtime)
|
||||
if runtime == "sandbox" and not self.sandbox_cfg.get("enable", False):
|
||||
logger.warning(
|
||||
"Skills runtime is set to sandbox, but sandbox mode is disabled, will skip skills prompt injection.",
|
||||
)
|
||||
req.system_prompt += "\n[Background: User added some skills, and skills runtime is set to sandbox, but sandbox mode is disabled. So skills will be unavailable.]\n"
|
||||
elif skills:
|
||||
# persona.skills == None means all skills are allowed
|
||||
if persona and persona.get("skills") is not None:
|
||||
if not persona["skills"]:
|
||||
return
|
||||
allowed = set(persona["skills"])
|
||||
skills = [skill for skill in skills if skill.name in allowed]
|
||||
if skills:
|
||||
req.system_prompt += f"\n{build_skills_prompt(skills)}\n"
|
||||
|
||||
# if user wants to use skills in non-sandbox mode, apply local env tools
|
||||
runtime = self.skills_cfg.get("runtime", "local")
|
||||
sandbox_enabled = self.sandbox_cfg.get("enable", False)
|
||||
if runtime == "local" and not sandbox_enabled:
|
||||
self._apply_local_env_tools(req)
|
||||
|
||||
# tools select
|
||||
tmgr = self.ctx.get_llm_tool_manager()
|
||||
|
||||
# SubAgent orchestrator mode: main LLM only sees handoff tools.
|
||||
# NOTE: subagent_orchestrator config lives at top-level now.
|
||||
orch_cfg = self.ctx.get_config().get("subagent_orchestrator", {})
|
||||
if orch_cfg.get("main_enable", False):
|
||||
policy = str(orch_cfg.get("main_tools_policy", "handoff_only")).strip()
|
||||
if policy not in {"handoff_only", "unassigned_to_main"}:
|
||||
# Prefer the safer default when config contains unknown values.
|
||||
policy = "handoff_only"
|
||||
|
||||
assigned_tools: set[str] = set()
|
||||
agents = orch_cfg.get("agents", [])
|
||||
if isinstance(agents, list):
|
||||
for a in agents:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if a.get("enabled", True) is False:
|
||||
continue
|
||||
persona_tools = None
|
||||
persona_id = a.get("persona_id")
|
||||
if persona_id:
|
||||
persona_tools = next(
|
||||
(
|
||||
p.get("tools")
|
||||
for p in self.ctx.persona_manager.personas_v3
|
||||
if p["name"] == persona_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
tools = a.get("tools", [])
|
||||
if persona_tools is not None:
|
||||
tools = persona_tools
|
||||
if tools is None:
|
||||
assigned_tools.update(
|
||||
[
|
||||
tool.name
|
||||
for tool in tmgr.func_list
|
||||
if not isinstance(tool, HandoffTool)
|
||||
]
|
||||
)
|
||||
continue
|
||||
if not isinstance(tools, list):
|
||||
continue
|
||||
for t in tools:
|
||||
name = str(t).strip()
|
||||
if name:
|
||||
assigned_tools.add(name)
|
||||
|
||||
toolset = ToolSet()
|
||||
|
||||
# Always expose handoff tools (transfer_to_*) when orchestrator is enabled.
|
||||
for tool in tmgr.func_list:
|
||||
if isinstance(tool, HandoffTool) and tool.active:
|
||||
toolset.add_tool(tool)
|
||||
|
||||
# Optional mode: keep tools that are not assigned to any subagent on the main LLM.
|
||||
if policy == "unassigned_to_main":
|
||||
for tool in tmgr.func_list:
|
||||
if not tool.active:
|
||||
continue
|
||||
if isinstance(tool, HandoffTool):
|
||||
continue
|
||||
if tool.handler_module_path == "core.subagent_orchestrator":
|
||||
continue
|
||||
if tool.name in assigned_tools:
|
||||
continue
|
||||
toolset.add_tool(tool)
|
||||
|
||||
# Override any earlier tool injection (e.g. skills local env tools) to keep
|
||||
# main-LLM tool visibility predictable under subagent orchestrator.
|
||||
req.func_tool = toolset
|
||||
|
||||
# Encourage the model to delegate to subagents.
|
||||
# Use the built-in default router prompt; user overrides are disabled for now.
|
||||
router_prompt = (
|
||||
self.ctx.get_config()
|
||||
.get("subagent_orchestrator", {})
|
||||
.get("router_system_prompt", "")
|
||||
).strip()
|
||||
if router_prompt:
|
||||
req.system_prompt += f"\n{router_prompt}\n"
|
||||
|
||||
if policy == "unassigned_to_main":
|
||||
req.system_prompt += (
|
||||
"\n[Note: You may directly call the tools visible to the main LLM "
|
||||
"if they are not assigned to any subagent; otherwise prefer delegating "
|
||||
"to subagents via transfer_to_*.]\n"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Default behavior: follow persona tool selection.
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
# select all
|
||||
toolset = tmgr.get_full_tool_set()
|
||||
for tool in toolset:
|
||||
if not tool.active:
|
||||
toolset.remove_tool(tool.name)
|
||||
else:
|
||||
toolset = ToolSet()
|
||||
if persona["tools"]:
|
||||
for tool_name in persona["tools"]:
|
||||
tool = tmgr.get_func(tool_name)
|
||||
if tool and tool.active:
|
||||
toolset.add_tool(tool)
|
||||
if not req.func_tool:
|
||||
req.func_tool = toolset
|
||||
else:
|
||||
req.func_tool.merge(toolset)
|
||||
event.trace.record(
|
||||
"sel_persona", persona_id=persona_id, persona_toolset=toolset.names()
|
||||
)
|
||||
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
|
||||
|
||||
async def _ensure_img_caption(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
cfg: dict,
|
||||
img_cap_prov_id: str,
|
||||
):
|
||||
try:
|
||||
caption = await self._request_img_caption(
|
||||
img_cap_prov_id,
|
||||
cfg,
|
||||
req.image_urls,
|
||||
)
|
||||
if caption:
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
|
||||
async def _request_img_caption(
|
||||
self,
|
||||
provider_id: str,
|
||||
cfg: dict,
|
||||
image_urls: list[str],
|
||||
) -> str:
|
||||
if prov := self.ctx.get_provider_by_id(provider_id):
|
||||
if isinstance(prov, Provider):
|
||||
img_cap_prompt = cfg.get(
|
||||
"image_caption_prompt",
|
||||
"Please describe the image.",
|
||||
)
|
||||
logger.debug(f"Processing image caption with provider: {provider_id}")
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt=img_cap_prompt,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
return llm_resp.completion_text
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.",
|
||||
)
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not exist.",
|
||||
)
|
||||
|
||||
async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[
|
||||
"provider_settings"
|
||||
]
|
||||
self.skills_cfg = cfg.get("skills", {})
|
||||
self.sandbox_cfg = cfg.get("sandbox", {})
|
||||
|
||||
# prompt prefix
|
||||
if prefix := cfg.get("prompt_prefix"):
|
||||
# 支持 {{prompt}} 作为用户输入的占位符
|
||||
if "{{prompt}}" in prefix:
|
||||
req.prompt = prefix.replace("{{prompt}}", req.prompt)
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# 收集系统提醒信息
|
||||
system_parts = []
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
if not event.message_obj.group:
|
||||
logger.error(
|
||||
f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
|
||||
)
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
current_time = None
|
||||
if self.timezone:
|
||||
# 启用时区
|
||||
try:
|
||||
now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone))
|
||||
current_time = now.strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
except Exception as e:
|
||||
logger.error(f"时区设置错误: {e}, 使用本地时区")
|
||||
if not current_time:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
# inject persona for this request
|
||||
platform_type = event.get_platform_name()
|
||||
await self._ensure_persona(
|
||||
req, cfg, event.unified_msg_origin, platform_type, event
|
||||
)
|
||||
|
||||
# image caption
|
||||
if img_cap_prov_id and req.image_urls:
|
||||
await self._ensure_img_caption(req, cfg, img_cap_prov_id)
|
||||
|
||||
# quote message processing
|
||||
# 解析引用内容
|
||||
quote = None
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Reply):
|
||||
quote = comp
|
||||
break
|
||||
if quote:
|
||||
content_parts = []
|
||||
|
||||
# 1. 处理引用的文本
|
||||
sender_info = (
|
||||
f"({quote.sender_nickname}): " if quote.sender_nickname else ""
|
||||
)
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
content_parts.append(f"{sender_info}{message_str}")
|
||||
|
||||
# 2. 处理引用的图片 (保留原有逻辑,但改变输出目标)
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
|
||||
if image_seg:
|
||||
try:
|
||||
# 找到可以生成图片描述的 provider
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = self.ctx.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = self.ctx.get_using_provider(event.unified_msg_origin)
|
||||
|
||||
# 调用 provider 生成图片描述
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
# 将图片描述作为文本添加到 content_parts
|
||||
content_parts.append(
|
||||
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No provider found for image captioning in quote."
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
|
||||
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||
# 确保引用内容被正确的标签包裹
|
||||
quoted_content = "\n".join(content_parts)
|
||||
# 确保所有内容都在<Quoted Message>标签内
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
# 统一包裹所有系统提醒
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
@@ -3,6 +3,7 @@ import inspect
|
||||
import traceback
|
||||
import typing as T
|
||||
import uuid
|
||||
import json
|
||||
|
||||
import mcp
|
||||
|
||||
@@ -20,9 +21,13 @@ from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
)
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.message.components import Plain
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT,
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
)
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@@ -148,7 +153,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
task_id: str,
|
||||
**tool_args,
|
||||
):
|
||||
from astrbot.core.astr_main_agent import build_main_agent, MainAgentBuildConfig
|
||||
from astrbot.core.astr_main_agent import (
|
||||
build_main_agent,
|
||||
MainAgentBuildConfig,
|
||||
_get_session_conv,
|
||||
)
|
||||
|
||||
# run the tool
|
||||
result_text = ""
|
||||
@@ -191,47 +200,47 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
message_type=session.message_type,
|
||||
)
|
||||
config = MainAgentBuildConfig(tool_call_timeout=3600)
|
||||
result = await build_main_agent(
|
||||
event=cron_event, plugin_context=ctx, config=config
|
||||
)
|
||||
if not result:
|
||||
logger.error("Failed to build main agent for cron job.")
|
||||
return
|
||||
runner = result.agent_runner
|
||||
req = result.provider_request
|
||||
|
||||
bg = extras["background_task_result"]
|
||||
result_text = bg["result"] or "Empty Response"
|
||||
if req.contexts:
|
||||
req = ProviderRequest()
|
||||
conv = await _get_session_conv(event=cron_event, plugin_context=ctx)
|
||||
req.conversation = conv
|
||||
context = json.loads(conv.history)
|
||||
if context:
|
||||
req.contexts = context
|
||||
context_dump = req._print_friendly_context()
|
||||
req.contexts = []
|
||||
req.system_prompt += (
|
||||
"\n\nBellow is you and user previous conversation history:\n"
|
||||
f"{context_dump}"
|
||||
)
|
||||
req.system_prompt += (
|
||||
"You now have a new background task result:\n"
|
||||
f"- Task ID: {bg['task_id']}\n"
|
||||
f"- Executed Tool: {tool.name}\n"
|
||||
f"- Tool Args: {tool_args}\n"
|
||||
f"- Result: {result_text}\n"
|
||||
f"- Note: {note}\n"
|
||||
"Please tell the user the result of the background task in your next response."
|
||||
)
|
||||
|
||||
bg = json.dumps(extras["background_task_result"], ensure_ascii=False)
|
||||
req.system_prompt += BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT.format(
|
||||
background_task_result=bg
|
||||
)
|
||||
req.prompt = (
|
||||
"You have a new background task result to report to the user."
|
||||
" Please include the result in your next response."
|
||||
" Using same language as previous conversation."
|
||||
"Proceed according to your system instructions. "
|
||||
"Output using same language as previous conversation."
|
||||
)
|
||||
if not req.func_tool:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
|
||||
|
||||
result = await build_main_agent(
|
||||
event=cron_event, plugin_context=ctx, config=config, req=req
|
||||
)
|
||||
if not result:
|
||||
logger.error("Failed to build main agent for background task job.")
|
||||
return
|
||||
|
||||
runner = result.agent_runner
|
||||
async for _ in runner.step_until_done(30):
|
||||
# agent will send message to user via using tools
|
||||
pass
|
||||
llm_resp = runner.get_final_llm_resp()
|
||||
if not llm_resp:
|
||||
logger.warning("Cron job agent got no response")
|
||||
logger.warning("background task agent got no response")
|
||||
return
|
||||
message_chain = MessageChain(chain=[Plain(text=llm_resp.completion_text)])
|
||||
await ctx.send_message(session=session, message_chain=message_chain)
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
|
||||
+445
-19
@@ -1,22 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import zoneinfo
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.api import sp
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from astrbot.core.astr_agent_run_util import AgentRunner
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
CHATUI_EXTRA_PROMPT,
|
||||
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
LIVE_MODE_SYSTEM_PROMPT,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT,
|
||||
PYTHON_TOOL,
|
||||
SANDBOX_MODE_PROMPT,
|
||||
TOOL_CALL_PROMPT,
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
|
||||
retrieve_knowledge_base,
|
||||
)
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.tools.cron_tools import (
|
||||
@@ -27,42 +51,59 @@ from astrbot.core.tools.cron_tools import (
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
|
||||
from .astr_main_agent_resources import (
|
||||
CHATUI_EXTRA_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
LIVE_MODE_SYSTEM_PROMPT,
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT,
|
||||
PYTHON_TOOL,
|
||||
SANDBOX_MODE_PROMPT,
|
||||
TOOL_CALL_PROMPT,
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
|
||||
retrieve_knowledge_base,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MainAgentBuildConfig:
|
||||
"""The main agent build configuration.
|
||||
Most of the configs can be found in the cmd_config.json"""
|
||||
|
||||
tool_call_timeout: int
|
||||
"""The timeout (in seconds) for a tool call.
|
||||
When the tool call exceeds this time,
|
||||
a timeout error as a tool result will be returned.
|
||||
"""
|
||||
tool_schema_mode: str = "full"
|
||||
"""The tool schema mode, can be 'full' or 'skills-like'."""
|
||||
provider_wake_prefix: str = ""
|
||||
"""The wake prefix for the provider. If the user message does not start with this prefix,
|
||||
the main agent will not be triggered."""
|
||||
streaming_response: bool = True
|
||||
"""Whether to use streaming response."""
|
||||
sanitize_context_by_modalities: bool = False
|
||||
"""Whether to sanitize the context based on the provider's supported modalities.
|
||||
This will remove unsupported message types(e.g. image) from the context to prevent issues."""
|
||||
kb_agentic_mode: bool = False
|
||||
"""Whether to use agentic mode for knowledge base retrieval.
|
||||
This will inject the knowledge base query tool into the main agent's toolset to allow dynamic querying."""
|
||||
file_extract_enabled: bool = False
|
||||
"""Whether to enable file content extraction for uploaded files."""
|
||||
file_extract_prov: str = "moonshotai"
|
||||
"""The file extraction provider."""
|
||||
file_extract_msh_api_key: str = ""
|
||||
"""The API key for Moonshot AI file extraction provider."""
|
||||
context_limit_reached_strategy: str = "truncate_by_turns"
|
||||
"""The strategy to handle context length limit reached."""
|
||||
llm_compress_instruction: str = ""
|
||||
llm_compress_keep_recent: int = 4
|
||||
"""The instruction for compression in llm_compress strategy."""
|
||||
llm_compress_keep_recent: int = 6
|
||||
"""The number of most recent turns to keep during llm_compress strategy."""
|
||||
llm_compress_provider_id: str = ""
|
||||
max_context_length: int = 0
|
||||
"""The provider ID for the LLM used in context compression."""
|
||||
max_context_length: int = -1
|
||||
"""The maximum number of turns to keep in context. -1 means no limit.
|
||||
This enforce max turns before compression"""
|
||||
dequeue_context_length: int = 1
|
||||
"""The number of oldest turns to remove when context length limit is reached."""
|
||||
llm_safety_mode: bool = True
|
||||
"""This will inject healthy and safe system prompt into the main agent,
|
||||
to prevent LLM output harmful information"""
|
||||
safety_mode_strategy: str = "system_prompt"
|
||||
sandbox_cfg: dict = field(default_factory=dict)
|
||||
add_cron_tools: bool = True
|
||||
"""This will add cron job management tools to the main agent for proactive cron job execution."""
|
||||
provider_settings: dict = field(default_factory=dict)
|
||||
subagent_orchestrator: dict = field(default_factory=dict)
|
||||
timezone: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -189,6 +230,388 @@ async def _apply_file_extract(
|
||||
)
|
||||
|
||||
|
||||
def _apply_prompt_prefix(req: ProviderRequest, cfg: dict) -> None:
|
||||
prefix = cfg.get("prompt_prefix")
|
||||
if not prefix:
|
||||
return
|
||||
if "{{prompt}}" in prefix:
|
||||
req.prompt = prefix.replace("{{prompt}}", req.prompt)
|
||||
else:
|
||||
req.prompt = f"{prefix}{req.prompt}"
|
||||
|
||||
|
||||
def _apply_local_env_tools(req: ProviderRequest) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(LOCAL_EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(LOCAL_PYTHON_TOOL)
|
||||
|
||||
|
||||
async def _ensure_persona_and_skills(
|
||||
req: ProviderRequest,
|
||||
cfg: dict,
|
||||
plugin_context: Context,
|
||||
event: AstrMessageEvent,
|
||||
) -> None:
|
||||
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
|
||||
if not req.conversation:
|
||||
return
|
||||
|
||||
# get persona ID
|
||||
persona_id = (
|
||||
await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=event.unified_msg_origin,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
).get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
||||
if persona_id is None or persona_id != "[%None]":
|
||||
default_persona = plugin_context.persona_manager.selected_default_persona_v3
|
||||
if default_persona:
|
||||
persona_id = default_persona["name"]
|
||||
if event.get_platform_name() == "webchat":
|
||||
persona_id = "_chatui_default_"
|
||||
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
|
||||
|
||||
persona = next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == persona_id,
|
||||
plugin_context.persona_manager.personas_v3,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if persona:
|
||||
# Inject persona system prompt
|
||||
if prompt := persona["prompt"]:
|
||||
req.system_prompt += f"\n# Persona Instructions\n\n{prompt}\n"
|
||||
if begin_dialogs := copy.deepcopy(persona.get("_begin_dialogs_processed")):
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# Inject skills prompt
|
||||
skills_cfg = cfg.get("skills", {})
|
||||
sandbox_cfg = cfg.get("sandbox", {})
|
||||
skill_manager = SkillManager()
|
||||
runtime = skills_cfg.get("runtime", "local")
|
||||
skills = skill_manager.list_skills(active_only=True, runtime=runtime)
|
||||
|
||||
if runtime == "sandbox" and not sandbox_cfg.get("enable", False):
|
||||
logger.warning(
|
||||
"Skills runtime is set to sandbox, but sandbox mode is disabled, will skip skills prompt injection.",
|
||||
)
|
||||
req.system_prompt += (
|
||||
"\n[Background: User added some skills, and skills runtime is set to sandbox, "
|
||||
"but sandbox mode is disabled. So skills will be unavailable.]\n"
|
||||
)
|
||||
elif skills:
|
||||
if persona and persona.get("skills") is not None:
|
||||
if not persona["skills"]:
|
||||
skills = []
|
||||
else:
|
||||
allowed = set(persona["skills"])
|
||||
skills = [skill for skill in skills if skill.name in allowed]
|
||||
if skills:
|
||||
req.system_prompt += f"\n{build_skills_prompt(skills)}\n"
|
||||
|
||||
runtime = skills_cfg.get("runtime", "local")
|
||||
sandbox_enabled = sandbox_cfg.get("enable", False)
|
||||
if runtime == "local" and not sandbox_enabled:
|
||||
_apply_local_env_tools(req)
|
||||
|
||||
tmgr = plugin_context.get_llm_tool_manager()
|
||||
|
||||
orch_cfg = plugin_context.get_config().get("subagent_orchestrator", {})
|
||||
if orch_cfg.get("main_enable", False):
|
||||
policy = str(orch_cfg.get("main_tools_policy", "handoff_only")).strip()
|
||||
if policy not in {"handoff_only", "unassigned_to_main"}:
|
||||
policy = "handoff_only"
|
||||
|
||||
assigned_tools: set[str] = set()
|
||||
agents = orch_cfg.get("agents", [])
|
||||
if isinstance(agents, list):
|
||||
for a in agents:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if a.get("enabled", True) is False:
|
||||
continue
|
||||
persona_tools = None
|
||||
pid = a.get("persona_id")
|
||||
if pid:
|
||||
persona_tools = next(
|
||||
(
|
||||
p.get("tools")
|
||||
for p in plugin_context.persona_manager.personas_v3
|
||||
if p["name"] == pid
|
||||
),
|
||||
None,
|
||||
)
|
||||
tools = a.get("tools", [])
|
||||
if persona_tools is not None:
|
||||
tools = persona_tools
|
||||
if tools is None:
|
||||
assigned_tools.update(
|
||||
[
|
||||
tool.name
|
||||
for tool in tmgr.func_list
|
||||
if not isinstance(tool, HandoffTool)
|
||||
]
|
||||
)
|
||||
continue
|
||||
if not isinstance(tools, list):
|
||||
continue
|
||||
for t in tools:
|
||||
name = str(t).strip()
|
||||
if name:
|
||||
assigned_tools.add(name)
|
||||
|
||||
toolset = ToolSet()
|
||||
for tool in tmgr.func_list:
|
||||
if isinstance(tool, HandoffTool) and tool.active:
|
||||
toolset.add_tool(tool)
|
||||
|
||||
if policy == "unassigned_to_main":
|
||||
for tool in tmgr.func_list:
|
||||
if not tool.active:
|
||||
continue
|
||||
if isinstance(tool, HandoffTool):
|
||||
continue
|
||||
if tool.handler_module_path == "core.subagent_orchestrator":
|
||||
continue
|
||||
if tool.name in assigned_tools:
|
||||
continue
|
||||
toolset.add_tool(tool)
|
||||
|
||||
req.func_tool = toolset
|
||||
|
||||
router_prompt = (
|
||||
plugin_context.get_config()
|
||||
.get("subagent_orchestrator", {})
|
||||
.get("router_system_prompt", "")
|
||||
).strip()
|
||||
if router_prompt:
|
||||
req.system_prompt += f"\n{router_prompt}\n"
|
||||
if policy == "unassigned_to_main":
|
||||
req.system_prompt += (
|
||||
"\n[Note: You may directly call the tools visible to the main LLM "
|
||||
"if they are not assigned to any subagent; otherwise prefer delegating "
|
||||
"to subagents via transfer_to_*.]\n"
|
||||
)
|
||||
return
|
||||
|
||||
# inject toolset in the persona
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
toolset = tmgr.get_full_tool_set()
|
||||
for tool in list(toolset):
|
||||
if not tool.active:
|
||||
toolset.remove_tool(tool.name)
|
||||
else:
|
||||
toolset = ToolSet()
|
||||
if persona["tools"]:
|
||||
for tool_name in persona["tools"]:
|
||||
tool = tmgr.get_func(tool_name)
|
||||
if tool and tool.active:
|
||||
toolset.add_tool(tool)
|
||||
if not req.func_tool:
|
||||
req.func_tool = toolset
|
||||
else:
|
||||
req.func_tool.merge(toolset)
|
||||
try:
|
||||
event.trace.record(
|
||||
"sel_persona", persona_id=persona_id, persona_toolset=toolset.names()
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Tool set for persona %s: %s", persona_id, toolset.names())
|
||||
|
||||
|
||||
async def _request_img_caption(
|
||||
provider_id: str,
|
||||
cfg: dict,
|
||||
image_urls: list[str],
|
||||
plugin_context: Context,
|
||||
) -> str:
|
||||
prov = plugin_context.get_provider_by_id(provider_id)
|
||||
if prov is None:
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not exist.",
|
||||
)
|
||||
if not isinstance(prov, Provider):
|
||||
raise ValueError(
|
||||
f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}.",
|
||||
)
|
||||
|
||||
img_cap_prompt = cfg.get(
|
||||
"image_caption_prompt",
|
||||
"Please describe the image.",
|
||||
)
|
||||
logger.debug("Processing image caption with provider: %s", provider_id)
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt=img_cap_prompt,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
return llm_resp.completion_text
|
||||
|
||||
|
||||
async def _ensure_img_caption(
|
||||
req: ProviderRequest,
|
||||
cfg: dict,
|
||||
plugin_context: Context,
|
||||
image_caption_provider: str,
|
||||
) -> None:
|
||||
try:
|
||||
caption = await _request_img_caption(
|
||||
image_caption_provider,
|
||||
cfg,
|
||||
req.image_urls,
|
||||
plugin_context,
|
||||
)
|
||||
if caption:
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("处理图片描述失败: %s", exc)
|
||||
|
||||
|
||||
async def _process_quote_message(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
img_cap_prov_id: str,
|
||||
plugin_context: Context,
|
||||
) -> None:
|
||||
quote = None
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Reply):
|
||||
quote = comp
|
||||
break
|
||||
if not quote:
|
||||
return
|
||||
|
||||
content_parts = []
|
||||
sender_info = f"({quote.sender_nickname}): " if quote.sender_nickname else ""
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
content_parts.append(f"{sender_info}{message_str}")
|
||||
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
|
||||
if image_seg:
|
||||
try:
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = plugin_context.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = plugin_context.get_using_provider(event.unified_msg_origin)
|
||||
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
content_parts.append(
|
||||
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
|
||||
)
|
||||
else:
|
||||
logger.warning("No provider found for image captioning in quote.")
|
||||
except BaseException as exc:
|
||||
logger.error("处理引用图片失败: %s", exc)
|
||||
|
||||
quoted_content = "\n".join(content_parts)
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
|
||||
def _append_system_reminders(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
cfg: dict,
|
||||
timezone: str | None,
|
||||
) -> None:
|
||||
system_parts: list[str] = []
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
if not event.message_obj.group:
|
||||
logger.error(
|
||||
"Group name display enabled but group object is None. Group ID: %s",
|
||||
event.message_obj.group_id,
|
||||
)
|
||||
else:
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
current_time = None
|
||||
if timezone:
|
||||
try:
|
||||
now = datetime.datetime.now(zoneinfo.ZoneInfo(timezone))
|
||||
current_time = now.strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("时区设置错误: %s, 使用本地时区", exc)
|
||||
if not current_time:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
|
||||
|
||||
async def _decorate_llm_request(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
plugin_context: Context,
|
||||
config: MainAgentBuildConfig,
|
||||
) -> None:
|
||||
cfg = config.provider_settings or plugin_context.get_config(
|
||||
umo=event.unified_msg_origin
|
||||
).get("provider_settings", {})
|
||||
|
||||
_apply_prompt_prefix(req, cfg)
|
||||
|
||||
if req.conversation:
|
||||
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if img_cap_prov_id and req.image_urls:
|
||||
await _ensure_img_caption(
|
||||
req,
|
||||
cfg,
|
||||
plugin_context,
|
||||
img_cap_prov_id,
|
||||
)
|
||||
|
||||
img_cap_prov_id = cfg.get("default_image_caption_provider_id") or ""
|
||||
await _process_quote_message(
|
||||
event,
|
||||
req,
|
||||
img_cap_prov_id,
|
||||
plugin_context,
|
||||
)
|
||||
|
||||
tz = config.timezone
|
||||
if tz is None:
|
||||
tz = plugin_context.get_config().get("timezone")
|
||||
_append_system_reminders(event, req, cfg, tz)
|
||||
|
||||
|
||||
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
@@ -373,7 +796,7 @@ def _apply_sandbox_tools(
|
||||
req.system_prompt += f"\n{SANDBOX_MODE_PROMPT}\n"
|
||||
|
||||
|
||||
def _proactive_cron_job_tools(req: ProviderRequest, event: AstrMessageEvent) -> None:
|
||||
def _proactive_cron_job_tools(req: ProviderRequest) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(CREATE_CRON_JOB_TOOL)
|
||||
@@ -474,6 +897,8 @@ async def build_main_agent(
|
||||
else:
|
||||
return None
|
||||
|
||||
await _decorate_llm_request(event, req, plugin_context, config)
|
||||
|
||||
await _apply_kb(event, req, plugin_context, config)
|
||||
|
||||
if not req.session_id:
|
||||
@@ -495,7 +920,8 @@ async def build_main_agent(
|
||||
event=event,
|
||||
)
|
||||
|
||||
_proactive_cron_job_tools(req, event)
|
||||
if config.add_cron_tools:
|
||||
_proactive_cron_job_tools(req)
|
||||
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
|
||||
@@ -41,11 +41,12 @@ SANDBOX_MODE_PROMPT = (
|
||||
)
|
||||
|
||||
TOOL_CALL_PROMPT = (
|
||||
"You MUST NOT return an empty response, especially after invoking a tool."
|
||||
" Before calling any tool, provide a brief explanatory message to the user stating the purpose of the tool call."
|
||||
" Use the provided tool schema to format arguments and do not guess parameters that are not defined."
|
||||
" After the tool call is completed, you must briefly summarize the results returned by the tool for the user."
|
||||
" Keep the role-play and style consistent throughout the conversation."
|
||||
"When using tools: "
|
||||
"never return an empty response; "
|
||||
"briefly explain the purpose before calling a tool; "
|
||||
"follow the tool schema exactly and do not invent parameters; "
|
||||
"after execution, briefly summarize the result for the user; "
|
||||
"keep the conversation style consistent."
|
||||
)
|
||||
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE = (
|
||||
@@ -91,6 +92,43 @@ LIVE_MODE_SYSTEM_PROMPT = (
|
||||
"Sound like a real conversation, not a Q&A system."
|
||||
)
|
||||
|
||||
PROACTIVE_AGENT_CRON_WOKE_SYSTEM_PROMPT = (
|
||||
"You are an autonomous proactive agent.\n\n"
|
||||
"You are awakened by a scheduled cron job, not by a user message.\n"
|
||||
"You are given:"
|
||||
"1. A cron job description explaining why you are activated.\n"
|
||||
"2. Historical conversation context between you and the user.\n"
|
||||
"3. Your available tools and skills.\n"
|
||||
"# IMPORTANT RULES\n"
|
||||
"1. This is NOT a chat turn. Do NOT greet the user. Do NOT ask the user questions unless strictly necessary.\n"
|
||||
"2. Use historical conversation and memory to understand you and user's relationship, preferences, and context.\n"
|
||||
"3. If messaging the user: Explain WHY you are contacting them; Reference the cron task implicitly (not technical details).\n"
|
||||
"4. You can use your available tools and skills to finish the task if needed.\n"
|
||||
"5. Use `send_message_to_user` tool to send message to user if needed."
|
||||
"# CRON JOB CONTEXT\n"
|
||||
"The following object describes the scheduled task that triggered you:\n"
|
||||
"{cron_job}"
|
||||
)
|
||||
|
||||
BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT = (
|
||||
"You are an autonomous proactive agent.\n\n"
|
||||
"You are awakened by the completion of a background task you initiated earlier.\n"
|
||||
"You are given:"
|
||||
"1. A description of the background task you initiated.\n"
|
||||
"2. The result of the background task.\n"
|
||||
"3. Historical conversation context between you and the user.\n"
|
||||
"4. Your available tools and skills.\n"
|
||||
"# IMPORTANT RULES\n"
|
||||
"1. This is NOT a chat turn. Do NOT greet the user. Do NOT ask the user questions unless strictly necessary. Do NOT respond if no meaningful action is required."
|
||||
"2. Use historical conversation and memory to understand you and user's relationship, preferences, and context."
|
||||
"3. If messaging the user: Explain WHY you are contacting them; Reference the background task implicitly (not technical details)."
|
||||
"4. You can use your available tools and skills to finish the task if needed.\n"
|
||||
"5. Use `send_message_to_user` tool to send message to user if needed."
|
||||
"# BACKGROUND TASK CONTEXT\n"
|
||||
"The following object describes the background task that completed:\n"
|
||||
"{background_task_result}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]):
|
||||
|
||||
@@ -91,7 +91,7 @@ DEFAULT_CONFIG = {
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_keep_recent": 6,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Awaitable, Callable
|
||||
from zoneinfo import ZoneInfo
|
||||
@@ -7,12 +8,12 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.cron.events import CronMessageEvent
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import CronJob
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.message.components import Plain
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -239,7 +240,16 @@ class CronJobManager:
|
||||
session_str: str,
|
||||
extras: dict,
|
||||
):
|
||||
from astrbot.core.astr_main_agent import build_main_agent, MainAgentBuildConfig
|
||||
"""Woke the main agent to handle the cron job message."""
|
||||
from astrbot.core.astr_main_agent import (
|
||||
build_main_agent,
|
||||
MainAgentBuildConfig,
|
||||
_get_session_conv,
|
||||
)
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
PROACTIVE_AGENT_CRON_WOKE_SYSTEM_PROMPT,
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
)
|
||||
|
||||
try:
|
||||
session = (
|
||||
@@ -259,43 +269,53 @@ class CronJobManager:
|
||||
message_type=session.message_type,
|
||||
)
|
||||
|
||||
config = MainAgentBuildConfig(tool_call_timeout=3600)
|
||||
config = MainAgentBuildConfig(
|
||||
tool_call_timeout=3600,
|
||||
llm_safety_mode=False,
|
||||
)
|
||||
req = ProviderRequest()
|
||||
conv = await _get_session_conv(event=cron_event, plugin_context=self.ctx)
|
||||
req.conversation = conv
|
||||
# finetine the messages
|
||||
context = json.loads(conv.history)
|
||||
if context:
|
||||
req.contexts = context
|
||||
context_dump = req._print_friendly_context()
|
||||
req.contexts = []
|
||||
req.system_prompt += (
|
||||
"\n\nBellow is you and user previous conversation history:\n"
|
||||
f"---\n"
|
||||
f"{context_dump}\n"
|
||||
f"---\n"
|
||||
)
|
||||
cron_job_str = json.dumps(extras.get("cron_job", {}), ensure_ascii=False)
|
||||
req.system_prompt += PROACTIVE_AGENT_CRON_WOKE_SYSTEM_PROMPT.format(
|
||||
cron_job=cron_job_str
|
||||
)
|
||||
req.prompt = (
|
||||
"You are now responding to a scheduled task"
|
||||
"Proceed according to your system instructions. "
|
||||
"Output using same language as previous conversation."
|
||||
)
|
||||
if not req.func_tool:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
|
||||
|
||||
result = await build_main_agent(
|
||||
event=cron_event, plugin_context=self.ctx, config=config
|
||||
event=cron_event, plugin_context=self.ctx, config=config, req=req
|
||||
)
|
||||
if not result:
|
||||
logger.error("Failed to build main agent for cron job.")
|
||||
return
|
||||
req = result.provider_request
|
||||
|
||||
runner = result.agent_runner
|
||||
|
||||
# finetine the messages
|
||||
job_name = extras.get("name", "scheduled task")
|
||||
note = extras.get("note") or extras.get("description") or ""
|
||||
if req.contexts:
|
||||
context_dump = req._print_friendly_context()
|
||||
req.system_prompt += (
|
||||
"\n\nBellow is you and user previous conversation history:\n"
|
||||
f"{context_dump}"
|
||||
)
|
||||
req.system_prompt += (
|
||||
"\n[Scheduler Context] This turn is triggered automatically by cron job "
|
||||
f'"{job_name}" (type: {extras.get("type", "unknown")}). '
|
||||
"Act proactively based on the provided note and current context. "
|
||||
)
|
||||
if note:
|
||||
req.system_prompt += f"[Scheduler Note]: {note}\n"
|
||||
|
||||
req.prompt = "You are now responding to a scheduled task. Output using same language as previous conversation."
|
||||
|
||||
async for _ in runner.step_until_done(30):
|
||||
# agent will send message to user via using tools
|
||||
pass
|
||||
llm_resp = runner.get_final_llm_resp()
|
||||
if not llm_resp:
|
||||
logger.warning("Cron job agent got no response")
|
||||
return
|
||||
message_chain = MessageChain(chain=[Plain(text=llm_resp.completion_text)])
|
||||
await self.ctx.send_message(session=session, message_chain=message_chain)
|
||||
|
||||
|
||||
__all__ = ["CronJobManager"]
|
||||
|
||||
@@ -113,6 +113,9 @@ class InternalAgentSubStage(Stage):
|
||||
llm_safety_mode=self.llm_safety_mode,
|
||||
safety_mode_strategy=self.safety_mode_strategy,
|
||||
sandbox_cfg=self.sandbox_cfg,
|
||||
provider_settings=settings,
|
||||
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
|
||||
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
|
||||
)
|
||||
|
||||
async def process(
|
||||
|
||||
@@ -8,10 +8,10 @@ from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
|
||||
@dataclass
|
||||
class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "create_cron_job"
|
||||
name: str = "create_future_task"
|
||||
description: str = (
|
||||
"Create a scheduled active agent task using a cron expression. "
|
||||
"Use this when the user asks for recurring tasks (e.g., daily reports)."
|
||||
"Create a future task for your future using a cron expression. "
|
||||
"Use this when you or the user want recurring follow-up (e.g., daily report to self)."
|
||||
)
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
@@ -19,15 +19,15 @@ class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
|
||||
"properties": {
|
||||
"cron_expression": {
|
||||
"type": "string",
|
||||
"description": "Cron expression defining when to trigger (e.g., '0 8 * * *').",
|
||||
"description": "Cron expression defining when your future agent should wake (e.g., '0 8 * * *').",
|
||||
},
|
||||
"note": {
|
||||
"type": "string",
|
||||
"description": "Instruction for the future agent run when the job triggers.",
|
||||
"description": "Detailed instructions for your future agent to execute when it wakes.",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional job name for identification.",
|
||||
"description": "Optional label to recognize this future task.",
|
||||
},
|
||||
},
|
||||
"required": ["cron_expression", "note"],
|
||||
@@ -61,15 +61,15 @@ class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
|
||||
)
|
||||
next_run = job.next_run_time
|
||||
return (
|
||||
f"Scheduled cron job {job.job_id} ({job.name}) with expression '{cron_expression}'. "
|
||||
f"Next run: {next_run}"
|
||||
f"Scheduled future task {job.job_id} ({job.name}) with expression '{cron_expression}'. "
|
||||
f"Your future agent will wake at: {next_run}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteCronJobTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "delete_cron_job"
|
||||
description: str = "Delete a cron job by its job_id."
|
||||
name: str = "delete_future_task"
|
||||
description: str = "Delete a future task (cron job) by its job_id."
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
@@ -98,8 +98,8 @@ class DeleteCronJobTool(FunctionTool[AstrAgentContext]):
|
||||
|
||||
@dataclass
|
||||
class ListCronJobsTool(FunctionTool[AstrAgentContext]):
|
||||
name: str = "list_cron_jobs"
|
||||
description: str = "List existing cron jobs for inspection."
|
||||
name: str = "list_future_tasks"
|
||||
description: str = "List existing future tasks (cron jobs) for inspection."
|
||||
parameters: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"type": "object",
|
||||
|
||||
Reference in New Issue
Block a user