diff --git a/astrbot/cli/__init__.py b/astrbot/cli/__init__.py index 60d15c0d5..068376473 100644 --- a/astrbot/cli/__init__.py +++ b/astrbot/cli/__init__.py @@ -1 +1 @@ -__version__ = "4.18.2" +__version__ = "4.18.3" diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 10cf2e96c..94069089d 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,9 +1,10 @@ +import asyncio import copy import sys import time import traceback import typing as T -from dataclasses import dataclass +from dataclasses import dataclass, field from mcp.types import ( BlobResourceContents, @@ -68,6 +69,14 @@ class _HandleFunctionToolsResult: return cls(kind="cached_image", cached_image=image) +@dataclass(slots=True) +class FollowUpTicket: + seq: int + text: str + consumed: bool = False + resolved: asyncio.Event = field(default_factory=asyncio.Event) + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): @override async def reset( @@ -139,6 +148,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): self.run_context = run_context self._stop_requested = False self._aborted = False + self._pending_follow_ups: list[FollowUpTicket] = [] + self._follow_up_seq = 0 # These two are used for tool schema mode handling # We now have two modes: @@ -277,6 +288,55 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): roles.append(message.role) logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}") + def follow_up( + self, + *, + message_text: str, + ) -> FollowUpTicket | None: + """Queue a follow-up message for the next tool result.""" + if self.done(): + return None + text = (message_text or "").strip() + if not text: + return None + ticket = FollowUpTicket(seq=self._follow_up_seq, text=text) + self._follow_up_seq += 1 + self._pending_follow_ups.append(ticket) + return ticket + + def _resolve_unconsumed_follow_ups(self) -> None: + if not self._pending_follow_ups: + return + follow_ups = self._pending_follow_ups + self._pending_follow_ups = [] + for ticket in follow_ups: + ticket.resolved.set() + + def _consume_follow_up_notice(self) -> str: + if not self._pending_follow_ups: + return "" + follow_ups = self._pending_follow_ups + self._pending_follow_ups = [] + for ticket in follow_ups: + ticket.consumed = True + ticket.resolved.set() + follow_up_lines = "\n".join( + f"{idx}. {ticket.text}" for idx, ticket in enumerate(follow_ups, start=1) + ) + return ( + "\n\n[SYSTEM NOTICE] User sent follow-up messages while tool execution " + "was in progress. Prioritize these follow-up instructions in your next " + "actions. In your very next action, briefly acknowledge to the user " + "that their follow-up message(s) were received before continuing.\n" + f"{follow_up_lines}" + ) + + def _merge_follow_up_notice(self, content: str) -> str: + notice = self._consume_follow_up_notice() + if not notice: + return content + return f"{content}{notice}" + @override async def step(self): """Process a single step of the agent. @@ -391,6 +451,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): type="aborted", data=AgentResponseData(chain=MessageChain(type="aborted")), ) + self._resolve_unconsumed_follow_ups() return # 处理 LLM 响应 @@ -401,6 +462,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): self.final_llm_resp = llm_resp self.stats.end_time = time.time() self._transition_state(AgentState.ERROR) + self._resolve_unconsumed_follow_ups() yield AgentResponse( type="err", data=AgentResponseData( @@ -439,6 +501,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): await self.agent_hooks.on_agent_done(self.run_context, llm_resp) except Exception as e: logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + self._resolve_unconsumed_follow_ups() # 返回 LLM 结果 if llm_resp.result_chain: @@ -583,6 +646,15 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): tool_call_result_blocks: list[ToolCallMessageSegment] = [] logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") + def _append_tool_call_result(tool_call_id: str, content: str) -> None: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=tool_call_id, + content=self._merge_follow_up_notice(content), + ), + ) + # 执行函数调用 for func_tool_name, func_tool_args, func_tool_id in zip( llm_response.tools_call_name, @@ -622,12 +694,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): if not func_tool: logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。") - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: Tool {func_tool_name} not found.", - ), + _append_tool_call_result( + func_tool_id, + f"error: Tool {func_tool_name} not found.", ) continue @@ -680,12 +749,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): res = resp _final_resp = resp if isinstance(res.content[0], TextContent): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, - ), + _append_tool_call_result( + func_tool_id, + res.content[0].text, ) elif isinstance(res.content[0], ImageContent): # Cache the image instead of sending directly @@ -696,15 +762,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): index=0, mime_type=res.content[0].mimeType or "image/png", ) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=( - f"Image returned and cached at path='{cached_img.file_path}'. " - f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " - f"with type='image' and path='{cached_img.file_path}'." - ), + _append_tool_call_result( + func_tool_id, + ( + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." ), ) # Yield image info for LLM visibility (will be handled in step()) @@ -714,12 +777,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): elif isinstance(res.content[0], EmbeddedResource): resource = res.content[0].resource if isinstance(resource, TextResourceContents): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resource.text, - ), + _append_tool_call_result( + func_tool_id, + resource.text, ) elif ( isinstance(resource, BlobResourceContents) @@ -734,15 +794,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): index=0, mime_type=resource.mimeType, ) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=( - f"Image returned and cached at path='{cached_img.file_path}'. " - f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " - f"with type='image' and path='{cached_img.file_path}'." - ), + _append_tool_call_result( + func_tool_id, + ( + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." ), ) # Yield image info for LLM visibility @@ -750,12 +807,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): cached_img ) else: - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="The tool has returned a data type that is not supported.", - ), + _append_tool_call_result( + func_tool_id, + "The tool has returned a data type that is not supported.", ) elif resp is None: @@ -767,24 +821,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) self._transition_state(AgentState.DONE) self.stats.end_time = time.time() - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="The tool has no return value, or has sent the result directly to the user.", - ), + _append_tool_call_result( + func_tool_id, + "The tool has no return value, or has sent the result directly to the user.", ) else: # 不应该出现其他类型 logger.warning( f"Tool 返回了不支持的类型: {type(resp)}。", ) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*", - ), + _append_tool_call_result( + func_tool_id, + "*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*", ) try: @@ -798,12 +846,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): logger.error(f"Error in on_tool_end hook: {e}", exc_info=True) except Exception as e: logger.warning(traceback.format_exc()) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: {e!s}", - ), + _append_tool_call_result( + func_tool_id, + f"error: {e!s}", ) # yield the last tool call result diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 2e070f827..a0f662e61 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -5,7 +5,7 @@ from typing import Any, TypedDict from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "4.18.2" +VERSION = "4.18.3" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") WEBHOOK_SUPPORTED_PLATFORMS = [ @@ -429,7 +429,15 @@ CONFIG_METADATA_2 = { "slack_webhook_port": 6197, "slack_webhook_path": "/astrbot-slack-webhook/callback", }, - # LINE's config is located in line_adapter.py + "Line": { + "id": "line", + "type": "line", + "enable": False, + "channel_access_token": "", + "channel_secret": "", + "unified_webhook_mode": True, + "webhook_uuid": "", + }, "Satori": { "id": "satori", "type": "satori", diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 0363d4692..2fced806d 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -8,7 +8,7 @@ resolution for backward compatibility. from __future__ import annotations from importlib import import_module -from typing import Any +from typing import TYPE_CHECKING, Any from astrbot.core.message.message_event_result import ( EventResultType, @@ -17,6 +17,17 @@ from astrbot.core.message.message_event_result import ( from .stage_order import STAGES_ORDER +if TYPE_CHECKING: + from .content_safety_check.stage import ContentSafetyCheckStage + from .preprocess_stage.stage import PreProcessStage + from .process_stage.stage import ProcessStage + from .rate_limit_check.stage import RateLimitStage + from .respond.stage import RespondStage + from .result_decorate.stage import ResultDecorateStage + from .session_status_check.stage import SessionStatusCheckStage + from .waking_check.stage import WakingCheckStage + from .whitelist_check.stage import WhitelistCheckStage + _LAZY_EXPORTS = { "ContentSafetyCheckStage": ( "astrbot.core.pipeline.content_safety_check.stage", diff --git a/astrbot/core/pipeline/process_stage/follow_up.py b/astrbot/core/pipeline/process_stage/follow_up.py new file mode 100644 index 000000000..6c1a4fa06 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/follow_up.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + +from astrbot import logger +from astrbot.core.agent.runners.tool_loop_agent_runner import FollowUpTicket +from astrbot.core.astr_agent_run_util import AgentRunner +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +_ACTIVE_AGENT_RUNNERS: dict[str, AgentRunner] = {} +_FOLLOW_UP_ORDER_STATE: dict[str, dict[str, object]] = {} +"""UMO-level follow-up order state. + +State fields: +- `statuses`: seq -> {"pending"|"active"|"consumed"|"finished"} +- `next_order`: monotonically increasing sequence allocator +- `next_turn`: next sequence allowed to proceed when not consumed +""" + + +@dataclass(slots=True) +class FollowUpCapture: + umo: str + ticket: FollowUpTicket + order_seq: int + monitor_task: asyncio.Task[None] + + +def _event_follow_up_text(event: AstrMessageEvent) -> str: + text = (event.get_message_str() or "").strip() + if text: + return text + return event.get_message_outline().strip() + + +def register_active_runner(umo: str, runner: AgentRunner) -> None: + _ACTIVE_AGENT_RUNNERS[umo] = runner + + +def unregister_active_runner(umo: str, runner: AgentRunner) -> None: + if _ACTIVE_AGENT_RUNNERS.get(umo) is runner: + _ACTIVE_AGENT_RUNNERS.pop(umo, None) + + +def _get_follow_up_order_state(umo: str) -> dict[str, object]: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if state is None: + state = { + "condition": asyncio.Condition(), + # Sequence status map for strict in-order resume after unresolved follow-ups. + "statuses": {}, + # Stable allocator for arrival order; never decreases for the same UMO state. + "next_order": 0, + # The sequence currently allowed to continue main internal flow. + "next_turn": 0, + } + _FOLLOW_UP_ORDER_STATE[umo] = state + return state + + +def _advance_follow_up_turn_locked(state: dict[str, object]) -> None: + # Skip slots that are already handled, and stop at the first unfinished slot. + statuses = state["statuses"] + assert isinstance(statuses, dict) + next_turn = state["next_turn"] + assert isinstance(next_turn, int) + + while True: + curr = statuses.get(next_turn) + if curr in ("consumed", "finished"): + statuses.pop(next_turn, None) + next_turn += 1 + continue + break + + state["next_turn"] = next_turn + + +def _allocate_follow_up_order(umo: str) -> int: + state = _get_follow_up_order_state(umo) + next_order = state["next_order"] + assert isinstance(next_order, int) + seq = next_order + state["next_order"] = seq + 1 + statuses = state["statuses"] + assert isinstance(statuses, dict) + statuses[seq] = "pending" + return seq + + +async def _mark_follow_up_consumed(umo: str, seq: int) -> None: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if not state: + return + condition = state["condition"] + assert isinstance(condition, asyncio.Condition) + async with condition: + statuses = state["statuses"] + assert isinstance(statuses, dict) + if seq in statuses and statuses[seq] != "finished": + statuses[seq] = "consumed" + _advance_follow_up_turn_locked(state) + condition.notify_all() + + # Release state only when this UMO has no pending statuses and no active runner. + if not statuses and _ACTIVE_AGENT_RUNNERS.get(umo) is None: + _FOLLOW_UP_ORDER_STATE.pop(umo, None) + + +async def _activate_and_wait_follow_up_turn(umo: str, seq: int) -> None: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if not state: + return + condition = state["condition"] + assert isinstance(condition, asyncio.Condition) + async with condition: + statuses = state["statuses"] + assert isinstance(statuses, dict) + if seq in statuses: + statuses[seq] = "active" + + # Strict ordering: only the head (`next_turn`) can continue. + while True: + next_turn = state["next_turn"] + assert isinstance(next_turn, int) + if next_turn == seq: + break + await condition.wait() + + +async def _finish_follow_up_turn(umo: str, seq: int) -> None: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if not state: + return + condition = state["condition"] + assert isinstance(condition, asyncio.Condition) + async with condition: + statuses = state["statuses"] + assert isinstance(statuses, dict) + if seq in statuses: + statuses[seq] = "finished" + _advance_follow_up_turn_locked(state) + condition.notify_all() + + if not statuses and _ACTIVE_AGENT_RUNNERS.get(umo) is None: + _FOLLOW_UP_ORDER_STATE.pop(umo, None) + + +async def _monitor_follow_up_ticket( + umo: str, + ticket: FollowUpTicket, + order_seq: int, +) -> None: + """Advance consumed slots immediately on resolution to avoid wake-order drift.""" + await ticket.resolved.wait() + if ticket.consumed: + await _mark_follow_up_consumed(umo, order_seq) + + +def try_capture_follow_up(event: AstrMessageEvent) -> FollowUpCapture | None: + sender_id = event.get_sender_id() + if not sender_id: + return None + runner = _ACTIVE_AGENT_RUNNERS.get(event.unified_msg_origin) + if not runner: + return None + runner_event = getattr(getattr(runner.run_context, "context", None), "event", None) + if runner_event is None: + return None + active_sender_id = runner_event.get_sender_id() + if not active_sender_id or active_sender_id != sender_id: + return None + + ticket = runner.follow_up(message_text=_event_follow_up_text(event)) + if not ticket: + return None + # Allocate strict order at capture time (arrival order), not at wake time. + order_seq = _allocate_follow_up_order(event.unified_msg_origin) + monitor_task = asyncio.create_task( + _monitor_follow_up_ticket( + event.unified_msg_origin, + ticket, + order_seq, + ) + ) + logger.info( + "Captured follow-up message for active agent run, umo=%s, order_seq=%s", + event.unified_msg_origin, + order_seq, + ) + return FollowUpCapture( + umo=event.unified_msg_origin, + ticket=ticket, + order_seq=order_seq, + monitor_task=monitor_task, + ) + + +async def prepare_follow_up_capture(capture: FollowUpCapture) -> tuple[bool, bool]: + """Return `(consumed_marked, activated)` for internal stage branch handling.""" + await capture.ticket.resolved.wait() + if capture.ticket.consumed: + await _mark_follow_up_consumed(capture.umo, capture.order_seq) + return True, False + await _activate_and_wait_follow_up_turn(capture.umo, capture.order_seq) + return False, True + + +async def finalize_follow_up_capture( + capture: FollowUpCapture, + *, + activated: bool, + consumed_marked: bool, +) -> None: + # Best-effort cancellation: monitor task is auxiliary and should not leak. + if not capture.monitor_task.done(): + capture.monitor_task.cancel() + try: + await capture.monitor_task + except asyncio.CancelledError: + pass + + if activated: + await _finish_follow_up_turn(capture.umo, capture.order_seq) + elif not consumed_marked: + await _mark_follow_up_consumed(capture.umo, capture.order_seq) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 98cf77fcc..d95f7f86c 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -29,8 +29,16 @@ from astrbot.core.star.star_handler import EventType from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from .....astr_agent_run_util import run_agent, run_live_agent +from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent from ....context import PipelineContext, call_event_hook +from ...follow_up import ( + FollowUpCapture, + finalize_follow_up_capture, + prepare_follow_up_capture, + register_active_runner, + try_capture_follow_up, + unregister_active_runner, +) class InternalAgentSubStage(Stage): @@ -130,6 +138,9 @@ class InternalAgentSubStage(Stage): async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: + follow_up_capture: FollowUpCapture | None = None + follow_up_consumed_marked = False + follow_up_activated = False try: streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: @@ -150,188 +161,208 @@ class InternalAgentSubStage(Stage): return logger.debug("ready to request llm provider") + follow_up_capture = try_capture_follow_up(event) + if follow_up_capture: + ( + follow_up_consumed_marked, + follow_up_activated, + ) = await prepare_follow_up_capture(follow_up_capture) + if follow_up_consumed_marked: + logger.info( + "Follow-up ticket already consumed, stopping processing. umo=%s, seq=%s", + event.unified_msg_origin, + follow_up_capture.ticket.seq, + ) + return await event.send_typing() await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): logger.debug("acquired session lock for llm request") + agent_runner: AgentRunner | None = None + runner_registered = False + try: + build_cfg = replace( + self.main_agent_cfg, + provider_wake_prefix=provider_wake_prefix, + streaming_response=streaming_response, + ) - build_cfg = replace( - self.main_agent_cfg, - provider_wake_prefix=provider_wake_prefix, - streaming_response=streaming_response, - ) + build_result: MainAgentBuildResult | None = await build_main_agent( + event=event, + plugin_context=self.ctx.plugin_manager.context, + config=build_cfg, + apply_reset=False, + ) - build_result: MainAgentBuildResult | None = await build_main_agent( - event=event, - plugin_context=self.ctx.plugin_manager.context, - config=build_cfg, - apply_reset=False, - ) - - if build_result is None: - return - - agent_runner = build_result.agent_runner - req = build_result.provider_request - provider = build_result.provider - reset_coro = build_result.reset_coro - - api_base = provider.provider_config.get("api_base", "") - for host in decoded_blocked: - if host in api_base: - logger.error( - "Provider API base %s is blocked due to security reasons. Please use another ai provider.", - api_base, - ) + if build_result is None: return - stream_to_general = ( - self.unsupported_streaming_strategy == "turn_off" - and not event.platform_meta.support_streaming_message - ) + agent_runner = build_result.agent_runner + req = build_result.provider_request + provider = build_result.provider + reset_coro = build_result.reset_coro - if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + api_base = provider.provider_config.get("api_base", "") + for host in decoded_blocked: + if host in api_base: + logger.error( + "Provider API base %s is blocked due to security reasons. Please use another ai provider.", + api_base, + ) + return + + stream_to_general = ( + self.unsupported_streaming_strategy == "turn_off" + and not event.platform_meta.support_streaming_message + ) + + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + if reset_coro: + reset_coro.close() + return + + # apply reset if reset_coro: - reset_coro.close() - return + await reset_coro - # apply reset - if reset_coro: - await reset_coro + register_active_runner(event.unified_msg_origin, agent_runner) + runner_registered = True + action_type = event.get_extra("action_type") - action_type = event.get_extra("action_type") - - event.trace.record( - "astr_agent_prepare", - system_prompt=req.system_prompt, - tools=req.func_tool.names() if req.func_tool else [], - stream=streaming_response, - chat_provider={ - "id": provider.provider_config.get("id", ""), - "model": provider.get_model(), - }, - ) - - # 检测 Live Mode - if action_type == "live": - # Live Mode: 使用 run_live_agent - logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理") - - # 获取 TTS Provider - tts_provider = ( - self.ctx.plugin_manager.context.get_using_tts_provider( - event.unified_msg_origin - ) + event.trace.record( + "astr_agent_prepare", + system_prompt=req.system_prompt, + tools=req.func_tool.names() if req.func_tool else [], + stream=streaming_response, + chat_provider={ + "id": provider.provider_config.get("id", ""), + "model": provider.get_model(), + }, ) - if not tts_provider: - logger.warning( - "[Live Mode] TTS Provider 未配置,将使用普通流式模式" + # 检测 Live Mode + if action_type == "live": + # Live Mode: 使用 run_live_agent + logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理") + + # 获取 TTS Provider + tts_provider = ( + self.ctx.plugin_manager.context.get_using_tts_provider( + event.unified_msg_origin + ) ) - # 使用 run_live_agent,总是使用流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_live_agent( - agent_runner, - tts_provider, - self.max_step, - self.show_tool_use, - self.show_tool_call_result, - show_reasoning=self.show_reasoning, + if not tts_provider: + logger.warning( + "[Live Mode] TTS Provider 未配置,将使用普通流式模式" + ) + + # 使用 run_live_agent,总是使用流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_live_agent( + agent_runner, + tts_provider, + self.max_step, + self.show_tool_use, + self.show_tool_call_result, + show_reasoning=self.show_reasoning, + ), ), - ), - ) - yield + ) + yield - # 保存历史记录 - if agent_runner.done() and ( - not event.is_stopped() or agent_runner.was_aborted() - ): + # 保存历史记录 + if agent_runner.done() and ( + not event.is_stopped() or agent_runner.was_aborted() + ): + await self._save_to_history( + event, + req, + agent_runner.get_final_llm_resp(), + agent_runner.run_context.messages, + agent_runner.stats, + user_aborted=agent_runner.was_aborted(), + ) + + elif streaming_response and not stream_to_general: + # 流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + self.show_tool_call_result, + show_reasoning=self.show_reasoning, + ), + ), + ) + yield + if agent_runner.done(): + if final_llm_resp := agent_runner.get_final_llm_resp(): + if final_llm_resp.completion_text: + chain = ( + MessageChain() + .message(final_llm_resp.completion_text) + .chain + ) + elif final_llm_resp.result_chain: + chain = final_llm_resp.result_chain.chain + else: + chain = MessageChain().chain + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + async for _ in run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + self.show_tool_call_result, + stream_to_general, + show_reasoning=self.show_reasoning, + ): + yield + + final_resp = agent_runner.get_final_llm_resp() + + event.trace.record( + "astr_agent_complete", + stats=agent_runner.stats.to_dict(), + resp=final_resp.completion_text if final_resp else None, + ) + + # 检查事件是否被停止,如果被停止则不保存历史记录 + if not event.is_stopped() or agent_runner.was_aborted(): await self._save_to_history( event, req, - agent_runner.get_final_llm_resp(), + final_resp, agent_runner.run_context.messages, agent_runner.stats, user_aborted=agent_runner.was_aborted(), ) - elif streaming_response and not stream_to_general: - # 流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - self.show_tool_call_result, - show_reasoning=self.show_reasoning, - ), + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, ), ) - yield - if agent_runner.done(): - if final_llm_resp := agent_runner.get_final_llm_resp(): - if final_llm_resp.completion_text: - chain = ( - MessageChain() - .message(final_llm_resp.completion_text) - .chain - ) - elif final_llm_resp.result_chain: - chain = final_llm_resp.result_chain.chain - else: - chain = MessageChain().chain - event.set_result( - MessageEventResult( - chain=chain, - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - async for _ in run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - self.show_tool_call_result, - stream_to_general, - show_reasoning=self.show_reasoning, - ): - yield - - final_resp = agent_runner.get_final_llm_resp() - - event.trace.record( - "astr_agent_complete", - stats=agent_runner.stats.to_dict(), - resp=final_resp.completion_text if final_resp else None, - ) - - # 检查事件是否被停止,如果被停止则不保存历史记录 - if not event.is_stopped() or agent_runner.was_aborted(): - await self._save_to_history( - event, - req, - final_resp, - agent_runner.run_context.messages, - agent_runner.stats, - user_aborted=agent_runner.was_aborted(), - ) - - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=agent_runner.provider.get_model(), - provider_type=agent_runner.provider.meta().type, - ), - ) + finally: + if runner_registered and agent_runner is not None: + unregister_active_runner(event.unified_msg_origin, agent_runner) except Exception as e: logger.error(f"Error occurred while processing agent: {e}") @@ -340,6 +371,13 @@ class InternalAgentSubStage(Stage): f"Error occurred while processing agent request: {e}" ) ) + finally: + if follow_up_capture: + await finalize_follow_up_capture( + follow_up_capture, + activated=follow_up_activated, + consumed_marked=follow_up_consumed_marked, + ) async def _save_to_history( self, diff --git a/astrbot/core/platform/sources/line/line_adapter.py b/astrbot/core/platform/sources/line/line_adapter.py index 9348ff100..c13677b13 100644 --- a/astrbot/core/platform/sources/line/line_adapter.py +++ b/astrbot/core/platform/sources/line/line_adapter.py @@ -65,15 +65,6 @@ LINE_I18N_RESOURCES = { "line", "LINE Messaging API 适配器", support_streaming_message=False, - default_config_tmpl={ - "id": "line", - "type": "line", - "enable": False, - "channel_access_token": "", - "channel_secret": "", - "unified_webhook_mode": True, - "webhook_uuid": "", - }, config_metadata=LINE_CONFIG_METADATA, i18n_resources=LINE_I18N_RESOURCES, ) diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py new file mode 100644 index 000000000..43072ec1c --- /dev/null +++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py @@ -0,0 +1,465 @@ +import json +import mimetypes +import shutil +import uuid +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path +from typing import Any + +from astrbot.core.db.po import Attachment +from astrbot.core.message.components import ( + File, + Image, + Json, + Plain, + Record, + Reply, + Video, +) +from astrbot.core.message.message_event_result import MessageChain + +AttachmentGetter = Callable[[str], Awaitable[Attachment | None]] +AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]] +ReplyHistoryGetter = Callable[ + [Any], + Awaitable[tuple[list[dict], str | None, str | None] | None], +] + +MEDIA_PART_TYPES = {"image", "record", "file", "video"} + + +def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]: + return [{k: v for k, v in part.items() if k != "path"} for part in message_parts] + + +def webchat_message_parts_have_content(message_parts: list[dict]) -> bool: + return any( + part.get("type") in ("plain", "image", "record", "file", "video") + and (part.get("text") or part.get("attachment_id") or part.get("filename")) + for part in message_parts + ) + + +async def parse_webchat_message_parts( + message_parts: list, + *, + strict: bool = False, + include_empty_plain: bool = False, + verify_media_path_exists: bool = True, + reply_history_getter: ReplyHistoryGetter | None = None, + current_depth: int = 0, + max_reply_depth: int = 0, + cast_reply_id_to_str: bool = True, +) -> tuple[list, list[str], bool]: + """Parse webchat message parts into components/text parts. + + Returns: + tuple[list, list[str], bool]: + (components, plain_text_parts, has_non_reply_content) + """ + components = [] + text_parts: list[str] = [] + has_content = False + + for part in message_parts: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text or include_empty_plain: + components.append(Plain(text=text)) + text_parts.append(text) + if text: + has_content = True + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + + reply_chain = [] + reply_message_str = str(part.get("selected_text", "")) + sender_id = None + sender_name = None + + if reply_message_str: + reply_chain = [Plain(text=reply_message_str)] + elif ( + reply_history_getter + and current_depth < max_reply_depth + and message_id is not None + ): + reply_info = await reply_history_getter(message_id) + if reply_info: + reply_parts, sender_id, sender_name = reply_info + ( + reply_chain, + reply_text_parts, + _, + ) = await parse_webchat_message_parts( + reply_parts, + strict=strict, + include_empty_plain=include_empty_plain, + verify_media_path_exists=verify_media_path_exists, + reply_history_getter=reply_history_getter, + current_depth=current_depth + 1, + max_reply_depth=max_reply_depth, + cast_reply_id_to_str=cast_reply_id_to_str, + ) + reply_message_str = "".join(reply_text_parts) + + reply_id = str(message_id) if cast_reply_id_to_str else message_id + components.append( + Reply( + id=reply_id, + message_str=reply_message_str, + chain=reply_chain, + sender_id=sender_id, + sender_nickname=sender_name, + ) + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + path = part.get("path") + if not path: + if strict: + raise ValueError(f"{part_type} part missing path") + continue + + file_path = Path(str(path)) + if verify_media_path_exists and not file_path.exists(): + if strict: + raise ValueError(f"file not found: {file_path!s}") + continue + + file_path_str = ( + str(file_path.resolve()) if verify_media_path_exists else str(file_path) + ) + has_content = True + if part_type == "image": + components.append(Image.fromFileSystem(file_path_str)) + elif part_type == "record": + components.append(Record.fromFileSystem(file_path_str)) + elif part_type == "video": + components.append(Video.fromFileSystem(file_path_str)) + else: + filename = str(part.get("filename", "")).strip() or file_path.name + components.append(File(name=filename, file=file_path_str)) + + return components, text_parts, has_content + + +async def build_webchat_message_parts( + message_payload: str | list, + *, + get_attachment_by_id: AttachmentGetter, + strict: bool = False, +) -> list[dict]: + if isinstance(message_payload, str): + text = message_payload.strip() + return [{"type": "plain", "text": text}] if text else [] + + if not isinstance(message_payload, list): + if strict: + raise ValueError("message must be a string or list") + return [] + + message_parts: list[dict] = [] + for part in message_payload: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + message_parts.append({"type": "plain", "text": text}) + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + message_parts.append( + { + "type": "reply", + "message_id": message_id, + "selected_text": str(part.get("selected_text", "")), + } + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + attachment_id = part.get("attachment_id") + if not attachment_id: + if strict: + raise ValueError(f"{part_type} part missing attachment_id") + continue + + attachment = await get_attachment_by_id(str(attachment_id)) + if not attachment: + if strict: + raise ValueError(f"attachment not found: {attachment_id}") + continue + + attachment_path = Path(attachment.path) + message_parts.append( + { + "type": attachment.type, + "attachment_id": attachment.attachment_id, + "filename": attachment_path.name, + "path": str(attachment_path), + } + ) + + return message_parts + + +def webchat_message_parts_to_message_chain( + message_parts: list[dict], + *, + strict: bool = False, +) -> MessageChain: + components = [] + has_content = False + + for part in message_parts: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + components.append(Plain(text=text)) + has_content = True + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + components.append( + Reply( + id=str(message_id), + message_str=str(part.get("selected_text", "")), + chain=[], + ) + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + path = part.get("path") + if not path: + if strict: + raise ValueError(f"{part_type} part missing path") + continue + + file_path = Path(str(path)) + if not file_path.exists(): + if strict: + raise ValueError(f"file not found: {file_path!s}") + continue + + file_path_str = str(file_path.resolve()) + has_content = True + if part_type == "image": + components.append(Image.fromFileSystem(file_path_str)) + elif part_type == "record": + components.append(Record.fromFileSystem(file_path_str)) + elif part_type == "video": + components.append(Video.fromFileSystem(file_path_str)) + else: + filename = str(part.get("filename", "")).strip() or file_path.name + components.append(File(name=filename, file=file_path_str)) + + if strict and (not components or not has_content): + raise ValueError("Message content is empty (reply only is not allowed)") + + return MessageChain(chain=components) + + +async def build_message_chain_from_payload( + message_payload: str | list, + *, + get_attachment_by_id: AttachmentGetter, + strict: bool = True, +) -> MessageChain: + message_parts = await build_webchat_message_parts( + message_payload, + get_attachment_by_id=get_attachment_by_id, + strict=strict, + ) + components, _, has_content = await parse_webchat_message_parts( + message_parts, + strict=strict, + ) + if strict and (not components or not has_content): + raise ValueError("Message content is empty (reply only is not allowed)") + return MessageChain(chain=components) + + +async def create_attachment_part_from_existing_file( + filename: str, + *, + attach_type: str, + insert_attachment: AttachmentInserter, + attachments_dir: str | Path, + fallback_dirs: Sequence[str | Path] = (), +) -> dict | None: + basename = Path(filename).name + candidate_paths = [Path(attachments_dir) / basename] + candidate_paths.extend(Path(p) / basename for p in fallback_dirs) + + file_path = next((path for path in candidate_paths if path.exists()), None) + if not file_path: + return None + + mime_type, _ = mimetypes.guess_type(str(file_path)) + attachment = await insert_attachment( + str(file_path), + attach_type, + mime_type or "application/octet-stream", + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": file_path.name, + } + + +async def message_chain_to_storage_message_parts( + message_chain: MessageChain, + *, + insert_attachment: AttachmentInserter, + attachments_dir: str | Path, +) -> list[dict]: + target_dir = Path(attachments_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + parts: list[dict] = [] + for comp in message_chain.chain: + if isinstance(comp, Plain): + if comp.text: + parts.append({"type": "plain", "text": comp.text}) + continue + + if isinstance(comp, Json): + parts.append( + {"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)} + ) + continue + + if isinstance(comp, Image): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="image", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, Record): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="record", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, Video): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="video", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, File): + file_path = await comp.get_file() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="file", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + display_name=comp.name, + ) + if attachment_part: + parts.append(attachment_part) + continue + + return parts + + +async def _copy_file_to_attachment_part( + *, + file_path: str, + attach_type: str, + insert_attachment: AttachmentInserter, + attachments_dir: Path, + display_name: str | None = None, +) -> dict | None: + src_path = Path(file_path) + if not src_path.exists() or not src_path.is_file(): + return None + + suffix = src_path.suffix + target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}" + shutil.copy2(src_path, target_path) + + mime_type, _ = mimetypes.guess_type(target_path.name) + attachment = await insert_attachment( + str(target_path), + attach_type, + mime_type or "application/octet-stream", + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": display_name or src_path.name, + } diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 047417aaa..54718fefb 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -3,12 +3,12 @@ import os import time import uuid from collections.abc import Callable, Coroutine +from pathlib import Path from typing import Any from astrbot import logger from astrbot.core import db_helper from astrbot.core.db.po import PlatformMessageHistory -from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform import ( AstrBotMessage, @@ -21,10 +21,23 @@ from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.utils.astrbot_path import get_astrbot_data_path from ...register import register_platform_adapter +from .message_parts_helper import ( + message_chain_to_storage_message_parts, + parse_webchat_message_parts, +) from .webchat_event import WebChatMessageEvent from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr +def _extract_conversation_id(session_id: str) -> str: + """Extract raw webchat conversation id from event/session id.""" + if session_id.startswith("webchat!"): + parts = session_id.split("!", 2) + if len(parts) == 3: + return parts[2] + return session_id + + class QueueListener: def __init__( self, @@ -57,13 +70,15 @@ class WebChatAdapter(Platform): self.settings = platform_settings self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + self.attachments_dir = Path(get_astrbot_data_path()) / "attachments" os.makedirs(self.imgs_dir, exist_ok=True) + self.attachments_dir.mkdir(parents=True, exist_ok=True) self.metadata = PlatformMetadata( name="webchat", description="webchat", id="webchat", - support_proactive_message=False, + support_proactive_message=True, ) self._shutdown_event = asyncio.Event() self._webchat_queue_mgr = webchat_queue_mgr @@ -73,10 +88,67 @@ class WebChatAdapter(Platform): session: MessageSesion, message_chain: MessageChain, ) -> None: - message_id = f"active_{str(uuid.uuid4())}" - await WebChatMessageEvent._send(message_id, message_chain, session.session_id) + conversation_id = _extract_conversation_id(session.session_id) + active_request_ids = self._webchat_queue_mgr.list_back_request_ids( + conversation_id + ) + subscription_request_ids = [ + req_id for req_id in active_request_ids if req_id.startswith("ws_sub_") + ] + target_request_ids = subscription_request_ids or active_request_ids + + if target_request_ids: + for request_id in target_request_ids: + await WebChatMessageEvent._send( + request_id, + message_chain, + session.session_id, + ) + else: + message_id = f"active_{uuid.uuid4()!s}" + await WebChatMessageEvent._send( + message_id, + message_chain, + session.session_id, + ) + + should_persist = ( + bool(subscription_request_ids) + or not active_request_ids + or all(req_id.startswith("active_") for req_id in active_request_ids) + ) + if should_persist: + try: + await self._save_proactive_message(conversation_id, message_chain) + except Exception as e: + logger.error( + f"[WebChatAdapter] Failed to save proactive message: {e}", + exc_info=True, + ) + await super().send_by_session(session, message_chain) + async def _save_proactive_message( + self, + conversation_id: str, + message_chain: MessageChain, + ) -> None: + message_parts = await message_chain_to_storage_message_parts( + message_chain, + insert_attachment=db_helper.insert_attachment, + attachments_dir=self.attachments_dir, + ) + if not message_parts: + return + + await db_helper.insert_platform_message_history( + platform_id="webchat", + user_id=conversation_id, + content={"type": "bot", "message": message_parts}, + sender_id="bot", + sender_name="bot", + ) + async def _get_message_history( self, message_id: int ) -> PlatformMessageHistory | None: @@ -98,72 +170,30 @@ class WebChatAdapter(Platform): Returns: tuple[list, list[str]]: (消息组件列表, 纯文本列表) """ - components = [] - text_parts = [] - for part in message_parts: - part_type = part.get("type") - if part_type == "plain": - text = part.get("text", "") - components.append(Plain(text=text)) - text_parts.append(text) - elif part_type == "reply": - message_id = part.get("message_id") - reply_chain = [] - reply_message_str = part.get("selected_text", "") - sender_id = None - sender_name = None + async def get_reply_parts( + message_id: Any, + ) -> tuple[list[dict], str | None, str | None] | None: + history = await self._get_message_history(message_id) + if not history or not history.content: + return None - if reply_message_str: - reply_chain = [Plain(text=reply_message_str)] + reply_parts = history.content.get("message", []) + if not isinstance(reply_parts, list): + return None - # recursively get the content of the referenced message, if selected_text is empty - if not reply_message_str and depth < max_depth and message_id: - history = await self._get_message_history(message_id) - if history and history.content: - reply_parts = history.content.get("message", []) - if isinstance(reply_parts, list): - ( - reply_chain, - reply_text_parts, - ) = await self._parse_message_parts( - reply_parts, - depth=depth + 1, - max_depth=max_depth, - ) - reply_message_str = "".join(reply_text_parts) - sender_id = history.sender_id - sender_name = history.sender_name - - components.append( - Reply( - id=message_id, - chain=reply_chain, - message_str=reply_message_str, - sender_id=sender_id, - sender_nickname=sender_name, - ) - ) - elif part_type == "image": - path = part.get("path") - if path: - components.append(Image.fromFileSystem(path)) - elif part_type == "record": - path = part.get("path") - if path: - components.append(Record.fromFileSystem(path)) - elif part_type == "file": - path = part.get("path") - if path: - filename = part.get("filename") or ( - os.path.basename(path) if path else "file" - ) - components.append(File(name=filename, file=path)) - elif part_type == "video": - path = part.get("path") - if path: - components.append(Video.fromFileSystem(path)) + return reply_parts, history.sender_id, history.sender_name + components, text_parts, _ = await parse_webchat_message_parts( + message_parts, + strict=False, + include_empty_plain=True, + verify_media_path_exists=False, + reply_history_getter=get_reply_parts, + current_depth=depth, + max_reply_depth=max_depth, + cast_reply_id_to_str=False, + ) return components, text_parts async def convert_message(self, data: tuple) -> AstrBotMessage: diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index a680f7617..b7da864aa 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -14,6 +14,15 @@ from .webchat_queue_mgr import webchat_queue_mgr attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") +def _extract_conversation_id(session_id: str) -> str: + """Extract raw webchat conversation id from event/session id.""" + if session_id.startswith("webchat!"): + parts = session_id.split("!", 2) + if len(parts) == 3: + return parts[2] + return session_id + + class WebChatMessageEvent(AstrMessageEvent): def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) @@ -27,7 +36,7 @@ class WebChatMessageEvent(AstrMessageEvent): streaming: bool = False, ) -> str | None: request_id = str(message_id) - conversation_id = session_id.split("!")[-1] + conversation_id = _extract_conversation_id(session_id) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( request_id, conversation_id, @@ -130,7 +139,7 @@ class WebChatMessageEvent(AstrMessageEvent): reasoning_content = "" message_id = self.message_obj.message_id request_id = str(message_id) - conversation_id = self.session_id.split("!")[-1] + conversation_id = _extract_conversation_id(self.session_id) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( request_id, conversation_id, diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index fd35e837c..f3ade1589 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -75,6 +75,10 @@ class WebChatQueueMgr: if task is not None: task.cancel() + def list_back_request_ids(self, conversation_id: str) -> list[str]: + """List active back-queue request IDs for a conversation.""" + return list(self._conversation_back_requests.get(conversation_id, set())) + def has_queue(self, conversation_id: str) -> bool: """Check if a queue exists for the given conversation ID""" return conversation_id in self.queues diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 815b306aa..13251d2ba 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -388,6 +388,33 @@ class PluginManager: except KeyError: logger.warning(f"模块 {module_name} 未载入") + def _cleanup_plugin_state(self, dir_name: str) -> None: + plugin_root_name = "data.plugins." + + # 清理 sys.modules + for key in list(sys.modules.keys()): + if key.startswith(f"{plugin_root_name}{dir_name}"): + logger.info(f"清除了插件{dir_name}中的{key}模块") + del sys.modules[key] + + possible_paths = [ + f"{plugin_root_name}{dir_name}.main", + f"{plugin_root_name}{dir_name}.{dir_name}", + ] + + # 清理 handlers + for path in possible_paths: + handlers = star_handlers_registry.get_handlers_by_module_name(path) + for handler in handlers: + star_handlers_registry.remove(handler) + logger.info(f"清理处理器: {handler.handler_name}") + + # 清理工具 + for tool in list(llm_tools.func_list): + if tool.handler_module_path in possible_paths: + llm_tools.func_list.remove(tool) + logger.info(f"清理工具: {tool.name}") + async def reload_failed_plugin(self, dir_name): """ 重新加载未注册(加载失败)的插件 @@ -398,17 +425,21 @@ class PluginManager: - success (bool): 重载是否成功 - error_message (str|None): 错误信息,成功时为 None """ + async with self._pm_lock: - if dir_name in self.failed_plugin_dict: - success, error = await self.load(specified_dir_name=dir_name) - if success: - self.failed_plugin_dict.pop(dir_name, None) - if not self.failed_plugin_dict: - self.failed_plugin_info = "" - return success, None - else: - return False, error - return False, "插件不存在于失败列表中" + if dir_name not in self.failed_plugin_dict: + return False, "插件不存在于失败列表中" + + self._cleanup_plugin_state(dir_name) + + success, error = await self.load(specified_dir_name=dir_name) + if success: + self.failed_plugin_dict.pop(dir_name, None) + if not self.failed_plugin_dict: + self.failed_plugin_info = "" + return success, None + else: + return False, error async def reload(self, specified_plugin_name=None): """重新加载插件 diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 1235dd381..0602cc074 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,6 +1,5 @@ import asyncio import json -import mimetypes import os import re import uuid @@ -14,6 +13,12 @@ from astrbot.core import logger, sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_webchat_message_parts, + create_attachment_part_from_existing_file, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.active_event_registry import active_event_registry from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -166,83 +171,24 @@ class ChatRoute(Route): ) async def _build_user_message_parts(self, message: str | list) -> list[dict]: - """构建用户消息的部分列表 - - Args: - message: 文本消息 (str) 或消息段列表 (list) - """ - parts = [] - - if isinstance(message, list): - for part in message: - part_type = part.get("type") - if part_type == "plain": - parts.append({"type": "plain", "text": part.get("text", "")}) - elif part_type == "reply": - parts.append( - { - "type": "reply", - "message_id": part.get("message_id"), - "selected_text": part.get("selected_text", ""), - } - ) - elif attachment_id := part.get("attachment_id"): - attachment = await self.db.get_attachment_by_id(attachment_id) - if attachment: - parts.append( - { - "type": attachment.type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(attachment.path), - "path": attachment.path, # will be deleted - } - ) - return parts - - if message: - parts.append({"type": "plain", "text": message}) - - return parts + """构建用户消息的部分列表。""" + return await build_webchat_message_parts( + message, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=False, + ) async def _create_attachment_from_file( self, filename: str, attach_type: str ) -> dict | None: - """从本地文件创建 attachment 并返回消息部分 - - 用于处理 bot 回复中的媒体文件 - - Args: - filename: 存储的文件名 - attach_type: 附件类型 (image, record, file, video) - """ - basename = os.path.basename(filename) - candidate_paths = [ - os.path.join(self.attachments_dir, basename), - os.path.join(self.legacy_img_dir, basename), - ] - file_path = next((p for p in candidate_paths if os.path.exists(p)), None) - if not file_path: - return None - - # guess mime type - mime_type, _ = mimetypes.guess_type(filename) - if not mime_type: - mime_type = "application/octet-stream" - - # insert attachment - attachment = await self.db.insert_attachment( - path=file_path, - type=attach_type, - mime_type=mime_type, + """从本地文件创建 attachment 并返回消息部分。""" + return await create_attachment_part_from_existing_file( + filename, + attach_type=attach_type, + insert_attachment=self.db.insert_attachment, + attachments_dir=self.attachments_dir, + fallback_dirs=[self.legacy_img_dir], ) - if not attachment: - return None - - return { - "type": attach_type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(file_path), - } def _extract_web_search_refs( self, accumulated_text: str, accumulated_parts: list @@ -356,21 +302,6 @@ class ChatRoute(Route): selected_model = post_data.get("selected_model") enable_streaming = post_data.get("enable_streaming", True) - # 检查消息是否为空 - if isinstance(message, list): - has_content = any( - part.get("type") in ("plain", "image", "record", "file", "video") - for part in message - ) - if not has_content: - return ( - Response() - .error("Message content is empty (reply only is not allowed)") - .__dict__ - ) - elif not message: - return Response().error("Message are both empty").__dict__ - if not session_id: return Response().error("session_id is empty").__dict__ @@ -378,6 +309,12 @@ class ChatRoute(Route): # 构建用户消息段(包含 path 用于传递给 adapter) message_parts = await self._build_user_message_parts(message) + if not webchat_message_parts_have_content(message_parts): + return ( + Response() + .error("Message content is empty (reply only is not allowed)") + .__dict__ + ) message_id = str(uuid.uuid4()) back_queue = webchat_queue_mgr.get_or_create_back_queue( @@ -583,10 +520,7 @@ class ChatRoute(Route): ), ) - message_parts_for_storage = [] - for part in message_parts: - part_copy = {k: v for k, v in part.items() if k != "path"} - message_parts_for_storage.append(part_copy) + message_parts_for_storage = strip_message_parts_path_fields(message_parts) await self.platform_history_mgr.insert( platform_id="webchat", diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 8c922ab69..25438565e 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -1,6 +1,7 @@ import asyncio import json import os +import re import time import uuid import wave @@ -10,9 +11,16 @@ import jwt from quart import websocket from astrbot import logger +from astrbot.core import sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_webchat_message_parts, + create_attachment_part_from_existing_file, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path from .route import Route, RouteContext @@ -30,6 +38,9 @@ class LiveChatSession: self.audio_frames: list[bytes] = [] self.current_stamp: str | None = None self.temp_audio_path: str | None = None + self.chat_subscriptions: dict[str, str] = {} + self.chat_subscription_tasks: dict[str, asyncio.Task] = {} + self.ws_send_lock = asyncio.Lock() def start_speaking(self, stamp: str) -> None: """开始说话""" @@ -106,13 +117,26 @@ class LiveChatRoute(Route): self.core_lifecycle = core_lifecycle self.db = db self.plugin_manager = core_lifecycle.plugin_manager + self.platform_history_mgr = core_lifecycle.platform_message_history_manager self.sessions: dict[str, LiveChatSession] = {} + self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") + self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + os.makedirs(self.attachments_dir, exist_ok=True) # 注册 WebSocket 路由 self.app.websocket("/api/live_chat/ws")(self.live_chat_ws) + self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws) async def live_chat_ws(self) -> None: - """Live Chat WebSocket 处理器""" + """Legacy Live Chat WebSocket 处理器(默认 ct=live)""" + await self._unified_ws_loop(force_ct="live") + + async def unified_chat_ws(self) -> None: + """Unified Chat WebSocket 处理器(支持 ct=live/chat)""" + await self._unified_ws_loop(force_ct=None) + + async def _unified_ws_loop(self, force_ct: str | None = None) -> None: + """统一 WebSocket 循环""" # WebSocket 不能通过 header 传递 token,需要从 query 参数获取 # 注意:WebSocket 上下文使用 websocket.args 而不是 request.args token = websocket.args.get("token") @@ -140,7 +164,11 @@ class LiveChatRoute(Route): try: while True: message = await websocket.receive_json() - await self._handle_message(live_session, message) + ct = force_ct or message.get("ct", "live") + if ct == "chat": + await self._handle_chat_message(live_session, message) + else: + await self._handle_message(live_session, message) except Exception as e: logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True) @@ -148,10 +176,488 @@ class LiveChatRoute(Route): finally: # 清理会话 if session_id in self.sessions: + await self._cleanup_chat_subscriptions(live_session) live_session.cleanup() del self.sessions[session_id] logger.info(f"[Live Chat] WebSocket 连接关闭: {username}") + async def _create_attachment_from_file( + self, filename: str, attach_type: str + ) -> dict | None: + """从本地文件创建 attachment 并返回消息部分。""" + return await create_attachment_part_from_existing_file( + filename, + attach_type=attach_type, + insert_attachment=self.db.insert_attachment, + attachments_dir=self.attachments_dir, + fallback_dirs=[self.legacy_img_dir], + ) + + def _extract_web_search_refs( + self, accumulated_text: str, accumulated_parts: list + ) -> dict: + """从消息中提取 web_search 引用。""" + supported = ["web_search_tavily", "web_search_bocha"] + web_search_results = {} + tool_call_parts = [ + p + for p in accumulated_parts + if p.get("type") == "tool_call" and p.get("tool_calls") + ] + + for part in tool_call_parts: + for tool_call in part["tool_calls"]: + if tool_call.get("name") not in supported or not tool_call.get( + "result" + ): + continue + try: + result_data = json.loads(tool_call["result"]) + for item in result_data.get("results", []): + if idx := item.get("index"): + web_search_results[idx] = { + "url": item.get("url"), + "title": item.get("title"), + "snippet": item.get("snippet"), + } + except (json.JSONDecodeError, KeyError): + pass + + if not web_search_results: + return {} + + ref_indices = { + m.strip() for m in re.findall(r"(.*?)", accumulated_text) + } + + used_refs = [] + for ref_index in ref_indices: + if ref_index not in web_search_results: + continue + payload = {"index": ref_index, **web_search_results[ref_index]} + if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]): + payload["favicon"] = favicon + used_refs.append(payload) + + return {"used": used_refs} if used_refs else {} + + async def _save_bot_message( + self, + webchat_conv_id: str, + text: str, + media_parts: list, + reasoning: str, + agent_stats: dict, + refs: dict, + ): + """保存 bot 消息到历史记录。""" + bot_message_parts = [] + bot_message_parts.extend(media_parts) + if text: + bot_message_parts.append({"type": "plain", "text": text}) + + new_his = {"type": "bot", "message": bot_message_parts} + if reasoning: + new_his["reasoning"] = reasoning + if agent_stats: + new_his["agent_stats"] = agent_stats + if refs: + new_his["refs"] = refs + + return await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", + ) + + async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None: + async with session.ws_send_lock: + await websocket.send_json(payload) + + async def _forward_chat_subscription( + self, + session: LiveChatSession, + chat_session_id: str, + request_id: str, + ) -> None: + back_queue = webchat_queue_mgr.get_or_create_back_queue( + request_id, chat_session_id + ) + try: + while True: + result = await back_queue.get() + if not result: + continue + await self._send_chat_payload(session, {"ct": "chat", **result}) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error( + f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}", + exc_info=True, + ) + finally: + webchat_queue_mgr.remove_back_queue(request_id) + if session.chat_subscriptions.get(chat_session_id) == request_id: + session.chat_subscriptions.pop(chat_session_id, None) + session.chat_subscription_tasks.pop(chat_session_id, None) + + async def _ensure_chat_subscription( + self, + session: LiveChatSession, + chat_session_id: str, + ) -> str: + existing_request_id = session.chat_subscriptions.get(chat_session_id) + existing_task = session.chat_subscription_tasks.get(chat_session_id) + if existing_request_id and existing_task and not existing_task.done(): + return existing_request_id + + request_id = f"ws_sub_{uuid.uuid4().hex}" + session.chat_subscriptions[chat_session_id] = request_id + task = asyncio.create_task( + self._forward_chat_subscription(session, chat_session_id, request_id), + name=f"chat_ws_sub_{chat_session_id}", + ) + session.chat_subscription_tasks[chat_session_id] = task + return request_id + + async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None: + tasks = list(session.chat_subscription_tasks.values()) + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + for request_id in list(session.chat_subscriptions.values()): + webchat_queue_mgr.remove_back_queue(request_id) + session.chat_subscriptions.clear() + session.chat_subscription_tasks.clear() + + async def _handle_chat_message( + self, session: LiveChatSession, message: dict + ) -> None: + """处理 Chat Mode 消息(ct=chat)""" + msg_type = message.get("t") + + if msg_type == "bind": + chat_session_id = message.get("session_id") + if not isinstance(chat_session_id, str) or not chat_session_id: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "session_id is required", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + request_id = await self._ensure_chat_subscription(session, chat_session_id) + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "session_bound", + "session_id": chat_session_id, + "message_id": request_id, + }, + ) + return + + if msg_type == "interrupt": + session.should_interrupt = True + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "INTERRUPTED", + "code": "INTERRUPTED", + }, + ) + return + + if msg_type != "send": + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": f"Unsupported message type: {msg_type}", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + if session.is_processing: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "Session is busy", + "code": "PROCESSING_ERROR", + }, + ) + return + + payload = message.get("message") + session_id = message.get("session_id") or session.session_id + message_id = message.get("message_id") or str(uuid.uuid4()) + selected_provider = message.get("selected_provider") + selected_model = message.get("selected_model") + selected_stt_provider = message.get("selected_stt_provider") + selected_tts_provider = message.get("selected_tts_provider") + persona_prompt = message.get("persona_prompt") + show_reasoning = message.get("show_reasoning") + enable_streaming = message.get("enable_streaming", True) + + if not isinstance(payload, list): + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "message must be list", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + message_parts = await self._build_chat_message_parts(payload) + has_content = webchat_message_parts_have_content(message_parts) + if not has_content: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "Message content is empty", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + await self._ensure_chat_subscription(session, session_id) + + session.is_processing = True + session.should_interrupt = False + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) + + try: + chat_queue = webchat_queue_mgr.get_or_create_queue(session_id) + await chat_queue.put( + ( + session.username, + session_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "selected_stt_provider": selected_stt_provider, + "selected_tts_provider": selected_tts_provider, + "persona_prompt": persona_prompt, + "show_reasoning": show_reasoning, + "enable_streaming": enable_streaming, + "message_id": message_id, + }, + ), + ) + + message_parts_for_storage = strip_message_parts_path_fields(message_parts) + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=session_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=session.username, + sender_name=session.username, + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + refs = {} + + while True: + if session.should_interrupt: + session.should_interrupt = False + break + + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + + if not result: + continue + if result.get("message_id") and result.get("message_id") != message_id: + continue + + result_text = result.get("data", "") + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + if chain_type == "agent_stats": + try: + parsed_agent_stats = json.loads(result_text) + agent_stats = parsed_agent_stats + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "agent_stats", + "data": parsed_agent_stats, + }, + ) + except Exception: + pass + continue + + outgoing = {"ct": "chat", **result} + await self._send_chat_payload(session, outgoing) + + if msg_type == "plain": + if chain_type == "tool_call": + try: + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + except Exception: + pass + elif chain_type == "tool_call_result": + try: + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + { + "type": "tool_call", + "tool_calls": [tool_calls[tc_id]], + } + ) + tool_calls.pop(tc_id, None) + except Exception: + pass + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = str(result_text).replace("[IMAGE]", "") + part = await self._create_attachment_from_file(filename, "image") + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = str(result_text).replace("[RECORD]", "") + part = await self._create_attachment_from_file(filename, "record") + if part: + accumulated_parts.append(part) + elif msg_type == "file": + filename = str(result_text).replace("[FILE]", "").split("|", 1)[0] + part = await self._create_attachment_from_file(filename, "file") + if part: + accumulated_parts.append(part) + elif msg_type == "video": + filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0] + part = await self._create_attachment_from_file(filename, "video") + if part: + accumulated_parts.append(part) + + should_save = False + if msg_type == "end": + should_save = bool( + accumulated_parts + or accumulated_text + or accumulated_reasoning + or refs + or agent_stats + ) + elif (streaming and msg_type == "complete") or not streaming: + if chain_type not in ( + "tool_call", + "tool_call_result", + "agent_stats", + ): + should_save = True + + if should_save: + try: + refs = self._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"[Live Chat] Failed to extract web search refs: {e}", + exc_info=True, + ) + + saved_record = await self._save_bot_message( + session_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + refs, + ) + if saved_record: + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + }, + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + agent_stats = {} + refs = {} + + if msg_type == "end": + break + + except Exception as e: + logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True) + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": f"处理失败: {str(e)}", + "code": "PROCESSING_ERROR", + }, + ) + finally: + session.is_processing = False + webchat_queue_mgr.remove_back_queue(message_id) + + async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]: + """构建 chat websocket 用户消息段(复用 webchat 逻辑)""" + return await build_webchat_message_parts( + message, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=False, + ) + async def _handle_message(self, session: LiveChatSession, message: dict) -> None: """处理 WebSocket 消息""" msg_type = message.get("t") # 使用 t 代替 type diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index c25870ebb..653e22cbf 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -1,15 +1,22 @@ -from pathlib import Path +import asyncio +import hashlib +import json from uuid import uuid4 -from quart import g, request +from quart import g, request, websocket from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video -from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.message_session import MessageSesion +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_message_chain_from_payload, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr +from .api_key import ALL_OPEN_API_SCOPES from .chat import ChatRoute from .route import Response, Route, RouteContext @@ -37,6 +44,7 @@ class OpenApiRoute(Route): "/v1/im/bots": ("GET", self.get_bots), } self.register_routes() + self.app.websocket("/api/v1/chat/ws")(self.chat_ws) @staticmethod def _resolve_open_username( @@ -181,6 +189,348 @@ class OpenApiRoute(Route): finally: g.username = original_username + @staticmethod + def _extract_ws_api_key() -> str | None: + if key := websocket.args.get("api_key"): + return key.strip() + if key := websocket.args.get("key"): + return key.strip() + if key := websocket.headers.get("X-API-Key"): + return key.strip() + + auth_header = websocket.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return auth_header.removeprefix("Bearer ").strip() + if auth_header.startswith("ApiKey "): + return auth_header.removeprefix("ApiKey ").strip() + return None + + async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]: + raw_key = self._extract_ws_api_key() + if not raw_key: + return False, "Missing API key" + + key_hash = hashlib.pbkdf2_hmac( + "sha256", + raw_key.encode("utf-8"), + b"astrbot_api_key", + 100_000, + ).hex() + api_key = await self.db.get_active_api_key_by_hash(key_hash) + if not api_key: + return False, "Invalid API key" + + if isinstance(api_key.scopes, list): + scopes = api_key.scopes + else: + scopes = list(ALL_OPEN_API_SCOPES) + + if "*" not in scopes and "chat" not in scopes: + return False, "Insufficient API key scope" + + await self.db.touch_api_key(api_key.key_id) + return True, None + + async def _send_chat_ws_error(self, message: str, code: str) -> None: + await websocket.send_json( + { + "type": "error", + "code": code, + "data": message, + } + ) + + async def _update_session_config_route( + self, + *, + username: str, + session_id: str, + config_id: str | None, + ) -> str | None: + if not config_id: + return None + + umo = f"webchat:FriendMessage:webchat!{username}!{session_id}" + try: + if config_id == "default": + await self.core_lifecycle.umop_config_router.delete_route(umo) + else: + await self.core_lifecycle.umop_config_router.update_route( + umo, config_id + ) + except Exception as e: + logger.error( + "Failed to update chat config route for %s with %s: %s", + umo, + config_id, + e, + exc_info=True, + ) + return f"Failed to update chat config route: {e}" + return None + + async def _handle_chat_ws_send(self, post_data: dict) -> None: + effective_username, username_err = self._resolve_open_username( + post_data.get("username") + ) + if username_err or not effective_username: + await self._send_chat_ws_error( + username_err or "Invalid username", "BAD_USER" + ) + return + + message = post_data.get("message") + if message is None: + await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE") + return + + raw_session_id = post_data.get("session_id", post_data.get("conversation_id")) + session_id = str(raw_session_id).strip() if raw_session_id is not None else "" + if not session_id: + session_id = str(uuid4()) + + ensure_session_err = await self._ensure_chat_session( + effective_username, + session_id, + ) + if ensure_session_err: + await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR") + return + + config_id, resolve_err = self._resolve_chat_config_id(post_data) + if resolve_err: + await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR") + return + + config_err = await self._update_session_config_route( + username=effective_username, + session_id=session_id, + config_id=config_id, + ) + if config_err: + await self._send_chat_ws_error(config_err, "CONFIG_ERROR") + return + + message_parts = await self.chat_route._build_user_message_parts(message) + if not webchat_message_parts_have_content(message_parts): + await self._send_chat_ws_error( + "Message content is empty (reply only is not allowed)", + "INVALID_MESSAGE", + ) + return + + message_id = str(post_data.get("message_id") or uuid4()) + selected_provider = post_data.get("selected_provider") + selected_model = post_data.get("selected_model") + enable_streaming = post_data.get("enable_streaming", True) + + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) + try: + chat_queue = webchat_queue_mgr.get_or_create_queue(session_id) + await chat_queue.put( + ( + effective_username, + session_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "enable_streaming": enable_streaming, + "message_id": message_id, + }, + ) + ) + + message_parts_for_storage = strip_message_parts_path_fields(message_parts) + await self.chat_route.platform_history_mgr.insert( + platform_id="webchat", + user_id=session_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=effective_username, + sender_name=effective_username, + ) + + await websocket.send_json( + { + "type": "session_id", + "data": None, + "session_id": session_id, + "message_id": message_id, + } + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + refs = {} + while True: + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + + if not result: + continue + + if "message_id" in result and result["message_id"] != message_id: + logger.warning("openapi ws stream message_id mismatch") + continue + + result_text = result.get("data", "") + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + + if chain_type == "agent_stats": + try: + stats_info = { + "type": "agent_stats", + "data": json.loads(result_text), + } + await websocket.send_json(stats_info) + agent_stats = stats_info["data"] + except Exception: + pass + continue + + await websocket.send_json(result) + + if msg_type == "plain": + if chain_type == "tool_call": + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + elif chain_type == "tool_call_result": + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + {"type": "tool_call", "tool_calls": [tool_calls[tc_id]]} + ) + tool_calls.pop(tc_id, None) + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = str(result_text).replace("[IMAGE]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "image" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = str(result_text).replace("[RECORD]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "record" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "file": + filename = str(result_text).replace("[FILE]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "file" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "video": + filename = str(result_text).replace("[VIDEO]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "video" + ) + if part: + accumulated_parts.append(part) + + if msg_type == "end": + break + if (streaming and msg_type == "complete") or not streaming: + if chain_type in ("tool_call", "tool_call_result"): + continue + try: + refs = self.chat_route._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"Open API WS failed to extract web search refs: {e}", + exc_info=True, + ) + + saved_record = await self.chat_route._save_bot_message( + session_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + refs, + ) + if saved_record: + await websocket.send_json( + { + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + "session_id": session_id, + } + ) + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + agent_stats = {} + refs = {} + except Exception as e: + logger.exception(f"Open API WS chat failed: {e}", exc_info=True) + await self._send_chat_ws_error( + f"Failed to process message: {e}", "PROCESSING_ERROR" + ) + finally: + webchat_queue_mgr.remove_back_queue(message_id) + + async def chat_ws(self) -> None: + authed, auth_err = await self._authenticate_chat_ws_api_key() + if not authed: + await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED") + await websocket.close(1008, auth_err or "Unauthorized") + return + + try: + while True: + message = await websocket.receive_json() + if not isinstance(message, dict): + await self._send_chat_ws_error( + "message must be an object", + "INVALID_MESSAGE", + ) + continue + + msg_type = message.get("t", "send") + if msg_type == "ping": + await websocket.send_json({"type": "pong"}) + continue + if msg_type != "send": + await self._send_chat_ws_error( + f"Unsupported message type: {msg_type}", + "INVALID_MESSAGE", + ) + continue + + await self._handle_chat_ws_send(message) + except Exception as e: + logger.debug("Open API WS connection closed: %s", e) + async def upload_file(self): return await self.chat_route.post_file() @@ -254,83 +604,12 @@ class OpenApiRoute(Route): async def _build_message_chain_from_payload( self, message_payload: str | list, - ) -> MessageChain: - if isinstance(message_payload, str): - text = message_payload.strip() - if not text: - raise ValueError("Message is empty") - return MessageChain(chain=[Plain(text=text)]) - - if not isinstance(message_payload, list): - raise ValueError("message must be a string or list") - - components = [] - has_content = False - - for part in message_payload: - if not isinstance(part, dict): - raise ValueError("message part must be an object") - - part_type = str(part.get("type", "")).strip() - if part_type == "plain": - text = str(part.get("text", "")) - if text: - has_content = True - components.append(Plain(text=text)) - continue - - if part_type == "reply": - message_id = part.get("message_id") - if message_id is None: - raise ValueError("reply part missing message_id") - components.append( - Reply( - id=str(message_id), - message_str=str(part.get("selected_text", "")), - chain=[], - ) - ) - continue - - if part_type not in {"image", "record", "file", "video"}: - raise ValueError(f"unsupported message part type: {part_type}") - - has_content = True - file_path: Path | None = None - resolved_type = part_type - filename = str(part.get("filename", "")).strip() - - attachment_id = part.get("attachment_id") - if attachment_id: - attachment = await self.db.get_attachment_by_id(str(attachment_id)) - if not attachment: - raise ValueError(f"attachment not found: {attachment_id}") - file_path = Path(attachment.path) - resolved_type = attachment.type - if not filename: - filename = file_path.name - else: - raise ValueError(f"{part_type} part missing attachment_id") - - if not file_path.exists(): - raise ValueError(f"file not found: {file_path!s}") - - file_path_str = str(file_path.resolve()) - if resolved_type == "image": - components.append(Image.fromFileSystem(file_path_str)) - elif resolved_type == "record": - components.append(Record.fromFileSystem(file_path_str)) - elif resolved_type == "video": - components.append(Video.fromFileSystem(file_path_str)) - else: - components.append( - File(name=filename or file_path.name, file=file_path_str) - ) - - if not components or not has_content: - raise ValueError("Message content is empty (reply only is not allowed)") - - return MessageChain(chain=components) + ): + return await build_message_chain_from_payload( + message_payload, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=True, + ) async def send_message(self): post_data = await request.json or {} diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index a9631fc09..a9650cd06 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -204,6 +204,10 @@ class AstrBotDashboard: @staticmethod def _extract_raw_api_key() -> str | None: + if key := request.args.get("api_key"): + return key.strip() + if key := request.args.get("key"): + return key.strip() if key := request.headers.get("X-API-Key"): return key.strip() auth_header = request.headers.get("Authorization", "").strip() @@ -217,6 +221,7 @@ class AstrBotDashboard: def _get_required_open_api_scope(path: str) -> str | None: scope_map = { "/api/v1/chat": "chat", + "/api/v1/chat/ws": "chat", "/api/v1/chat/sessions": "chat", "/api/v1/configs": "config", "/api/v1/file": "file", diff --git a/changelogs/v4.18.3.md b/changelogs/v4.18.3.md new file mode 100644 index 000000000..e2b426b4c --- /dev/null +++ b/changelogs/v4.18.3.md @@ -0,0 +1,49 @@ +## What's Changed + +### 新增 + +- 新增桌面端通用更新桥接能力,便于接入客户端内更新流程 ([#5424](https://github.com/AstrBotDevs/AstrBot/issues/5424))。 + +### 修复 + +- 修复新增平台对话框中 Line 适配器未显示的问题。 +- 修复 Telegram 无法发送 Video 的问题 ([#5430](https://github.com/AstrBotDevs/AstrBot/issues/5430))。 +- 修复创建 embedding provider 时无法自动识别向量维度的问题 ([#5442](https://github.com/AstrBotDevs/AstrBot/issues/5442))。 +- 修复 QQ 官方平台发送媒体消息时 markdown 字段未清理的问题 ([#5445](https://github.com/AstrBotDevs/AstrBot/issues/5445))。 +- 修复上下文管理策略 -> 上下文截断时 tool call / response 配对丢失的问题 ([#5417](https://github.com/AstrBotDevs/AstrBot/issues/5417))。 +- 修复会话更新时 `persona_id` 被覆盖的问题,并增强 persona 解析逻辑。 +- 修复 WebUI 中 GitHub 代理地址显示异常的问题 ([#5438](https://github.com/AstrBotDevs/AstrBot/issues/5438))。 +- 修复设置页新建开发者 API Key 后复制失败的问题 ([#5439](https://github.com/AstrBotDevs/AstrBot/issues/5439))。 +- 修复 Telegram 语音消息格式与 OpenAI STT 兼容性问题(使用 OGG) ([#5389](https://github.com/AstrBotDevs/AstrBot/issues/5389))。 + +### 优化 + +- 优化知识库检索流程,改为批量查询元数据,修复 N+1 查询性能问题 ([#5463](https://github.com/AstrBotDevs/AstrBot/issues/5463))。 +- 优化 Cron 未来任务执行的会话隔离能力,提升并发稳定性。 +- 优化 WebUI 插件页的交互。 + +## What's Changed (EN) + +### New Features + +- Added `useExtensionPage` composable for unified plugin extension page state management. +- Added a generic desktop app updater bridge to support in-app update workflows ([#5424](https://github.com/AstrBotDevs/AstrBot/issues/5424)). + +### Bug Fixes + +- Fixed the Line adapter not appearing in the "Add Platform" dialog. +- Fixed Telegram video sending issues ([#5430](https://github.com/AstrBotDevs/AstrBot/issues/5430)). +- Fixed Pyright static type checking errors ([#5437](https://github.com/AstrBotDevs/AstrBot/issues/5437)). +- Fixed embedding dimension auto-detection when creating embedding providers ([#5442](https://github.com/AstrBotDevs/AstrBot/issues/5442)). +- Fixed stale markdown fields when sending media messages via QQ Official Platform ([#5445](https://github.com/AstrBotDevs/AstrBot/issues/5445)). +- Fixed tool call/response pairing loss during context truncation ([#5417](https://github.com/AstrBotDevs/AstrBot/issues/5417)). +- Fixed `persona_id` being overwritten during conversation updates and improved persona resolution logic. +- Fixed incorrect GitHub proxy display in WebUI ([#5438](https://github.com/AstrBotDevs/AstrBot/issues/5438)). +- Fixed API key copy failure after creating a new key in settings ([#5439](https://github.com/AstrBotDevs/AstrBot/issues/5439)). +- Fixed Telegram voice format compatibility with OpenAI STT by using OGG ([#5389](https://github.com/AstrBotDevs/AstrBot/issues/5389)). + +### Improvements + +- Improved knowledge base retrieval by batching metadata queries to eliminate the N+1 query pattern ([#5463](https://github.com/AstrBotDevs/AstrBot/issues/5463)). +- Improved session isolation for future cron tasks to increase stability under concurrency. +- Improved WebUI plugin page interactions. \ No newline at end of file diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 803c5d826..054a18662 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -10,6 +10,7 @@ :selectedSessions="selectedSessions" :currSessionId="currSessionId" :selectedProjectId="selectedProjectId" + :transportMode="transportMode" :isDark="isDark" :chatboxMode="chatboxMode" :isMobile="isMobile" @@ -26,6 +27,7 @@ @createProject="showCreateProjectDialog" @editProject="showEditProjectDialog" @deleteProject="handleDeleteProject" + @updateTransportMode="setTransportMode" /> @@ -301,11 +303,14 @@ const { isStreaming, isConvRunning, enableStreaming, + transportMode, currentSessionProject, getSessionMessages: getSessionMsg, sendMessage: sendMsg, stopMessage: stopMsg, - toggleStreaming + toggleStreaming, + setTransportMode, + cleanupTransport } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); // 组件引用 @@ -695,6 +700,7 @@ onMounted(() => { onBeforeUnmount(() => { window.removeEventListener('resize', checkMobile); cleanupMediaCache(); + cleanupTransport(); }); diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index a728930d9..97f2179e7 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -117,6 +117,27 @@ {{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }} + + + + {{ tm('transport.title') }} + + +