Compare commits
4 Commits
dev
...
feat/stop-agent
| Author | SHA1 | Date | |
|---|---|---|---|
| 90602dd97f | |||
| 1ab2fa1788 | |||
| 650a092cc1 | |||
| 6240125440 |
@@ -102,6 +102,30 @@ class ConversationCommands:
|
|||||||
|
|
||||||
message.set_result(MessageEventResult().message(ret))
|
message.set_result(MessageEventResult().message(ret))
|
||||||
|
|
||||||
|
async def stop(self, message: AstrMessageEvent) -> None:
|
||||||
|
"""停止当前会话正在运行的 Agent"""
|
||||||
|
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||||
|
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||||
|
umo = message.unified_msg_origin
|
||||||
|
|
||||||
|
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||||
|
stopped_count = active_event_registry.stop_all(umo, exclude=message)
|
||||||
|
else:
|
||||||
|
stopped_count = active_event_registry.request_agent_stop_all(
|
||||||
|
umo,
|
||||||
|
exclude=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stopped_count > 0:
|
||||||
|
message.set_result(
|
||||||
|
MessageEventResult().message(
|
||||||
|
f"已请求停止 {stopped_count} 个运行中的任务。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
|
||||||
|
|
||||||
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
|
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||||
"""查看对话记录"""
|
"""查看对话记录"""
|
||||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||||
|
|||||||
@@ -132,6 +132,11 @@ class Main(star.Star):
|
|||||||
"""重置 LLM 会话"""
|
"""重置 LLM 会话"""
|
||||||
await self.conversation_c.reset(message)
|
await self.conversation_c.reset(message)
|
||||||
|
|
||||||
|
@filter.command("stop")
|
||||||
|
async def stop(self, message: AstrMessageEvent) -> None:
|
||||||
|
"""停止当前会话中正在运行的 Agent"""
|
||||||
|
await self.conversation_c.stop(message)
|
||||||
|
|
||||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||||
@filter.command("model")
|
@filter.command("model")
|
||||||
async def model_ls(
|
async def model_ls(
|
||||||
|
|||||||
@@ -137,6 +137,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
self.tool_executor = tool_executor
|
self.tool_executor = tool_executor
|
||||||
self.agent_hooks = agent_hooks
|
self.agent_hooks = agent_hooks
|
||||||
self.run_context = run_context
|
self.run_context = run_context
|
||||||
|
self._stop_requested = False
|
||||||
|
self._aborted = False
|
||||||
|
|
||||||
# These two are used for tool schema mode handling
|
# These two are used for tool schema mode handling
|
||||||
# We now have two modes:
|
# We now have two modes:
|
||||||
@@ -328,6 +330,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if self._stop_requested:
|
||||||
|
llm_resp_result = LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
|
||||||
|
reasoning_content=llm_response.reasoning_content,
|
||||||
|
reasoning_signature=llm_response.reasoning_signature,
|
||||||
|
)
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
llm_resp_result = llm_response
|
llm_resp_result = llm_response
|
||||||
|
|
||||||
@@ -339,6 +349,48 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
break # got final response
|
break # got final response
|
||||||
|
|
||||||
if not llm_resp_result:
|
if not llm_resp_result:
|
||||||
|
if self._stop_requested:
|
||||||
|
llm_resp_result = LLMResponse(role="assistant", completion_text="")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._stop_requested:
|
||||||
|
logger.info("Agent execution was requested to stop by user.")
|
||||||
|
llm_resp = llm_resp_result
|
||||||
|
if llm_resp.role != "assistant":
|
||||||
|
llm_resp = LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
|
||||||
|
)
|
||||||
|
self.final_llm_resp = llm_resp
|
||||||
|
self._aborted = True
|
||||||
|
self._transition_state(AgentState.DONE)
|
||||||
|
self.stats.end_time = time.time()
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||||
|
parts.append(
|
||||||
|
ThinkPart(
|
||||||
|
think=llm_resp.reasoning_content,
|
||||||
|
encrypted=llm_resp.reasoning_signature,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if llm_resp.completion_text:
|
||||||
|
parts.append(TextPart(text=llm_resp.completion_text))
|
||||||
|
if parts:
|
||||||
|
self.run_context.messages.append(
|
||||||
|
Message(role="assistant", content=parts)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||||
|
|
||||||
|
yield AgentResponse(
|
||||||
|
type="aborted",
|
||||||
|
data=AgentResponseData(chain=MessageChain(type="aborted")),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 处理 LLM 响应
|
# 处理 LLM 响应
|
||||||
@@ -848,5 +900,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
"""检查 Agent 是否已完成工作"""
|
"""检查 Agent 是否已完成工作"""
|
||||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||||
|
|
||||||
|
def request_stop(self) -> None:
|
||||||
|
self._stop_requested = True
|
||||||
|
|
||||||
|
def was_aborted(self) -> bool:
|
||||||
|
return self._aborted
|
||||||
|
|
||||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||||
return self.final_llm_resp
|
return self.final_llm_resp
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ from astrbot.core.provider.provider import TTSProvider
|
|||||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||||
|
|
||||||
|
|
||||||
|
def _should_stop_agent(astr_event) -> bool:
|
||||||
|
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
|
||||||
|
|
||||||
|
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
agent_runner: AgentRunner,
|
agent_runner: AgentRunner,
|
||||||
max_step: int = 30,
|
max_step: int = 30,
|
||||||
@@ -48,10 +52,28 @@ async def run_agent(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stop_watcher = asyncio.create_task(
|
||||||
|
_watch_agent_stop_signal(agent_runner, astr_event),
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
async for resp in agent_runner.step():
|
async for resp in agent_runner.step():
|
||||||
if astr_event.is_stopped():
|
if _should_stop_agent(astr_event):
|
||||||
|
agent_runner.request_stop()
|
||||||
|
|
||||||
|
if resp.type == "aborted":
|
||||||
|
if not stop_watcher.done():
|
||||||
|
stop_watcher.cancel()
|
||||||
|
try:
|
||||||
|
await stop_watcher
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
astr_event.set_extra("agent_user_aborted", True)
|
||||||
|
astr_event.set_extra("agent_stop_requested", False)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if _should_stop_agent(astr_event):
|
||||||
|
continue
|
||||||
|
|
||||||
if resp.type == "tool_call_result":
|
if resp.type == "tool_call_result":
|
||||||
msg_chain = resp.data["chain"]
|
msg_chain = resp.data["chain"]
|
||||||
|
|
||||||
@@ -120,6 +142,12 @@ async def run_agent(
|
|||||||
# display the reasoning content only when configured
|
# display the reasoning content only when configured
|
||||||
continue
|
continue
|
||||||
yield resp.data["chain"] # MessageChain
|
yield resp.data["chain"] # MessageChain
|
||||||
|
if not stop_watcher.done():
|
||||||
|
stop_watcher.cancel()
|
||||||
|
try:
|
||||||
|
await stop_watcher
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
# send agent stats to webchat
|
# send agent stats to webchat
|
||||||
if astr_event.get_platform_name() == "webchat":
|
if astr_event.get_platform_name() == "webchat":
|
||||||
@@ -133,6 +161,12 @@ async def run_agent(
|
|||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if "stop_watcher" in locals() and not stop_watcher.done():
|
||||||
|
stop_watcher.cancel()
|
||||||
|
try:
|
||||||
|
await stop_watcher
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||||
@@ -155,6 +189,14 @@ async def run_agent(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None:
|
||||||
|
while not agent_runner.done():
|
||||||
|
if _should_stop_agent(astr_event):
|
||||||
|
agent_runner.request_stop()
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
async def run_live_agent(
|
async def run_live_agent(
|
||||||
agent_runner: AgentRunner,
|
agent_runner: AgentRunner,
|
||||||
tts_provider: TTSProvider | None = None,
|
tts_provider: TTSProvider | None = None,
|
||||||
|
|||||||
@@ -247,13 +247,16 @@ class InternalAgentSubStage(Stage):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
# 保存历史记录
|
# 保存历史记录
|
||||||
if not event.is_stopped() and agent_runner.done():
|
if agent_runner.done() and (
|
||||||
|
not event.is_stopped() or agent_runner.was_aborted()
|
||||||
|
):
|
||||||
await self._save_to_history(
|
await self._save_to_history(
|
||||||
event,
|
event,
|
||||||
req,
|
req,
|
||||||
agent_runner.get_final_llm_resp(),
|
agent_runner.get_final_llm_resp(),
|
||||||
agent_runner.run_context.messages,
|
agent_runner.run_context.messages,
|
||||||
agent_runner.stats,
|
agent_runner.stats,
|
||||||
|
user_aborted=agent_runner.was_aborted(),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif streaming_response and not stream_to_general:
|
elif streaming_response and not stream_to_general:
|
||||||
@@ -308,13 +311,14 @@ class InternalAgentSubStage(Stage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 检查事件是否被停止,如果被停止则不保存历史记录
|
# 检查事件是否被停止,如果被停止则不保存历史记录
|
||||||
if not event.is_stopped():
|
if not event.is_stopped() or agent_runner.was_aborted():
|
||||||
await self._save_to_history(
|
await self._save_to_history(
|
||||||
event,
|
event,
|
||||||
req,
|
req,
|
||||||
final_resp,
|
final_resp,
|
||||||
agent_runner.run_context.messages,
|
agent_runner.run_context.messages,
|
||||||
agent_runner.stats,
|
agent_runner.stats,
|
||||||
|
user_aborted=agent_runner.was_aborted(),
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -340,16 +344,29 @@ class InternalAgentSubStage(Stage):
|
|||||||
llm_response: LLMResponse | None,
|
llm_response: LLMResponse | None,
|
||||||
all_messages: list[Message],
|
all_messages: list[Message],
|
||||||
runner_stats: AgentStats | None,
|
runner_stats: AgentStats | None,
|
||||||
|
user_aborted: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if (
|
if not req or not req.conversation:
|
||||||
not req
|
|
||||||
or not req.conversation
|
|
||||||
or not llm_response
|
|
||||||
or llm_response.role != "assistant"
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not llm_response.completion_text and not req.tool_calls_result:
|
if not llm_response and not user_aborted:
|
||||||
|
return
|
||||||
|
|
||||||
|
if llm_response and llm_response.role != "assistant":
|
||||||
|
if not user_aborted:
|
||||||
|
return
|
||||||
|
llm_response = LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text=llm_response.completion_text or "",
|
||||||
|
)
|
||||||
|
elif llm_response is None:
|
||||||
|
llm_response = LLMResponse(role="assistant", completion_text="")
|
||||||
|
|
||||||
|
if (
|
||||||
|
not llm_response.completion_text
|
||||||
|
and not req.tool_calls_result
|
||||||
|
and not user_aborted
|
||||||
|
):
|
||||||
logger.debug("LLM 响应为空,不保存记录。")
|
logger.debug("LLM 响应为空,不保存记录。")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -363,6 +380,14 @@ class InternalAgentSubStage(Stage):
|
|||||||
continue
|
continue
|
||||||
message_to_save.append(message.model_dump())
|
message_to_save.append(message.model_dump())
|
||||||
|
|
||||||
|
# if user_aborted:
|
||||||
|
# message_to_save.append(
|
||||||
|
# Message(
|
||||||
|
# role="assistant",
|
||||||
|
# content="[User aborted this request. Partial output before abort was preserved.]",
|
||||||
|
# ).model_dump()
|
||||||
|
# )
|
||||||
|
|
||||||
token_usage = None
|
token_usage = None
|
||||||
if runner_stats:
|
if runner_stats:
|
||||||
# token_usage = runner_stats.token_usage.total
|
# token_usage = runner_stats.token_usage.total
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
|
|
||||||
from .webchat_queue_mgr import webchat_queue_mgr
|
from .webchat_queue_mgr import webchat_queue_mgr
|
||||||
|
|
||||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
|
||||||
|
|
||||||
|
|
||||||
class WebChatMessageEvent(AstrMessageEvent):
|
class WebChatMessageEvent(AstrMessageEvent):
|
||||||
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
|
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
os.makedirs(imgs_dir, exist_ok=True)
|
os.makedirs(attachments_dir, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _send(
|
async def _send(
|
||||||
@@ -69,7 +69,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
# save image to local
|
# save image to local
|
||||||
filename = f"{str(uuid.uuid4())}.jpg"
|
filename = f"{str(uuid.uuid4())}.jpg"
|
||||||
path = os.path.join(imgs_dir, filename)
|
path = os.path.join(attachments_dir, filename)
|
||||||
image_base64 = await comp.convert_to_base64()
|
image_base64 = await comp.convert_to_base64()
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(base64.b64decode(image_base64))
|
f.write(base64.b64decode(image_base64))
|
||||||
@@ -85,7 +85,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
elif isinstance(comp, Record):
|
elif isinstance(comp, Record):
|
||||||
# save record to local
|
# save record to local
|
||||||
filename = f"{str(uuid.uuid4())}.wav"
|
filename = f"{str(uuid.uuid4())}.wav"
|
||||||
path = os.path.join(imgs_dir, filename)
|
path = os.path.join(attachments_dir, filename)
|
||||||
record_base64 = await comp.convert_to_base64()
|
record_base64 = await comp.convert_to_base64()
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(base64.b64decode(record_base64))
|
f.write(base64.b64decode(record_base64))
|
||||||
@@ -104,7 +104,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
original_name = comp.name or os.path.basename(file_path)
|
original_name = comp.name or os.path.basename(file_path)
|
||||||
ext = os.path.splitext(original_name)[1] or ""
|
ext = os.path.splitext(original_name)[1] or ""
|
||||||
filename = f"{uuid.uuid4()!s}{ext}"
|
filename = f"{uuid.uuid4()!s}{ext}"
|
||||||
dest_path = os.path.join(imgs_dir, filename)
|
dest_path = os.path.join(attachments_dir, filename)
|
||||||
shutil.copy2(file_path, dest_path)
|
shutil.copy2(file_path, dest_path)
|
||||||
data = f"[FILE]{filename}"
|
data = f"[FILE]{filename}"
|
||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
|
|||||||
@@ -46,5 +46,22 @@ class ActiveEventRegistry:
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def request_agent_stop_all(
|
||||||
|
self,
|
||||||
|
umo: str,
|
||||||
|
exclude: AstrMessageEvent | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""请求停止指定 UMO 的所有活跃事件中的 Agent 运行。
|
||||||
|
|
||||||
|
与 stop_all 不同,这里不会调用 event.stop_event(),
|
||||||
|
因此不会中断事件传播,后续流程(如历史记录保存)仍可继续。
|
||||||
|
"""
|
||||||
|
count = 0
|
||||||
|
for event in list(self._events.get(umo, [])):
|
||||||
|
if event is not exclude:
|
||||||
|
event.set_extra("agent_stop_requested", True)
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
active_event_registry = ActiveEventRegistry()
|
active_event_registry = ActiveEventRegistry()
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ from quart import g, make_response, request, send_file
|
|||||||
from astrbot.core import logger, sp
|
from astrbot.core import logger, sp
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||||
|
from astrbot.core.utils.active_event_registry import active_event_registry
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
from .route import Response, Route, RouteContext
|
from .route import Response, Route, RouteContext
|
||||||
@@ -41,6 +43,7 @@ class ChatRoute(Route):
|
|||||||
"/chat/new_session": ("GET", self.new_session),
|
"/chat/new_session": ("GET", self.new_session),
|
||||||
"/chat/sessions": ("GET", self.get_sessions),
|
"/chat/sessions": ("GET", self.get_sessions),
|
||||||
"/chat/get_session": ("GET", self.get_session),
|
"/chat/get_session": ("GET", self.get_session),
|
||||||
|
"/chat/stop": ("POST", self.stop_session),
|
||||||
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
||||||
"/chat/update_session_display_name": (
|
"/chat/update_session_display_name": (
|
||||||
"POST",
|
"POST",
|
||||||
@@ -212,8 +215,13 @@ class ChatRoute(Route):
|
|||||||
filename: 存储的文件名
|
filename: 存储的文件名
|
||||||
attach_type: 附件类型 (image, record, file, video)
|
attach_type: 附件类型 (image, record, file, video)
|
||||||
"""
|
"""
|
||||||
file_path = os.path.join(self.attachments_dir, os.path.basename(filename))
|
basename = os.path.basename(filename)
|
||||||
if not os.path.exists(file_path):
|
candidate_paths = [
|
||||||
|
os.path.join(self.attachments_dir, basename),
|
||||||
|
os.path.join(self.legacy_img_dir, basename),
|
||||||
|
]
|
||||||
|
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
|
||||||
|
if not file_path:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# guess mime type
|
# guess mime type
|
||||||
@@ -466,13 +474,13 @@ class ChatRoute(Route):
|
|||||||
if tc_id in tool_calls:
|
if tc_id in tool_calls:
|
||||||
tool_calls[tc_id]["result"] = tcr.get("result")
|
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||||
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||||
accumulated_parts.append(
|
accumulated_parts.append(
|
||||||
{
|
{
|
||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
"tool_calls": [tool_calls[tc_id]],
|
"tool_calls": [tool_calls[tc_id]],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tool_calls.pop(tc_id, None)
|
tool_calls.pop(tc_id, None)
|
||||||
elif chain_type == "reasoning":
|
elif chain_type == "reasoning":
|
||||||
accumulated_reasoning += result_text
|
accumulated_reasoning += result_text
|
||||||
elif streaming:
|
elif streaming:
|
||||||
@@ -603,6 +611,36 @@ class ChatRoute(Route):
|
|||||||
response.timeout = None # fix SSE auto disconnect issue
|
response.timeout = None # fix SSE auto disconnect issue
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def stop_session(self):
|
||||||
|
"""Stop active agent runs for a session."""
|
||||||
|
post_data = await request.json
|
||||||
|
if post_data is None:
|
||||||
|
return Response().error("Missing JSON body").__dict__
|
||||||
|
|
||||||
|
session_id = post_data.get("session_id")
|
||||||
|
if not session_id:
|
||||||
|
return Response().error("Missing key: session_id").__dict__
|
||||||
|
|
||||||
|
username = g.get("username", "guest")
|
||||||
|
session = await self.db.get_platform_session_by_id(session_id)
|
||||||
|
if not session:
|
||||||
|
return Response().error(f"Session {session_id} not found").__dict__
|
||||||
|
if session.creator != username:
|
||||||
|
return Response().error("Permission denied").__dict__
|
||||||
|
|
||||||
|
message_type = (
|
||||||
|
MessageType.GROUP_MESSAGE.value
|
||||||
|
if session.is_group
|
||||||
|
else MessageType.FRIEND_MESSAGE.value
|
||||||
|
)
|
||||||
|
umo = (
|
||||||
|
f"{session.platform_id}:{message_type}:"
|
||||||
|
f"{session.platform_id}!{username}!{session_id}"
|
||||||
|
)
|
||||||
|
stopped_count = active_event_registry.request_agent_stop_all(umo)
|
||||||
|
|
||||||
|
return Response().ok(data={"stopped_count": stopped_count}).__dict__
|
||||||
|
|
||||||
async def delete_webchat_session(self):
|
async def delete_webchat_session(self):
|
||||||
"""Delete a Platform session and all its related data."""
|
"""Delete a Platform session and all its related data."""
|
||||||
session_id = request.args.get("session_id")
|
session_id = request.args.get("session_id")
|
||||||
|
|||||||
@@ -77,12 +77,14 @@
|
|||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:stagedFiles="stagedNonImageFiles"
|
:stagedFiles="stagedNonImageFiles"
|
||||||
:disabled="isStreaming"
|
:disabled="isStreaming"
|
||||||
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:replyTo="replyTo"
|
:replyTo="replyTo"
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@removeImage="removeImage"
|
@removeImage="removeImage"
|
||||||
@removeAudio="removeAudio"
|
@removeAudio="removeAudio"
|
||||||
@@ -106,12 +108,14 @@
|
|||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:stagedFiles="stagedNonImageFiles"
|
:stagedFiles="stagedNonImageFiles"
|
||||||
:disabled="isStreaming"
|
:disabled="isStreaming"
|
||||||
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:replyTo="replyTo"
|
:replyTo="replyTo"
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@removeImage="removeImage"
|
@removeImage="removeImage"
|
||||||
@removeAudio="removeAudio"
|
@removeAudio="removeAudio"
|
||||||
@@ -134,12 +138,14 @@
|
|||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:stagedFiles="stagedNonImageFiles"
|
:stagedFiles="stagedNonImageFiles"
|
||||||
:disabled="isStreaming"
|
:disabled="isStreaming"
|
||||||
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:replyTo="replyTo"
|
:replyTo="replyTo"
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@removeImage="removeImage"
|
@removeImage="removeImage"
|
||||||
@removeAudio="removeAudio"
|
@removeAudio="removeAudio"
|
||||||
@@ -298,6 +304,7 @@ const {
|
|||||||
currentSessionProject,
|
currentSessionProject,
|
||||||
getSessionMessages: getSessionMsg,
|
getSessionMessages: getSessionMsg,
|
||||||
sendMessage: sendMsg,
|
sendMessage: sendMsg,
|
||||||
|
stopMessage: stopMsg,
|
||||||
toggleStreaming
|
toggleStreaming
|
||||||
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
||||||
|
|
||||||
@@ -631,6 +638,10 @@ async function handleSendMessage() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleStopMessage() {
|
||||||
|
await stopMsg();
|
||||||
|
}
|
||||||
|
|
||||||
// 路由变化监听
|
// 路由变化监听
|
||||||
watch(
|
watch(
|
||||||
() => route.path,
|
() => route.path,
|
||||||
|
|||||||
@@ -94,8 +94,29 @@
|
|||||||
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
|
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
|
||||||
</v-tooltip>
|
</v-tooltip>
|
||||||
</v-btn>
|
</v-btn>
|
||||||
<v-btn @click="$emit('send')" icon="mdi-send" variant="text" color="deep-purple"
|
<v-btn
|
||||||
:disabled="!canSend" class="send-btn" size="small" />
|
icon
|
||||||
|
v-if="isRunning"
|
||||||
|
@click="$emit('stop')"
|
||||||
|
variant="text"
|
||||||
|
class="send-btn"
|
||||||
|
size="small"
|
||||||
|
>
|
||||||
|
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
|
||||||
|
<v-tooltip activator="parent" location="top">
|
||||||
|
{{ tm('input.stopGenerating') }}
|
||||||
|
</v-tooltip>
|
||||||
|
</v-btn>
|
||||||
|
<v-btn
|
||||||
|
v-else
|
||||||
|
@click="$emit('send')"
|
||||||
|
icon="mdi-send"
|
||||||
|
variant="text"
|
||||||
|
color="deep-purple"
|
||||||
|
:disabled="!canSend"
|
||||||
|
class="send-btn"
|
||||||
|
size="small"
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -160,6 +181,7 @@ interface Props {
|
|||||||
disabled: boolean;
|
disabled: boolean;
|
||||||
enableStreaming: boolean;
|
enableStreaming: boolean;
|
||||||
isRecording: boolean;
|
isRecording: boolean;
|
||||||
|
isRunning: boolean;
|
||||||
sessionId?: string | null;
|
sessionId?: string | null;
|
||||||
currentSession?: Session | null;
|
currentSession?: Session | null;
|
||||||
configId?: string | null;
|
configId?: string | null;
|
||||||
@@ -177,6 +199,7 @@ const props = withDefaults(defineProps<Props>(), {
|
|||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
'update:prompt': [value: string];
|
'update:prompt': [value: string];
|
||||||
send: [];
|
send: [];
|
||||||
|
stop: [];
|
||||||
toggleStreaming: [];
|
toggleStreaming: [];
|
||||||
removeImage: [index: number];
|
removeImage: [index: number];
|
||||||
removeAudio: [];
|
removeAudio: [];
|
||||||
|
|||||||
@@ -23,12 +23,14 @@
|
|||||||
:stagedImagesUrl="stagedImagesUrl"
|
:stagedImagesUrl="stagedImagesUrl"
|
||||||
:stagedAudioUrl="stagedAudioUrl"
|
:stagedAudioUrl="stagedAudioUrl"
|
||||||
:disabled="isStreaming"
|
:disabled="isStreaming"
|
||||||
|
:is-running="isStreaming || isConvRunning"
|
||||||
:enableStreaming="enableStreaming"
|
:enableStreaming="enableStreaming"
|
||||||
:isRecording="isRecording"
|
:isRecording="isRecording"
|
||||||
:session-id="currSessionId || null"
|
:session-id="currSessionId || null"
|
||||||
:current-session="getCurrentSession"
|
:current-session="getCurrentSession"
|
||||||
:config-id="configId"
|
:config-id="configId"
|
||||||
@send="handleSendMessage"
|
@send="handleSendMessage"
|
||||||
|
@stop="handleStopMessage"
|
||||||
@toggleStreaming="toggleStreaming"
|
@toggleStreaming="toggleStreaming"
|
||||||
@removeImage="removeImage"
|
@removeImage="removeImage"
|
||||||
@removeAudio="removeAudio"
|
@removeAudio="removeAudio"
|
||||||
@@ -156,6 +158,7 @@ const {
|
|||||||
enableStreaming,
|
enableStreaming,
|
||||||
getSessionMessages: getSessionMsg,
|
getSessionMessages: getSessionMsg,
|
||||||
sendMessage: sendMsg,
|
sendMessage: sendMsg,
|
||||||
|
stopMessage: stopMsg,
|
||||||
toggleStreaming
|
toggleStreaming
|
||||||
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
||||||
|
|
||||||
@@ -236,6 +239,10 @@ async function handleSendMessage() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleStopMessage() {
|
||||||
|
await stopMsg();
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
// 独立模式在挂载时创建新会话
|
// 独立模式在挂载时创建新会话
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -82,6 +82,10 @@ export function useMessages(
|
|||||||
const activeSSECount = ref(0);
|
const activeSSECount = ref(0);
|
||||||
const enableStreaming = ref(true);
|
const enableStreaming = ref(true);
|
||||||
const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL
|
const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL
|
||||||
|
const currentRequestController = ref<AbortController | null>(null);
|
||||||
|
const currentReader = ref<ReadableStreamDefaultReader<Uint8Array> | null>(null);
|
||||||
|
const currentRunningSessionId = ref('');
|
||||||
|
const userStopRequested = ref(false);
|
||||||
|
|
||||||
// 当前会话的项目信息
|
// 当前会话的项目信息
|
||||||
const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null);
|
const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null);
|
||||||
@@ -289,6 +293,8 @@ export function useMessages(
|
|||||||
if (activeSSECount.value === 1) {
|
if (activeSSECount.value === 1) {
|
||||||
isConvRunning.value = true;
|
isConvRunning.value = true;
|
||||||
}
|
}
|
||||||
|
userStopRequested.value = false;
|
||||||
|
currentRunningSessionId.value = currSessionId.value;
|
||||||
|
|
||||||
// 收集所有 attachment_id
|
// 收集所有 attachment_id
|
||||||
const files = stagedFiles.map(f => f.attachment_id);
|
const files = stagedFiles.map(f => f.attachment_id);
|
||||||
@@ -330,12 +336,15 @@ export function useMessages(
|
|||||||
messageToSend = prompt;
|
messageToSend = prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const controller = new AbortController();
|
||||||
|
currentRequestController.value = controller;
|
||||||
const response = await fetch('/api/chat/send', {
|
const response = await fetch('/api/chat/send', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||||
},
|
},
|
||||||
|
signal: controller.signal,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
message: messageToSend,
|
message: messageToSend,
|
||||||
session_id: currSessionId.value,
|
session_id: currSessionId.value,
|
||||||
@@ -350,6 +359,7 @@ export function useMessages(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const reader = response.body!.getReader();
|
const reader = response.body!.getReader();
|
||||||
|
currentReader.value = reader;
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let in_streaming = false;
|
let in_streaming = false;
|
||||||
let message_obj: MessageContent | null = null;
|
let message_obj: MessageContent | null = null;
|
||||||
@@ -560,7 +570,9 @@ export function useMessages(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (readError) {
|
} catch (readError) {
|
||||||
console.error('SSE读取错误:', readError);
|
if (!userStopRequested.value) {
|
||||||
|
console.error('SSE读取错误:', readError);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -569,7 +581,9 @@ export function useMessages(
|
|||||||
onSessionsUpdate();
|
onSessionsUpdate();
|
||||||
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('发送消息失败:', err);
|
if (!userStopRequested.value) {
|
||||||
|
console.error('发送消息失败:', err);
|
||||||
|
}
|
||||||
// 移除加载占位符
|
// 移除加载占位符
|
||||||
const lastMsg = messages.value[messages.value.length - 1];
|
const lastMsg = messages.value[messages.value.length - 1];
|
||||||
if (lastMsg?.content?.isLoading) {
|
if (lastMsg?.content?.isLoading) {
|
||||||
@@ -577,6 +591,10 @@ export function useMessages(
|
|||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
isStreaming.value = false;
|
isStreaming.value = false;
|
||||||
|
currentReader.value = null;
|
||||||
|
currentRequestController.value = null;
|
||||||
|
currentRunningSessionId.value = '';
|
||||||
|
userStopRequested.value = false;
|
||||||
activeSSECount.value--;
|
activeSSECount.value--;
|
||||||
if (activeSSECount.value === 0) {
|
if (activeSSECount.value === 0) {
|
||||||
isConvRunning.value = false;
|
isConvRunning.value = false;
|
||||||
@@ -584,6 +602,33 @@ export function useMessages(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function stopMessage() {
|
||||||
|
const sessionId = currentRunningSessionId.value || currSessionId.value;
|
||||||
|
if (!sessionId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
userStopRequested.value = true;
|
||||||
|
try {
|
||||||
|
await axios.post('/api/chat/stop', {
|
||||||
|
session_id: sessionId
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
console.error('停止会话失败:', err);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await currentReader.value?.cancel();
|
||||||
|
} catch (err) {
|
||||||
|
// ignore reader cancel failures
|
||||||
|
}
|
||||||
|
currentReader.value = null;
|
||||||
|
currentRequestController.value?.abort();
|
||||||
|
currentRequestController.value = null;
|
||||||
|
|
||||||
|
isStreaming.value = false;
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
messages,
|
messages,
|
||||||
isStreaming,
|
isStreaming,
|
||||||
@@ -592,6 +637,7 @@ export function useMessages(
|
|||||||
currentSessionProject,
|
currentSessionProject,
|
||||||
getSessionMessages,
|
getSessionMessages,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
|
stopMessage,
|
||||||
toggleStreaming,
|
toggleStreaming,
|
||||||
getAttachment
|
getAttachment
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"voice": "Voice Input",
|
"voice": "Voice Input",
|
||||||
"recordingPrompt": "Recording, please speak...",
|
"recordingPrompt": "Recording, please speak...",
|
||||||
"chatPrompt": "Let's chat!",
|
"chatPrompt": "Let's chat!",
|
||||||
"dropToUpload": "Drop files to upload"
|
"dropToUpload": "Drop files to upload",
|
||||||
|
"stopGenerating": "Stop generating"
|
||||||
},
|
},
|
||||||
"message": {
|
"message": {
|
||||||
"user": "User",
|
"user": "User",
|
||||||
|
|||||||
@@ -9,7 +9,8 @@
|
|||||||
"voice": "语音输入",
|
"voice": "语音输入",
|
||||||
"recordingPrompt": "录音中,请说话...",
|
"recordingPrompt": "录音中,请说话...",
|
||||||
"chatPrompt": "聊天吧!",
|
"chatPrompt": "聊天吧!",
|
||||||
"dropToUpload": "松开鼠标上传文件"
|
"dropToUpload": "松开鼠标上传文件",
|
||||||
|
"stopGenerating": "停止生成"
|
||||||
},
|
},
|
||||||
"message": {
|
"message": {
|
||||||
"user": "用户",
|
"user": "用户",
|
||||||
|
|||||||
@@ -105,6 +105,28 @@ class MockErrProvider(MockProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAbortableStreamProvider(MockProvider):
|
||||||
|
async def text_chat_stream(self, **kwargs):
|
||||||
|
abort_signal = kwargs.get("abort_signal")
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text="partial ",
|
||||||
|
is_chunk=True,
|
||||||
|
)
|
||||||
|
if abort_signal and abort_signal.is_set():
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text="partial ",
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
yield LLMResponse(
|
||||||
|
role="assistant",
|
||||||
|
completion_text="partial final",
|
||||||
|
is_chunk=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MockHooks(BaseAgentRunHooks):
|
class MockHooks(BaseAgentRunHooks):
|
||||||
"""模拟钩子函数"""
|
"""模拟钩子函数"""
|
||||||
|
|
||||||
@@ -394,6 +416,41 @@ async def test_fallback_provider_used_when_primary_returns_err(
|
|||||||
assert fallback_provider.call_count == 1
|
assert fallback_provider.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_signal_returns_aborted_and_persists_partial_message(
|
||||||
|
runner, provider_request, mock_tool_executor, mock_hooks
|
||||||
|
):
|
||||||
|
provider = MockAbortableStreamProvider()
|
||||||
|
|
||||||
|
await runner.reset(
|
||||||
|
provider=provider,
|
||||||
|
request=provider_request,
|
||||||
|
run_context=ContextWrapper(context=None),
|
||||||
|
tool_executor=mock_tool_executor,
|
||||||
|
agent_hooks=mock_hooks,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
step_iter = runner.step()
|
||||||
|
first_resp = await step_iter.__anext__()
|
||||||
|
assert first_resp.type == "streaming_delta"
|
||||||
|
|
||||||
|
runner.request_stop()
|
||||||
|
|
||||||
|
rest_responses = []
|
||||||
|
async for response in step_iter:
|
||||||
|
rest_responses.append(response)
|
||||||
|
|
||||||
|
assert any(resp.type == "aborted" for resp in rest_responses)
|
||||||
|
assert runner.was_aborted() is True
|
||||||
|
|
||||||
|
final_resp = runner.get_final_llm_resp()
|
||||||
|
assert final_resp is not None
|
||||||
|
assert final_resp.role == "assistant"
|
||||||
|
assert final_resp.completion_text == "partial "
|
||||||
|
assert runner.run_context.messages[-1].role == "assistant"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 运行测试
|
# 运行测试
|
||||||
pytest.main([__file__, "-v"])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
Reference in New Issue
Block a user