diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..d8fdb04ba --- /dev/null +++ b/Makefile @@ -0,0 +1,32 @@ +.PHONY: worktree worktree-add worktree-rm + +WORKTREE_DIR ?= ../astrbot_worktree +BRANCH ?= $(word 2,$(MAKECMDGOALS)) +BASE ?= $(word 3,$(MAKECMDGOALS)) +BASE ?= master + +worktree: + @echo "Usage:" + @echo " make worktree-add [base-branch]" + @echo " make worktree-rm " + +worktree-add: +ifeq ($(strip $(BRANCH)),) + $(error Branch name required. Usage: make worktree-add [base-branch]) +endif + @mkdir -p $(WORKTREE_DIR) + git worktree add $(WORKTREE_DIR)/$(BRANCH) -b $(BRANCH) $(BASE) + +worktree-rm: +ifeq ($(strip $(BRANCH)),) + $(error Branch name required. Usage: make worktree-rm ) +endif + @if [ -d "$(WORKTREE_DIR)/$(BRANCH)" ]; then \ + git worktree remove $(WORKTREE_DIR)/$(BRANCH); \ + else \ + echo "Worktree $(WORKTREE_DIR)/$(BRANCH) not found."; \ + fi + +# Swallow extra args (branch/base) so make doesn't treat them as targets +%: + @true diff --git a/astrbot/builtin_stars/astrbot/process_llm_request.py b/astrbot/builtin_stars/astrbot/process_llm_request.py index fb8639c65..06bfe790a 100644 --- a/astrbot/builtin_stars/astrbot/process_llm_request.py +++ b/astrbot/builtin_stars/astrbot/process_llm_request.py @@ -9,7 +9,7 @@ from astrbot.api.message_components import Image, Reply from astrbot.api.provider import Provider, ProviderRequest from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.message import TextPart -from astrbot.core.pipeline.process_stage.utils import ( +from astrbot.core.astr_main_agent_resources import ( CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT, LOCAL_EXECUTE_SHELL_TOOL, LOCAL_PYTHON_TOOL, diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 2a50a00bb..91f8b3e21 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -22,6 +22,7 @@ from astrbot.core.message.message_event_result import ( ) from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.register import llm_tools +from astrbot.core.message.components import Plain class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @@ -147,6 +148,8 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): task_id: str, **tool_args, ): + from astrbot.core.astr_main_agent import build_main_agent, MainAgentBuildConfig + # run the tool result_text = "" try: @@ -187,7 +190,48 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): extras=extras, message_type=session.message_type, ) - ctx.get_event_queue().put_nowait(cron_event) + config = MainAgentBuildConfig(tool_call_timeout=3600) + result = await build_main_agent( + event=cron_event, plugin_context=ctx, config=config + ) + if not result: + logger.error("Failed to build main agent for cron job.") + return + runner = result.agent_runner + req = result.provider_request + + bg = extras["background_task_result"] + result_text = bg["result"] or "Empty Response" + if req.contexts: + context_dump = req._print_friendly_context() + req.system_prompt += ( + "\n\nBellow is you and user previous conversation history:\n" + f"{context_dump}" + ) + req.system_prompt += ( + "You now have a new background task result:\n" + f"- Task ID: {bg['task_id']}\n" + f"- Executed Tool: {tool.name}\n" + f"- Tool Args: {tool_args}\n" + f"- Result: {result_text}\n" + f"- Note: {note}\n" + "Please tell the user the result of the background task in your next response." + ) + + req.prompt = ( + "You have a new background task result to report to the user." + " Please include the result in your next response." + " Using same language as previous conversation." + ) + + async for _ in runner.step_until_done(30): + pass + llm_resp = runner.get_final_llm_resp() + if not llm_resp: + logger.warning("Cron job agent got no response") + return + message_chain = MessageChain(chain=[Plain(text=llm_resp.completion_text)]) + await ctx.send_message(session=session, message_chain=message_chain) @classmethod async def _execute_local( diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py new file mode 100644 index 000000000..b35bd971b --- /dev/null +++ b/astrbot/core/astr_main_agent.py @@ -0,0 +1,545 @@ +from __future__ import annotations + +import asyncio +import json +import os +from dataclasses import dataclass, field + +from astrbot.core import logger +from astrbot.core.agent.message import TextPart +from astrbot.core.agent.tool import ToolSet +from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext +from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS +from astrbot.core.astr_agent_run_util import AgentRunner +from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor +from astrbot.core.conversation_mgr import Conversation +from astrbot.core.message.components import File, Image, Reply +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider import Provider +from astrbot.core.provider.entities import ProviderRequest +from astrbot.core.star.context import Context +from astrbot.core.star.star_handler import star_map +from astrbot.core.tools.cron_tools import ( + CREATE_CRON_JOB_TOOL, + DELETE_CRON_JOB_TOOL, + LIST_CRON_JOBS_TOOL, +) +from astrbot.core.utils.file_extract import extract_file_moonshotai +from astrbot.core.utils.llm_metadata import LLM_METADATAS + +from .astr_main_agent_resources import ( + CHATUI_EXTRA_PROMPT, + EXECUTE_SHELL_TOOL, + FILE_DOWNLOAD_TOOL, + FILE_UPLOAD_TOOL, + KNOWLEDGE_BASE_QUERY_TOOL, + LIVE_MODE_SYSTEM_PROMPT, + LLM_SAFETY_MODE_SYSTEM_PROMPT, + PYTHON_TOOL, + SANDBOX_MODE_PROMPT, + TOOL_CALL_PROMPT, + TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, + retrieve_knowledge_base, +) + + +@dataclass(slots=True) +class MainAgentBuildConfig: + tool_call_timeout: int + tool_schema_mode: str = "full" + provider_wake_prefix: str = "" + streaming_response: bool = True + sanitize_context_by_modalities: bool = False + kb_agentic_mode: bool = False + file_extract_enabled: bool = False + file_extract_prov: str = "moonshotai" + file_extract_msh_api_key: str = "" + context_limit_reached_strategy: str = "truncate_by_turns" + llm_compress_instruction: str = "" + llm_compress_keep_recent: int = 4 + llm_compress_provider_id: str = "" + max_context_length: int = 0 + dequeue_context_length: int = 1 + llm_safety_mode: bool = True + safety_mode_strategy: str = "system_prompt" + sandbox_cfg: dict = field(default_factory=dict) + + +@dataclass(slots=True) +class MainAgentBuildResult: + agent_runner: AgentRunner + provider_request: ProviderRequest + provider: Provider + + +def _select_provider( + event: AstrMessageEvent, plugin_context: Context +) -> Provider | None: + """Select chat provider for the event.""" + sel_provider = event.get_extra("selected_provider") + if sel_provider and isinstance(sel_provider, str): + provider = plugin_context.get_provider_by_id(sel_provider) + if not provider: + logger.error("未找到指定的提供商: %s。", sel_provider) + if not isinstance(provider, Provider): + logger.error( + "选择的提供商类型无效(%s),跳过 LLM 请求处理。", type(provider) + ) + return None + return provider + try: + return plugin_context.get_using_provider(umo=event.unified_msg_origin) + except ValueError as exc: + logger.error("Error occurred while selecting provider: %s", exc) + return None + + +async def _get_session_conv( + event: AstrMessageEvent, plugin_context: Context +) -> Conversation: + conv_mgr = plugin_context.conversation_manager + umo = event.unified_msg_origin + cid = await conv_mgr.get_curr_conversation_id(umo) + if not cid: + cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) + conversation = await conv_mgr.get_conversation(umo, cid) + if not conversation: + cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) + conversation = await conv_mgr.get_conversation(umo, cid) + if not conversation: + raise RuntimeError("无法创建新的对话。") + return conversation + + +async def _apply_kb( + event: AstrMessageEvent, + req: ProviderRequest, + plugin_context: Context, + config: MainAgentBuildConfig, +) -> None: + if not config.kb_agentic_mode: + if req.prompt is None: + return + try: + kb_result = await retrieve_knowledge_base( + query=req.prompt, + umo=event.unified_msg_origin, + context=plugin_context, + ) + if not kb_result: + return + if req.system_prompt is not None: + req.system_prompt += ( + f"\n\n[Related Knowledge Base Results]:\n{kb_result}" + ) + except Exception as exc: # noqa: BLE001 + logger.error("Error occurred while retrieving knowledge base: %s", exc) + else: + if req.func_tool is None: + req.func_tool = ToolSet() + req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) + + +async def _apply_file_extract( + event: AstrMessageEvent, + req: ProviderRequest, + config: MainAgentBuildConfig, +) -> None: + file_paths = [] + file_names = [] + for comp in event.message_obj.message: + if isinstance(comp, File): + file_paths.append(await comp.get_file()) + file_names.append(comp.name) + elif isinstance(comp, Reply) and comp.chain: + for reply_comp in comp.chain: + if isinstance(reply_comp, File): + file_paths.append(await reply_comp.get_file()) + file_names.append(reply_comp.name) + if not file_paths: + return + if not req.prompt: + req.prompt = "总结一下文件里面讲了什么?" + if config.file_extract_prov == "moonshotai": + if not config.file_extract_msh_api_key: + logger.error("Moonshot AI API key for file extract is not set") + return + file_contents = await asyncio.gather( + *[ + extract_file_moonshotai( + file_path, + config.file_extract_msh_api_key, + ) + for file_path in file_paths + ] + ) + else: + logger.error("Unsupported file extract provider: %s", config.file_extract_prov) + return + + for file_content, file_name in zip(file_contents, file_names): + req.contexts.append( + { + "role": "system", + "content": ( + "File Extract Results of user uploaded files:\n" + f"{file_content}\nFile Name: {file_name or 'Unknown'}" + ), + }, + ) + + +def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: + if req.image_urls: + provider_cfg = provider.provider_config.get("modalities", ["image"]) + if "image" not in provider_cfg: + logger.debug( + "Provider %s does not support image, using placeholder.", provider + ) + image_count = len(req.image_urls) + placeholder = " ".join(["[图片]"] * image_count) + if req.prompt: + req.prompt = f"{placeholder} {req.prompt}" + else: + req.prompt = placeholder + req.image_urls = [] + if req.func_tool: + provider_cfg = provider.provider_config.get("modalities", ["tool_use"]) + if "tool_use" not in provider_cfg: + logger.debug( + "Provider %s does not support tool_use, clearing tools.", provider + ) + req.func_tool = None + + +def _sanitize_context_by_modalities( + config: MainAgentBuildConfig, + provider: Provider, + req: ProviderRequest, +) -> None: + if not config.sanitize_context_by_modalities: + return + if not isinstance(req.contexts, list) or not req.contexts: + return + modalities = provider.provider_config.get("modalities", None) + if not modalities or not isinstance(modalities, list): + return + supports_image = bool("image" in modalities) + supports_tool_use = bool("tool_use" in modalities) + if supports_image and supports_tool_use: + return + + sanitized_contexts: list[dict] = [] + removed_image_blocks = 0 + removed_tool_messages = 0 + removed_tool_calls = 0 + + for msg in req.contexts: + if not isinstance(msg, dict): + continue + role = msg.get("role") + if not role: + continue + + new_msg = msg + if not supports_tool_use: + if role == "tool": + removed_tool_messages += 1 + continue + if role == "assistant" and "tool_calls" in new_msg: + if "tool_calls" in new_msg: + removed_tool_calls += 1 + new_msg.pop("tool_calls", None) + new_msg.pop("tool_call_id", None) + + if not supports_image: + content = new_msg.get("content") + if isinstance(content, list): + filtered_parts: list = [] + removed_any_image = False + for part in content: + if isinstance(part, dict): + part_type = str(part.get("type", "")).lower() + if part_type in {"image_url", "image"}: + removed_any_image = True + removed_image_blocks += 1 + continue + filtered_parts.append(part) + if removed_any_image: + new_msg["content"] = filtered_parts + + if role == "assistant": + content = new_msg.get("content") + has_tool_calls = bool(new_msg.get("tool_calls")) + if not has_tool_calls: + if not content: + continue + if isinstance(content, str) and not content.strip(): + continue + + sanitized_contexts.append(new_msg) + + if removed_image_blocks or removed_tool_messages or removed_tool_calls: + logger.debug( + "sanitize_context_by_modalities applied: " + "removed_image_blocks=%s, removed_tool_messages=%s, removed_tool_calls=%s", + removed_image_blocks, + removed_tool_messages, + removed_tool_calls, + ) + req.contexts = sanitized_contexts + + +def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: + if event.plugins_name is not None and req.func_tool: + new_tool_set = ToolSet() + for tool in req.func_tool.tools: + mp = tool.handler_module_path + if not mp: + continue + plugin = star_map.get(mp) + if not plugin: + continue + if plugin.name in event.plugins_name or plugin.reserved: + new_tool_set.add_tool(tool) + req.func_tool = new_tool_set + + +async def _handle_webchat( + event: AstrMessageEvent, req: ProviderRequest, prov: Provider +) -> None: + from astrbot.core import db_helper + + chatui_session_id = event.session_id.split("!")[-1] + user_prompt = req.prompt + session = await db_helper.get_platform_session_by_id(chatui_session_id) + + if not user_prompt or not chatui_session_id or not session or session.display_name: + return + + llm_resp = await prov.text_chat( + system_prompt=( + "You are a conversation title generator. " + "Generate a concise title in the same language as the user’s input, " + "no more than 10 words, capturing only the core topic." + "If the input is a greeting, small talk, or has no clear topic, " + "(e.g., “hi”, “hello”, “haha”), return . " + "Output only the title itself or , with no explanations." + ), + prompt=f"Generate a concise title for the following user query:\n{user_prompt}", + ) + if llm_resp and llm_resp.completion_text: + title = llm_resp.completion_text.strip() + if not title or "" in title: + return + logger.info( + "Generated chatui title for session %s: %s", chatui_session_id, title + ) + await db_helper.update_platform_session( + session_id=chatui_session_id, + display_name=title, + ) + + +def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) -> None: + if config.safety_mode_strategy == "system_prompt": + req.system_prompt = ( + f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}\n\n{req.system_prompt or ''}" + ) + else: + logger.warning( + "Unsupported llm_safety_mode strategy: %s.", + config.safety_mode_strategy, + ) + + +def _apply_sandbox_tools( + config: MainAgentBuildConfig, req: ProviderRequest, session_id: str +) -> None: + if req.func_tool is None: + req.func_tool = ToolSet() + if config.sandbox_cfg.get("booter") == "shipyard": + ep = config.sandbox_cfg.get("shipyard_endpoint", "") + at = config.sandbox_cfg.get("shipyard_access_token", "") + if not ep or not at: + logger.error("Shipyard sandbox configuration is incomplete.") + return + os.environ["SHIPYARD_ENDPOINT"] = ep + os.environ["SHIPYARD_ACCESS_TOKEN"] = at + req.func_tool.add_tool(EXECUTE_SHELL_TOOL) + req.func_tool.add_tool(PYTHON_TOOL) + req.func_tool.add_tool(FILE_UPLOAD_TOOL) + req.func_tool.add_tool(FILE_DOWNLOAD_TOOL) + req.system_prompt += f"\n{SANDBOX_MODE_PROMPT}\n" + + +def _proactive_cron_job_tools(req: ProviderRequest, event: AstrMessageEvent) -> None: + if req.func_tool is None: + req.func_tool = ToolSet() + req.func_tool.add_tool(CREATE_CRON_JOB_TOOL) + req.func_tool.add_tool(DELETE_CRON_JOB_TOOL) + req.func_tool.add_tool(LIST_CRON_JOBS_TOOL) + + +def _get_compress_provider( + config: MainAgentBuildConfig, plugin_context: Context +) -> Provider | None: + if not config.llm_compress_provider_id: + return None + if config.context_limit_reached_strategy != "llm_compress": + return None + provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id) + if provider is None: + logger.warning( + "未找到指定的上下文压缩模型 %s,将跳过压缩。", + config.llm_compress_provider_id, + ) + return None + if not isinstance(provider, Provider): + logger.warning( + "指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。", + config.llm_compress_provider_id, + ) + return None + return provider + + +async def build_main_agent( + *, + event: AstrMessageEvent, + plugin_context: Context, + config: MainAgentBuildConfig, + provider: Provider | None = None, + req: ProviderRequest | None = None, +) -> MainAgentBuildResult | None: + """构建主对话代理(Main Agent),并且自动 reset。""" + provider = provider or _select_provider(event, plugin_context) + if provider is None: + logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。") + return None + + if req is None: + if event.get_extra("provider_request"): + req = event.get_extra("provider_request") + assert isinstance(req, ProviderRequest), ( + "provider_request 必须是 ProviderRequest 类型。" + ) + if req.conversation: + req.contexts = json.loads(req.conversation.history) + else: + req = ProviderRequest() + req.prompt = "" + req.image_urls = [] + if sel_model := event.get_extra("selected_model"): + req.model = sel_model + if config.provider_wake_prefix and not event.message_str.startswith( + config.provider_wake_prefix + ): + return None + + req.prompt = event.message_str[len(config.provider_wake_prefix) :] + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) + req.extra_user_content_parts.append( + TextPart(text=f"[Image Attachment: path {image_path}]") + ) + elif isinstance(comp, File): + file_path = await comp.get_file() + file_name = comp.name or os.path.basename(file_path) + req.extra_user_content_parts.append( + TextPart( + text=f"[File Attachment: name {file_name}, path {file_path}]" + ) + ) + + conversation = await _get_session_conv(event, plugin_context) + req.conversation = conversation + req.contexts = json.loads(conversation.history) + event.set_extra("provider_request", req) + + if isinstance(req.contexts, str): + req.contexts = json.loads(req.contexts) + + if config.file_extract_enabled: + try: + await _apply_file_extract(event, req, config) + except Exception as exc: # noqa: BLE001 + logger.error("Error occurred while applying file extract: %s", exc) + + if not req.prompt and not req.image_urls: + if not event.get_group_id() and req.extra_user_content_parts: + req.prompt = "" + else: + return None + + await _apply_kb(event, req, plugin_context, config) + + if not req.session_id: + req.session_id = event.unified_msg_origin + + _modalities_fix(provider, req) + _plugin_tool_fix(event, req) + _sanitize_context_by_modalities(config, provider, req) + + if config.llm_safety_mode: + _apply_llm_safety_mode(config, req) + + if config.sandbox_cfg.get("enable", False): + _apply_sandbox_tools(config, req, req.session_id) + + agent_runner = AgentRunner() + astr_agent_ctx = AstrAgentContext( + context=plugin_context, + event=event, + ) + + _proactive_cron_job_tools(req, event) + + if provider.provider_config.get("max_context_tokens", 0) <= 0: + model = provider.get_model() + if model_info := LLM_METADATAS.get(model): + provider.provider_config["max_context_tokens"] = model_info["limit"][ + "context" + ] + + if event.get_platform_name() == "webchat": + asyncio.create_task(_handle_webchat(event, req, provider)) + req.system_prompt += f"\n{CHATUI_EXTRA_PROMPT}\n" + + if req.func_tool and req.func_tool.tools: + tool_prompt = ( + TOOL_CALL_PROMPT + if config.tool_schema_mode == "full" + else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE + ) + req.system_prompt += f"\n{tool_prompt}\n" + + action_type = event.get_extra("action_type") + if action_type == "live": + req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n" + + await agent_runner.reset( + provider=provider, + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=config.tool_call_timeout, + ), + tool_executor=FunctionToolExecutor(), + agent_hooks=MAIN_AGENT_HOOKS, + streaming=config.streaming_response, + llm_compress_instruction=config.llm_compress_instruction, + llm_compress_keep_recent=config.llm_compress_keep_recent, + llm_compress_provider=_get_compress_provider(config, plugin_context), + truncate_turns=config.dequeue_context_length, + enforce_max_turns=config.max_context_length, + tool_schema_mode=config.tool_schema_mode, + ) + + return MainAgentBuildResult( + agent_runner=agent_runner, + provider_request=req, + provider=provider, + ) diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/astr_main_agent_resources.py similarity index 98% rename from astrbot/core/pipeline/process_stage/utils.py rename to astrbot/core/astr_main_agent_resources.py index 1b44f1752..10554cbae 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -165,7 +165,9 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]): try: target_session = ( - MessageSession.from_str(session) if isinstance(session, str) else session + MessageSession.from_str(session) + if isinstance(session, str) + else session ) except Exception as e: return f"error: invalid session: {e}" diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 6a0088aff..96ee1611e 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -163,7 +163,7 @@ class AstrBotCoreLifecycle: self.kb_manager = KnowledgeBaseManager(self.provider_manager) # 初始化 CronJob 管理器 - self.cron_manager = CronJobManager(self.star_context, self.db) + self.cron_manager = CronJobManager(self.db) # 初始化提供给插件的上下文 self.star_context = Context( @@ -231,7 +231,7 @@ class AstrBotCoreLifecycle: cron_task = None if self.cron_manager: cron_task = asyncio.create_task( - self.cron_manager.start(), + self.cron_manager.start(self.star_context), name="cron_manager", ) diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index 8a4ced6e7..a877d45ce 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -11,20 +11,27 @@ from astrbot.core.cron.events import CronMessageEvent from astrbot.core.db import BaseDatabase from astrbot.core.db.po import CronJob from astrbot.core.platform.message_session import MessageSession +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.message.components import Plain + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from astrbot.core.star.context import Context class CronJobManager: """Central scheduler for BasicCronJob and ActiveAgentCronJob.""" - def __init__(self, ctx, db: BaseDatabase): - self.ctx = ctx + def __init__(self, db: BaseDatabase): self.db = db self.scheduler = AsyncIOScheduler() self._basic_handlers: dict[str, Callable[..., Any]] = {} self._lock = asyncio.Lock() self._started = False - async def start(self): + async def start(self, ctx: "Context"): + self.ctx: Context = ctx # star context async with self._lock: if self._started: return @@ -219,19 +226,21 @@ class CronJobManager: "cron_payload": payload, } - await self._dispatch_agent_event( + await self._woke_main_agent( message=note, session_str=session_str, extras=extras, ) - async def _dispatch_agent_event( + async def _woke_main_agent( self, *, message: str, session_str: str, - extras: dict | None = None, + extras: dict, ): + from astrbot.core.astr_main_agent import build_main_agent, MainAgentBuildConfig + try: session = ( session_str @@ -250,7 +259,43 @@ class CronJobManager: message_type=session.message_type, ) - await self.ctx.get_event_queue().put(cron_event) + config = MainAgentBuildConfig(tool_call_timeout=3600) + result = await build_main_agent( + event=cron_event, plugin_context=self.ctx, config=config + ) + if not result: + logger.error("Failed to build main agent for cron job.") + return + req = result.provider_request + runner = result.agent_runner + + # finetine the messages + job_name = extras.get("name", "scheduled task") + note = extras.get("note") or extras.get("description") or "" + if req.contexts: + context_dump = req._print_friendly_context() + req.system_prompt += ( + "\n\nBellow is you and user previous conversation history:\n" + f"{context_dump}" + ) + req.system_prompt += ( + "\n[Scheduler Context] This turn is triggered automatically by cron job " + f'"{job_name}" (type: {extras.get("type", "unknown")}). ' + "Act proactively based on the provided note and current context. " + ) + if note: + req.system_prompt += f"[Scheduler Note]: {note}\n" + + req.prompt = "You are now responding to a scheduled task. Output using same language as previous conversation." + + async for _ in runner.step_until_done(30): + pass + llm_resp = runner.get_final_llm_resp() + if not llm_resp: + logger.warning("Cron job agent got no response") + return + message_chain = MessageChain(chain=[Plain(text=llm_resp.completion_text)]) + await self.ctx.send_message(session=session, message_chain=message_chain) __all__ = ["CronJobManager"] diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 57da5193c..9a1fad0d2 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -1,61 +1,36 @@ """本地 Agent 模式的 LLM 调用 Stage""" import asyncio -import json -import os +import base64 from collections.abc import AsyncGenerator from astrbot.core import logger -from astrbot.core.agent.message import Message, TextPart +from astrbot.core.agent.message import Message from astrbot.core.agent.response import AgentStats -from astrbot.core.agent.tool import ToolSet -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.conversation_mgr import Conversation -from astrbot.core.message.components import File, Image, Reply +from astrbot.core.message.components import File, Image from astrbot.core.message.message_event_result import ( MessageChain, MessageEventResult, ResultContentType, ) from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.provider import Provider from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) -from astrbot.core.star.star_handler import EventType, star_map -from astrbot.core.utils.file_extract import extract_file_moonshotai -from astrbot.core.utils.llm_metadata import LLM_METADATAS +from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from .....astr_agent_context import AgentContextWrapper -from .....astr_agent_hooks import MAIN_AGENT_HOOKS -from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent -from .....astr_agent_tool_exec import FunctionToolExecutor +from .....astr_agent_run_util import run_agent, run_live_agent from ....context import PipelineContext, call_event_hook from ...stage import Stage -from ...utils import ( - CHATUI_EXTRA_PROMPT, - EXECUTE_SHELL_TOOL, - FILE_DOWNLOAD_TOOL, - FILE_UPLOAD_TOOL, - KNOWLEDGE_BASE_QUERY_TOOL, - LIVE_MODE_SYSTEM_PROMPT, - LLM_SAFETY_MODE_SYSTEM_PROMPT, - PYTHON_TOOL, - SANDBOX_MODE_PROMPT, - TOOL_CALL_PROMPT, - TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, - SEND_MESSAGE_TO_USER_TOOL, - decoded_blocked, - retrieve_knowledge_base, -) -from astrbot.core.tools.cron_tools import ( - CREATE_CRON_JOB_TOOL, - DELETE_CRON_JOB_TOOL, - LIST_CRON_JOBS_TOOL, +from astrbot.core.astr_main_agent import ( + MainAgentBuildConfig, + MainAgentBuildResult, + build_main_agent, ) +from dataclasses import replace class InternalAgentSubStage(Stage): @@ -121,453 +96,35 @@ class InternalAgentSubStage(Stage): self.conv_manager = ctx.plugin_manager.context.conversation_manager - def _select_provider(self, event: AstrMessageEvent): - """选择使用的 LLM 提供商""" - sel_provider = event.get_extra("selected_provider") - _ctx = self.ctx.plugin_manager.context - if sel_provider and isinstance(sel_provider, str): - provider = _ctx.get_provider_by_id(sel_provider) - if not provider: - logger.error(f"未找到指定的提供商: {sel_provider}。") - return provider - try: - prov = _ctx.get_using_provider(umo=event.unified_msg_origin) - except ValueError as e: - logger.error(f"Error occurred while selecting provider: {e}") - return None - return prov - - async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation: - umo = event.unified_msg_origin - conv_mgr = self.conv_manager - - # 获取对话上下文 - cid = await conv_mgr.get_curr_conversation_id(umo) - if not cid: - cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) - conversation = await conv_mgr.get_conversation(umo, cid) - if not conversation: - cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) - conversation = await conv_mgr.get_conversation(umo, cid) - if not conversation: - raise RuntimeError("无法创建新的对话。") - return conversation - - async def _apply_kb( - self, - event: AstrMessageEvent, - req: ProviderRequest, - ): - """Apply knowledge base context to the provider request""" - if not self.kb_agentic_mode: - if req.prompt is None: - return - try: - kb_result = await retrieve_knowledge_base( - query=req.prompt, - umo=event.unified_msg_origin, - context=self.ctx.plugin_manager.context, - ) - if not kb_result: - return - if req.system_prompt is not None: - req.system_prompt += ( - f"\n\n[Related Knowledge Base Results]:\n{kb_result}" - ) - except Exception as e: - logger.error(f"Error occurred while retrieving knowledge base: {e}") - else: - if req.func_tool is None: - req.func_tool = ToolSet() - req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL) - - async def _apply_file_extract( - self, - event: AstrMessageEvent, - req: ProviderRequest, - ): - """Apply file extract to the provider request""" - file_paths = [] - file_names = [] - for comp in event.message_obj.message: - if isinstance(comp, File): - file_paths.append(await comp.get_file()) - file_names.append(comp.name) - elif isinstance(comp, Reply) and comp.chain: - for reply_comp in comp.chain: - if isinstance(reply_comp, File): - file_paths.append(await reply_comp.get_file()) - file_names.append(reply_comp.name) - if not file_paths: - return - if not req.prompt: - req.prompt = "总结一下文件里面讲了什么?" - if self.file_extract_prov == "moonshotai": - if not self.file_extract_msh_api_key: - logger.error("Moonshot AI API key for file extract is not set") - return - file_contents = await asyncio.gather( - *[ - extract_file_moonshotai(file_path, self.file_extract_msh_api_key) - for file_path in file_paths - ] - ) - else: - logger.error(f"Unsupported file extract provider: {self.file_extract_prov}") - return - - # add file extract results to contexts - for file_content, file_name in zip(file_contents, file_names): - req.contexts.append( - { - "role": "system", - "content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}", - }, - ) - - def _modalities_fix( - self, - provider: Provider, - req: ProviderRequest, - ): - """检查提供商的模态能力,清理请求中的不支持内容""" - if req.image_urls: - provider_cfg = provider.provider_config.get("modalities", ["image"]) - if "image" not in provider_cfg: - logger.debug( - f"用户设置提供商 {provider} 不支持图像,将图像替换为占位符。" - ) - # 为每个图片添加占位符到 prompt - image_count = len(req.image_urls) - placeholder = " ".join(["[图片]"] * image_count) - if req.prompt: - req.prompt = f"{placeholder} {req.prompt}" - else: - req.prompt = placeholder - req.image_urls = [] - if req.func_tool: - provider_cfg = provider.provider_config.get("modalities", ["tool_use"]) - # 如果模型不支持工具使用,但请求中包含工具列表,则清空。 - if "tool_use" not in provider_cfg: - logger.debug( - f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。", - ) - req.func_tool = None - - def _sanitize_context_by_modalities( - self, - provider: Provider, - req: ProviderRequest, - ) -> None: - """Sanitize `req.contexts` (including history) by current provider modalities.""" - if not self.sanitize_context_by_modalities: - return - - if not isinstance(req.contexts, list) or not req.contexts: - return - - modalities = provider.provider_config.get("modalities", None) - # if modalities is not configured, do not sanitize. - if not modalities or not isinstance(modalities, list): - return - - supports_image = bool("image" in modalities) - supports_tool_use = bool("tool_use" in modalities) - - if supports_image and supports_tool_use: - return - - sanitized_contexts: list[dict] = [] - removed_image_blocks = 0 - removed_tool_messages = 0 - removed_tool_calls = 0 - - for msg in req.contexts: - if not isinstance(msg, dict): - continue - - role = msg.get("role") - if not role: - continue - - new_msg: dict = msg - - # tool_use sanitize - if not supports_tool_use: - if role == "tool": - # tool response block - removed_tool_messages += 1 - continue - if role == "assistant" and "tool_calls" in new_msg: - # assistant message with tool calls - if "tool_calls" in new_msg: - removed_tool_calls += 1 - new_msg.pop("tool_calls", None) - new_msg.pop("tool_call_id", None) - - # image sanitize - if not supports_image: - content = new_msg.get("content") - if isinstance(content, list): - filtered_parts: list = [] - removed_any_image = False - for part in content: - if isinstance(part, dict): - part_type = str(part.get("type", "")).lower() - if part_type in {"image_url", "image"}: - removed_any_image = True - removed_image_blocks += 1 - continue - filtered_parts.append(part) - - if removed_any_image: - new_msg["content"] = filtered_parts - - # drop empty assistant messages (e.g. only tool_calls without content) - if role == "assistant": - content = new_msg.get("content") - has_tool_calls = bool(new_msg.get("tool_calls")) - if not has_tool_calls: - if not content: - continue - if isinstance(content, str) and not content.strip(): - continue - - sanitized_contexts.append(new_msg) - - if removed_image_blocks or removed_tool_messages or removed_tool_calls: - logger.debug( - "sanitize_context_by_modalities applied: " - f"removed_image_blocks={removed_image_blocks}, " - f"removed_tool_messages={removed_tool_messages}, " - f"removed_tool_calls={removed_tool_calls}" - ) - - req.contexts = sanitized_contexts - - def _plugin_tool_fix( - self, - event: AstrMessageEvent, - req: ProviderRequest, - ): - """根据事件中的插件设置,过滤请求中的工具列表""" - if event.plugins_name is not None and req.func_tool: - new_tool_set = ToolSet() - for tool in req.func_tool.tools: - mp = tool.handler_module_path - if not mp: - continue - plugin = star_map.get(mp) - if not plugin: - continue - if plugin.name in event.plugins_name or plugin.reserved: - new_tool_set.add_tool(tool) - req.func_tool = new_tool_set - - async def _handle_webchat( - self, - event: AstrMessageEvent, - req: ProviderRequest, - prov: Provider, - ): - """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" - from astrbot.core import db_helper - - chatui_session_id = event.session_id.split("!")[-1] - user_prompt = req.prompt - - session = await db_helper.get_platform_session_by_id(chatui_session_id) - - if ( - not user_prompt - or not chatui_session_id - or not session - or session.display_name - ): - return - - llm_resp = await prov.text_chat( - system_prompt=( - "You are a conversation title generator. " - "Generate a concise title in the same language as the user’s input, " - "no more than 10 words, capturing only the core topic." - "If the input is a greeting, small talk, or has no clear topic, " - "(e.g., “hi”, “hello”, “haha”), return . " - "Output only the title itself or , with no explanations." - ), - prompt=( - f"Generate a concise title for the following user query:\n{user_prompt}" - ), + self.main_agent_cfg = MainAgentBuildConfig( + tool_call_timeout=self.tool_call_timeout, + tool_schema_mode=self.tool_schema_mode, + sanitize_context_by_modalities=self.sanitize_context_by_modalities, + kb_agentic_mode=self.kb_agentic_mode, + file_extract_enabled=self.file_extract_enabled, + file_extract_prov=self.file_extract_prov, + file_extract_msh_api_key=self.file_extract_msh_api_key, + context_limit_reached_strategy=self.context_limit_reached_strategy, + llm_compress_instruction=self.llm_compress_instruction, + llm_compress_keep_recent=self.llm_compress_keep_recent, + llm_compress_provider_id=self.llm_compress_provider_id, + max_context_length=self.max_context_length, + dequeue_context_length=self.dequeue_context_length, + llm_safety_mode=self.llm_safety_mode, + safety_mode_strategy=self.safety_mode_strategy, + sandbox_cfg=self.sandbox_cfg, ) - if llm_resp and llm_resp.completion_text: - title = llm_resp.completion_text.strip() - if not title or "" in title: - return - logger.info( - f"Generated chatui title for session {chatui_session_id}: {title}" - ) - await db_helper.update_platform_session( - session_id=chatui_session_id, - display_name=title, - ) - - async def _save_to_history( - self, - event: AstrMessageEvent, - req: ProviderRequest, - llm_response: LLMResponse | None, - all_messages: list[Message], - runner_stats: AgentStats | None, - ): - if ( - not req - or not req.conversation - or not llm_response - or llm_response.role != "assistant" - ): - return - - if not llm_response.completion_text and not req.tool_calls_result: - logger.debug("LLM 响应为空,不保存记录。") - return - - # using agent context messages to save to history - message_to_save = [] - skipped_initial_system = False - for message in all_messages: - if message.role == "system" and not skipped_initial_system: - skipped_initial_system = True - continue # skip first system message - if message.role in ["assistant", "user"] and getattr( - message, "_no_save", None - ): - # we do not save user and assistant messages that are marked as _no_save - continue - message_to_save.append(message.model_dump()) - - # get token usage from agent runner stats - token_usage = None - if runner_stats: - token_usage = runner_stats.token_usage.total - - await self.conv_manager.update_conversation( - event.unified_msg_origin, - req.conversation.cid, - history=message_to_save, - token_usage=token_usage, - ) - - def _get_compress_provider(self) -> Provider | None: - if not self.llm_compress_provider_id: - return None - if self.context_limit_reached_strategy != "llm_compress": - return None - provider = self.ctx.plugin_manager.context.get_provider_by_id( - self.llm_compress_provider_id, - ) - if provider is None: - logger.warning( - f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。", - ) - return None - if not isinstance(provider, Provider): - logger.warning( - f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。" - ) - return None - return provider - - def _apply_llm_safety_mode(self, req: ProviderRequest) -> None: - """Apply LLM safety mode to the provider request.""" - if self.safety_mode_strategy == "system_prompt": - req.system_prompt = ( - f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}\n\n{req.system_prompt or ''}" - ) - else: - logger.warning( - f"Unsupported llm_safety_mode strategy: {self.safety_mode_strategy}.", - ) - - def _apply_sandbox_tools(self, req: ProviderRequest, session_id: str) -> None: - """Add sandbox tools to the provider request.""" - if req.func_tool is None: - req.func_tool = ToolSet() - if self.sandbox_cfg.get("booter") == "shipyard": - ep = self.sandbox_cfg.get("shipyard_endpoint", "") - at = self.sandbox_cfg.get("shipyard_access_token", "") - if not ep or not at: - logger.error("Shipyard sandbox configuration is incomplete.") - return - os.environ["SHIPYARD_ENDPOINT"] = ep - os.environ["SHIPYARD_ACCESS_TOKEN"] = at - req.func_tool.add_tool(EXECUTE_SHELL_TOOL) - req.func_tool.add_tool(PYTHON_TOOL) - req.func_tool.add_tool(FILE_UPLOAD_TOOL) - req.func_tool.add_tool(FILE_DOWNLOAD_TOOL) - req.system_prompt += f"\n{SANDBOX_MODE_PROMPT}\n" - - def _proactive_cron_job_tools( - self, req: ProviderRequest, event: AstrMessageEvent - ) -> None: - """Inject cron job context and tools into the provider request for proactive scheduling.""" - - if req.func_tool is None: - req.func_tool = ToolSet() - req.func_tool.add_tool(CREATE_CRON_JOB_TOOL) - req.func_tool.add_tool(DELETE_CRON_JOB_TOOL) - req.func_tool.add_tool(LIST_CRON_JOBS_TOOL) - - cron_meta = event.get_extra("cron_job") - if cron_meta: - # The message event is triggered by a known cron job - if req.func_tool is None: - req.func_tool = ToolSet() - req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) - - job_name = cron_meta.get("name", "scheduled task") - note = cron_meta.get("note") or cron_meta.get("description") or "" - req.system_prompt += ( - f"\n[Scheduler Context] This turn is triggered automatically by cron job " - f'"{job_name}" (type: {cron_meta.get("type", "unknown")}). ' - "Act proactively based on the provided note and current context. " - "If you want to proactively notify the user, call `send_message_to_user` with a concise message.\n" - ) - if note: - req.system_prompt += f"[Scheduler Note]: {note}\n" - - if bg := event.get_extra("background_task_result"): - # The message event is triggered after a background task done - result_text = bg.get("result") or "" - if req.func_tool is None: - req.func_tool = ToolSet() - req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) - if result_text: - req.system_prompt += f"\n[Background Task Result] {result_text}\n" async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: - req: ProviderRequest | None = None - try: - provider = self._select_provider(event) - if provider is None: - logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。") - return - if not isinstance(provider, Provider): - logger.error( - f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。" - ) - return - streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: streaming_response = bool(enable_streaming) - # 检查消息内容是否有效,避免空消息触发钩子 has_provider_request = event.get_extra("provider_request") is not None has_valid_message = bool(event.message_str and event.message_str.strip()) - # 检查是否有图片或其他媒体内容 has_media_content = any( isinstance(comp, Image | File) for comp in event.message_obj.message ) @@ -580,179 +137,50 @@ class InternalAgentSubStage(Stage): logger.debug("skip llm request: empty message and no provider_request") return - api_base = provider.provider_config.get("api_base", "") - for host in decoded_blocked: - if host in api_base: - logger.error( - f"Provider API base {api_base} is blocked due to security reasons. Please use another ai provider." - ) - return - logger.debug("ready to request llm provider") - # 通知等待调用 LLM(在获取锁之前) await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): logger.debug("acquired session lock for llm request") - if event.get_extra("provider_request"): - req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), ( - "provider_request 必须是 ProviderRequest 类型。" - ) - if req.conversation: - req.contexts = json.loads(req.conversation.history) + build_cfg = replace( + self.main_agent_cfg, + provider_wake_prefix=provider_wake_prefix, + streaming_response=streaming_response, + ) - else: - req = ProviderRequest() - req.prompt = "" - req.image_urls = [] - if sel_model := event.get_extra("selected_model"): - req.model = sel_model - if provider_wake_prefix and not event.message_str.startswith( - provider_wake_prefix - ): - return + build_result: MainAgentBuildResult | None = await build_main_agent( + event=event, + plugin_context=self.ctx.plugin_manager.context, + config=build_cfg, + ) - req.prompt = event.message_str[len(provider_wake_prefix) :] - # func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。 - # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() - for comp in event.message_obj.message: - if isinstance(comp, Image): - image_path = await comp.convert_to_file_path() - req.image_urls.append(image_path) - - req.extra_user_content_parts.append( - TextPart(text=f"[Image Attachment: path {image_path}]") - ) - elif isinstance(comp, File): - file_path = await comp.get_file() - file_name = comp.name or os.path.basename(file_path) - req.extra_user_content_parts.append( - TextPart( - text=f"[File Attachment: name {file_name}, path {file_path}]" - ) - ) - - conversation = await self._get_session_conv(event) - req.conversation = conversation - req.contexts = json.loads(conversation.history) - - event.set_extra("provider_request", req) - - # fix contexts json str - if isinstance(req.contexts, str): - req.contexts = json.loads(req.contexts) - - # apply file extract - if self.file_extract_enabled: - try: - await self._apply_file_extract(event, req) - except Exception as e: - logger.error(f"Error occurred while applying file extract: {e}") - - if not req.prompt and not req.image_urls: - if not event.get_group_id() and req.extra_user_content_parts: - req.prompt = "" - else: - return - - # call event hook - if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + if build_result is None: return - # apply knowledge base feature - await self._apply_kb(event, req) + agent_runner = build_result.agent_runner + req = build_result.provider_request + provider = build_result.provider - # truncate contexts to fit max length - # NOW moved to ContextManager inside ToolLoopAgentRunner - # if req.contexts: - # req.contexts = self._truncate_contexts(req.contexts) - # self._fix_messages(req.contexts) - - # session_id - if not req.session_id: - req.session_id = event.unified_msg_origin - - # check provider modalities, if provider does not support image/tool_use, clear them in request. - self._modalities_fix(provider, req) - - # filter tools, only keep tools from this pipeline's selected plugins - self._plugin_tool_fix(event, req) - - # sanitize contexts (including history) by provider modalities - self._sanitize_context_by_modalities(provider, req) - - # apply llm safety mode - if self.llm_safety_mode: - self._apply_llm_safety_mode(req) - - # apply sandbox tools - if self.sandbox_cfg.get("enable", False): - self._apply_sandbox_tools(req, req.session_id) + api_base = provider.provider_config.get("api_base", "") + for host in decoded_blocked: + if host in api_base: + logger.error( + "Provider API base %s is blocked due to security reasons. Please use another ai provider.", + api_base, + ) + return stream_to_general = ( self.unsupported_streaming_strategy == "turn_off" and not event.platform_meta.support_streaming_message ) - # run agent - agent_runner = AgentRunner() - logger.debug( - f"handle provider[id: {provider.provider_config['id']}] request: {req}", - ) - astr_agent_ctx = AstrAgentContext( - context=self.ctx.plugin_manager.context, - event=event, - ) - - # inject model context length limit - if provider.provider_config.get("max_context_tokens", 0) <= 0: - model = provider.get_model() - if model_info := LLM_METADATAS.get(model): - provider.provider_config["max_context_tokens"] = model_info[ - "limit" - ]["context"] - - # ChatUI 对话的标题生成 - if event.get_platform_name() == "webchat": - asyncio.create_task(self._handle_webchat(event, req, provider)) - - # 注入 ChatUI 额外 prompt - # 比如 follow-up questions 提示等 - req.system_prompt += f"\n{CHATUI_EXTRA_PROMPT}\n" - - # 注入基本 prompt - if req.func_tool and req.func_tool.tools: - tool_prompt = ( - TOOL_CALL_PROMPT - if self.tool_schema_mode == "full" - else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE - ) - req.system_prompt += f"\n{tool_prompt}\n" + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + return action_type = event.get_extra("action_type") - if action_type == "live": - req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n" - - await agent_runner.reset( - provider=provider, - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=self.tool_call_timeout, - ), - tool_executor=FunctionToolExecutor(), - agent_hooks=MAIN_AGENT_HOOKS, - streaming=streaming_response, - llm_compress_instruction=self.llm_compress_instruction, - llm_compress_keep_recent=self.llm_compress_keep_recent, - llm_compress_provider=self._get_compress_provider(), - truncate_turns=self.dequeue_context_length, - enforce_max_turns=self.max_context_length, - tool_schema_mode=self.tool_schema_mode, - ) # 检测 Live Mode if action_type == "live": @@ -865,3 +293,52 @@ class InternalAgentSubStage(Stage): f"Error occurred while processing agent request: {e}" ) ) + + async def _save_to_history( + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse | None, + all_messages: list[Message], + runner_stats: AgentStats | None, + ): + if ( + not req + or not req.conversation + or not llm_response + or llm_response.role != "assistant" + ): + return + + if not llm_response.completion_text and not req.tool_calls_result: + logger.debug("LLM 响应为空,不保存记录。") + return + + message_to_save = [] + skipped_initial_system = False + for message in all_messages: + if message.role == "system" and not skipped_initial_system: + skipped_initial_system = True + continue + if message.role in ["assistant", "user"] and getattr( + message, "_no_save", None + ): + continue + message_to_save.append(message.model_dump()) + + token_usage = None + if runner_stats: + token_usage = runner_stats.token_usage.total + + await self.conv_manager.update_conversation( + event.unified_msg_origin, + req.conversation.cid, + history=message_to_save, + token_usage=token_usage, + ) + + +# we prevent astrbot from connecting to known malicious hosts +# these hosts are base64 encoded +BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"} +decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED] diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index a1a6039f4..7c568626d 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -165,7 +165,7 @@ class ProviderRequest: result_parts.append(f"{role}: {''.join(msg_parts)}") - return result_parts + return "\n".join(result_parts) async def assemble_context(self) -> dict: """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""