Merge remote-tracking branch 'origin/master' into feat/neo-skill-self-iteration
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "4.18.2"
|
||||
__version__ = "4.18.3"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
@@ -68,6 +69,14 @@ class _HandleFunctionToolsResult:
|
||||
return cls(kind="cached_image", cached_image=image)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FollowUpTicket:
|
||||
seq: int
|
||||
text: str
|
||||
consumed: bool = False
|
||||
resolved: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
|
||||
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
@override
|
||||
async def reset(
|
||||
@@ -139,6 +148,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.run_context = run_context
|
||||
self._stop_requested = False
|
||||
self._aborted = False
|
||||
self._pending_follow_ups: list[FollowUpTicket] = []
|
||||
self._follow_up_seq = 0
|
||||
|
||||
# These two are used for tool schema mode handling
|
||||
# We now have two modes:
|
||||
@@ -277,6 +288,55 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
roles.append(message.role)
|
||||
logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")
|
||||
|
||||
def follow_up(
|
||||
self,
|
||||
*,
|
||||
message_text: str,
|
||||
) -> FollowUpTicket | None:
|
||||
"""Queue a follow-up message for the next tool result."""
|
||||
if self.done():
|
||||
return None
|
||||
text = (message_text or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
ticket = FollowUpTicket(seq=self._follow_up_seq, text=text)
|
||||
self._follow_up_seq += 1
|
||||
self._pending_follow_ups.append(ticket)
|
||||
return ticket
|
||||
|
||||
def _resolve_unconsumed_follow_ups(self) -> None:
|
||||
if not self._pending_follow_ups:
|
||||
return
|
||||
follow_ups = self._pending_follow_ups
|
||||
self._pending_follow_ups = []
|
||||
for ticket in follow_ups:
|
||||
ticket.resolved.set()
|
||||
|
||||
def _consume_follow_up_notice(self) -> str:
|
||||
if not self._pending_follow_ups:
|
||||
return ""
|
||||
follow_ups = self._pending_follow_ups
|
||||
self._pending_follow_ups = []
|
||||
for ticket in follow_ups:
|
||||
ticket.consumed = True
|
||||
ticket.resolved.set()
|
||||
follow_up_lines = "\n".join(
|
||||
f"{idx}. {ticket.text}" for idx, ticket in enumerate(follow_ups, start=1)
|
||||
)
|
||||
return (
|
||||
"\n\n[SYSTEM NOTICE] User sent follow-up messages while tool execution "
|
||||
"was in progress. Prioritize these follow-up instructions in your next "
|
||||
"actions. In your very next action, briefly acknowledge to the user "
|
||||
"that their follow-up message(s) were received before continuing.\n"
|
||||
f"{follow_up_lines}"
|
||||
)
|
||||
|
||||
def _merge_follow_up_notice(self, content: str) -> str:
|
||||
notice = self._consume_follow_up_notice()
|
||||
if not notice:
|
||||
return content
|
||||
return f"{content}{notice}"
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
"""Process a single step of the agent.
|
||||
@@ -391,6 +451,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
type="aborted",
|
||||
data=AgentResponseData(chain=MessageChain(type="aborted")),
|
||||
)
|
||||
self._resolve_unconsumed_follow_ups()
|
||||
return
|
||||
|
||||
# 处理 LLM 响应
|
||||
@@ -401,6 +462,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.final_llm_resp = llm_resp
|
||||
self.stats.end_time = time.time()
|
||||
self._transition_state(AgentState.ERROR)
|
||||
self._resolve_unconsumed_follow_ups()
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
data=AgentResponseData(
|
||||
@@ -439,6 +501,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
self._resolve_unconsumed_follow_ups()
|
||||
|
||||
# 返回 LLM 结果
|
||||
if llm_resp.result_chain:
|
||||
@@ -583,6 +646,15 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
tool_call_result_blocks: list[ToolCallMessageSegment] = []
|
||||
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
|
||||
|
||||
def _append_tool_call_result(tool_call_id: str, content: str) -> None:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=tool_call_id,
|
||||
content=self._merge_follow_up_notice(content),
|
||||
),
|
||||
)
|
||||
|
||||
# 执行函数调用
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
@@ -622,12 +694,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
if not func_tool:
|
||||
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: Tool {func_tool_name} not found.",
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
f"error: Tool {func_tool_name} not found.",
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -680,12 +749,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
res = resp
|
||||
_final_resp = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
res.content[0].text,
|
||||
)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
# Cache the image instead of sending directly
|
||||
@@ -696,15 +762,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
index=0,
|
||||
mime_type=res.content[0].mimeType or "image/png",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=(
|
||||
f"Image returned and cached at path='{cached_img.file_path}'. "
|
||||
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
|
||||
f"with type='image' and path='{cached_img.file_path}'."
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
(
|
||||
f"Image returned and cached at path='{cached_img.file_path}'. "
|
||||
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
|
||||
f"with type='image' and path='{cached_img.file_path}'."
|
||||
),
|
||||
)
|
||||
# Yield image info for LLM visibility (will be handled in step())
|
||||
@@ -714,12 +777,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
resource.text,
|
||||
)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
@@ -734,15 +794,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
index=0,
|
||||
mime_type=resource.mimeType,
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=(
|
||||
f"Image returned and cached at path='{cached_img.file_path}'. "
|
||||
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
|
||||
f"with type='image' and path='{cached_img.file_path}'."
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
(
|
||||
f"Image returned and cached at path='{cached_img.file_path}'. "
|
||||
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
|
||||
f"with type='image' and path='{cached_img.file_path}'."
|
||||
),
|
||||
)
|
||||
# Yield image info for LLM visibility
|
||||
@@ -750,12 +807,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
cached_img
|
||||
)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="The tool has returned a data type that is not supported.",
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
"The tool has returned a data type that is not supported.",
|
||||
)
|
||||
|
||||
elif resp is None:
|
||||
@@ -767,24 +821,18 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="The tool has no return value, or has sent the result directly to the user.",
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
"The tool has no return value, or has sent the result directly to the user.",
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -798,12 +846,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {e!s}",
|
||||
),
|
||||
_append_tool_call_result(
|
||||
func_tool_id,
|
||||
f"error: {e!s}",
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.18.2"
|
||||
VERSION = "4.18.3"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -429,7 +429,15 @@ CONFIG_METADATA_2 = {
|
||||
"slack_webhook_port": 6197,
|
||||
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
||||
},
|
||||
# LINE's config is located in line_adapter.py
|
||||
"Line": {
|
||||
"id": "line",
|
||||
"type": "line",
|
||||
"enable": False,
|
||||
"channel_access_token": "",
|
||||
"channel_secret": "",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
},
|
||||
"Satori": {
|
||||
"id": "satori",
|
||||
"type": "satori",
|
||||
|
||||
@@ -8,7 +8,7 @@ resolution for backward compatibility.
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from astrbot.core.message.message_event_result import (
|
||||
EventResultType,
|
||||
@@ -17,6 +17,17 @@ from astrbot.core.message.message_event_result import (
|
||||
|
||||
from .stage_order import STAGES_ORDER
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
|
||||
_LAZY_EXPORTS = {
|
||||
"ContentSafetyCheckStage": (
|
||||
"astrbot.core.pipeline.content_safety_check.stage",
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import FollowUpTicket
|
||||
from astrbot.core.astr_agent_run_util import AgentRunner
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
_ACTIVE_AGENT_RUNNERS: dict[str, AgentRunner] = {}
|
||||
_FOLLOW_UP_ORDER_STATE: dict[str, dict[str, object]] = {}
|
||||
"""UMO-level follow-up order state.
|
||||
|
||||
State fields:
|
||||
- `statuses`: seq -> {"pending"|"active"|"consumed"|"finished"}
|
||||
- `next_order`: monotonically increasing sequence allocator
|
||||
- `next_turn`: next sequence allowed to proceed when not consumed
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FollowUpCapture:
|
||||
umo: str
|
||||
ticket: FollowUpTicket
|
||||
order_seq: int
|
||||
monitor_task: asyncio.Task[None]
|
||||
|
||||
|
||||
def _event_follow_up_text(event: AstrMessageEvent) -> str:
|
||||
text = (event.get_message_str() or "").strip()
|
||||
if text:
|
||||
return text
|
||||
return event.get_message_outline().strip()
|
||||
|
||||
|
||||
def register_active_runner(umo: str, runner: AgentRunner) -> None:
|
||||
_ACTIVE_AGENT_RUNNERS[umo] = runner
|
||||
|
||||
|
||||
def unregister_active_runner(umo: str, runner: AgentRunner) -> None:
|
||||
if _ACTIVE_AGENT_RUNNERS.get(umo) is runner:
|
||||
_ACTIVE_AGENT_RUNNERS.pop(umo, None)
|
||||
|
||||
|
||||
def _get_follow_up_order_state(umo: str) -> dict[str, object]:
|
||||
state = _FOLLOW_UP_ORDER_STATE.get(umo)
|
||||
if state is None:
|
||||
state = {
|
||||
"condition": asyncio.Condition(),
|
||||
# Sequence status map for strict in-order resume after unresolved follow-ups.
|
||||
"statuses": {},
|
||||
# Stable allocator for arrival order; never decreases for the same UMO state.
|
||||
"next_order": 0,
|
||||
# The sequence currently allowed to continue main internal flow.
|
||||
"next_turn": 0,
|
||||
}
|
||||
_FOLLOW_UP_ORDER_STATE[umo] = state
|
||||
return state
|
||||
|
||||
|
||||
def _advance_follow_up_turn_locked(state: dict[str, object]) -> None:
|
||||
# Skip slots that are already handled, and stop at the first unfinished slot.
|
||||
statuses = state["statuses"]
|
||||
assert isinstance(statuses, dict)
|
||||
next_turn = state["next_turn"]
|
||||
assert isinstance(next_turn, int)
|
||||
|
||||
while True:
|
||||
curr = statuses.get(next_turn)
|
||||
if curr in ("consumed", "finished"):
|
||||
statuses.pop(next_turn, None)
|
||||
next_turn += 1
|
||||
continue
|
||||
break
|
||||
|
||||
state["next_turn"] = next_turn
|
||||
|
||||
|
||||
def _allocate_follow_up_order(umo: str) -> int:
|
||||
state = _get_follow_up_order_state(umo)
|
||||
next_order = state["next_order"]
|
||||
assert isinstance(next_order, int)
|
||||
seq = next_order
|
||||
state["next_order"] = seq + 1
|
||||
statuses = state["statuses"]
|
||||
assert isinstance(statuses, dict)
|
||||
statuses[seq] = "pending"
|
||||
return seq
|
||||
|
||||
|
||||
async def _mark_follow_up_consumed(umo: str, seq: int) -> None:
|
||||
state = _FOLLOW_UP_ORDER_STATE.get(umo)
|
||||
if not state:
|
||||
return
|
||||
condition = state["condition"]
|
||||
assert isinstance(condition, asyncio.Condition)
|
||||
async with condition:
|
||||
statuses = state["statuses"]
|
||||
assert isinstance(statuses, dict)
|
||||
if seq in statuses and statuses[seq] != "finished":
|
||||
statuses[seq] = "consumed"
|
||||
_advance_follow_up_turn_locked(state)
|
||||
condition.notify_all()
|
||||
|
||||
# Release state only when this UMO has no pending statuses and no active runner.
|
||||
if not statuses and _ACTIVE_AGENT_RUNNERS.get(umo) is None:
|
||||
_FOLLOW_UP_ORDER_STATE.pop(umo, None)
|
||||
|
||||
|
||||
async def _activate_and_wait_follow_up_turn(umo: str, seq: int) -> None:
|
||||
state = _FOLLOW_UP_ORDER_STATE.get(umo)
|
||||
if not state:
|
||||
return
|
||||
condition = state["condition"]
|
||||
assert isinstance(condition, asyncio.Condition)
|
||||
async with condition:
|
||||
statuses = state["statuses"]
|
||||
assert isinstance(statuses, dict)
|
||||
if seq in statuses:
|
||||
statuses[seq] = "active"
|
||||
|
||||
# Strict ordering: only the head (`next_turn`) can continue.
|
||||
while True:
|
||||
next_turn = state["next_turn"]
|
||||
assert isinstance(next_turn, int)
|
||||
if next_turn == seq:
|
||||
break
|
||||
await condition.wait()
|
||||
|
||||
|
||||
async def _finish_follow_up_turn(umo: str, seq: int) -> None:
|
||||
state = _FOLLOW_UP_ORDER_STATE.get(umo)
|
||||
if not state:
|
||||
return
|
||||
condition = state["condition"]
|
||||
assert isinstance(condition, asyncio.Condition)
|
||||
async with condition:
|
||||
statuses = state["statuses"]
|
||||
assert isinstance(statuses, dict)
|
||||
if seq in statuses:
|
||||
statuses[seq] = "finished"
|
||||
_advance_follow_up_turn_locked(state)
|
||||
condition.notify_all()
|
||||
|
||||
if not statuses and _ACTIVE_AGENT_RUNNERS.get(umo) is None:
|
||||
_FOLLOW_UP_ORDER_STATE.pop(umo, None)
|
||||
|
||||
|
||||
async def _monitor_follow_up_ticket(
|
||||
umo: str,
|
||||
ticket: FollowUpTicket,
|
||||
order_seq: int,
|
||||
) -> None:
|
||||
"""Advance consumed slots immediately on resolution to avoid wake-order drift."""
|
||||
await ticket.resolved.wait()
|
||||
if ticket.consumed:
|
||||
await _mark_follow_up_consumed(umo, order_seq)
|
||||
|
||||
|
||||
def try_capture_follow_up(event: AstrMessageEvent) -> FollowUpCapture | None:
|
||||
sender_id = event.get_sender_id()
|
||||
if not sender_id:
|
||||
return None
|
||||
runner = _ACTIVE_AGENT_RUNNERS.get(event.unified_msg_origin)
|
||||
if not runner:
|
||||
return None
|
||||
runner_event = getattr(getattr(runner.run_context, "context", None), "event", None)
|
||||
if runner_event is None:
|
||||
return None
|
||||
active_sender_id = runner_event.get_sender_id()
|
||||
if not active_sender_id or active_sender_id != sender_id:
|
||||
return None
|
||||
|
||||
ticket = runner.follow_up(message_text=_event_follow_up_text(event))
|
||||
if not ticket:
|
||||
return None
|
||||
# Allocate strict order at capture time (arrival order), not at wake time.
|
||||
order_seq = _allocate_follow_up_order(event.unified_msg_origin)
|
||||
monitor_task = asyncio.create_task(
|
||||
_monitor_follow_up_ticket(
|
||||
event.unified_msg_origin,
|
||||
ticket,
|
||||
order_seq,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Captured follow-up message for active agent run, umo=%s, order_seq=%s",
|
||||
event.unified_msg_origin,
|
||||
order_seq,
|
||||
)
|
||||
return FollowUpCapture(
|
||||
umo=event.unified_msg_origin,
|
||||
ticket=ticket,
|
||||
order_seq=order_seq,
|
||||
monitor_task=monitor_task,
|
||||
)
|
||||
|
||||
|
||||
async def prepare_follow_up_capture(capture: FollowUpCapture) -> tuple[bool, bool]:
|
||||
"""Return `(consumed_marked, activated)` for internal stage branch handling."""
|
||||
await capture.ticket.resolved.wait()
|
||||
if capture.ticket.consumed:
|
||||
await _mark_follow_up_consumed(capture.umo, capture.order_seq)
|
||||
return True, False
|
||||
await _activate_and_wait_follow_up_turn(capture.umo, capture.order_seq)
|
||||
return False, True
|
||||
|
||||
|
||||
async def finalize_follow_up_capture(
|
||||
capture: FollowUpCapture,
|
||||
*,
|
||||
activated: bool,
|
||||
consumed_marked: bool,
|
||||
) -> None:
|
||||
# Best-effort cancellation: monitor task is auxiliary and should not leak.
|
||||
if not capture.monitor_task.done():
|
||||
capture.monitor_task.cancel()
|
||||
try:
|
||||
await capture.monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if activated:
|
||||
await _finish_follow_up_turn(capture.umo, capture.order_seq)
|
||||
elif not consumed_marked:
|
||||
await _mark_follow_up_consumed(capture.umo, capture.order_seq)
|
||||
@@ -29,8 +29,16 @@ from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
from .....astr_agent_run_util import run_agent, run_live_agent
|
||||
from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent
|
||||
from ....context import PipelineContext, call_event_hook
|
||||
from ...follow_up import (
|
||||
FollowUpCapture,
|
||||
finalize_follow_up_capture,
|
||||
prepare_follow_up_capture,
|
||||
register_active_runner,
|
||||
try_capture_follow_up,
|
||||
unregister_active_runner,
|
||||
)
|
||||
|
||||
|
||||
class InternalAgentSubStage(Stage):
|
||||
@@ -130,6 +138,9 @@ class InternalAgentSubStage(Stage):
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
follow_up_capture: FollowUpCapture | None = None
|
||||
follow_up_consumed_marked = False
|
||||
follow_up_activated = False
|
||||
try:
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
@@ -150,188 +161,208 @@ class InternalAgentSubStage(Stage):
|
||||
return
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
follow_up_capture = try_capture_follow_up(event)
|
||||
if follow_up_capture:
|
||||
(
|
||||
follow_up_consumed_marked,
|
||||
follow_up_activated,
|
||||
) = await prepare_follow_up_capture(follow_up_capture)
|
||||
if follow_up_consumed_marked:
|
||||
logger.info(
|
||||
"Follow-up ticket already consumed, stopping processing. umo=%s, seq=%s",
|
||||
event.unified_msg_origin,
|
||||
follow_up_capture.ticket.seq,
|
||||
)
|
||||
return
|
||||
|
||||
await event.send_typing()
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
agent_runner: AgentRunner | None = None
|
||||
runner_registered = False
|
||||
try:
|
||||
build_cfg = replace(
|
||||
self.main_agent_cfg,
|
||||
provider_wake_prefix=provider_wake_prefix,
|
||||
streaming_response=streaming_response,
|
||||
)
|
||||
|
||||
build_cfg = replace(
|
||||
self.main_agent_cfg,
|
||||
provider_wake_prefix=provider_wake_prefix,
|
||||
streaming_response=streaming_response,
|
||||
)
|
||||
build_result: MainAgentBuildResult | None = await build_main_agent(
|
||||
event=event,
|
||||
plugin_context=self.ctx.plugin_manager.context,
|
||||
config=build_cfg,
|
||||
apply_reset=False,
|
||||
)
|
||||
|
||||
build_result: MainAgentBuildResult | None = await build_main_agent(
|
||||
event=event,
|
||||
plugin_context=self.ctx.plugin_manager.context,
|
||||
config=build_cfg,
|
||||
apply_reset=False,
|
||||
)
|
||||
|
||||
if build_result is None:
|
||||
return
|
||||
|
||||
agent_runner = build_result.agent_runner
|
||||
req = build_result.provider_request
|
||||
provider = build_result.provider
|
||||
reset_coro = build_result.reset_coro
|
||||
|
||||
api_base = provider.provider_config.get("api_base", "")
|
||||
for host in decoded_blocked:
|
||||
if host in api_base:
|
||||
logger.error(
|
||||
"Provider API base %s is blocked due to security reasons. Please use another ai provider.",
|
||||
api_base,
|
||||
)
|
||||
if build_result is None:
|
||||
return
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
agent_runner = build_result.agent_runner
|
||||
req = build_result.provider_request
|
||||
provider = build_result.provider
|
||||
reset_coro = build_result.reset_coro
|
||||
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
api_base = provider.provider_config.get("api_base", "")
|
||||
for host in decoded_blocked:
|
||||
if host in api_base:
|
||||
logger.error(
|
||||
"Provider API base %s is blocked due to security reasons. Please use another ai provider.",
|
||||
api_base,
|
||||
)
|
||||
return
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
if reset_coro:
|
||||
reset_coro.close()
|
||||
return
|
||||
|
||||
# apply reset
|
||||
if reset_coro:
|
||||
reset_coro.close()
|
||||
return
|
||||
await reset_coro
|
||||
|
||||
# apply reset
|
||||
if reset_coro:
|
||||
await reset_coro
|
||||
register_active_runner(event.unified_msg_origin, agent_runner)
|
||||
runner_registered = True
|
||||
action_type = event.get_extra("action_type")
|
||||
|
||||
action_type = event.get_extra("action_type")
|
||||
|
||||
event.trace.record(
|
||||
"astr_agent_prepare",
|
||||
system_prompt=req.system_prompt,
|
||||
tools=req.func_tool.names() if req.func_tool else [],
|
||||
stream=streaming_response,
|
||||
chat_provider={
|
||||
"id": provider.provider_config.get("id", ""),
|
||||
"model": provider.get_model(),
|
||||
},
|
||||
)
|
||||
|
||||
# 检测 Live Mode
|
||||
if action_type == "live":
|
||||
# Live Mode: 使用 run_live_agent
|
||||
logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理")
|
||||
|
||||
# 获取 TTS Provider
|
||||
tts_provider = (
|
||||
self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
event.trace.record(
|
||||
"astr_agent_prepare",
|
||||
system_prompt=req.system_prompt,
|
||||
tools=req.func_tool.names() if req.func_tool else [],
|
||||
stream=streaming_response,
|
||||
chat_provider={
|
||||
"id": provider.provider_config.get("id", ""),
|
||||
"model": provider.get_model(),
|
||||
},
|
||||
)
|
||||
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
"[Live Mode] TTS Provider 未配置,将使用普通流式模式"
|
||||
# 检测 Live Mode
|
||||
if action_type == "live":
|
||||
# Live Mode: 使用 run_live_agent
|
||||
logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理")
|
||||
|
||||
# 获取 TTS Provider
|
||||
tts_provider = (
|
||||
self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
)
|
||||
|
||||
# 使用 run_live_agent,总是使用流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_live_agent(
|
||||
agent_runner,
|
||||
tts_provider,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
show_reasoning=self.show_reasoning,
|
||||
if not tts_provider:
|
||||
logger.warning(
|
||||
"[Live Mode] TTS Provider 未配置,将使用普通流式模式"
|
||||
)
|
||||
|
||||
# 使用 run_live_agent,总是使用流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_live_agent(
|
||||
agent_runner,
|
||||
tts_provider,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
)
|
||||
yield
|
||||
|
||||
# 保存历史记录
|
||||
if agent_runner.done() and (
|
||||
not event.is_stopped() or agent_runner.was_aborted()
|
||||
):
|
||||
# 保存历史记录
|
||||
if agent_runner.done() and (
|
||||
not event.is_stopped() or agent_runner.was_aborted()
|
||||
):
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
user_aborted=agent_runner.was_aborted(),
|
||||
)
|
||||
|
||||
elif streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
final_resp = agent_runner.get_final_llm_resp()
|
||||
|
||||
event.trace.record(
|
||||
"astr_agent_complete",
|
||||
stats=agent_runner.stats.to_dict(),
|
||||
resp=final_resp.completion_text if final_resp else None,
|
||||
)
|
||||
|
||||
# 检查事件是否被停止,如果被停止则不保存历史记录
|
||||
if not event.is_stopped() or agent_runner.was_aborted():
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
final_resp,
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
user_aborted=agent_runner.was_aborted(),
|
||||
)
|
||||
|
||||
elif streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
self.show_tool_call_result,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
final_resp = agent_runner.get_final_llm_resp()
|
||||
|
||||
event.trace.record(
|
||||
"astr_agent_complete",
|
||||
stats=agent_runner.stats.to_dict(),
|
||||
resp=final_resp.completion_text if final_resp else None,
|
||||
)
|
||||
|
||||
# 检查事件是否被停止,如果被停止则不保存历史记录
|
||||
if not event.is_stopped() or agent_runner.was_aborted():
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
final_resp,
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
user_aborted=agent_runner.was_aborted(),
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
finally:
|
||||
if runner_registered and agent_runner is not None:
|
||||
unregister_active_runner(event.unified_msg_origin, agent_runner)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing agent: {e}")
|
||||
@@ -340,6 +371,13 @@ class InternalAgentSubStage(Stage):
|
||||
f"Error occurred while processing agent request: {e}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if follow_up_capture:
|
||||
await finalize_follow_up_capture(
|
||||
follow_up_capture,
|
||||
activated=follow_up_activated,
|
||||
consumed_marked=follow_up_consumed_marked,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
|
||||
@@ -65,15 +65,6 @@ LINE_I18N_RESOURCES = {
|
||||
"line",
|
||||
"LINE Messaging API 适配器",
|
||||
support_streaming_message=False,
|
||||
default_config_tmpl={
|
||||
"id": "line",
|
||||
"type": "line",
|
||||
"enable": False,
|
||||
"channel_access_token": "",
|
||||
"channel_secret": "",
|
||||
"unified_webhook_mode": True,
|
||||
"webhook_uuid": "",
|
||||
},
|
||||
config_metadata=LINE_CONFIG_METADATA,
|
||||
i18n_resources=LINE_I18N_RESOURCES,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,465 @@
|
||||
import json
|
||||
import mimetypes
|
||||
import shutil
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core.db.po import Attachment
|
||||
from astrbot.core.message.components import (
|
||||
File,
|
||||
Image,
|
||||
Json,
|
||||
Plain,
|
||||
Record,
|
||||
Reply,
|
||||
Video,
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
AttachmentGetter = Callable[[str], Awaitable[Attachment | None]]
|
||||
AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]]
|
||||
ReplyHistoryGetter = Callable[
|
||||
[Any],
|
||||
Awaitable[tuple[list[dict], str | None, str | None] | None],
|
||||
]
|
||||
|
||||
MEDIA_PART_TYPES = {"image", "record", "file", "video"}
|
||||
|
||||
|
||||
def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]:
|
||||
return [{k: v for k, v in part.items() if k != "path"} for part in message_parts]
|
||||
|
||||
|
||||
def webchat_message_parts_have_content(message_parts: list[dict]) -> bool:
|
||||
return any(
|
||||
part.get("type") in ("plain", "image", "record", "file", "video")
|
||||
and (part.get("text") or part.get("attachment_id") or part.get("filename"))
|
||||
for part in message_parts
|
||||
)
|
||||
|
||||
|
||||
async def parse_webchat_message_parts(
|
||||
message_parts: list,
|
||||
*,
|
||||
strict: bool = False,
|
||||
include_empty_plain: bool = False,
|
||||
verify_media_path_exists: bool = True,
|
||||
reply_history_getter: ReplyHistoryGetter | None = None,
|
||||
current_depth: int = 0,
|
||||
max_reply_depth: int = 0,
|
||||
cast_reply_id_to_str: bool = True,
|
||||
) -> tuple[list, list[str], bool]:
|
||||
"""Parse webchat message parts into components/text parts.
|
||||
|
||||
Returns:
|
||||
tuple[list, list[str], bool]:
|
||||
(components, plain_text_parts, has_non_reply_content)
|
||||
"""
|
||||
components = []
|
||||
text_parts: list[str] = []
|
||||
has_content = False
|
||||
|
||||
for part in message_parts:
|
||||
if not isinstance(part, dict):
|
||||
if strict:
|
||||
raise ValueError("message part must be an object")
|
||||
continue
|
||||
|
||||
part_type = str(part.get("type", "")).strip()
|
||||
if part_type == "plain":
|
||||
text = str(part.get("text", ""))
|
||||
if text or include_empty_plain:
|
||||
components.append(Plain(text=text))
|
||||
text_parts.append(text)
|
||||
if text:
|
||||
has_content = True
|
||||
continue
|
||||
|
||||
if part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
if message_id is None:
|
||||
if strict:
|
||||
raise ValueError("reply part missing message_id")
|
||||
continue
|
||||
|
||||
reply_chain = []
|
||||
reply_message_str = str(part.get("selected_text", ""))
|
||||
sender_id = None
|
||||
sender_name = None
|
||||
|
||||
if reply_message_str:
|
||||
reply_chain = [Plain(text=reply_message_str)]
|
||||
elif (
|
||||
reply_history_getter
|
||||
and current_depth < max_reply_depth
|
||||
and message_id is not None
|
||||
):
|
||||
reply_info = await reply_history_getter(message_id)
|
||||
if reply_info:
|
||||
reply_parts, sender_id, sender_name = reply_info
|
||||
(
|
||||
reply_chain,
|
||||
reply_text_parts,
|
||||
_,
|
||||
) = await parse_webchat_message_parts(
|
||||
reply_parts,
|
||||
strict=strict,
|
||||
include_empty_plain=include_empty_plain,
|
||||
verify_media_path_exists=verify_media_path_exists,
|
||||
reply_history_getter=reply_history_getter,
|
||||
current_depth=current_depth + 1,
|
||||
max_reply_depth=max_reply_depth,
|
||||
cast_reply_id_to_str=cast_reply_id_to_str,
|
||||
)
|
||||
reply_message_str = "".join(reply_text_parts)
|
||||
|
||||
reply_id = str(message_id) if cast_reply_id_to_str else message_id
|
||||
components.append(
|
||||
Reply(
|
||||
id=reply_id,
|
||||
message_str=reply_message_str,
|
||||
chain=reply_chain,
|
||||
sender_id=sender_id,
|
||||
sender_nickname=sender_name,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if part_type not in MEDIA_PART_TYPES:
|
||||
if strict:
|
||||
raise ValueError(f"unsupported message part type: {part_type}")
|
||||
continue
|
||||
|
||||
path = part.get("path")
|
||||
if not path:
|
||||
if strict:
|
||||
raise ValueError(f"{part_type} part missing path")
|
||||
continue
|
||||
|
||||
file_path = Path(str(path))
|
||||
if verify_media_path_exists and not file_path.exists():
|
||||
if strict:
|
||||
raise ValueError(f"file not found: {file_path!s}")
|
||||
continue
|
||||
|
||||
file_path_str = (
|
||||
str(file_path.resolve()) if verify_media_path_exists else str(file_path)
|
||||
)
|
||||
has_content = True
|
||||
if part_type == "image":
|
||||
components.append(Image.fromFileSystem(file_path_str))
|
||||
elif part_type == "record":
|
||||
components.append(Record.fromFileSystem(file_path_str))
|
||||
elif part_type == "video":
|
||||
components.append(Video.fromFileSystem(file_path_str))
|
||||
else:
|
||||
filename = str(part.get("filename", "")).strip() or file_path.name
|
||||
components.append(File(name=filename, file=file_path_str))
|
||||
|
||||
return components, text_parts, has_content
|
||||
|
||||
|
||||
async def build_webchat_message_parts(
|
||||
message_payload: str | list,
|
||||
*,
|
||||
get_attachment_by_id: AttachmentGetter,
|
||||
strict: bool = False,
|
||||
) -> list[dict]:
|
||||
if isinstance(message_payload, str):
|
||||
text = message_payload.strip()
|
||||
return [{"type": "plain", "text": text}] if text else []
|
||||
|
||||
if not isinstance(message_payload, list):
|
||||
if strict:
|
||||
raise ValueError("message must be a string or list")
|
||||
return []
|
||||
|
||||
message_parts: list[dict] = []
|
||||
for part in message_payload:
|
||||
if not isinstance(part, dict):
|
||||
if strict:
|
||||
raise ValueError("message part must be an object")
|
||||
continue
|
||||
|
||||
part_type = str(part.get("type", "")).strip()
|
||||
if part_type == "plain":
|
||||
text = str(part.get("text", ""))
|
||||
if text:
|
||||
message_parts.append({"type": "plain", "text": text})
|
||||
continue
|
||||
|
||||
if part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
if message_id is None:
|
||||
if strict:
|
||||
raise ValueError("reply part missing message_id")
|
||||
continue
|
||||
message_parts.append(
|
||||
{
|
||||
"type": "reply",
|
||||
"message_id": message_id,
|
||||
"selected_text": str(part.get("selected_text", "")),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if part_type not in MEDIA_PART_TYPES:
|
||||
if strict:
|
||||
raise ValueError(f"unsupported message part type: {part_type}")
|
||||
continue
|
||||
|
||||
attachment_id = part.get("attachment_id")
|
||||
if not attachment_id:
|
||||
if strict:
|
||||
raise ValueError(f"{part_type} part missing attachment_id")
|
||||
continue
|
||||
|
||||
attachment = await get_attachment_by_id(str(attachment_id))
|
||||
if not attachment:
|
||||
if strict:
|
||||
raise ValueError(f"attachment not found: {attachment_id}")
|
||||
continue
|
||||
|
||||
attachment_path = Path(attachment.path)
|
||||
message_parts.append(
|
||||
{
|
||||
"type": attachment.type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": attachment_path.name,
|
||||
"path": str(attachment_path),
|
||||
}
|
||||
)
|
||||
|
||||
return message_parts
|
||||
|
||||
|
||||
def webchat_message_parts_to_message_chain(
|
||||
message_parts: list[dict],
|
||||
*,
|
||||
strict: bool = False,
|
||||
) -> MessageChain:
|
||||
components = []
|
||||
has_content = False
|
||||
|
||||
for part in message_parts:
|
||||
if not isinstance(part, dict):
|
||||
if strict:
|
||||
raise ValueError("message part must be an object")
|
||||
continue
|
||||
|
||||
part_type = str(part.get("type", "")).strip()
|
||||
if part_type == "plain":
|
||||
text = str(part.get("text", ""))
|
||||
if text:
|
||||
components.append(Plain(text=text))
|
||||
has_content = True
|
||||
continue
|
||||
|
||||
if part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
if message_id is None:
|
||||
if strict:
|
||||
raise ValueError("reply part missing message_id")
|
||||
continue
|
||||
components.append(
|
||||
Reply(
|
||||
id=str(message_id),
|
||||
message_str=str(part.get("selected_text", "")),
|
||||
chain=[],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if part_type not in MEDIA_PART_TYPES:
|
||||
if strict:
|
||||
raise ValueError(f"unsupported message part type: {part_type}")
|
||||
continue
|
||||
|
||||
path = part.get("path")
|
||||
if not path:
|
||||
if strict:
|
||||
raise ValueError(f"{part_type} part missing path")
|
||||
continue
|
||||
|
||||
file_path = Path(str(path))
|
||||
if not file_path.exists():
|
||||
if strict:
|
||||
raise ValueError(f"file not found: {file_path!s}")
|
||||
continue
|
||||
|
||||
file_path_str = str(file_path.resolve())
|
||||
has_content = True
|
||||
if part_type == "image":
|
||||
components.append(Image.fromFileSystem(file_path_str))
|
||||
elif part_type == "record":
|
||||
components.append(Record.fromFileSystem(file_path_str))
|
||||
elif part_type == "video":
|
||||
components.append(Video.fromFileSystem(file_path_str))
|
||||
else:
|
||||
filename = str(part.get("filename", "")).strip() or file_path.name
|
||||
components.append(File(name=filename, file=file_path_str))
|
||||
|
||||
if strict and (not components or not has_content):
|
||||
raise ValueError("Message content is empty (reply only is not allowed)")
|
||||
|
||||
return MessageChain(chain=components)
|
||||
|
||||
|
||||
async def build_message_chain_from_payload(
|
||||
message_payload: str | list,
|
||||
*,
|
||||
get_attachment_by_id: AttachmentGetter,
|
||||
strict: bool = True,
|
||||
) -> MessageChain:
|
||||
message_parts = await build_webchat_message_parts(
|
||||
message_payload,
|
||||
get_attachment_by_id=get_attachment_by_id,
|
||||
strict=strict,
|
||||
)
|
||||
components, _, has_content = await parse_webchat_message_parts(
|
||||
message_parts,
|
||||
strict=strict,
|
||||
)
|
||||
if strict and (not components or not has_content):
|
||||
raise ValueError("Message content is empty (reply only is not allowed)")
|
||||
return MessageChain(chain=components)
|
||||
|
||||
|
||||
async def create_attachment_part_from_existing_file(
|
||||
filename: str,
|
||||
*,
|
||||
attach_type: str,
|
||||
insert_attachment: AttachmentInserter,
|
||||
attachments_dir: str | Path,
|
||||
fallback_dirs: Sequence[str | Path] = (),
|
||||
) -> dict | None:
|
||||
basename = Path(filename).name
|
||||
candidate_paths = [Path(attachments_dir) / basename]
|
||||
candidate_paths.extend(Path(p) / basename for p in fallback_dirs)
|
||||
|
||||
file_path = next((path for path in candidate_paths if path.exists()), None)
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(str(file_path))
|
||||
attachment = await insert_attachment(
|
||||
str(file_path),
|
||||
attach_type,
|
||||
mime_type or "application/octet-stream",
|
||||
)
|
||||
if not attachment:
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": attach_type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": file_path.name,
|
||||
}
|
||||
|
||||
|
||||
async def message_chain_to_storage_message_parts(
|
||||
message_chain: MessageChain,
|
||||
*,
|
||||
insert_attachment: AttachmentInserter,
|
||||
attachments_dir: str | Path,
|
||||
) -> list[dict]:
|
||||
target_dir = Path(attachments_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
parts: list[dict] = []
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
if comp.text:
|
||||
parts.append({"type": "plain", "text": comp.text})
|
||||
continue
|
||||
|
||||
if isinstance(comp, Json):
|
||||
parts.append(
|
||||
{"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)}
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(comp, Image):
|
||||
file_path = await comp.convert_to_file_path()
|
||||
attachment_part = await _copy_file_to_attachment_part(
|
||||
file_path=file_path,
|
||||
attach_type="image",
|
||||
insert_attachment=insert_attachment,
|
||||
attachments_dir=target_dir,
|
||||
)
|
||||
if attachment_part:
|
||||
parts.append(attachment_part)
|
||||
continue
|
||||
|
||||
if isinstance(comp, Record):
|
||||
file_path = await comp.convert_to_file_path()
|
||||
attachment_part = await _copy_file_to_attachment_part(
|
||||
file_path=file_path,
|
||||
attach_type="record",
|
||||
insert_attachment=insert_attachment,
|
||||
attachments_dir=target_dir,
|
||||
)
|
||||
if attachment_part:
|
||||
parts.append(attachment_part)
|
||||
continue
|
||||
|
||||
if isinstance(comp, Video):
|
||||
file_path = await comp.convert_to_file_path()
|
||||
attachment_part = await _copy_file_to_attachment_part(
|
||||
file_path=file_path,
|
||||
attach_type="video",
|
||||
insert_attachment=insert_attachment,
|
||||
attachments_dir=target_dir,
|
||||
)
|
||||
if attachment_part:
|
||||
parts.append(attachment_part)
|
||||
continue
|
||||
|
||||
if isinstance(comp, File):
|
||||
file_path = await comp.get_file()
|
||||
attachment_part = await _copy_file_to_attachment_part(
|
||||
file_path=file_path,
|
||||
attach_type="file",
|
||||
insert_attachment=insert_attachment,
|
||||
attachments_dir=target_dir,
|
||||
display_name=comp.name,
|
||||
)
|
||||
if attachment_part:
|
||||
parts.append(attachment_part)
|
||||
continue
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
async def _copy_file_to_attachment_part(
|
||||
*,
|
||||
file_path: str,
|
||||
attach_type: str,
|
||||
insert_attachment: AttachmentInserter,
|
||||
attachments_dir: Path,
|
||||
display_name: str | None = None,
|
||||
) -> dict | None:
|
||||
src_path = Path(file_path)
|
||||
if not src_path.exists() or not src_path.is_file():
|
||||
return None
|
||||
|
||||
suffix = src_path.suffix
|
||||
target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}"
|
||||
shutil.copy2(src_path, target_path)
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(target_path.name)
|
||||
attachment = await insert_attachment(
|
||||
str(target_path),
|
||||
attach_type,
|
||||
mime_type or "application/octet-stream",
|
||||
)
|
||||
if not attachment:
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": attach_type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": display_name or src_path.name,
|
||||
}
|
||||
@@ -3,12 +3,12 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core.db.po import PlatformMessageHistory
|
||||
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform import (
|
||||
AstrBotMessage,
|
||||
@@ -21,10 +21,23 @@ from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .message_parts_helper import (
|
||||
message_chain_to_storage_message_parts,
|
||||
parse_webchat_message_parts,
|
||||
)
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
|
||||
|
||||
|
||||
def _extract_conversation_id(session_id: str) -> str:
|
||||
"""Extract raw webchat conversation id from event/session id."""
|
||||
if session_id.startswith("webchat!"):
|
||||
parts = session_id.split("!", 2)
|
||||
if len(parts) == 3:
|
||||
return parts[2]
|
||||
return session_id
|
||||
|
||||
|
||||
class QueueListener:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,13 +70,15 @@ class WebChatAdapter(Platform):
|
||||
|
||||
self.settings = platform_settings
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
self.attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="webchat",
|
||||
description="webchat",
|
||||
id="webchat",
|
||||
support_proactive_message=False,
|
||||
support_proactive_message=True,
|
||||
)
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._webchat_queue_mgr = webchat_queue_mgr
|
||||
@@ -73,10 +88,67 @@ class WebChatAdapter(Platform):
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
) -> None:
|
||||
message_id = f"active_{str(uuid.uuid4())}"
|
||||
await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
|
||||
conversation_id = _extract_conversation_id(session.session_id)
|
||||
active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
|
||||
conversation_id
|
||||
)
|
||||
subscription_request_ids = [
|
||||
req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
|
||||
]
|
||||
target_request_ids = subscription_request_ids or active_request_ids
|
||||
|
||||
if target_request_ids:
|
||||
for request_id in target_request_ids:
|
||||
await WebChatMessageEvent._send(
|
||||
request_id,
|
||||
message_chain,
|
||||
session.session_id,
|
||||
)
|
||||
else:
|
||||
message_id = f"active_{uuid.uuid4()!s}"
|
||||
await WebChatMessageEvent._send(
|
||||
message_id,
|
||||
message_chain,
|
||||
session.session_id,
|
||||
)
|
||||
|
||||
should_persist = (
|
||||
bool(subscription_request_ids)
|
||||
or not active_request_ids
|
||||
or all(req_id.startswith("active_") for req_id in active_request_ids)
|
||||
)
|
||||
if should_persist:
|
||||
try:
|
||||
await self._save_proactive_message(conversation_id, message_chain)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[WebChatAdapter] Failed to save proactive message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def _save_proactive_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
message_chain: MessageChain,
|
||||
) -> None:
|
||||
message_parts = await message_chain_to_storage_message_parts(
|
||||
message_chain,
|
||||
insert_attachment=db_helper.insert_attachment,
|
||||
attachments_dir=self.attachments_dir,
|
||||
)
|
||||
if not message_parts:
|
||||
return
|
||||
|
||||
await db_helper.insert_platform_message_history(
|
||||
platform_id="webchat",
|
||||
user_id=conversation_id,
|
||||
content={"type": "bot", "message": message_parts},
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
|
||||
async def _get_message_history(
|
||||
self, message_id: int
|
||||
) -> PlatformMessageHistory | None:
|
||||
@@ -98,72 +170,30 @@ class WebChatAdapter(Platform):
|
||||
Returns:
|
||||
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
|
||||
"""
|
||||
components = []
|
||||
text_parts = []
|
||||
|
||||
for part in message_parts:
|
||||
part_type = part.get("type")
|
||||
if part_type == "plain":
|
||||
text = part.get("text", "")
|
||||
components.append(Plain(text=text))
|
||||
text_parts.append(text)
|
||||
elif part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
reply_chain = []
|
||||
reply_message_str = part.get("selected_text", "")
|
||||
sender_id = None
|
||||
sender_name = None
|
||||
async def get_reply_parts(
|
||||
message_id: Any,
|
||||
) -> tuple[list[dict], str | None, str | None] | None:
|
||||
history = await self._get_message_history(message_id)
|
||||
if not history or not history.content:
|
||||
return None
|
||||
|
||||
if reply_message_str:
|
||||
reply_chain = [Plain(text=reply_message_str)]
|
||||
reply_parts = history.content.get("message", [])
|
||||
if not isinstance(reply_parts, list):
|
||||
return None
|
||||
|
||||
# recursively get the content of the referenced message, if selected_text is empty
|
||||
if not reply_message_str and depth < max_depth and message_id:
|
||||
history = await self._get_message_history(message_id)
|
||||
if history and history.content:
|
||||
reply_parts = history.content.get("message", [])
|
||||
if isinstance(reply_parts, list):
|
||||
(
|
||||
reply_chain,
|
||||
reply_text_parts,
|
||||
) = await self._parse_message_parts(
|
||||
reply_parts,
|
||||
depth=depth + 1,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
reply_message_str = "".join(reply_text_parts)
|
||||
sender_id = history.sender_id
|
||||
sender_name = history.sender_name
|
||||
|
||||
components.append(
|
||||
Reply(
|
||||
id=message_id,
|
||||
chain=reply_chain,
|
||||
message_str=reply_message_str,
|
||||
sender_id=sender_id,
|
||||
sender_nickname=sender_name,
|
||||
)
|
||||
)
|
||||
elif part_type == "image":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Image.fromFileSystem(path))
|
||||
elif part_type == "record":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Record.fromFileSystem(path))
|
||||
elif part_type == "file":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
filename = part.get("filename") or (
|
||||
os.path.basename(path) if path else "file"
|
||||
)
|
||||
components.append(File(name=filename, file=path))
|
||||
elif part_type == "video":
|
||||
path = part.get("path")
|
||||
if path:
|
||||
components.append(Video.fromFileSystem(path))
|
||||
return reply_parts, history.sender_id, history.sender_name
|
||||
|
||||
components, text_parts, _ = await parse_webchat_message_parts(
|
||||
message_parts,
|
||||
strict=False,
|
||||
include_empty_plain=True,
|
||||
verify_media_path_exists=False,
|
||||
reply_history_getter=get_reply_parts,
|
||||
current_depth=depth,
|
||||
max_reply_depth=max_depth,
|
||||
cast_reply_id_to_str=False,
|
||||
)
|
||||
return components, text_parts
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
|
||||
@@ -14,6 +14,15 @@ from .webchat_queue_mgr import webchat_queue_mgr
|
||||
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
|
||||
|
||||
|
||||
def _extract_conversation_id(session_id: str) -> str:
|
||||
"""Extract raw webchat conversation id from event/session id."""
|
||||
if session_id.startswith("webchat!"):
|
||||
parts = session_id.split("!", 2)
|
||||
if len(parts) == 3:
|
||||
return parts[2]
|
||||
return session_id
|
||||
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
@@ -27,7 +36,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
streaming: bool = False,
|
||||
) -> str | None:
|
||||
request_id = str(message_id)
|
||||
conversation_id = session_id.split("!")[-1]
|
||||
conversation_id = _extract_conversation_id(session_id)
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
@@ -130,7 +139,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
reasoning_content = ""
|
||||
message_id = self.message_obj.message_id
|
||||
request_id = str(message_id)
|
||||
conversation_id = self.session_id.split("!")[-1]
|
||||
conversation_id = _extract_conversation_id(self.session_id)
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
|
||||
@@ -75,6 +75,10 @@ class WebChatQueueMgr:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def list_back_request_ids(self, conversation_id: str) -> list[str]:
|
||||
"""List active back-queue request IDs for a conversation."""
|
||||
return list(self._conversation_back_requests.get(conversation_id, set()))
|
||||
|
||||
def has_queue(self, conversation_id: str) -> bool:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
@@ -388,6 +388,33 @@ class PluginManager:
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {module_name} 未载入")
|
||||
|
||||
def _cleanup_plugin_state(self, dir_name: str) -> None:
|
||||
plugin_root_name = "data.plugins."
|
||||
|
||||
# 清理 sys.modules
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith(f"{plugin_root_name}{dir_name}"):
|
||||
logger.info(f"清除了插件{dir_name}中的{key}模块")
|
||||
del sys.modules[key]
|
||||
|
||||
possible_paths = [
|
||||
f"{plugin_root_name}{dir_name}.main",
|
||||
f"{plugin_root_name}{dir_name}.{dir_name}",
|
||||
]
|
||||
|
||||
# 清理 handlers
|
||||
for path in possible_paths:
|
||||
handlers = star_handlers_registry.get_handlers_by_module_name(path)
|
||||
for handler in handlers:
|
||||
star_handlers_registry.remove(handler)
|
||||
logger.info(f"清理处理器: {handler.handler_name}")
|
||||
|
||||
# 清理工具
|
||||
for tool in list(llm_tools.func_list):
|
||||
if tool.handler_module_path in possible_paths:
|
||||
llm_tools.func_list.remove(tool)
|
||||
logger.info(f"清理工具: {tool.name}")
|
||||
|
||||
async def reload_failed_plugin(self, dir_name):
|
||||
"""
|
||||
重新加载未注册(加载失败)的插件
|
||||
@@ -398,17 +425,21 @@ class PluginManager:
|
||||
- success (bool): 重载是否成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
|
||||
async with self._pm_lock:
|
||||
if dir_name in self.failed_plugin_dict:
|
||||
success, error = await self.load(specified_dir_name=dir_name)
|
||||
if success:
|
||||
self.failed_plugin_dict.pop(dir_name, None)
|
||||
if not self.failed_plugin_dict:
|
||||
self.failed_plugin_info = ""
|
||||
return success, None
|
||||
else:
|
||||
return False, error
|
||||
return False, "插件不存在于失败列表中"
|
||||
if dir_name not in self.failed_plugin_dict:
|
||||
return False, "插件不存在于失败列表中"
|
||||
|
||||
self._cleanup_plugin_state(dir_name)
|
||||
|
||||
success, error = await self.load(specified_dir_name=dir_name)
|
||||
if success:
|
||||
self.failed_plugin_dict.pop(dir_name, None)
|
||||
if not self.failed_plugin_dict:
|
||||
self.failed_plugin_info = ""
|
||||
return success, None
|
||||
else:
|
||||
return False, error
|
||||
|
||||
async def reload(self, specified_plugin_name=None):
|
||||
"""重新加载插件
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
@@ -14,6 +13,12 @@ from astrbot.core import logger, sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_webchat_message_parts,
|
||||
create_attachment_part_from_existing_file,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.active_event_registry import active_event_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
@@ -166,83 +171,24 @@ class ChatRoute(Route):
|
||||
)
|
||||
|
||||
async def _build_user_message_parts(self, message: str | list) -> list[dict]:
|
||||
"""构建用户消息的部分列表
|
||||
|
||||
Args:
|
||||
message: 文本消息 (str) 或消息段列表 (list)
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if isinstance(message, list):
|
||||
for part in message:
|
||||
part_type = part.get("type")
|
||||
if part_type == "plain":
|
||||
parts.append({"type": "plain", "text": part.get("text", "")})
|
||||
elif part_type == "reply":
|
||||
parts.append(
|
||||
{
|
||||
"type": "reply",
|
||||
"message_id": part.get("message_id"),
|
||||
"selected_text": part.get("selected_text", ""),
|
||||
}
|
||||
)
|
||||
elif attachment_id := part.get("attachment_id"):
|
||||
attachment = await self.db.get_attachment_by_id(attachment_id)
|
||||
if attachment:
|
||||
parts.append(
|
||||
{
|
||||
"type": attachment.type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": os.path.basename(attachment.path),
|
||||
"path": attachment.path, # will be deleted
|
||||
}
|
||||
)
|
||||
return parts
|
||||
|
||||
if message:
|
||||
parts.append({"type": "plain", "text": message})
|
||||
|
||||
return parts
|
||||
"""构建用户消息的部分列表。"""
|
||||
return await build_webchat_message_parts(
|
||||
message,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
async def _create_attachment_from_file(
|
||||
self, filename: str, attach_type: str
|
||||
) -> dict | None:
|
||||
"""从本地文件创建 attachment 并返回消息部分
|
||||
|
||||
用于处理 bot 回复中的媒体文件
|
||||
|
||||
Args:
|
||||
filename: 存储的文件名
|
||||
attach_type: 附件类型 (image, record, file, video)
|
||||
"""
|
||||
basename = os.path.basename(filename)
|
||||
candidate_paths = [
|
||||
os.path.join(self.attachments_dir, basename),
|
||||
os.path.join(self.legacy_img_dir, basename),
|
||||
]
|
||||
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
# guess mime type
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# insert attachment
|
||||
attachment = await self.db.insert_attachment(
|
||||
path=file_path,
|
||||
type=attach_type,
|
||||
mime_type=mime_type,
|
||||
"""从本地文件创建 attachment 并返回消息部分。"""
|
||||
return await create_attachment_part_from_existing_file(
|
||||
filename,
|
||||
attach_type=attach_type,
|
||||
insert_attachment=self.db.insert_attachment,
|
||||
attachments_dir=self.attachments_dir,
|
||||
fallback_dirs=[self.legacy_img_dir],
|
||||
)
|
||||
if not attachment:
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": attach_type,
|
||||
"attachment_id": attachment.attachment_id,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
|
||||
def _extract_web_search_refs(
|
||||
self, accumulated_text: str, accumulated_parts: list
|
||||
@@ -356,21 +302,6 @@ class ChatRoute(Route):
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True)
|
||||
|
||||
# 检查消息是否为空
|
||||
if isinstance(message, list):
|
||||
has_content = any(
|
||||
part.get("type") in ("plain", "image", "record", "file", "video")
|
||||
for part in message
|
||||
)
|
||||
if not has_content:
|
||||
return (
|
||||
Response()
|
||||
.error("Message content is empty (reply only is not allowed)")
|
||||
.__dict__
|
||||
)
|
||||
elif not message:
|
||||
return Response().error("Message are both empty").__dict__
|
||||
|
||||
if not session_id:
|
||||
return Response().error("session_id is empty").__dict__
|
||||
|
||||
@@ -378,6 +309,12 @@ class ChatRoute(Route):
|
||||
|
||||
# 构建用户消息段(包含 path 用于传递给 adapter)
|
||||
message_parts = await self._build_user_message_parts(message)
|
||||
if not webchat_message_parts_have_content(message_parts):
|
||||
return (
|
||||
Response()
|
||||
.error("Message content is empty (reply only is not allowed)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
@@ -583,10 +520,7 @@ class ChatRoute(Route):
|
||||
),
|
||||
)
|
||||
|
||||
message_parts_for_storage = []
|
||||
for part in message_parts:
|
||||
part_copy = {k: v for k, v in part.items() if k != "path"}
|
||||
message_parts_for_storage.append(part_copy)
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import wave
|
||||
@@ -10,9 +11,16 @@ import jwt
|
||||
from quart import websocket
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_webchat_message_parts,
|
||||
create_attachment_part_from_existing_file,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
|
||||
|
||||
from .route import Route, RouteContext
|
||||
|
||||
@@ -30,6 +38,9 @@ class LiveChatSession:
|
||||
self.audio_frames: list[bytes] = []
|
||||
self.current_stamp: str | None = None
|
||||
self.temp_audio_path: str | None = None
|
||||
self.chat_subscriptions: dict[str, str] = {}
|
||||
self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
|
||||
self.ws_send_lock = asyncio.Lock()
|
||||
|
||||
def start_speaking(self, stamp: str) -> None:
|
||||
"""开始说话"""
|
||||
@@ -106,13 +117,26 @@ class LiveChatRoute(Route):
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.db = db
|
||||
self.plugin_manager = core_lifecycle.plugin_manager
|
||||
self.platform_history_mgr = core_lifecycle.platform_message_history_manager
|
||||
self.sessions: dict[str, LiveChatSession] = {}
|
||||
self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
|
||||
self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.attachments_dir, exist_ok=True)
|
||||
|
||||
# 注册 WebSocket 路由
|
||||
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
|
||||
self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
|
||||
|
||||
async def live_chat_ws(self) -> None:
|
||||
"""Live Chat WebSocket 处理器"""
|
||||
"""Legacy Live Chat WebSocket 处理器(默认 ct=live)"""
|
||||
await self._unified_ws_loop(force_ct="live")
|
||||
|
||||
async def unified_chat_ws(self) -> None:
|
||||
"""Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
|
||||
await self._unified_ws_loop(force_ct=None)
|
||||
|
||||
async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
|
||||
"""统一 WebSocket 循环"""
|
||||
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
|
||||
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
|
||||
token = websocket.args.get("token")
|
||||
@@ -140,7 +164,11 @@ class LiveChatRoute(Route):
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
await self._handle_message(live_session, message)
|
||||
ct = force_ct or message.get("ct", "live")
|
||||
if ct == "chat":
|
||||
await self._handle_chat_message(live_session, message)
|
||||
else:
|
||||
await self._handle_message(live_session, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
|
||||
@@ -148,10 +176,488 @@ class LiveChatRoute(Route):
|
||||
finally:
|
||||
# 清理会话
|
||||
if session_id in self.sessions:
|
||||
await self._cleanup_chat_subscriptions(live_session)
|
||||
live_session.cleanup()
|
||||
del self.sessions[session_id]
|
||||
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
|
||||
|
||||
async def _create_attachment_from_file(
|
||||
self, filename: str, attach_type: str
|
||||
) -> dict | None:
|
||||
"""从本地文件创建 attachment 并返回消息部分。"""
|
||||
return await create_attachment_part_from_existing_file(
|
||||
filename,
|
||||
attach_type=attach_type,
|
||||
insert_attachment=self.db.insert_attachment,
|
||||
attachments_dir=self.attachments_dir,
|
||||
fallback_dirs=[self.legacy_img_dir],
|
||||
)
|
||||
|
||||
def _extract_web_search_refs(
|
||||
self, accumulated_text: str, accumulated_parts: list
|
||||
) -> dict:
|
||||
"""从消息中提取 web_search 引用。"""
|
||||
supported = ["web_search_tavily", "web_search_bocha"]
|
||||
web_search_results = {}
|
||||
tool_call_parts = [
|
||||
p
|
||||
for p in accumulated_parts
|
||||
if p.get("type") == "tool_call" and p.get("tool_calls")
|
||||
]
|
||||
|
||||
for part in tool_call_parts:
|
||||
for tool_call in part["tool_calls"]:
|
||||
if tool_call.get("name") not in supported or not tool_call.get(
|
||||
"result"
|
||||
):
|
||||
continue
|
||||
try:
|
||||
result_data = json.loads(tool_call["result"])
|
||||
for item in result_data.get("results", []):
|
||||
if idx := item.get("index"):
|
||||
web_search_results[idx] = {
|
||||
"url": item.get("url"),
|
||||
"title": item.get("title"),
|
||||
"snippet": item.get("snippet"),
|
||||
}
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
if not web_search_results:
|
||||
return {}
|
||||
|
||||
ref_indices = {
|
||||
m.strip() for m in re.findall(r"<ref>(.*?)</ref>", accumulated_text)
|
||||
}
|
||||
|
||||
used_refs = []
|
||||
for ref_index in ref_indices:
|
||||
if ref_index not in web_search_results:
|
||||
continue
|
||||
payload = {"index": ref_index, **web_search_results[ref_index]}
|
||||
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
payload["favicon"] = favicon
|
||||
used_refs.append(payload)
|
||||
|
||||
return {"used": used_refs} if used_refs else {}
|
||||
|
||||
async def _save_bot_message(
|
||||
self,
|
||||
webchat_conv_id: str,
|
||||
text: str,
|
||||
media_parts: list,
|
||||
reasoning: str,
|
||||
agent_stats: dict,
|
||||
refs: dict,
|
||||
):
|
||||
"""保存 bot 消息到历史记录。"""
|
||||
bot_message_parts = []
|
||||
bot_message_parts.extend(media_parts)
|
||||
if text:
|
||||
bot_message_parts.append({"type": "plain", "text": text})
|
||||
|
||||
new_his = {"type": "bot", "message": bot_message_parts}
|
||||
if reasoning:
|
||||
new_his["reasoning"] = reasoning
|
||||
if agent_stats:
|
||||
new_his["agent_stats"] = agent_stats
|
||||
if refs:
|
||||
new_his["refs"] = refs
|
||||
|
||||
return await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=webchat_conv_id,
|
||||
content=new_his,
|
||||
sender_id="bot",
|
||||
sender_name="bot",
|
||||
)
|
||||
|
||||
async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
|
||||
async with session.ws_send_lock:
|
||||
await websocket.send_json(payload)
|
||||
|
||||
async def _forward_chat_subscription(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
chat_session_id: str,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id, chat_session_id
|
||||
)
|
||||
try:
|
||||
while True:
|
||||
result = await back_queue.get()
|
||||
if not result:
|
||||
continue
|
||||
await self._send_chat_payload(session, {"ct": "chat", **result})
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(request_id)
|
||||
if session.chat_subscriptions.get(chat_session_id) == request_id:
|
||||
session.chat_subscriptions.pop(chat_session_id, None)
|
||||
session.chat_subscription_tasks.pop(chat_session_id, None)
|
||||
|
||||
async def _ensure_chat_subscription(
|
||||
self,
|
||||
session: LiveChatSession,
|
||||
chat_session_id: str,
|
||||
) -> str:
|
||||
existing_request_id = session.chat_subscriptions.get(chat_session_id)
|
||||
existing_task = session.chat_subscription_tasks.get(chat_session_id)
|
||||
if existing_request_id and existing_task and not existing_task.done():
|
||||
return existing_request_id
|
||||
|
||||
request_id = f"ws_sub_{uuid.uuid4().hex}"
|
||||
session.chat_subscriptions[chat_session_id] = request_id
|
||||
task = asyncio.create_task(
|
||||
self._forward_chat_subscription(session, chat_session_id, request_id),
|
||||
name=f"chat_ws_sub_{chat_session_id}",
|
||||
)
|
||||
session.chat_subscription_tasks[chat_session_id] = task
|
||||
return request_id
|
||||
|
||||
async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
|
||||
tasks = list(session.chat_subscription_tasks.values())
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for request_id in list(session.chat_subscriptions.values()):
|
||||
webchat_queue_mgr.remove_back_queue(request_id)
|
||||
session.chat_subscriptions.clear()
|
||||
session.chat_subscription_tasks.clear()
|
||||
|
||||
async def _handle_chat_message(
|
||||
self, session: LiveChatSession, message: dict
|
||||
) -> None:
|
||||
"""处理 Chat Mode 消息(ct=chat)"""
|
||||
msg_type = message.get("t")
|
||||
|
||||
if msg_type == "bind":
|
||||
chat_session_id = message.get("session_id")
|
||||
if not isinstance(chat_session_id, str) or not chat_session_id:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "session_id is required",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
request_id = await self._ensure_chat_subscription(session, chat_session_id)
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "session_bound",
|
||||
"session_id": chat_session_id,
|
||||
"message_id": request_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if msg_type == "interrupt":
|
||||
session.should_interrupt = True
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "INTERRUPTED",
|
||||
"code": "INTERRUPTED",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if msg_type != "send":
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": f"Unsupported message type: {msg_type}",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if session.is_processing:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "Session is busy",
|
||||
"code": "PROCESSING_ERROR",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
payload = message.get("message")
|
||||
session_id = message.get("session_id") or session.session_id
|
||||
message_id = message.get("message_id") or str(uuid.uuid4())
|
||||
selected_provider = message.get("selected_provider")
|
||||
selected_model = message.get("selected_model")
|
||||
selected_stt_provider = message.get("selected_stt_provider")
|
||||
selected_tts_provider = message.get("selected_tts_provider")
|
||||
persona_prompt = message.get("persona_prompt")
|
||||
show_reasoning = message.get("show_reasoning")
|
||||
enable_streaming = message.get("enable_streaming", True)
|
||||
|
||||
if not isinstance(payload, list):
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "message must be list",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
message_parts = await self._build_chat_message_parts(payload)
|
||||
has_content = webchat_message_parts_have_content(message_parts)
|
||||
if not has_content:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": "Message content is empty",
|
||||
"code": "INVALID_MESSAGE_FORMAT",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
await self._ensure_chat_subscription(session, session_id)
|
||||
|
||||
session.is_processing = True
|
||||
session.should_interrupt = False
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||
|
||||
try:
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
session.username,
|
||||
session_id,
|
||||
{
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"selected_stt_provider": selected_stt_provider,
|
||||
"selected_tts_provider": selected_tts_provider,
|
||||
"persona_prompt": persona_prompt,
|
||||
"show_reasoning": show_reasoning,
|
||||
"enable_streaming": enable_streaming,
|
||||
"message_id": message_id,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=session_id,
|
||||
content={"type": "user", "message": message_parts_for_storage},
|
||||
sender_id=session.username,
|
||||
sender_name=session.username,
|
||||
)
|
||||
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
tool_calls = {}
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
session.should_interrupt = False
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if result.get("message_id") and result.get("message_id") != message_id:
|
||||
continue
|
||||
|
||||
result_text = result.get("data", "")
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
if chain_type == "agent_stats":
|
||||
try:
|
||||
parsed_agent_stats = json.loads(result_text)
|
||||
agent_stats = parsed_agent_stats
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "agent_stats",
|
||||
"data": parsed_agent_stats,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
outgoing = {"ct": "chat", **result}
|
||||
await self._send_chat_payload(session, outgoing)
|
||||
|
||||
if msg_type == "plain":
|
||||
if chain_type == "tool_call":
|
||||
try:
|
||||
tool_call = json.loads(result_text)
|
||||
tool_calls[tool_call.get("id")] = tool_call
|
||||
if accumulated_text:
|
||||
accumulated_parts.append(
|
||||
{"type": "plain", "text": accumulated_text}
|
||||
)
|
||||
accumulated_text = ""
|
||||
except Exception:
|
||||
pass
|
||||
elif chain_type == "tool_call_result":
|
||||
try:
|
||||
tcr = json.loads(result_text)
|
||||
tc_id = tcr.get("id")
|
||||
if tc_id in tool_calls:
|
||||
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||
accumulated_parts.append(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool_calls": [tool_calls[tc_id]],
|
||||
}
|
||||
)
|
||||
tool_calls.pop(tc_id, None)
|
||||
except Exception:
|
||||
pass
|
||||
elif chain_type == "reasoning":
|
||||
accumulated_reasoning += result_text
|
||||
elif streaming:
|
||||
accumulated_text += result_text
|
||||
else:
|
||||
accumulated_text = result_text
|
||||
elif msg_type == "image":
|
||||
filename = str(result_text).replace("[IMAGE]", "")
|
||||
part = await self._create_attachment_from_file(filename, "image")
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "record":
|
||||
filename = str(result_text).replace("[RECORD]", "")
|
||||
part = await self._create_attachment_from_file(filename, "record")
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "file":
|
||||
filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
|
||||
part = await self._create_attachment_from_file(filename, "file")
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "video":
|
||||
filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
|
||||
part = await self._create_attachment_from_file(filename, "video")
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
|
||||
should_save = False
|
||||
if msg_type == "end":
|
||||
should_save = bool(
|
||||
accumulated_parts
|
||||
or accumulated_text
|
||||
or accumulated_reasoning
|
||||
or refs
|
||||
or agent_stats
|
||||
)
|
||||
elif (streaming and msg_type == "complete") or not streaming:
|
||||
if chain_type not in (
|
||||
"tool_call",
|
||||
"tool_call_result",
|
||||
"agent_stats",
|
||||
):
|
||||
should_save = True
|
||||
|
||||
if should_save:
|
||||
try:
|
||||
refs = self._extract_web_search_refs(
|
||||
accumulated_text,
|
||||
accumulated_parts,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[Live Chat] Failed to extract web search refs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
saved_record = await self._save_bot_message(
|
||||
session_id,
|
||||
accumulated_text,
|
||||
accumulated_parts,
|
||||
accumulated_reasoning,
|
||||
agent_stats,
|
||||
refs,
|
||||
)
|
||||
if saved_record:
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": saved_record.created_at.astimezone().isoformat(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
|
||||
if msg_type == "end":
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
|
||||
await self._send_chat_payload(
|
||||
session,
|
||||
{
|
||||
"ct": "chat",
|
||||
"t": "error",
|
||||
"data": f"处理失败: {str(e)}",
|
||||
"code": "PROCESSING_ERROR",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
session.is_processing = False
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
|
||||
"""构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
|
||||
return await build_webchat_message_parts(
|
||||
message,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
|
||||
"""处理 WebSocket 消息"""
|
||||
msg_type = message.get("t") # 使用 t 代替 type
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from quart import g, request
|
||||
from quart import g, request, websocket
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.message_session import MessageSesion
|
||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||
build_message_chain_from_payload,
|
||||
strip_message_parts_path_fields,
|
||||
webchat_message_parts_have_content,
|
||||
)
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
|
||||
from .api_key import ALL_OPEN_API_SCOPES
|
||||
from .chat import ChatRoute
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -37,6 +44,7 @@ class OpenApiRoute(Route):
|
||||
"/v1/im/bots": ("GET", self.get_bots),
|
||||
}
|
||||
self.register_routes()
|
||||
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_open_username(
|
||||
@@ -181,6 +189,348 @@ class OpenApiRoute(Route):
|
||||
finally:
|
||||
g.username = original_username
|
||||
|
||||
@staticmethod
|
||||
def _extract_ws_api_key() -> str | None:
|
||||
if key := websocket.args.get("api_key"):
|
||||
return key.strip()
|
||||
if key := websocket.args.get("key"):
|
||||
return key.strip()
|
||||
if key := websocket.headers.get("X-API-Key"):
|
||||
return key.strip()
|
||||
|
||||
auth_header = websocket.headers.get("Authorization", "").strip()
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header.removeprefix("Bearer ").strip()
|
||||
if auth_header.startswith("ApiKey "):
|
||||
return auth_header.removeprefix("ApiKey ").strip()
|
||||
return None
|
||||
|
||||
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
|
||||
raw_key = self._extract_ws_api_key()
|
||||
if not raw_key:
|
||||
return False, "Missing API key"
|
||||
|
||||
key_hash = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
raw_key.encode("utf-8"),
|
||||
b"astrbot_api_key",
|
||||
100_000,
|
||||
).hex()
|
||||
api_key = await self.db.get_active_api_key_by_hash(key_hash)
|
||||
if not api_key:
|
||||
return False, "Invalid API key"
|
||||
|
||||
if isinstance(api_key.scopes, list):
|
||||
scopes = api_key.scopes
|
||||
else:
|
||||
scopes = list(ALL_OPEN_API_SCOPES)
|
||||
|
||||
if "*" not in scopes and "chat" not in scopes:
|
||||
return False, "Insufficient API key scope"
|
||||
|
||||
await self.db.touch_api_key(api_key.key_id)
|
||||
return True, None
|
||||
|
||||
async def _send_chat_ws_error(self, message: str, code: str) -> None:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"code": code,
|
||||
"data": message,
|
||||
}
|
||||
)
|
||||
|
||||
async def _update_session_config_route(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
session_id: str,
|
||||
config_id: str | None,
|
||||
) -> str | None:
|
||||
if not config_id:
|
||||
return None
|
||||
|
||||
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
|
||||
try:
|
||||
if config_id == "default":
|
||||
await self.core_lifecycle.umop_config_router.delete_route(umo)
|
||||
else:
|
||||
await self.core_lifecycle.umop_config_router.update_route(
|
||||
umo, config_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update chat config route for %s with %s: %s",
|
||||
umo,
|
||||
config_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return f"Failed to update chat config route: {e}"
|
||||
return None
|
||||
|
||||
async def _handle_chat_ws_send(self, post_data: dict) -> None:
|
||||
effective_username, username_err = self._resolve_open_username(
|
||||
post_data.get("username")
|
||||
)
|
||||
if username_err or not effective_username:
|
||||
await self._send_chat_ws_error(
|
||||
username_err or "Invalid username", "BAD_USER"
|
||||
)
|
||||
return
|
||||
|
||||
message = post_data.get("message")
|
||||
if message is None:
|
||||
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
|
||||
return
|
||||
|
||||
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
|
||||
if not session_id:
|
||||
session_id = str(uuid4())
|
||||
|
||||
ensure_session_err = await self._ensure_chat_session(
|
||||
effective_username,
|
||||
session_id,
|
||||
)
|
||||
if ensure_session_err:
|
||||
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
|
||||
return
|
||||
|
||||
config_id, resolve_err = self._resolve_chat_config_id(post_data)
|
||||
if resolve_err:
|
||||
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
|
||||
return
|
||||
|
||||
config_err = await self._update_session_config_route(
|
||||
username=effective_username,
|
||||
session_id=session_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
if config_err:
|
||||
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
|
||||
return
|
||||
|
||||
message_parts = await self.chat_route._build_user_message_parts(message)
|
||||
if not webchat_message_parts_have_content(message_parts):
|
||||
await self._send_chat_ws_error(
|
||||
"Message content is empty (reply only is not allowed)",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
return
|
||||
|
||||
message_id = str(post_data.get("message_id") or uuid4())
|
||||
selected_provider = post_data.get("selected_provider")
|
||||
selected_model = post_data.get("selected_model")
|
||||
enable_streaming = post_data.get("enable_streaming", True)
|
||||
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||
try:
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||
await chat_queue.put(
|
||||
(
|
||||
effective_username,
|
||||
session_id,
|
||||
{
|
||||
"message": message_parts,
|
||||
"selected_provider": selected_provider,
|
||||
"selected_model": selected_model,
|
||||
"enable_streaming": enable_streaming,
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||
await self.chat_route.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
user_id=session_id,
|
||||
content={"type": "user", "message": message_parts_for_storage},
|
||||
sender_id=effective_username,
|
||||
sender_name=effective_username,
|
||||
)
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "session_id",
|
||||
"data": None,
|
||||
"session_id": session_id,
|
||||
"message_id": message_id,
|
||||
}
|
||||
)
|
||||
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
tool_calls = {}
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if "message_id" in result and result["message_id"] != message_id:
|
||||
logger.warning("openapi ws stream message_id mismatch")
|
||||
continue
|
||||
|
||||
result_text = result.get("data", "")
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
|
||||
if chain_type == "agent_stats":
|
||||
try:
|
||||
stats_info = {
|
||||
"type": "agent_stats",
|
||||
"data": json.loads(result_text),
|
||||
}
|
||||
await websocket.send_json(stats_info)
|
||||
agent_stats = stats_info["data"]
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
await websocket.send_json(result)
|
||||
|
||||
if msg_type == "plain":
|
||||
if chain_type == "tool_call":
|
||||
tool_call = json.loads(result_text)
|
||||
tool_calls[tool_call.get("id")] = tool_call
|
||||
if accumulated_text:
|
||||
accumulated_parts.append(
|
||||
{"type": "plain", "text": accumulated_text}
|
||||
)
|
||||
accumulated_text = ""
|
||||
elif chain_type == "tool_call_result":
|
||||
tcr = json.loads(result_text)
|
||||
tc_id = tcr.get("id")
|
||||
if tc_id in tool_calls:
|
||||
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||
accumulated_parts.append(
|
||||
{"type": "tool_call", "tool_calls": [tool_calls[tc_id]]}
|
||||
)
|
||||
tool_calls.pop(tc_id, None)
|
||||
elif chain_type == "reasoning":
|
||||
accumulated_reasoning += result_text
|
||||
elif streaming:
|
||||
accumulated_text += result_text
|
||||
else:
|
||||
accumulated_text = result_text
|
||||
elif msg_type == "image":
|
||||
filename = str(result_text).replace("[IMAGE]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "image"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "record":
|
||||
filename = str(result_text).replace("[RECORD]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "record"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "file":
|
||||
filename = str(result_text).replace("[FILE]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "file"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
elif msg_type == "video":
|
||||
filename = str(result_text).replace("[VIDEO]", "")
|
||||
part = await self.chat_route._create_attachment_from_file(
|
||||
filename, "video"
|
||||
)
|
||||
if part:
|
||||
accumulated_parts.append(part)
|
||||
|
||||
if msg_type == "end":
|
||||
break
|
||||
if (streaming and msg_type == "complete") or not streaming:
|
||||
if chain_type in ("tool_call", "tool_call_result"):
|
||||
continue
|
||||
try:
|
||||
refs = self.chat_route._extract_web_search_refs(
|
||||
accumulated_text,
|
||||
accumulated_parts,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Open API WS failed to extract web search refs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
saved_record = await self.chat_route._save_bot_message(
|
||||
session_id,
|
||||
accumulated_text,
|
||||
accumulated_parts,
|
||||
accumulated_reasoning,
|
||||
agent_stats,
|
||||
refs,
|
||||
)
|
||||
if saved_record:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "message_saved",
|
||||
"data": {
|
||||
"id": saved_record.id,
|
||||
"created_at": saved_record.created_at.astimezone().isoformat(),
|
||||
},
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
agent_stats = {}
|
||||
refs = {}
|
||||
except Exception as e:
|
||||
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
|
||||
await self._send_chat_ws_error(
|
||||
f"Failed to process message: {e}", "PROCESSING_ERROR"
|
||||
)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
async def chat_ws(self) -> None:
|
||||
authed, auth_err = await self._authenticate_chat_ws_api_key()
|
||||
if not authed:
|
||||
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
|
||||
await websocket.close(1008, auth_err or "Unauthorized")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
if not isinstance(message, dict):
|
||||
await self._send_chat_ws_error(
|
||||
"message must be an object",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
msg_type = message.get("t", "send")
|
||||
if msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
continue
|
||||
if msg_type != "send":
|
||||
await self._send_chat_ws_error(
|
||||
f"Unsupported message type: {msg_type}",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
await self._handle_chat_ws_send(message)
|
||||
except Exception as e:
|
||||
logger.debug("Open API WS connection closed: %s", e)
|
||||
|
||||
async def upload_file(self):
|
||||
return await self.chat_route.post_file()
|
||||
|
||||
@@ -254,83 +604,12 @@ class OpenApiRoute(Route):
|
||||
async def _build_message_chain_from_payload(
|
||||
self,
|
||||
message_payload: str | list,
|
||||
) -> MessageChain:
|
||||
if isinstance(message_payload, str):
|
||||
text = message_payload.strip()
|
||||
if not text:
|
||||
raise ValueError("Message is empty")
|
||||
return MessageChain(chain=[Plain(text=text)])
|
||||
|
||||
if not isinstance(message_payload, list):
|
||||
raise ValueError("message must be a string or list")
|
||||
|
||||
components = []
|
||||
has_content = False
|
||||
|
||||
for part in message_payload:
|
||||
if not isinstance(part, dict):
|
||||
raise ValueError("message part must be an object")
|
||||
|
||||
part_type = str(part.get("type", "")).strip()
|
||||
if part_type == "plain":
|
||||
text = str(part.get("text", ""))
|
||||
if text:
|
||||
has_content = True
|
||||
components.append(Plain(text=text))
|
||||
continue
|
||||
|
||||
if part_type == "reply":
|
||||
message_id = part.get("message_id")
|
||||
if message_id is None:
|
||||
raise ValueError("reply part missing message_id")
|
||||
components.append(
|
||||
Reply(
|
||||
id=str(message_id),
|
||||
message_str=str(part.get("selected_text", "")),
|
||||
chain=[],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if part_type not in {"image", "record", "file", "video"}:
|
||||
raise ValueError(f"unsupported message part type: {part_type}")
|
||||
|
||||
has_content = True
|
||||
file_path: Path | None = None
|
||||
resolved_type = part_type
|
||||
filename = str(part.get("filename", "")).strip()
|
||||
|
||||
attachment_id = part.get("attachment_id")
|
||||
if attachment_id:
|
||||
attachment = await self.db.get_attachment_by_id(str(attachment_id))
|
||||
if not attachment:
|
||||
raise ValueError(f"attachment not found: {attachment_id}")
|
||||
file_path = Path(attachment.path)
|
||||
resolved_type = attachment.type
|
||||
if not filename:
|
||||
filename = file_path.name
|
||||
else:
|
||||
raise ValueError(f"{part_type} part missing attachment_id")
|
||||
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"file not found: {file_path!s}")
|
||||
|
||||
file_path_str = str(file_path.resolve())
|
||||
if resolved_type == "image":
|
||||
components.append(Image.fromFileSystem(file_path_str))
|
||||
elif resolved_type == "record":
|
||||
components.append(Record.fromFileSystem(file_path_str))
|
||||
elif resolved_type == "video":
|
||||
components.append(Video.fromFileSystem(file_path_str))
|
||||
else:
|
||||
components.append(
|
||||
File(name=filename or file_path.name, file=file_path_str)
|
||||
)
|
||||
|
||||
if not components or not has_content:
|
||||
raise ValueError("Message content is empty (reply only is not allowed)")
|
||||
|
||||
return MessageChain(chain=components)
|
||||
):
|
||||
return await build_message_chain_from_payload(
|
||||
message_payload,
|
||||
get_attachment_by_id=self.db.get_attachment_by_id,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
async def send_message(self):
|
||||
post_data = await request.json or {}
|
||||
|
||||
@@ -204,6 +204,10 @@ class AstrBotDashboard:
|
||||
|
||||
@staticmethod
|
||||
def _extract_raw_api_key() -> str | None:
|
||||
if key := request.args.get("api_key"):
|
||||
return key.strip()
|
||||
if key := request.args.get("key"):
|
||||
return key.strip()
|
||||
if key := request.headers.get("X-API-Key"):
|
||||
return key.strip()
|
||||
auth_header = request.headers.get("Authorization", "").strip()
|
||||
@@ -217,6 +221,7 @@ class AstrBotDashboard:
|
||||
def _get_required_open_api_scope(path: str) -> str | None:
|
||||
scope_map = {
|
||||
"/api/v1/chat": "chat",
|
||||
"/api/v1/chat/ws": "chat",
|
||||
"/api/v1/chat/sessions": "chat",
|
||||
"/api/v1/configs": "config",
|
||||
"/api/v1/file": "file",
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
## What's Changed
|
||||
|
||||
### 新增
|
||||
|
||||
- 新增桌面端通用更新桥接能力,便于接入客户端内更新流程 ([#5424](https://github.com/AstrBotDevs/AstrBot/issues/5424))。
|
||||
|
||||
### 修复
|
||||
|
||||
- 修复新增平台对话框中 Line 适配器未显示的问题。
|
||||
- 修复 Telegram 无法发送 Video 的问题 ([#5430](https://github.com/AstrBotDevs/AstrBot/issues/5430))。
|
||||
- 修复创建 embedding provider 时无法自动识别向量维度的问题 ([#5442](https://github.com/AstrBotDevs/AstrBot/issues/5442))。
|
||||
- 修复 QQ 官方平台发送媒体消息时 markdown 字段未清理的问题 ([#5445](https://github.com/AstrBotDevs/AstrBot/issues/5445))。
|
||||
- 修复上下文管理策略 -> 上下文截断时 tool call / response 配对丢失的问题 ([#5417](https://github.com/AstrBotDevs/AstrBot/issues/5417))。
|
||||
- 修复会话更新时 `persona_id` 被覆盖的问题,并增强 persona 解析逻辑。
|
||||
- 修复 WebUI 中 GitHub 代理地址显示异常的问题 ([#5438](https://github.com/AstrBotDevs/AstrBot/issues/5438))。
|
||||
- 修复设置页新建开发者 API Key 后复制失败的问题 ([#5439](https://github.com/AstrBotDevs/AstrBot/issues/5439))。
|
||||
- 修复 Telegram 语音消息格式与 OpenAI STT 兼容性问题(使用 OGG) ([#5389](https://github.com/AstrBotDevs/AstrBot/issues/5389))。
|
||||
|
||||
### 优化
|
||||
|
||||
- 优化知识库检索流程,改为批量查询元数据,修复 N+1 查询性能问题 ([#5463](https://github.com/AstrBotDevs/AstrBot/issues/5463))。
|
||||
- 优化 Cron 未来任务执行的会话隔离能力,提升并发稳定性。
|
||||
- 优化 WebUI 插件页的交互。
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### New Features
|
||||
|
||||
- Added `useExtensionPage` composable for unified plugin extension page state management.
|
||||
- Added a generic desktop app updater bridge to support in-app update workflows ([#5424](https://github.com/AstrBotDevs/AstrBot/issues/5424)).
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fixed the Line adapter not appearing in the "Add Platform" dialog.
|
||||
- Fixed Telegram video sending issues ([#5430](https://github.com/AstrBotDevs/AstrBot/issues/5430)).
|
||||
- Fixed Pyright static type checking errors ([#5437](https://github.com/AstrBotDevs/AstrBot/issues/5437)).
|
||||
- Fixed embedding dimension auto-detection when creating embedding providers ([#5442](https://github.com/AstrBotDevs/AstrBot/issues/5442)).
|
||||
- Fixed stale markdown fields when sending media messages via QQ Official Platform ([#5445](https://github.com/AstrBotDevs/AstrBot/issues/5445)).
|
||||
- Fixed tool call/response pairing loss during context truncation ([#5417](https://github.com/AstrBotDevs/AstrBot/issues/5417)).
|
||||
- Fixed `persona_id` being overwritten during conversation updates and improved persona resolution logic.
|
||||
- Fixed incorrect GitHub proxy display in WebUI ([#5438](https://github.com/AstrBotDevs/AstrBot/issues/5438)).
|
||||
- Fixed API key copy failure after creating a new key in settings ([#5439](https://github.com/AstrBotDevs/AstrBot/issues/5439)).
|
||||
- Fixed Telegram voice format compatibility with OpenAI STT by using OGG ([#5389](https://github.com/AstrBotDevs/AstrBot/issues/5389)).
|
||||
|
||||
### Improvements
|
||||
|
||||
- Improved knowledge base retrieval by batching metadata queries to eliminate the N+1 query pattern ([#5463](https://github.com/AstrBotDevs/AstrBot/issues/5463)).
|
||||
- Improved session isolation for future cron tasks to increase stability under concurrency.
|
||||
- Improved WebUI plugin page interactions.
|
||||
@@ -10,6 +10,7 @@
|
||||
:selectedSessions="selectedSessions"
|
||||
:currSessionId="currSessionId"
|
||||
:selectedProjectId="selectedProjectId"
|
||||
:transportMode="transportMode"
|
||||
:isDark="isDark"
|
||||
:chatboxMode="chatboxMode"
|
||||
:isMobile="isMobile"
|
||||
@@ -26,6 +27,7 @@
|
||||
@createProject="showCreateProjectDialog"
|
||||
@editProject="showEditProjectDialog"
|
||||
@deleteProject="handleDeleteProject"
|
||||
@updateTransportMode="setTransportMode"
|
||||
/>
|
||||
|
||||
<!-- 右侧聊天内容区域 -->
|
||||
@@ -301,11 +303,14 @@ const {
|
||||
isStreaming,
|
||||
isConvRunning,
|
||||
enableStreaming,
|
||||
transportMode,
|
||||
currentSessionProject,
|
||||
getSessionMessages: getSessionMsg,
|
||||
sendMessage: sendMsg,
|
||||
stopMessage: stopMsg,
|
||||
toggleStreaming
|
||||
toggleStreaming,
|
||||
setTransportMode,
|
||||
cleanupTransport
|
||||
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
||||
|
||||
// 组件引用
|
||||
@@ -695,6 +700,7 @@ onMounted(() => {
|
||||
onBeforeUnmount(() => {
|
||||
window.removeEventListener('resize', checkMobile);
|
||||
cleanupMediaCache();
|
||||
cleanupTransport();
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
@@ -117,6 +117,27 @@
|
||||
<v-list-item-title>{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<!-- 通信传输模式 -->
|
||||
<v-list-item class="styled-menu-item">
|
||||
<template v-slot:prepend>
|
||||
<v-icon>mdi-lan-connect</v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ tm('transport.title') }}</v-list-item-title>
|
||||
<template v-slot:append>
|
||||
<v-select
|
||||
:model-value="transportMode"
|
||||
:items="transportOptions"
|
||||
item-title="label"
|
||||
item-value="value"
|
||||
density="compact"
|
||||
variant="underlined"
|
||||
hide-details
|
||||
class="transport-mode-select"
|
||||
@update:model-value="handleTransportModeChange"
|
||||
/>
|
||||
</template>
|
||||
</v-list-item>
|
||||
|
||||
<!-- 全屏/退出全屏 -->
|
||||
<v-list-item class="styled-menu-item" @click="$emit('toggleFullscreen')">
|
||||
<template v-slot:prepend>
|
||||
@@ -156,6 +177,7 @@ interface Props {
|
||||
selectedSessions: string[];
|
||||
currSessionId: string;
|
||||
selectedProjectId?: string | null;
|
||||
transportMode: 'sse' | 'websocket';
|
||||
isDark: boolean;
|
||||
chatboxMode: boolean;
|
||||
isMobile: boolean;
|
||||
@@ -179,6 +201,7 @@ const emit = defineEmits<{
|
||||
createProject: [];
|
||||
editProject: [project: Project];
|
||||
deleteProject: [projectId: string];
|
||||
updateTransportMode: [mode: 'sse' | 'websocket'];
|
||||
}>();
|
||||
|
||||
const { t } = useI18n();
|
||||
@@ -188,6 +211,10 @@ const confirmDialog = useConfirmDialog();
|
||||
|
||||
const sidebarCollapsed = ref(true);
|
||||
const showProviderConfigDialog = ref(false);
|
||||
const transportOptions = [
|
||||
{ label: tm('transport.sse'), value: 'sse' as const },
|
||||
{ label: tm('transport.websocket'), value: 'websocket' as const }
|
||||
];
|
||||
|
||||
// 从 localStorage 读取侧边栏折叠状态
|
||||
const savedCollapsedState = localStorage.getItem('sidebarCollapsed');
|
||||
@@ -209,6 +236,12 @@ async function handleDeleteConversation(session: Session) {
|
||||
emit('deleteConversation', session.session_id);
|
||||
}
|
||||
}
|
||||
|
||||
function handleTransportModeChange(mode: string | null) {
|
||||
if (mode === 'sse' || mode === 'websocket') {
|
||||
emit('updateTransportMode', mode);
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
@@ -361,4 +394,8 @@ async function handleDeleteConversation(session: Session) {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.transport-mode-select {
|
||||
min-width: 120px;
|
||||
}
|
||||
</style>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -81,9 +81,16 @@
|
||||
"disabled": "Streaming disabled",
|
||||
"on": "Stream",
|
||||
"off": "Normal"
|
||||
}, "config": {
|
||||
},
|
||||
"transport": {
|
||||
"title": "Transport Mode",
|
||||
"sse": "SSE",
|
||||
"websocket": "WebSocket"
|
||||
},
|
||||
"config": {
|
||||
"title": "Config"
|
||||
}, "reasoning": {
|
||||
},
|
||||
"reasoning": {
|
||||
"thinking": "Thinking Process"
|
||||
},
|
||||
"reply": {
|
||||
|
||||
@@ -82,6 +82,11 @@
|
||||
"on": "流式",
|
||||
"off": "普通"
|
||||
},
|
||||
"transport": {
|
||||
"title": "通信传输模式",
|
||||
"sse": "SSE",
|
||||
"websocket": "WebSocket"
|
||||
},
|
||||
"config": {
|
||||
"title": "配置文件"
|
||||
},
|
||||
|
||||
@@ -61,6 +61,7 @@ export function getTutorialLink(platformType) {
|
||||
"vocechat": "https://docs.astrbot.app/deploy/platform/vocechat.html",
|
||||
"satori": "https://docs.astrbot.app/deploy/platform/satori/llonebot.html",
|
||||
"misskey": "https://docs.astrbot.app/deploy/platform/misskey.html",
|
||||
"line": "https://docs.astrbot.app/deploy/platform/line.html",
|
||||
}
|
||||
return tutorialMap[platformType] || "https://docs.astrbot.app";
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ export default defineConfig({
|
||||
'/api': {
|
||||
target: 'http://127.0.0.1:6185/',
|
||||
changeOrigin: true,
|
||||
ws: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.18.2"
|
||||
version = "4.18.3"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
@@ -149,6 +149,20 @@ class MockHooks(BaseAgentRunHooks):
|
||||
self.agent_done_called = True
|
||||
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, umo: str, sender_id: str):
|
||||
self.unified_msg_origin = umo
|
||||
self._sender_id = sender_id
|
||||
|
||||
def get_sender_id(self):
|
||||
return self._sender_id
|
||||
|
||||
|
||||
class MockAgentContext:
|
||||
def __init__(self, event):
|
||||
self.event = event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
return MockProvider()
|
||||
@@ -451,6 +465,76 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message(
|
||||
assert runner.run_context.messages[-1].role == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_result_injects_follow_up_notice(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
mock_event = MockEvent("test:FriendMessage:follow_up", "u1")
|
||||
run_context = ContextWrapper(context=MockAgentContext(mock_event))
|
||||
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=run_context,
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
ticket1 = runner.follow_up(
|
||||
message_text="follow up 1",
|
||||
)
|
||||
ticket2 = runner.follow_up(
|
||||
message_text="follow up 2",
|
||||
)
|
||||
assert ticket1 is not None
|
||||
assert ticket2 is not None
|
||||
|
||||
async for _ in runner.step():
|
||||
pass
|
||||
|
||||
assert provider_request.tool_calls_result is not None
|
||||
assert isinstance(provider_request.tool_calls_result, list)
|
||||
assert provider_request.tool_calls_result
|
||||
tool_result = str(
|
||||
provider_request.tool_calls_result[0].tool_calls_result[0].content
|
||||
)
|
||||
assert "SYSTEM NOTICE" in tool_result
|
||||
assert "1. follow up 1" in tool_result
|
||||
assert "2. follow up 2" in tool_result
|
||||
assert ticket1.resolved.is_set() is True
|
||||
assert ticket2.resolved.is_set() is True
|
||||
assert ticket1.consumed is True
|
||||
assert ticket2.consumed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follow_up_ticket_not_consumed_when_no_next_tool_call(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
mock_provider.should_call_tools = False
|
||||
mock_event = MockEvent("test:FriendMessage:follow_up_no_tool", "u1")
|
||||
run_context = ContextWrapper(context=MockAgentContext(mock_event))
|
||||
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=run_context,
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
ticket = runner.follow_up(message_text="follow up without tool")
|
||||
assert ticket is not None
|
||||
|
||||
async for _ in runner.step():
|
||||
pass
|
||||
|
||||
assert ticket.resolved.is_set() is True
|
||||
assert ticket.consumed is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user