refactor: extract main agent
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
.PHONY: worktree worktree-add worktree-rm
|
||||
|
||||
WORKTREE_DIR ?= ../astrbot_worktree
|
||||
BRANCH ?= $(word 2,$(MAKECMDGOALS))
|
||||
BASE ?= $(word 3,$(MAKECMDGOALS))
|
||||
BASE ?= master
|
||||
|
||||
worktree:
|
||||
@echo "Usage:"
|
||||
@echo " make worktree-add <branch> [base-branch]"
|
||||
@echo " make worktree-rm <branch>"
|
||||
|
||||
worktree-add:
|
||||
ifeq ($(strip $(BRANCH)),)
|
||||
$(error Branch name required. Usage: make worktree-add <branch> [base-branch])
|
||||
endif
|
||||
@mkdir -p $(WORKTREE_DIR)
|
||||
git worktree add $(WORKTREE_DIR)/$(BRANCH) -b $(BRANCH) $(BASE)
|
||||
|
||||
worktree-rm:
|
||||
ifeq ($(strip $(BRANCH)),)
|
||||
$(error Branch name required. Usage: make worktree-rm <branch>)
|
||||
endif
|
||||
@if [ -d "$(WORKTREE_DIR)/$(BRANCH)" ]; then \
|
||||
git worktree remove $(WORKTREE_DIR)/$(BRANCH); \
|
||||
else \
|
||||
echo "Worktree $(WORKTREE_DIR)/$(BRANCH) not found."; \
|
||||
fi
|
||||
|
||||
# Swallow extra args (branch/base) so make doesn't treat them as targets
|
||||
%:
|
||||
@true
|
||||
@@ -9,7 +9,7 @@ from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.pipeline.process_stage.utils import (
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
|
||||
LOCAL_EXECUTE_SHELL_TOOL,
|
||||
LOCAL_PYTHON_TOOL,
|
||||
|
||||
@@ -22,6 +22,7 @@ from astrbot.core.message.message_event_result import (
|
||||
)
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.message.components import Plain
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@@ -147,6 +148,8 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
task_id: str,
|
||||
**tool_args,
|
||||
):
|
||||
from astrbot.core.astr_main_agent import build_main_agent, MainAgentBuildConfig
|
||||
|
||||
# run the tool
|
||||
result_text = ""
|
||||
try:
|
||||
@@ -187,7 +190,48 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
extras=extras,
|
||||
message_type=session.message_type,
|
||||
)
|
||||
ctx.get_event_queue().put_nowait(cron_event)
|
||||
config = MainAgentBuildConfig(tool_call_timeout=3600)
|
||||
result = await build_main_agent(
|
||||
event=cron_event, plugin_context=ctx, config=config
|
||||
)
|
||||
if not result:
|
||||
logger.error("Failed to build main agent for cron job.")
|
||||
return
|
||||
runner = result.agent_runner
|
||||
req = result.provider_request
|
||||
|
||||
bg = extras["background_task_result"]
|
||||
result_text = bg["result"] or "Empty Response"
|
||||
if req.contexts:
|
||||
context_dump = req._print_friendly_context()
|
||||
req.system_prompt += (
|
||||
"\n\nBellow is you and user previous conversation history:\n"
|
||||
f"{context_dump}"
|
||||
)
|
||||
req.system_prompt += (
|
||||
"You now have a new background task result:\n"
|
||||
f"- Task ID: {bg['task_id']}\n"
|
||||
f"- Executed Tool: {tool.name}\n"
|
||||
f"- Tool Args: {tool_args}\n"
|
||||
f"- Result: {result_text}\n"
|
||||
f"- Note: {note}\n"
|
||||
"Please tell the user the result of the background task in your next response."
|
||||
)
|
||||
|
||||
req.prompt = (
|
||||
"You have a new background task result to report to the user."
|
||||
" Please include the result in your next response."
|
||||
" Using same language as previous conversation."
|
||||
)
|
||||
|
||||
async for _ in runner.step_until_done(30):
|
||||
pass
|
||||
llm_resp = runner.get_final_llm_resp()
|
||||
if not llm_resp:
|
||||
logger.warning("Cron job agent got no response")
|
||||
return
|
||||
message_chain = MessageChain(chain=[Plain(text=llm_resp.completion_text)])
|
||||
await ctx.send_message(session=session, message_chain=message_chain)
|
||||
|
||||
@classmethod
|
||||
async def _execute_local(
|
||||
|
||||
@@ -0,0 +1,545 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from astrbot.core.astr_agent_run_util import AgentRunner
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.tools.cron_tools import (
|
||||
CREATE_CRON_JOB_TOOL,
|
||||
DELETE_CRON_JOB_TOOL,
|
||||
LIST_CRON_JOBS_TOOL,
|
||||
)
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
|
||||
from .astr_main_agent_resources import (
|
||||
CHATUI_EXTRA_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
LIVE_MODE_SYSTEM_PROMPT,
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT,
|
||||
PYTHON_TOOL,
|
||||
SANDBOX_MODE_PROMPT,
|
||||
TOOL_CALL_PROMPT,
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
|
||||
retrieve_knowledge_base,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MainAgentBuildConfig:
|
||||
tool_call_timeout: int
|
||||
tool_schema_mode: str = "full"
|
||||
provider_wake_prefix: str = ""
|
||||
streaming_response: bool = True
|
||||
sanitize_context_by_modalities: bool = False
|
||||
kb_agentic_mode: bool = False
|
||||
file_extract_enabled: bool = False
|
||||
file_extract_prov: str = "moonshotai"
|
||||
file_extract_msh_api_key: str = ""
|
||||
context_limit_reached_strategy: str = "truncate_by_turns"
|
||||
llm_compress_instruction: str = ""
|
||||
llm_compress_keep_recent: int = 4
|
||||
llm_compress_provider_id: str = ""
|
||||
max_context_length: int = 0
|
||||
dequeue_context_length: int = 1
|
||||
llm_safety_mode: bool = True
|
||||
safety_mode_strategy: str = "system_prompt"
|
||||
sandbox_cfg: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MainAgentBuildResult:
|
||||
agent_runner: AgentRunner
|
||||
provider_request: ProviderRequest
|
||||
provider: Provider
|
||||
|
||||
|
||||
def _select_provider(
|
||||
event: AstrMessageEvent, plugin_context: Context
|
||||
) -> Provider | None:
|
||||
"""Select chat provider for the event."""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
if sel_provider and isinstance(sel_provider, str):
|
||||
provider = plugin_context.get_provider_by_id(sel_provider)
|
||||
if not provider:
|
||||
logger.error("未找到指定的提供商: %s。", sel_provider)
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(
|
||||
"选择的提供商类型无效(%s),跳过 LLM 请求处理。", type(provider)
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
try:
|
||||
return plugin_context.get_using_provider(umo=event.unified_msg_origin)
|
||||
except ValueError as exc:
|
||||
logger.error("Error occurred while selecting provider: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
async def _get_session_conv(
|
||||
event: AstrMessageEvent, plugin_context: Context
|
||||
) -> Conversation:
|
||||
conv_mgr = plugin_context.conversation_manager
|
||||
umo = event.unified_msg_origin
|
||||
cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
if not cid:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
|
||||
async def _apply_kb(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
plugin_context: Context,
|
||||
config: MainAgentBuildConfig,
|
||||
) -> None:
|
||||
if not config.kb_agentic_mode:
|
||||
if req.prompt is None:
|
||||
return
|
||||
try:
|
||||
kb_result = await retrieve_knowledge_base(
|
||||
query=req.prompt,
|
||||
umo=event.unified_msg_origin,
|
||||
context=plugin_context,
|
||||
)
|
||||
if not kb_result:
|
||||
return
|
||||
if req.system_prompt is not None:
|
||||
req.system_prompt += (
|
||||
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Error occurred while retrieving knowledge base: %s", exc)
|
||||
else:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
|
||||
async def _apply_file_extract(
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
config: MainAgentBuildConfig,
|
||||
) -> None:
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if config.file_extract_prov == "moonshotai":
|
||||
if not config.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(
|
||||
file_path,
|
||||
config.file_extract_msh_api_key,
|
||||
)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error("Unsupported file extract provider: %s", config.file_extract_prov)
|
||||
return
|
||||
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"File Extract Results of user uploaded files:\n"
|
||||
f"{file_content}\nFile Name: {file_name or 'Unknown'}"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(
|
||||
"Provider %s does not support image, using placeholder.", provider
|
||||
)
|
||||
image_count = len(req.image_urls)
|
||||
placeholder = " ".join(["[图片]"] * image_count)
|
||||
if req.prompt:
|
||||
req.prompt = f"{placeholder} {req.prompt}"
|
||||
else:
|
||||
req.prompt = placeholder
|
||||
req.image_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(
|
||||
"Provider %s does not support tool_use, clearing tools.", provider
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
|
||||
def _sanitize_context_by_modalities(
|
||||
config: MainAgentBuildConfig,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
if not config.sanitize_context_by_modalities:
|
||||
return
|
||||
if not isinstance(req.contexts, list) or not req.contexts:
|
||||
return
|
||||
modalities = provider.provider_config.get("modalities", None)
|
||||
if not modalities or not isinstance(modalities, list):
|
||||
return
|
||||
supports_image = bool("image" in modalities)
|
||||
supports_tool_use = bool("tool_use" in modalities)
|
||||
if supports_image and supports_tool_use:
|
||||
return
|
||||
|
||||
sanitized_contexts: list[dict] = []
|
||||
removed_image_blocks = 0
|
||||
removed_tool_messages = 0
|
||||
removed_tool_calls = 0
|
||||
|
||||
for msg in req.contexts:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
|
||||
new_msg = msg
|
||||
if not supports_tool_use:
|
||||
if role == "tool":
|
||||
removed_tool_messages += 1
|
||||
continue
|
||||
if role == "assistant" and "tool_calls" in new_msg:
|
||||
if "tool_calls" in new_msg:
|
||||
removed_tool_calls += 1
|
||||
new_msg.pop("tool_calls", None)
|
||||
new_msg.pop("tool_call_id", None)
|
||||
|
||||
if not supports_image:
|
||||
content = new_msg.get("content")
|
||||
if isinstance(content, list):
|
||||
filtered_parts: list = []
|
||||
removed_any_image = False
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = str(part.get("type", "")).lower()
|
||||
if part_type in {"image_url", "image"}:
|
||||
removed_any_image = True
|
||||
removed_image_blocks += 1
|
||||
continue
|
||||
filtered_parts.append(part)
|
||||
if removed_any_image:
|
||||
new_msg["content"] = filtered_parts
|
||||
|
||||
if role == "assistant":
|
||||
content = new_msg.get("content")
|
||||
has_tool_calls = bool(new_msg.get("tool_calls"))
|
||||
if not has_tool_calls:
|
||||
if not content:
|
||||
continue
|
||||
if isinstance(content, str) and not content.strip():
|
||||
continue
|
||||
|
||||
sanitized_contexts.append(new_msg)
|
||||
|
||||
if removed_image_blocks or removed_tool_messages or removed_tool_calls:
|
||||
logger.debug(
|
||||
"sanitize_context_by_modalities applied: "
|
||||
"removed_image_blocks=%s, removed_tool_messages=%s, removed_tool_calls=%s",
|
||||
removed_image_blocks,
|
||||
removed_tool_messages,
|
||||
removed_tool_calls,
|
||||
)
|
||||
req.contexts = sanitized_contexts
|
||||
|
||||
|
||||
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
plugin = star_map.get(mp)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
|
||||
async def _handle_webchat(
|
||||
event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
) -> None:
|
||||
from astrbot.core import db_helper
|
||||
|
||||
chatui_session_id = event.session_id.split("!")[-1]
|
||||
user_prompt = req.prompt
|
||||
session = await db_helper.get_platform_session_by_id(chatui_session_id)
|
||||
|
||||
if not user_prompt or not chatui_session_id or not session or session.display_name:
|
||||
return
|
||||
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt=(
|
||||
"You are a conversation title generator. "
|
||||
"Generate a concise title in the same language as the user’s input, "
|
||||
"no more than 10 words, capturing only the core topic."
|
||||
"If the input is a greeting, small talk, or has no clear topic, "
|
||||
"(e.g., “hi”, “hello”, “haha”), return <None>. "
|
||||
"Output only the title itself or <None>, with no explanations."
|
||||
),
|
||||
prompt=f"Generate a concise title for the following user query:\n{user_prompt}",
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
logger.info(
|
||||
"Generated chatui title for session %s: %s", chatui_session_id, title
|
||||
)
|
||||
await db_helper.update_platform_session(
|
||||
session_id=chatui_session_id,
|
||||
display_name=title,
|
||||
)
|
||||
|
||||
|
||||
def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) -> None:
|
||||
if config.safety_mode_strategy == "system_prompt":
|
||||
req.system_prompt = (
|
||||
f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}\n\n{req.system_prompt or ''}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported llm_safety_mode strategy: %s.",
|
||||
config.safety_mode_strategy,
|
||||
)
|
||||
|
||||
|
||||
def _apply_sandbox_tools(
|
||||
config: MainAgentBuildConfig, req: ProviderRequest, session_id: str
|
||||
) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
if config.sandbox_cfg.get("booter") == "shipyard":
|
||||
ep = config.sandbox_cfg.get("shipyard_endpoint", "")
|
||||
at = config.sandbox_cfg.get("shipyard_access_token", "")
|
||||
if not ep or not at:
|
||||
logger.error("Shipyard sandbox configuration is incomplete.")
|
||||
return
|
||||
os.environ["SHIPYARD_ENDPOINT"] = ep
|
||||
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
|
||||
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(PYTHON_TOOL)
|
||||
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
|
||||
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
|
||||
req.system_prompt += f"\n{SANDBOX_MODE_PROMPT}\n"
|
||||
|
||||
|
||||
def _proactive_cron_job_tools(req: ProviderRequest, event: AstrMessageEvent) -> None:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(CREATE_CRON_JOB_TOOL)
|
||||
req.func_tool.add_tool(DELETE_CRON_JOB_TOOL)
|
||||
req.func_tool.add_tool(LIST_CRON_JOBS_TOOL)
|
||||
|
||||
|
||||
def _get_compress_provider(
|
||||
config: MainAgentBuildConfig, plugin_context: Context
|
||||
) -> Provider | None:
|
||||
if not config.llm_compress_provider_id:
|
||||
return None
|
||||
if config.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = plugin_context.get_provider_by_id(config.llm_compress_provider_id)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
"未找到指定的上下文压缩模型 %s,将跳过压缩。",
|
||||
config.llm_compress_provider_id,
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
"指定的上下文压缩模型 %s 不是对话模型,将跳过压缩。",
|
||||
config.llm_compress_provider_id,
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
|
||||
async def build_main_agent(
|
||||
*,
|
||||
event: AstrMessageEvent,
|
||||
plugin_context: Context,
|
||||
config: MainAgentBuildConfig,
|
||||
provider: Provider | None = None,
|
||||
req: ProviderRequest | None = None,
|
||||
) -> MainAgentBuildResult | None:
|
||||
"""构建主对话代理(Main Agent),并且自动 reset。"""
|
||||
provider = provider or _select_provider(event, plugin_context)
|
||||
if provider is None:
|
||||
logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。")
|
||||
return None
|
||||
|
||||
if req is None:
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if config.provider_wake_prefix and not event.message_str.startswith(
|
||||
config.provider_wake_prefix
|
||||
):
|
||||
return None
|
||||
|
||||
req.prompt = event.message_str[len(config.provider_wake_prefix) :]
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"[Image Attachment: path {image_path}]")
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
file_path = await comp.get_file()
|
||||
file_name = comp.name or os.path.basename(file_path)
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(
|
||||
text=f"[File Attachment: name {file_name}, path {file_path}]"
|
||||
)
|
||||
)
|
||||
|
||||
conversation = await _get_session_conv(event, plugin_context)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
if config.file_extract_enabled:
|
||||
try:
|
||||
await _apply_file_extract(event, req, config)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Error occurred while applying file extract: %s", exc)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
if not event.get_group_id() and req.extra_user_content_parts:
|
||||
req.prompt = "<attachment>"
|
||||
else:
|
||||
return None
|
||||
|
||||
await _apply_kb(event, req, plugin_context, config)
|
||||
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
_modalities_fix(provider, req)
|
||||
_plugin_tool_fix(event, req)
|
||||
_sanitize_context_by_modalities(config, provider, req)
|
||||
|
||||
if config.llm_safety_mode:
|
||||
_apply_llm_safety_mode(config, req)
|
||||
|
||||
if config.sandbox_cfg.get("enable", False):
|
||||
_apply_sandbox_tools(config, req, req.session_id)
|
||||
|
||||
agent_runner = AgentRunner()
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=plugin_context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
_proactive_cron_job_tools(req, event)
|
||||
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info["limit"][
|
||||
"context"
|
||||
]
|
||||
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(_handle_webchat(event, req, provider))
|
||||
req.system_prompt += f"\n{CHATUI_EXTRA_PROMPT}\n"
|
||||
|
||||
if req.func_tool and req.func_tool.tools:
|
||||
tool_prompt = (
|
||||
TOOL_CALL_PROMPT
|
||||
if config.tool_schema_mode == "full"
|
||||
else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE
|
||||
)
|
||||
req.system_prompt += f"\n{tool_prompt}\n"
|
||||
|
||||
action_type = event.get_extra("action_type")
|
||||
if action_type == "live":
|
||||
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=config.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=config.streaming_response,
|
||||
llm_compress_instruction=config.llm_compress_instruction,
|
||||
llm_compress_keep_recent=config.llm_compress_keep_recent,
|
||||
llm_compress_provider=_get_compress_provider(config, plugin_context),
|
||||
truncate_turns=config.dequeue_context_length,
|
||||
enforce_max_turns=config.max_context_length,
|
||||
tool_schema_mode=config.tool_schema_mode,
|
||||
)
|
||||
|
||||
return MainAgentBuildResult(
|
||||
agent_runner=agent_runner,
|
||||
provider_request=req,
|
||||
provider=provider,
|
||||
)
|
||||
+3
-1
@@ -165,7 +165,9 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
|
||||
|
||||
try:
|
||||
target_session = (
|
||||
MessageSession.from_str(session) if isinstance(session, str) else session
|
||||
MessageSession.from_str(session)
|
||||
if isinstance(session, str)
|
||||
else session
|
||||
)
|
||||
except Exception as e:
|
||||
return f"error: invalid session: {e}"
|
||||
@@ -163,7 +163,7 @@ class AstrBotCoreLifecycle:
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
|
||||
# 初始化 CronJob 管理器
|
||||
self.cron_manager = CronJobManager(self.star_context, self.db)
|
||||
self.cron_manager = CronJobManager(self.db)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -231,7 +231,7 @@ class AstrBotCoreLifecycle:
|
||||
cron_task = None
|
||||
if self.cron_manager:
|
||||
cron_task = asyncio.create_task(
|
||||
self.cron_manager.start(),
|
||||
self.cron_manager.start(self.star_context),
|
||||
name="cron_manager",
|
||||
)
|
||||
|
||||
|
||||
@@ -11,20 +11,27 @@ from astrbot.core.cron.events import CronMessageEvent
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import CronJob
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.message.components import Plain
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
class CronJobManager:
|
||||
"""Central scheduler for BasicCronJob and ActiveAgentCronJob."""
|
||||
|
||||
def __init__(self, ctx, db: BaseDatabase):
|
||||
self.ctx = ctx
|
||||
def __init__(self, db: BaseDatabase):
|
||||
self.db = db
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self._basic_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._started = False
|
||||
|
||||
async def start(self):
|
||||
async def start(self, ctx: "Context"):
|
||||
self.ctx: Context = ctx # star context
|
||||
async with self._lock:
|
||||
if self._started:
|
||||
return
|
||||
@@ -219,19 +226,21 @@ class CronJobManager:
|
||||
"cron_payload": payload,
|
||||
}
|
||||
|
||||
await self._dispatch_agent_event(
|
||||
await self._woke_main_agent(
|
||||
message=note,
|
||||
session_str=session_str,
|
||||
extras=extras,
|
||||
)
|
||||
|
||||
async def _dispatch_agent_event(
|
||||
async def _woke_main_agent(
|
||||
self,
|
||||
*,
|
||||
message: str,
|
||||
session_str: str,
|
||||
extras: dict | None = None,
|
||||
extras: dict,
|
||||
):
|
||||
from astrbot.core.astr_main_agent import build_main_agent, MainAgentBuildConfig
|
||||
|
||||
try:
|
||||
session = (
|
||||
session_str
|
||||
@@ -250,7 +259,43 @@ class CronJobManager:
|
||||
message_type=session.message_type,
|
||||
)
|
||||
|
||||
await self.ctx.get_event_queue().put(cron_event)
|
||||
config = MainAgentBuildConfig(tool_call_timeout=3600)
|
||||
result = await build_main_agent(
|
||||
event=cron_event, plugin_context=self.ctx, config=config
|
||||
)
|
||||
if not result:
|
||||
logger.error("Failed to build main agent for cron job.")
|
||||
return
|
||||
req = result.provider_request
|
||||
runner = result.agent_runner
|
||||
|
||||
# finetine the messages
|
||||
job_name = extras.get("name", "scheduled task")
|
||||
note = extras.get("note") or extras.get("description") or ""
|
||||
if req.contexts:
|
||||
context_dump = req._print_friendly_context()
|
||||
req.system_prompt += (
|
||||
"\n\nBellow is you and user previous conversation history:\n"
|
||||
f"{context_dump}"
|
||||
)
|
||||
req.system_prompt += (
|
||||
"\n[Scheduler Context] This turn is triggered automatically by cron job "
|
||||
f'"{job_name}" (type: {extras.get("type", "unknown")}). '
|
||||
"Act proactively based on the provided note and current context. "
|
||||
)
|
||||
if note:
|
||||
req.system_prompt += f"[Scheduler Note]: {note}\n"
|
||||
|
||||
req.prompt = "You are now responding to a scheduled task. Output using same language as previous conversation."
|
||||
|
||||
async for _ in runner.step_until_done(30):
|
||||
pass
|
||||
llm_resp = runner.get_final_llm_resp()
|
||||
if not llm_resp:
|
||||
logger.warning("Cron job agent got no response")
|
||||
return
|
||||
message_chain = MessageChain(chain=[Plain(text=llm_resp.completion_text)])
|
||||
await self.ctx.send_message(session=session, message_chain=message_chain)
|
||||
|
||||
|
||||
__all__ = ["CronJobManager"]
|
||||
|
||||
@@ -1,61 +1,36 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core.message.components import File, Image, Reply
|
||||
from astrbot.core.message.components import File, Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from .....astr_agent_context import AgentContextWrapper
|
||||
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent
|
||||
from .....astr_agent_tool_exec import FunctionToolExecutor
|
||||
from .....astr_agent_run_util import run_agent, run_live_agent
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...stage import Stage
|
||||
from ...utils import (
|
||||
CHATUI_EXTRA_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
FILE_UPLOAD_TOOL,
|
||||
KNOWLEDGE_BASE_QUERY_TOOL,
|
||||
LIVE_MODE_SYSTEM_PROMPT,
|
||||
LLM_SAFETY_MODE_SYSTEM_PROMPT,
|
||||
PYTHON_TOOL,
|
||||
SANDBOX_MODE_PROMPT,
|
||||
TOOL_CALL_PROMPT,
|
||||
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
|
||||
SEND_MESSAGE_TO_USER_TOOL,
|
||||
decoded_blocked,
|
||||
retrieve_knowledge_base,
|
||||
)
|
||||
from astrbot.core.tools.cron_tools import (
|
||||
CREATE_CRON_JOB_TOOL,
|
||||
DELETE_CRON_JOB_TOOL,
|
||||
LIST_CRON_JOBS_TOOL,
|
||||
from astrbot.core.astr_main_agent import (
|
||||
MainAgentBuildConfig,
|
||||
MainAgentBuildResult,
|
||||
build_main_agent,
|
||||
)
|
||||
from dataclasses import replace
|
||||
|
||||
|
||||
class InternalAgentSubStage(Stage):
|
||||
@@ -121,453 +96,35 @@ class InternalAgentSubStage(Stage):
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
"""选择使用的 LLM 提供商"""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
_ctx = self.ctx.plugin_manager.context
|
||||
if sel_provider and isinstance(sel_provider, str):
|
||||
provider = _ctx.get_provider_by_id(sel_provider)
|
||||
if not provider:
|
||||
logger.error(f"未找到指定的提供商: {sel_provider}。")
|
||||
return provider
|
||||
try:
|
||||
prov = _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error occurred while selecting provider: {e}")
|
||||
return None
|
||||
return prov
|
||||
|
||||
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||
umo = event.unified_msg_origin
|
||||
conv_mgr = self.conv_manager
|
||||
|
||||
# 获取对话上下文
|
||||
cid = await conv_mgr.get_curr_conversation_id(umo)
|
||||
if not cid:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||
if not conversation:
|
||||
raise RuntimeError("无法创建新的对话。")
|
||||
return conversation
|
||||
|
||||
async def _apply_kb(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply knowledge base context to the provider request"""
|
||||
if not self.kb_agentic_mode:
|
||||
if req.prompt is None:
|
||||
return
|
||||
try:
|
||||
kb_result = await retrieve_knowledge_base(
|
||||
query=req.prompt,
|
||||
umo=event.unified_msg_origin,
|
||||
context=self.ctx.plugin_manager.context,
|
||||
)
|
||||
if not kb_result:
|
||||
return
|
||||
if req.system_prompt is not None:
|
||||
req.system_prompt += (
|
||||
f"\n\n[Related Knowledge Base Results]:\n{kb_result}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while retrieving knowledge base: {e}")
|
||||
else:
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
|
||||
|
||||
async def _apply_file_extract(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""Apply file extract to the provider request"""
|
||||
file_paths = []
|
||||
file_names = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
file_paths.append(await comp.get_file())
|
||||
file_names.append(comp.name)
|
||||
elif isinstance(comp, Reply) and comp.chain:
|
||||
for reply_comp in comp.chain:
|
||||
if isinstance(reply_comp, File):
|
||||
file_paths.append(await reply_comp.get_file())
|
||||
file_names.append(reply_comp.name)
|
||||
if not file_paths:
|
||||
return
|
||||
if not req.prompt:
|
||||
req.prompt = "总结一下文件里面讲了什么?"
|
||||
if self.file_extract_prov == "moonshotai":
|
||||
if not self.file_extract_msh_api_key:
|
||||
logger.error("Moonshot AI API key for file extract is not set")
|
||||
return
|
||||
file_contents = await asyncio.gather(
|
||||
*[
|
||||
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
|
||||
for file_path in file_paths
|
||||
]
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
|
||||
return
|
||||
|
||||
# add file extract results to contexts
|
||||
for file_content, file_name in zip(file_contents, file_names):
|
||||
req.contexts.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
|
||||
},
|
||||
)
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""检查提供商的模态能力,清理请求中的不支持内容"""
|
||||
if req.image_urls:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
||||
if "image" not in provider_cfg:
|
||||
logger.debug(
|
||||
f"用户设置提供商 {provider} 不支持图像,将图像替换为占位符。"
|
||||
)
|
||||
# 为每个图片添加占位符到 prompt
|
||||
image_count = len(req.image_urls)
|
||||
placeholder = " ".join(["[图片]"] * image_count)
|
||||
if req.prompt:
|
||||
req.prompt = f"{placeholder} {req.prompt}"
|
||||
else:
|
||||
req.prompt = placeholder
|
||||
req.image_urls = []
|
||||
if req.func_tool:
|
||||
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
|
||||
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
||||
if "tool_use" not in provider_cfg:
|
||||
logger.debug(
|
||||
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
||||
)
|
||||
req.func_tool = None
|
||||
|
||||
def _sanitize_context_by_modalities(
|
||||
self,
|
||||
provider: Provider,
|
||||
req: ProviderRequest,
|
||||
) -> None:
|
||||
"""Sanitize `req.contexts` (including history) by current provider modalities."""
|
||||
if not self.sanitize_context_by_modalities:
|
||||
return
|
||||
|
||||
if not isinstance(req.contexts, list) or not req.contexts:
|
||||
return
|
||||
|
||||
modalities = provider.provider_config.get("modalities", None)
|
||||
# if modalities is not configured, do not sanitize.
|
||||
if not modalities or not isinstance(modalities, list):
|
||||
return
|
||||
|
||||
supports_image = bool("image" in modalities)
|
||||
supports_tool_use = bool("tool_use" in modalities)
|
||||
|
||||
if supports_image and supports_tool_use:
|
||||
return
|
||||
|
||||
sanitized_contexts: list[dict] = []
|
||||
removed_image_blocks = 0
|
||||
removed_tool_messages = 0
|
||||
removed_tool_calls = 0
|
||||
|
||||
for msg in req.contexts:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
|
||||
new_msg: dict = msg
|
||||
|
||||
# tool_use sanitize
|
||||
if not supports_tool_use:
|
||||
if role == "tool":
|
||||
# tool response block
|
||||
removed_tool_messages += 1
|
||||
continue
|
||||
if role == "assistant" and "tool_calls" in new_msg:
|
||||
# assistant message with tool calls
|
||||
if "tool_calls" in new_msg:
|
||||
removed_tool_calls += 1
|
||||
new_msg.pop("tool_calls", None)
|
||||
new_msg.pop("tool_call_id", None)
|
||||
|
||||
# image sanitize
|
||||
if not supports_image:
|
||||
content = new_msg.get("content")
|
||||
if isinstance(content, list):
|
||||
filtered_parts: list = []
|
||||
removed_any_image = False
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = str(part.get("type", "")).lower()
|
||||
if part_type in {"image_url", "image"}:
|
||||
removed_any_image = True
|
||||
removed_image_blocks += 1
|
||||
continue
|
||||
filtered_parts.append(part)
|
||||
|
||||
if removed_any_image:
|
||||
new_msg["content"] = filtered_parts
|
||||
|
||||
# drop empty assistant messages (e.g. only tool_calls without content)
|
||||
if role == "assistant":
|
||||
content = new_msg.get("content")
|
||||
has_tool_calls = bool(new_msg.get("tool_calls"))
|
||||
if not has_tool_calls:
|
||||
if not content:
|
||||
continue
|
||||
if isinstance(content, str) and not content.strip():
|
||||
continue
|
||||
|
||||
sanitized_contexts.append(new_msg)
|
||||
|
||||
if removed_image_blocks or removed_tool_messages or removed_tool_calls:
|
||||
logger.debug(
|
||||
"sanitize_context_by_modalities applied: "
|
||||
f"removed_image_blocks={removed_image_blocks}, "
|
||||
f"removed_tool_messages={removed_tool_messages}, "
|
||||
f"removed_tool_calls={removed_tool_calls}"
|
||||
)
|
||||
|
||||
req.contexts = sanitized_contexts
|
||||
|
||||
def _plugin_tool_fix(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表"""
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
plugin = star_map.get(mp)
|
||||
if not plugin:
|
||||
continue
|
||||
if plugin.name in event.plugins_name or plugin.reserved:
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
async def _handle_webchat(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
prov: Provider,
|
||||
):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
from astrbot.core import db_helper
|
||||
|
||||
chatui_session_id = event.session_id.split("!")[-1]
|
||||
user_prompt = req.prompt
|
||||
|
||||
session = await db_helper.get_platform_session_by_id(chatui_session_id)
|
||||
|
||||
if (
|
||||
not user_prompt
|
||||
or not chatui_session_id
|
||||
or not session
|
||||
or session.display_name
|
||||
):
|
||||
return
|
||||
|
||||
llm_resp = await prov.text_chat(
|
||||
system_prompt=(
|
||||
"You are a conversation title generator. "
|
||||
"Generate a concise title in the same language as the user’s input, "
|
||||
"no more than 10 words, capturing only the core topic."
|
||||
"If the input is a greeting, small talk, or has no clear topic, "
|
||||
"(e.g., “hi”, “hello”, “haha”), return <None>. "
|
||||
"Output only the title itself or <None>, with no explanations."
|
||||
),
|
||||
prompt=(
|
||||
f"Generate a concise title for the following user query:\n{user_prompt}"
|
||||
),
|
||||
self.main_agent_cfg = MainAgentBuildConfig(
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
tool_schema_mode=self.tool_schema_mode,
|
||||
sanitize_context_by_modalities=self.sanitize_context_by_modalities,
|
||||
kb_agentic_mode=self.kb_agentic_mode,
|
||||
file_extract_enabled=self.file_extract_enabled,
|
||||
file_extract_prov=self.file_extract_prov,
|
||||
file_extract_msh_api_key=self.file_extract_msh_api_key,
|
||||
context_limit_reached_strategy=self.context_limit_reached_strategy,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider_id=self.llm_compress_provider_id,
|
||||
max_context_length=self.max_context_length,
|
||||
dequeue_context_length=self.dequeue_context_length,
|
||||
llm_safety_mode=self.llm_safety_mode,
|
||||
safety_mode_strategy=self.safety_mode_strategy,
|
||||
sandbox_cfg=self.sandbox_cfg,
|
||||
)
|
||||
if llm_resp and llm_resp.completion_text:
|
||||
title = llm_resp.completion_text.strip()
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
logger.info(
|
||||
f"Generated chatui title for session {chatui_session_id}: {title}"
|
||||
)
|
||||
await db_helper.update_platform_session(
|
||||
session_id=chatui_session_id,
|
||||
display_name=title,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
return
|
||||
|
||||
if not llm_response.completion_text and not req.tool_calls_result:
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
# using agent context messages to save to history
|
||||
message_to_save = []
|
||||
skipped_initial_system = False
|
||||
for message in all_messages:
|
||||
if message.role == "system" and not skipped_initial_system:
|
||||
skipped_initial_system = True
|
||||
continue # skip first system message
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
# we do not save user and assistant messages that are marked as _no_save
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
def _apply_llm_safety_mode(self, req: ProviderRequest) -> None:
|
||||
"""Apply LLM safety mode to the provider request."""
|
||||
if self.safety_mode_strategy == "system_prompt":
|
||||
req.system_prompt = (
|
||||
f"{LLM_SAFETY_MODE_SYSTEM_PROMPT}\n\n{req.system_prompt or ''}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unsupported llm_safety_mode strategy: {self.safety_mode_strategy}.",
|
||||
)
|
||||
|
||||
def _apply_sandbox_tools(self, req: ProviderRequest, session_id: str) -> None:
|
||||
"""Add sandbox tools to the provider request."""
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
if self.sandbox_cfg.get("booter") == "shipyard":
|
||||
ep = self.sandbox_cfg.get("shipyard_endpoint", "")
|
||||
at = self.sandbox_cfg.get("shipyard_access_token", "")
|
||||
if not ep or not at:
|
||||
logger.error("Shipyard sandbox configuration is incomplete.")
|
||||
return
|
||||
os.environ["SHIPYARD_ENDPOINT"] = ep
|
||||
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
|
||||
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
|
||||
req.func_tool.add_tool(PYTHON_TOOL)
|
||||
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
|
||||
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
|
||||
req.system_prompt += f"\n{SANDBOX_MODE_PROMPT}\n"
|
||||
|
||||
def _proactive_cron_job_tools(
|
||||
self, req: ProviderRequest, event: AstrMessageEvent
|
||||
) -> None:
|
||||
"""Inject cron job context and tools into the provider request for proactive scheduling."""
|
||||
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(CREATE_CRON_JOB_TOOL)
|
||||
req.func_tool.add_tool(DELETE_CRON_JOB_TOOL)
|
||||
req.func_tool.add_tool(LIST_CRON_JOBS_TOOL)
|
||||
|
||||
cron_meta = event.get_extra("cron_job")
|
||||
if cron_meta:
|
||||
# The message event is triggered by a known cron job
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
|
||||
|
||||
job_name = cron_meta.get("name", "scheduled task")
|
||||
note = cron_meta.get("note") or cron_meta.get("description") or ""
|
||||
req.system_prompt += (
|
||||
f"\n[Scheduler Context] This turn is triggered automatically by cron job "
|
||||
f'"{job_name}" (type: {cron_meta.get("type", "unknown")}). '
|
||||
"Act proactively based on the provided note and current context. "
|
||||
"If you want to proactively notify the user, call `send_message_to_user` with a concise message.\n"
|
||||
)
|
||||
if note:
|
||||
req.system_prompt += f"[Scheduler Note]: {note}\n"
|
||||
|
||||
if bg := event.get_extra("background_task_result"):
|
||||
# The message event is triggered after a background task done
|
||||
result_text = bg.get("result") or ""
|
||||
if req.func_tool is None:
|
||||
req.func_tool = ToolSet()
|
||||
req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL)
|
||||
if result_text:
|
||||
req.system_prompt += f"\n[Background Task Result] {result_text}\n"
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
try:
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。")
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(
|
||||
f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。"
|
||||
)
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
# 检查消息内容是否有效,避免空消息触发钩子
|
||||
has_provider_request = event.get_extra("provider_request") is not None
|
||||
has_valid_message = bool(event.message_str and event.message_str.strip())
|
||||
# 检查是否有图片或其他媒体内容
|
||||
has_media_content = any(
|
||||
isinstance(comp, Image | File) for comp in event.message_obj.message
|
||||
)
|
||||
@@ -580,179 +137,50 @@ class InternalAgentSubStage(Stage):
|
||||
logger.debug("skip llm request: empty message and no provider_request")
|
||||
return
|
||||
|
||||
api_base = provider.provider_config.get("api_base", "")
|
||||
for host in decoded_blocked:
|
||||
if host in api_base:
|
||||
logger.error(
|
||||
f"Provider API base {api_base} is blocked due to security reasons. Please use another ai provider."
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
|
||||
# 通知等待调用 LLM(在获取锁之前)
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
build_cfg = replace(
|
||||
self.main_agent_cfg,
|
||||
provider_wake_prefix=provider_wake_prefix,
|
||||
streaming_response=streaming_response,
|
||||
)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
build_result: MainAgentBuildResult | None = await build_main_agent(
|
||||
event=event,
|
||||
plugin_context=self.ctx.plugin_manager.context,
|
||||
config=build_cfg,
|
||||
)
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"[Image Attachment: path {image_path}]")
|
||||
)
|
||||
elif isinstance(comp, File):
|
||||
file_path = await comp.get_file()
|
||||
file_name = comp.name or os.path.basename(file_path)
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(
|
||||
text=f"[File Attachment: name {file_name}, path {file_path}]"
|
||||
)
|
||||
)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
if not event.get_group_id() and req.extra_user_content_parts:
|
||||
req.prompt = "<attachment>"
|
||||
else:
|
||||
return
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
if build_result is None:
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
agent_runner = build_result.agent_runner
|
||||
req = build_result.provider_request
|
||||
provider = build_result.provider
|
||||
|
||||
# truncate contexts to fit max length
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
# sanitize contexts (including history) by provider modalities
|
||||
self._sanitize_context_by_modalities(provider, req)
|
||||
|
||||
# apply llm safety mode
|
||||
if self.llm_safety_mode:
|
||||
self._apply_llm_safety_mode(req)
|
||||
|
||||
# apply sandbox tools
|
||||
if self.sandbox_cfg.get("enable", False):
|
||||
self._apply_sandbox_tools(req, req.session_id)
|
||||
api_base = provider.provider_config.get("api_base", "")
|
||||
for host in decoded_blocked:
|
||||
if host in api_base:
|
||||
logger.error(
|
||||
"Provider API base %s is blocked due to security reasons. Please use another ai provider.",
|
||||
api_base,
|
||||
)
|
||||
return
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
# ChatUI 对话的标题生成
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
# 注入 ChatUI 额外 prompt
|
||||
# 比如 follow-up questions 提示等
|
||||
req.system_prompt += f"\n{CHATUI_EXTRA_PROMPT}\n"
|
||||
|
||||
# 注入基本 prompt
|
||||
if req.func_tool and req.func_tool.tools:
|
||||
tool_prompt = (
|
||||
TOOL_CALL_PROMPT
|
||||
if self.tool_schema_mode == "full"
|
||||
else TOOL_CALL_PROMPT_SKILLS_LIKE_MODE
|
||||
)
|
||||
req.system_prompt += f"\n{tool_prompt}\n"
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
action_type = event.get_extra("action_type")
|
||||
if action_type == "live":
|
||||
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
tool_schema_mode=self.tool_schema_mode,
|
||||
)
|
||||
|
||||
# 检测 Live Mode
|
||||
if action_type == "live":
|
||||
@@ -865,3 +293,52 @@ class InternalAgentSubStage(Stage):
|
||||
f"Error occurred while processing agent request: {e}"
|
||||
)
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
return
|
||||
|
||||
if not llm_response.completion_text and not req.tool_calls_result:
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
message_to_save = []
|
||||
skipped_initial_system = False
|
||||
for message in all_messages:
|
||||
if message.role == "system" and not skipped_initial_system:
|
||||
skipped_initial_system = True
|
||||
continue
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
|
||||
# we prevent astrbot from connecting to known malicious hosts
|
||||
# these hosts are base64 encoded
|
||||
BLOCKED = {"dGZid2h2d3IuY2xvdWQuc2VhbG9zLmlv", "a291cmljaGF0"}
|
||||
decoded_blocked = [base64.b64decode(b).decode("utf-8") for b in BLOCKED]
|
||||
|
||||
@@ -165,7 +165,7 @@ class ProviderRequest:
|
||||
|
||||
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
||||
|
||||
return result_parts
|
||||
return "\n".join(result_parts)
|
||||
|
||||
async def assemble_context(self) -> dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
|
||||
Reference in New Issue
Block a user