Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 90602dd97f | |||
| 1ab2fa1788 | |||
| 650a092cc1 | |||
| 6240125440 |
@@ -102,6 +102,30 @@ class ConversationCommands:
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
async def stop(self, message: AstrMessageEvent) -> None:
|
||||
"""停止当前会话正在运行的 Agent"""
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
umo = message.unified_msg_origin
|
||||
|
||||
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
|
||||
stopped_count = active_event_registry.stop_all(umo, exclude=message)
|
||||
else:
|
||||
stopped_count = active_event_registry.request_agent_stop_all(
|
||||
umo,
|
||||
exclude=message,
|
||||
)
|
||||
|
||||
if stopped_count > 0:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
f"已请求停止 {stopped_count} 个运行中的任务。"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
|
||||
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||
"""查看对话记录"""
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
|
||||
@@ -132,6 +132,11 @@ class Main(star.Star):
|
||||
"""重置 LLM 会话"""
|
||||
await self.conversation_c.reset(message)
|
||||
|
||||
@filter.command("stop")
|
||||
async def stop(self, message: AstrMessageEvent) -> None:
|
||||
"""停止当前会话中正在运行的 Agent"""
|
||||
await self.conversation_c.stop(message)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("model")
|
||||
async def model_ls(
|
||||
|
||||
@@ -137,6 +137,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.tool_executor = tool_executor
|
||||
self.agent_hooks = agent_hooks
|
||||
self.run_context = run_context
|
||||
self._stop_requested = False
|
||||
self._aborted = False
|
||||
|
||||
# These two are used for tool schema mode handling
|
||||
# We now have two modes:
|
||||
@@ -328,6 +330,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
),
|
||||
)
|
||||
if self._stop_requested:
|
||||
llm_resp_result = LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
|
||||
reasoning_content=llm_response.reasoning_content,
|
||||
reasoning_signature=llm_response.reasoning_signature,
|
||||
)
|
||||
break
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
|
||||
@@ -339,6 +349,48 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
if self._stop_requested:
|
||||
llm_resp_result = LLMResponse(role="assistant", completion_text="")
|
||||
else:
|
||||
return
|
||||
|
||||
if self._stop_requested:
|
||||
logger.info("Agent execution was requested to stop by user.")
|
||||
llm_resp = llm_resp_result
|
||||
if llm_resp.role != "assistant":
|
||||
llm_resp = LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
|
||||
)
|
||||
self.final_llm_resp = llm_resp
|
||||
self._aborted = True
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
parts.append(TextPart(text=llm_resp.completion_text))
|
||||
if parts:
|
||||
self.run_context.messages.append(
|
||||
Message(role="assistant", content=parts)
|
||||
)
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
|
||||
|
||||
yield AgentResponse(
|
||||
type="aborted",
|
||||
data=AgentResponseData(chain=MessageChain(type="aborted")),
|
||||
)
|
||||
return
|
||||
|
||||
# 处理 LLM 响应
|
||||
@@ -848,5 +900,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
"""检查 Agent 是否已完成工作"""
|
||||
return self._state in (AgentState.DONE, AgentState.ERROR)
|
||||
|
||||
def request_stop(self) -> None:
|
||||
self._stop_requested = True
|
||||
|
||||
def was_aborted(self) -> bool:
|
||||
return self._aborted
|
||||
|
||||
def get_final_llm_resp(self) -> LLMResponse | None:
|
||||
return self.final_llm_resp
|
||||
|
||||
@@ -20,6 +20,10 @@ from astrbot.core.provider.provider import TTSProvider
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
def _should_stop_agent(astr_event) -> bool:
|
||||
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
|
||||
|
||||
|
||||
async def run_agent(
|
||||
agent_runner: AgentRunner,
|
||||
max_step: int = 30,
|
||||
@@ -48,10 +52,28 @@ async def run_agent(
|
||||
)
|
||||
)
|
||||
|
||||
stop_watcher = asyncio.create_task(
|
||||
_watch_agent_stop_signal(agent_runner, astr_event),
|
||||
)
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
if _should_stop_agent(astr_event):
|
||||
agent_runner.request_stop()
|
||||
|
||||
if resp.type == "aborted":
|
||||
if not stop_watcher.done():
|
||||
stop_watcher.cancel()
|
||||
try:
|
||||
await stop_watcher
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
astr_event.set_extra("agent_user_aborted", True)
|
||||
astr_event.set_extra("agent_stop_requested", False)
|
||||
return
|
||||
|
||||
if _should_stop_agent(astr_event):
|
||||
continue
|
||||
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
|
||||
@@ -120,6 +142,12 @@ async def run_agent(
|
||||
# display the reasoning content only when configured
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if not stop_watcher.done():
|
||||
stop_watcher.cancel()
|
||||
try:
|
||||
await stop_watcher
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if agent_runner.done():
|
||||
# send agent stats to webchat
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
@@ -133,6 +161,12 @@ async def run_agent(
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
if "stop_watcher" in locals() and not stop_watcher.done():
|
||||
stop_watcher.cancel()
|
||||
try:
|
||||
await stop_watcher
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
|
||||
@@ -155,6 +189,14 @@ async def run_agent(
|
||||
return
|
||||
|
||||
|
||||
async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None:
|
||||
while not agent_runner.done():
|
||||
if _should_stop_agent(astr_event):
|
||||
agent_runner.request_stop()
|
||||
return
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
async def run_live_agent(
|
||||
agent_runner: AgentRunner,
|
||||
tts_provider: TTSProvider | None = None,
|
||||
|
||||
@@ -247,13 +247,16 @@ class InternalAgentSubStage(Stage):
|
||||
yield
|
||||
|
||||
# 保存历史记录
|
||||
if not event.is_stopped() and agent_runner.done():
|
||||
if agent_runner.done() and (
|
||||
not event.is_stopped() or agent_runner.was_aborted()
|
||||
):
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
user_aborted=agent_runner.was_aborted(),
|
||||
)
|
||||
|
||||
elif streaming_response and not stream_to_general:
|
||||
@@ -308,13 +311,14 @@ class InternalAgentSubStage(Stage):
|
||||
)
|
||||
|
||||
# 检查事件是否被停止,如果被停止则不保存历史记录
|
||||
if not event.is_stopped():
|
||||
if not event.is_stopped() or agent_runner.was_aborted():
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
final_resp,
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
user_aborted=agent_runner.was_aborted(),
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
@@ -340,16 +344,29 @@ class InternalAgentSubStage(Stage):
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
user_aborted: bool = False,
|
||||
) -> None:
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
or not llm_response
|
||||
or llm_response.role != "assistant"
|
||||
):
|
||||
if not req or not req.conversation:
|
||||
return
|
||||
|
||||
if not llm_response.completion_text and not req.tool_calls_result:
|
||||
if not llm_response and not user_aborted:
|
||||
return
|
||||
|
||||
if llm_response and llm_response.role != "assistant":
|
||||
if not user_aborted:
|
||||
return
|
||||
llm_response = LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=llm_response.completion_text or "",
|
||||
)
|
||||
elif llm_response is None:
|
||||
llm_response = LLMResponse(role="assistant", completion_text="")
|
||||
|
||||
if (
|
||||
not llm_response.completion_text
|
||||
and not req.tool_calls_result
|
||||
and not user_aborted
|
||||
):
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
@@ -363,6 +380,14 @@ class InternalAgentSubStage(Stage):
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# if user_aborted:
|
||||
# message_to_save.append(
|
||||
# Message(
|
||||
# role="assistant",
|
||||
# content="[User aborted this request. Partial output before abort was preserved.]",
|
||||
# ).model_dump()
|
||||
# )
|
||||
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
# token_usage = runner_stats.token_usage.total
|
||||
|
||||
@@ -11,13 +11,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .webchat_queue_mgr import webchat_queue_mgr
|
||||
|
||||
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
|
||||
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
os.makedirs(imgs_dir, exist_ok=True)
|
||||
os.makedirs(attachments_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
@@ -69,7 +69,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = f"{str(uuid.uuid4())}.jpg"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
path = os.path.join(attachments_dir, filename)
|
||||
image_base64 = await comp.convert_to_base64()
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(image_base64))
|
||||
@@ -85,7 +85,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
elif isinstance(comp, Record):
|
||||
# save record to local
|
||||
filename = f"{str(uuid.uuid4())}.wav"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
path = os.path.join(attachments_dir, filename)
|
||||
record_base64 = await comp.convert_to_base64()
|
||||
with open(path, "wb") as f:
|
||||
f.write(base64.b64decode(record_base64))
|
||||
@@ -104,7 +104,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
original_name = comp.name or os.path.basename(file_path)
|
||||
ext = os.path.splitext(original_name)[1] or ""
|
||||
filename = f"{uuid.uuid4()!s}{ext}"
|
||||
dest_path = os.path.join(imgs_dir, filename)
|
||||
dest_path = os.path.join(attachments_dir, filename)
|
||||
shutil.copy2(file_path, dest_path)
|
||||
data = f"[FILE]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
|
||||
@@ -46,5 +46,22 @@ class ActiveEventRegistry:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def request_agent_stop_all(
|
||||
self,
|
||||
umo: str,
|
||||
exclude: AstrMessageEvent | None = None,
|
||||
) -> int:
|
||||
"""请求停止指定 UMO 的所有活跃事件中的 Agent 运行。
|
||||
|
||||
与 stop_all 不同,这里不会调用 event.stop_event(),
|
||||
因此不会中断事件传播,后续流程(如历史记录保存)仍可继续。
|
||||
"""
|
||||
count = 0
|
||||
for event in list(self._events.get(umo, [])):
|
||||
if event is not exclude:
|
||||
event.set_extra("agent_stop_requested", True)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
active_event_registry = ActiveEventRegistry()
|
||||
|
||||
@@ -13,7 +13,9 @@ from quart import g, make_response, request, send_file
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.active_event_registry import active_event_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
@@ -41,6 +43,7 @@ class ChatRoute(Route):
|
||||
"/chat/new_session": ("GET", self.new_session),
|
||||
"/chat/sessions": ("GET", self.get_sessions),
|
||||
"/chat/get_session": ("GET", self.get_session),
|
||||
"/chat/stop": ("POST", self.stop_session),
|
||||
"/chat/delete_session": ("GET", self.delete_webchat_session),
|
||||
"/chat/update_session_display_name": (
|
||||
"POST",
|
||||
@@ -212,8 +215,13 @@ class ChatRoute(Route):
|
||||
filename: 存储的文件名
|
||||
attach_type: 附件类型 (image, record, file, video)
|
||||
"""
|
||||
file_path = os.path.join(self.attachments_dir, os.path.basename(filename))
|
||||
if not os.path.exists(file_path):
|
||||
basename = os.path.basename(filename)
|
||||
candidate_paths = [
|
||||
os.path.join(self.attachments_dir, basename),
|
||||
os.path.join(self.legacy_img_dir, basename),
|
||||
]
|
||||
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
|
||||
if not file_path:
|
||||
return None
|
||||
|
||||
# guess mime type
|
||||
@@ -466,13 +474,13 @@ class ChatRoute(Route):
|
||||
if tc_id in tool_calls:
|
||||
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||
accumulated_parts.append(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool_calls": [tool_calls[tc_id]],
|
||||
}
|
||||
)
|
||||
tool_calls.pop(tc_id, None)
|
||||
accumulated_parts.append(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool_calls": [tool_calls[tc_id]],
|
||||
}
|
||||
)
|
||||
tool_calls.pop(tc_id, None)
|
||||
elif chain_type == "reasoning":
|
||||
accumulated_reasoning += result_text
|
||||
elif streaming:
|
||||
@@ -603,6 +611,36 @@ class ChatRoute(Route):
|
||||
response.timeout = None # fix SSE auto disconnect issue
|
||||
return response
|
||||
|
||||
async def stop_session(self):
|
||||
"""Stop active agent runs for a session."""
|
||||
post_data = await request.json
|
||||
if post_data is None:
|
||||
return Response().error("Missing JSON body").__dict__
|
||||
|
||||
session_id = post_data.get("session_id")
|
||||
if not session_id:
|
||||
return Response().error("Missing key: session_id").__dict__
|
||||
|
||||
username = g.get("username", "guest")
|
||||
session = await self.db.get_platform_session_by_id(session_id)
|
||||
if not session:
|
||||
return Response().error(f"Session {session_id} not found").__dict__
|
||||
if session.creator != username:
|
||||
return Response().error("Permission denied").__dict__
|
||||
|
||||
message_type = (
|
||||
MessageType.GROUP_MESSAGE.value
|
||||
if session.is_group
|
||||
else MessageType.FRIEND_MESSAGE.value
|
||||
)
|
||||
umo = (
|
||||
f"{session.platform_id}:{message_type}:"
|
||||
f"{session.platform_id}!{username}!{session_id}"
|
||||
)
|
||||
stopped_count = active_event_registry.request_agent_stop_all(umo)
|
||||
|
||||
return Response().ok(data={"stopped_count": stopped_count}).__dict__
|
||||
|
||||
async def delete_webchat_session(self):
|
||||
"""Delete a Platform session and all its related data."""
|
||||
session_id = request.args.get("session_id")
|
||||
|
||||
@@ -77,12 +77,14 @@
|
||||
:stagedAudioUrl="stagedAudioUrl"
|
||||
:stagedFiles="stagedNonImageFiles"
|
||||
:disabled="isStreaming"
|
||||
:is-running="isStreaming || isConvRunning"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
:current-session="getCurrentSession"
|
||||
:replyTo="replyTo"
|
||||
@send="handleSendMessage"
|
||||
@stop="handleStopMessage"
|
||||
@toggleStreaming="toggleStreaming"
|
||||
@removeImage="removeImage"
|
||||
@removeAudio="removeAudio"
|
||||
@@ -106,12 +108,14 @@
|
||||
:stagedAudioUrl="stagedAudioUrl"
|
||||
:stagedFiles="stagedNonImageFiles"
|
||||
:disabled="isStreaming"
|
||||
:is-running="isStreaming || isConvRunning"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
:current-session="getCurrentSession"
|
||||
:replyTo="replyTo"
|
||||
@send="handleSendMessage"
|
||||
@stop="handleStopMessage"
|
||||
@toggleStreaming="toggleStreaming"
|
||||
@removeImage="removeImage"
|
||||
@removeAudio="removeAudio"
|
||||
@@ -134,12 +138,14 @@
|
||||
:stagedAudioUrl="stagedAudioUrl"
|
||||
:stagedFiles="stagedNonImageFiles"
|
||||
:disabled="isStreaming"
|
||||
:is-running="isStreaming || isConvRunning"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
:current-session="getCurrentSession"
|
||||
:replyTo="replyTo"
|
||||
@send="handleSendMessage"
|
||||
@stop="handleStopMessage"
|
||||
@toggleStreaming="toggleStreaming"
|
||||
@removeImage="removeImage"
|
||||
@removeAudio="removeAudio"
|
||||
@@ -298,6 +304,7 @@ const {
|
||||
currentSessionProject,
|
||||
getSessionMessages: getSessionMsg,
|
||||
sendMessage: sendMsg,
|
||||
stopMessage: stopMsg,
|
||||
toggleStreaming
|
||||
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
||||
|
||||
@@ -631,6 +638,10 @@ async function handleSendMessage() {
|
||||
}
|
||||
}
|
||||
|
||||
async function handleStopMessage() {
|
||||
await stopMsg();
|
||||
}
|
||||
|
||||
// 路由变化监听
|
||||
watch(
|
||||
() => route.path,
|
||||
|
||||
@@ -94,8 +94,29 @@
|
||||
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
|
||||
</v-tooltip>
|
||||
</v-btn>
|
||||
<v-btn @click="$emit('send')" icon="mdi-send" variant="text" color="deep-purple"
|
||||
:disabled="!canSend" class="send-btn" size="small" />
|
||||
<v-btn
|
||||
icon
|
||||
v-if="isRunning"
|
||||
@click="$emit('stop')"
|
||||
variant="text"
|
||||
class="send-btn"
|
||||
size="small"
|
||||
>
|
||||
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
|
||||
<v-tooltip activator="parent" location="top">
|
||||
{{ tm('input.stopGenerating') }}
|
||||
</v-tooltip>
|
||||
</v-btn>
|
||||
<v-btn
|
||||
v-else
|
||||
@click="$emit('send')"
|
||||
icon="mdi-send"
|
||||
variant="text"
|
||||
color="deep-purple"
|
||||
:disabled="!canSend"
|
||||
class="send-btn"
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -160,6 +181,7 @@ interface Props {
|
||||
disabled: boolean;
|
||||
enableStreaming: boolean;
|
||||
isRecording: boolean;
|
||||
isRunning: boolean;
|
||||
sessionId?: string | null;
|
||||
currentSession?: Session | null;
|
||||
configId?: string | null;
|
||||
@@ -177,6 +199,7 @@ const props = withDefaults(defineProps<Props>(), {
|
||||
const emit = defineEmits<{
|
||||
'update:prompt': [value: string];
|
||||
send: [];
|
||||
stop: [];
|
||||
toggleStreaming: [];
|
||||
removeImage: [index: number];
|
||||
removeAudio: [];
|
||||
|
||||
@@ -23,12 +23,14 @@
|
||||
:stagedImagesUrl="stagedImagesUrl"
|
||||
:stagedAudioUrl="stagedAudioUrl"
|
||||
:disabled="isStreaming"
|
||||
:is-running="isStreaming || isConvRunning"
|
||||
:enableStreaming="enableStreaming"
|
||||
:isRecording="isRecording"
|
||||
:session-id="currSessionId || null"
|
||||
:current-session="getCurrentSession"
|
||||
:config-id="configId"
|
||||
@send="handleSendMessage"
|
||||
@stop="handleStopMessage"
|
||||
@toggleStreaming="toggleStreaming"
|
||||
@removeImage="removeImage"
|
||||
@removeAudio="removeAudio"
|
||||
@@ -156,6 +158,7 @@ const {
|
||||
enableStreaming,
|
||||
getSessionMessages: getSessionMsg,
|
||||
sendMessage: sendMsg,
|
||||
stopMessage: stopMsg,
|
||||
toggleStreaming
|
||||
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
|
||||
|
||||
@@ -236,6 +239,10 @@ async function handleSendMessage() {
|
||||
}
|
||||
}
|
||||
|
||||
async function handleStopMessage() {
|
||||
await stopMsg();
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
// 独立模式在挂载时创建新会话
|
||||
try {
|
||||
|
||||
@@ -82,6 +82,10 @@ export function useMessages(
|
||||
const activeSSECount = ref(0);
|
||||
const enableStreaming = ref(true);
|
||||
const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL
|
||||
const currentRequestController = ref<AbortController | null>(null);
|
||||
const currentReader = ref<ReadableStreamDefaultReader<Uint8Array> | null>(null);
|
||||
const currentRunningSessionId = ref('');
|
||||
const userStopRequested = ref(false);
|
||||
|
||||
// 当前会话的项目信息
|
||||
const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null);
|
||||
@@ -289,6 +293,8 @@ export function useMessages(
|
||||
if (activeSSECount.value === 1) {
|
||||
isConvRunning.value = true;
|
||||
}
|
||||
userStopRequested.value = false;
|
||||
currentRunningSessionId.value = currSessionId.value;
|
||||
|
||||
// 收集所有 attachment_id
|
||||
const files = stagedFiles.map(f => f.attachment_id);
|
||||
@@ -330,12 +336,15 @@ export function useMessages(
|
||||
messageToSend = prompt;
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
currentRequestController.value = controller;
|
||||
const response = await fetch('/api/chat/send', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
},
|
||||
signal: controller.signal,
|
||||
body: JSON.stringify({
|
||||
message: messageToSend,
|
||||
session_id: currSessionId.value,
|
||||
@@ -350,6 +359,7 @@ export function useMessages(
|
||||
}
|
||||
|
||||
const reader = response.body!.getReader();
|
||||
currentReader.value = reader;
|
||||
const decoder = new TextDecoder();
|
||||
let in_streaming = false;
|
||||
let message_obj: MessageContent | null = null;
|
||||
@@ -560,7 +570,9 @@ export function useMessages(
|
||||
}
|
||||
}
|
||||
} catch (readError) {
|
||||
console.error('SSE读取错误:', readError);
|
||||
if (!userStopRequested.value) {
|
||||
console.error('SSE读取错误:', readError);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -569,7 +581,9 @@ export function useMessages(
|
||||
onSessionsUpdate();
|
||||
|
||||
} catch (err) {
|
||||
console.error('发送消息失败:', err);
|
||||
if (!userStopRequested.value) {
|
||||
console.error('发送消息失败:', err);
|
||||
}
|
||||
// 移除加载占位符
|
||||
const lastMsg = messages.value[messages.value.length - 1];
|
||||
if (lastMsg?.content?.isLoading) {
|
||||
@@ -577,6 +591,10 @@ export function useMessages(
|
||||
}
|
||||
} finally {
|
||||
isStreaming.value = false;
|
||||
currentReader.value = null;
|
||||
currentRequestController.value = null;
|
||||
currentRunningSessionId.value = '';
|
||||
userStopRequested.value = false;
|
||||
activeSSECount.value--;
|
||||
if (activeSSECount.value === 0) {
|
||||
isConvRunning.value = false;
|
||||
@@ -584,6 +602,33 @@ export function useMessages(
|
||||
}
|
||||
}
|
||||
|
||||
async function stopMessage() {
|
||||
const sessionId = currentRunningSessionId.value || currSessionId.value;
|
||||
if (!sessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
userStopRequested.value = true;
|
||||
try {
|
||||
await axios.post('/api/chat/stop', {
|
||||
session_id: sessionId
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('停止会话失败:', err);
|
||||
}
|
||||
|
||||
try {
|
||||
await currentReader.value?.cancel();
|
||||
} catch (err) {
|
||||
// ignore reader cancel failures
|
||||
}
|
||||
currentReader.value = null;
|
||||
currentRequestController.value?.abort();
|
||||
currentRequestController.value = null;
|
||||
|
||||
isStreaming.value = false;
|
||||
}
|
||||
|
||||
return {
|
||||
messages,
|
||||
isStreaming,
|
||||
@@ -592,6 +637,7 @@ export function useMessages(
|
||||
currentSessionProject,
|
||||
getSessionMessages,
|
||||
sendMessage,
|
||||
stopMessage,
|
||||
toggleStreaming,
|
||||
getAttachment
|
||||
};
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
"voice": "Voice Input",
|
||||
"recordingPrompt": "Recording, please speak...",
|
||||
"chatPrompt": "Let's chat!",
|
||||
"dropToUpload": "Drop files to upload"
|
||||
"dropToUpload": "Drop files to upload",
|
||||
"stopGenerating": "Stop generating"
|
||||
},
|
||||
"message": {
|
||||
"user": "User",
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
"voice": "语音输入",
|
||||
"recordingPrompt": "录音中,请说话...",
|
||||
"chatPrompt": "聊天吧!",
|
||||
"dropToUpload": "松开鼠标上传文件"
|
||||
"dropToUpload": "松开鼠标上传文件",
|
||||
"stopGenerating": "停止生成"
|
||||
},
|
||||
"message": {
|
||||
"user": "用户",
|
||||
|
||||
@@ -105,6 +105,28 @@ class MockErrProvider(MockProvider):
|
||||
)
|
||||
|
||||
|
||||
class MockAbortableStreamProvider(MockProvider):
|
||||
async def text_chat_stream(self, **kwargs):
|
||||
abort_signal = kwargs.get("abort_signal")
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="partial ",
|
||||
is_chunk=True,
|
||||
)
|
||||
if abort_signal and abort_signal.is_set():
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="partial ",
|
||||
is_chunk=False,
|
||||
)
|
||||
return
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="partial final",
|
||||
is_chunk=False,
|
||||
)
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
@@ -394,6 +416,41 @@ async def test_fallback_provider_used_when_primary_returns_err(
|
||||
assert fallback_provider.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_signal_returns_aborted_and_persists_partial_message(
|
||||
runner, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
provider = MockAbortableStreamProvider()
|
||||
|
||||
await runner.reset(
|
||||
provider=provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
step_iter = runner.step()
|
||||
first_resp = await step_iter.__anext__()
|
||||
assert first_resp.type == "streaming_delta"
|
||||
|
||||
runner.request_stop()
|
||||
|
||||
rest_responses = []
|
||||
async for response in step_iter:
|
||||
rest_responses.append(response)
|
||||
|
||||
assert any(resp.type == "aborted" for resp in rest_responses)
|
||||
assert runner.was_aborted() is True
|
||||
|
||||
final_resp = runner.get_final_llm_resp()
|
||||
assert final_resp is not None
|
||||
assert final_resp.role == "assistant"
|
||||
assert final_resp.completion_text == "partial "
|
||||
assert runner.run_context.messages[-1].role == "assistant"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user