Compare commits

...

4 Commits

15 changed files with 385 additions and 30 deletions
@@ -102,6 +102,30 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret)) message.set_result(MessageEventResult().message(ret))
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话正在运行的 Agent"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
umo = message.unified_msg_origin
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
stopped_count = active_event_registry.stop_all(umo, exclude=message)
else:
stopped_count = active_event_registry.request_agent_stop_all(
umo,
exclude=message,
)
if stopped_count > 0:
message.set_result(
MessageEventResult().message(
f"已请求停止 {stopped_count} 个运行中的任务。"
)
)
return
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
async def his(self, message: AstrMessageEvent, page: int = 1) -> None: async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录""" """查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin): if not self.context.get_using_provider(message.unified_msg_origin):
@@ -132,6 +132,11 @@ class Main(star.Star):
"""重置 LLM 会话""" """重置 LLM 会话"""
await self.conversation_c.reset(message) await self.conversation_c.reset(message)
@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话中正在运行的 Agent"""
await self.conversation_c.stop(message)
@filter.permission_type(filter.PermissionType.ADMIN) @filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("model") @filter.command("model")
async def model_ls( async def model_ls(
@@ -137,6 +137,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.tool_executor = tool_executor self.tool_executor = tool_executor
self.agent_hooks = agent_hooks self.agent_hooks = agent_hooks
self.run_context = run_context self.run_context = run_context
self._stop_requested = False
self._aborted = False
# These two are used for tool schema mode handling # These two are used for tool schema mode handling
# We now have two modes: # We now have two modes:
@@ -328,6 +330,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
), ),
), ),
) )
if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
reasoning_content=llm_response.reasoning_content,
reasoning_signature=llm_response.reasoning_signature,
)
break
continue continue
llm_resp_result = llm_response llm_resp_result = llm_response
@@ -339,6 +349,48 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
break # got final response break # got final response
if not llm_resp_result: if not llm_resp_result:
if self._stop_requested:
llm_resp_result = LLMResponse(role="assistant", completion_text="")
else:
return
if self._stop_requested:
logger.info("Agent execution was requested to stop by user.")
llm_resp = llm_resp_result
if llm_resp.role != "assistant":
llm_resp = LLMResponse(
role="assistant",
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
)
self.final_llm_resp = llm_resp
self._aborted = True
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if parts:
self.run_context.messages.append(
Message(role="assistant", content=parts)
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
yield AgentResponse(
type="aborted",
data=AgentResponseData(chain=MessageChain(type="aborted")),
)
return return
# 处理 LLM 响应 # 处理 LLM 响应
@@ -848,5 +900,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
"""检查 Agent 是否已完成工作""" """检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR) return self._state in (AgentState.DONE, AgentState.ERROR)
def request_stop(self) -> None:
self._stop_requested = True
def was_aborted(self) -> bool:
return self._aborted
def get_final_llm_resp(self) -> LLMResponse | None: def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp return self.final_llm_resp
+43 -1
View File
@@ -20,6 +20,10 @@ from astrbot.core.provider.provider import TTSProvider
AgentRunner = ToolLoopAgentRunner[AstrAgentContext] AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
def _should_stop_agent(astr_event) -> bool:
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
async def run_agent( async def run_agent(
agent_runner: AgentRunner, agent_runner: AgentRunner,
max_step: int = 30, max_step: int = 30,
@@ -48,10 +52,28 @@ async def run_agent(
) )
) )
stop_watcher = asyncio.create_task(
_watch_agent_stop_signal(agent_runner, astr_event),
)
try: try:
async for resp in agent_runner.step(): async for resp in agent_runner.step():
if astr_event.is_stopped(): if _should_stop_agent(astr_event):
agent_runner.request_stop()
if resp.type == "aborted":
if not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
astr_event.set_extra("agent_user_aborted", True)
astr_event.set_extra("agent_stop_requested", False)
return return
if _should_stop_agent(astr_event):
continue
if resp.type == "tool_call_result": if resp.type == "tool_call_result":
msg_chain = resp.data["chain"] msg_chain = resp.data["chain"]
@@ -120,6 +142,12 @@ async def run_agent(
# display the reasoning content only when configured # display the reasoning content only when configured
continue continue
yield resp.data["chain"] # MessageChain yield resp.data["chain"] # MessageChain
if not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
if agent_runner.done(): if agent_runner.done():
# send agent stats to webchat # send agent stats to webchat
if astr_event.get_platform_name() == "webchat": if astr_event.get_platform_name() == "webchat":
@@ -133,6 +161,12 @@ async def run_agent(
break break
except Exception as e: except Exception as e:
if "stop_watcher" in locals() and not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n" err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
@@ -155,6 +189,14 @@ async def run_agent(
return return
async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None:
while not agent_runner.done():
if _should_stop_agent(astr_event):
agent_runner.request_stop()
return
await asyncio.sleep(0.5)
async def run_live_agent( async def run_live_agent(
agent_runner: AgentRunner, agent_runner: AgentRunner,
tts_provider: TTSProvider | None = None, tts_provider: TTSProvider | None = None,
@@ -247,13 +247,16 @@ class InternalAgentSubStage(Stage):
yield yield
# 保存历史记录 # 保存历史记录
if not event.is_stopped() and agent_runner.done(): if agent_runner.done() and (
not event.is_stopped() or agent_runner.was_aborted()
):
await self._save_to_history( await self._save_to_history(
event, event,
req, req,
agent_runner.get_final_llm_resp(), agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages, agent_runner.run_context.messages,
agent_runner.stats, agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
) )
elif streaming_response and not stream_to_general: elif streaming_response and not stream_to_general:
@@ -308,13 +311,14 @@ class InternalAgentSubStage(Stage):
) )
# 检查事件是否被停止,如果被停止则不保存历史记录 # 检查事件是否被停止,如果被停止则不保存历史记录
if not event.is_stopped(): if not event.is_stopped() or agent_runner.was_aborted():
await self._save_to_history( await self._save_to_history(
event, event,
req, req,
final_resp, final_resp,
agent_runner.run_context.messages, agent_runner.run_context.messages,
agent_runner.stats, agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
) )
asyncio.create_task( asyncio.create_task(
@@ -340,16 +344,29 @@ class InternalAgentSubStage(Stage):
llm_response: LLMResponse | None, llm_response: LLMResponse | None,
all_messages: list[Message], all_messages: list[Message],
runner_stats: AgentStats | None, runner_stats: AgentStats | None,
user_aborted: bool = False,
) -> None: ) -> None:
if ( if not req or not req.conversation:
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
return return
if not llm_response.completion_text and not req.tool_calls_result: if not llm_response and not user_aborted:
return
if llm_response and llm_response.role != "assistant":
if not user_aborted:
return
llm_response = LLMResponse(
role="assistant",
completion_text=llm_response.completion_text or "",
)
elif llm_response is None:
llm_response = LLMResponse(role="assistant", completion_text="")
if (
not llm_response.completion_text
and not req.tool_calls_result
and not user_aborted
):
logger.debug("LLM 响应为空,不保存记录。") logger.debug("LLM 响应为空,不保存记录。")
return return
@@ -363,6 +380,14 @@ class InternalAgentSubStage(Stage):
continue continue
message_to_save.append(message.model_dump()) message_to_save.append(message.model_dump())
# if user_aborted:
# message_to_save.append(
# Message(
# role="assistant",
# content="[User aborted this request. Partial output before abort was preserved.]",
# ).model_dump()
# )
token_usage = None token_usage = None
if runner_stats: if runner_stats:
# token_usage = runner_stats.token_usage.total # token_usage = runner_stats.token_usage.total
@@ -11,13 +11,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr from .webchat_queue_mgr import webchat_queue_mgr
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
class WebChatMessageEvent(AstrMessageEvent): class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id) super().__init__(message_str, message_obj, platform_meta, session_id)
os.makedirs(imgs_dir, exist_ok=True) os.makedirs(attachments_dir, exist_ok=True)
@staticmethod @staticmethod
async def _send( async def _send(
@@ -69,7 +69,7 @@ class WebChatMessageEvent(AstrMessageEvent):
elif isinstance(comp, Image): elif isinstance(comp, Image):
# save image to local # save image to local
filename = f"{str(uuid.uuid4())}.jpg" filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(imgs_dir, filename) path = os.path.join(attachments_dir, filename)
image_base64 = await comp.convert_to_base64() image_base64 = await comp.convert_to_base64()
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(base64.b64decode(image_base64)) f.write(base64.b64decode(image_base64))
@@ -85,7 +85,7 @@ class WebChatMessageEvent(AstrMessageEvent):
elif isinstance(comp, Record): elif isinstance(comp, Record):
# save record to local # save record to local
filename = f"{str(uuid.uuid4())}.wav" filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(imgs_dir, filename) path = os.path.join(attachments_dir, filename)
record_base64 = await comp.convert_to_base64() record_base64 = await comp.convert_to_base64()
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(base64.b64decode(record_base64)) f.write(base64.b64decode(record_base64))
@@ -104,7 +104,7 @@ class WebChatMessageEvent(AstrMessageEvent):
original_name = comp.name or os.path.basename(file_path) original_name = comp.name or os.path.basename(file_path)
ext = os.path.splitext(original_name)[1] or "" ext = os.path.splitext(original_name)[1] or ""
filename = f"{uuid.uuid4()!s}{ext}" filename = f"{uuid.uuid4()!s}{ext}"
dest_path = os.path.join(imgs_dir, filename) dest_path = os.path.join(attachments_dir, filename)
shutil.copy2(file_path, dest_path) shutil.copy2(file_path, dest_path)
data = f"[FILE]{filename}" data = f"[FILE]{filename}"
await web_chat_back_queue.put( await web_chat_back_queue.put(
@@ -46,5 +46,22 @@ class ActiveEventRegistry:
count += 1 count += 1
return count return count
def request_agent_stop_all(
self,
umo: str,
exclude: AstrMessageEvent | None = None,
) -> int:
"""请求停止指定 UMO 的所有活跃事件中的 Agent 运行。
与 stop_all 不同,这里不会调用 event.stop_event()
因此不会中断事件传播,后续流程(如历史记录保存)仍可继续。
"""
count = 0
for event in list(self._events.get(umo, [])):
if event is not exclude:
event.set_extra("agent_stop_requested", True)
count += 1
return count
active_event_registry = ActiveEventRegistry() active_event_registry = ActiveEventRegistry()
+47 -9
View File
@@ -13,7 +13,9 @@ from quart import g, make_response, request, send_file
from astrbot.core import logger, sp from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.active_event_registry import active_event_registry
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -41,6 +43,7 @@ class ChatRoute(Route):
"/chat/new_session": ("GET", self.new_session), "/chat/new_session": ("GET", self.new_session),
"/chat/sessions": ("GET", self.get_sessions), "/chat/sessions": ("GET", self.get_sessions),
"/chat/get_session": ("GET", self.get_session), "/chat/get_session": ("GET", self.get_session),
"/chat/stop": ("POST", self.stop_session),
"/chat/delete_session": ("GET", self.delete_webchat_session), "/chat/delete_session": ("GET", self.delete_webchat_session),
"/chat/update_session_display_name": ( "/chat/update_session_display_name": (
"POST", "POST",
@@ -212,8 +215,13 @@ class ChatRoute(Route):
filename: 存储的文件名 filename: 存储的文件名
attach_type: 附件类型 (image, record, file, video) attach_type: 附件类型 (image, record, file, video)
""" """
file_path = os.path.join(self.attachments_dir, os.path.basename(filename)) basename = os.path.basename(filename)
if not os.path.exists(file_path): candidate_paths = [
os.path.join(self.attachments_dir, basename),
os.path.join(self.legacy_img_dir, basename),
]
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
if not file_path:
return None return None
# guess mime type # guess mime type
@@ -466,13 +474,13 @@ class ChatRoute(Route):
if tc_id in tool_calls: if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result") tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts") tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append( accumulated_parts.append(
{ {
"type": "tool_call", "type": "tool_call",
"tool_calls": [tool_calls[tc_id]], "tool_calls": [tool_calls[tc_id]],
} }
) )
tool_calls.pop(tc_id, None) tool_calls.pop(tc_id, None)
elif chain_type == "reasoning": elif chain_type == "reasoning":
accumulated_reasoning += result_text accumulated_reasoning += result_text
elif streaming: elif streaming:
@@ -603,6 +611,36 @@ class ChatRoute(Route):
response.timeout = None # fix SSE auto disconnect issue response.timeout = None # fix SSE auto disconnect issue
return response return response
async def stop_session(self):
"""Stop active agent runs for a session."""
post_data = await request.json
if post_data is None:
return Response().error("Missing JSON body").__dict__
session_id = post_data.get("session_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
message_type = (
MessageType.GROUP_MESSAGE.value
if session.is_group
else MessageType.FRIEND_MESSAGE.value
)
umo = (
f"{session.platform_id}:{message_type}:"
f"{session.platform_id}!{username}!{session_id}"
)
stopped_count = active_event_registry.request_agent_stop_all(umo)
return Response().ok(data={"stopped_count": stopped_count}).__dict__
async def delete_webchat_session(self): async def delete_webchat_session(self):
"""Delete a Platform session and all its related data.""" """Delete a Platform session and all its related data."""
session_id = request.args.get("session_id") session_id = request.args.get("session_id")
+11
View File
@@ -77,12 +77,14 @@
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles" :stagedFiles="stagedNonImageFiles"
:disabled="isStreaming" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:replyTo="replyTo" :replyTo="replyTo"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@removeImage="removeImage" @removeImage="removeImage"
@removeAudio="removeAudio" @removeAudio="removeAudio"
@@ -106,12 +108,14 @@
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles" :stagedFiles="stagedNonImageFiles"
:disabled="isStreaming" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:replyTo="replyTo" :replyTo="replyTo"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@removeImage="removeImage" @removeImage="removeImage"
@removeAudio="removeAudio" @removeAudio="removeAudio"
@@ -134,12 +138,14 @@
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles" :stagedFiles="stagedNonImageFiles"
:disabled="isStreaming" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:replyTo="replyTo" :replyTo="replyTo"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@removeImage="removeImage" @removeImage="removeImage"
@removeAudio="removeAudio" @removeAudio="removeAudio"
@@ -298,6 +304,7 @@ const {
currentSessionProject, currentSessionProject,
getSessionMessages: getSessionMsg, getSessionMessages: getSessionMsg,
sendMessage: sendMsg, sendMessage: sendMsg,
stopMessage: stopMsg,
toggleStreaming toggleStreaming
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
@@ -631,6 +638,10 @@ async function handleSendMessage() {
} }
} }
async function handleStopMessage() {
await stopMsg();
}
// //
watch( watch(
() => route.path, () => route.path,
+25 -2
View File
@@ -94,8 +94,29 @@
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }} {{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
</v-tooltip> </v-tooltip>
</v-btn> </v-btn>
<v-btn @click="$emit('send')" icon="mdi-send" variant="text" color="deep-purple" <v-btn
:disabled="!canSend" class="send-btn" size="small" /> icon
v-if="isRunning"
@click="$emit('stop')"
variant="text"
class="send-btn"
size="small"
>
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
<v-tooltip activator="parent" location="top">
{{ tm('input.stopGenerating') }}
</v-tooltip>
</v-btn>
<v-btn
v-else
@click="$emit('send')"
icon="mdi-send"
variant="text"
color="deep-purple"
:disabled="!canSend"
class="send-btn"
size="small"
/>
</div> </div>
</div> </div>
</div> </div>
@@ -160,6 +181,7 @@ interface Props {
disabled: boolean; disabled: boolean;
enableStreaming: boolean; enableStreaming: boolean;
isRecording: boolean; isRecording: boolean;
isRunning: boolean;
sessionId?: string | null; sessionId?: string | null;
currentSession?: Session | null; currentSession?: Session | null;
configId?: string | null; configId?: string | null;
@@ -177,6 +199,7 @@ const props = withDefaults(defineProps<Props>(), {
const emit = defineEmits<{ const emit = defineEmits<{
'update:prompt': [value: string]; 'update:prompt': [value: string];
send: []; send: [];
stop: [];
toggleStreaming: []; toggleStreaming: [];
removeImage: [index: number]; removeImage: [index: number];
removeAudio: []; removeAudio: [];
@@ -23,12 +23,14 @@
:stagedImagesUrl="stagedImagesUrl" :stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl" :stagedAudioUrl="stagedAudioUrl"
:disabled="isStreaming" :disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming" :enableStreaming="enableStreaming"
:isRecording="isRecording" :isRecording="isRecording"
:session-id="currSessionId || null" :session-id="currSessionId || null"
:current-session="getCurrentSession" :current-session="getCurrentSession"
:config-id="configId" :config-id="configId"
@send="handleSendMessage" @send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming" @toggleStreaming="toggleStreaming"
@removeImage="removeImage" @removeImage="removeImage"
@removeAudio="removeAudio" @removeAudio="removeAudio"
@@ -156,6 +158,7 @@ const {
enableStreaming, enableStreaming,
getSessionMessages: getSessionMsg, getSessionMessages: getSessionMsg,
sendMessage: sendMsg, sendMessage: sendMsg,
stopMessage: stopMsg,
toggleStreaming toggleStreaming
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
@@ -236,6 +239,10 @@ async function handleSendMessage() {
} }
} }
async function handleStopMessage() {
await stopMsg();
}
onMounted(async () => { onMounted(async () => {
// //
try { try {
+48 -2
View File
@@ -82,6 +82,10 @@ export function useMessages(
const activeSSECount = ref(0); const activeSSECount = ref(0);
const enableStreaming = ref(true); const enableStreaming = ref(true);
const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL
const currentRequestController = ref<AbortController | null>(null);
const currentReader = ref<ReadableStreamDefaultReader<Uint8Array> | null>(null);
const currentRunningSessionId = ref('');
const userStopRequested = ref(false);
// 当前会话的项目信息 // 当前会话的项目信息
const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null); const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null);
@@ -289,6 +293,8 @@ export function useMessages(
if (activeSSECount.value === 1) { if (activeSSECount.value === 1) {
isConvRunning.value = true; isConvRunning.value = true;
} }
userStopRequested.value = false;
currentRunningSessionId.value = currSessionId.value;
// 收集所有 attachment_id // 收集所有 attachment_id
const files = stagedFiles.map(f => f.attachment_id); const files = stagedFiles.map(f => f.attachment_id);
@@ -330,12 +336,15 @@ export function useMessages(
messageToSend = prompt; messageToSend = prompt;
} }
const controller = new AbortController();
currentRequestController.value = controller;
const response = await fetch('/api/chat/send', { const response = await fetch('/api/chat/send', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token') 'Authorization': 'Bearer ' + localStorage.getItem('token')
}, },
signal: controller.signal,
body: JSON.stringify({ body: JSON.stringify({
message: messageToSend, message: messageToSend,
session_id: currSessionId.value, session_id: currSessionId.value,
@@ -350,6 +359,7 @@ export function useMessages(
} }
const reader = response.body!.getReader(); const reader = response.body!.getReader();
currentReader.value = reader;
const decoder = new TextDecoder(); const decoder = new TextDecoder();
let in_streaming = false; let in_streaming = false;
let message_obj: MessageContent | null = null; let message_obj: MessageContent | null = null;
@@ -560,7 +570,9 @@ export function useMessages(
} }
} }
} catch (readError) { } catch (readError) {
console.error('SSE读取错误:', readError); if (!userStopRequested.value) {
console.error('SSE读取错误:', readError);
}
break; break;
} }
} }
@@ -569,7 +581,9 @@ export function useMessages(
onSessionsUpdate(); onSessionsUpdate();
} catch (err) { } catch (err) {
console.error('发送消息失败:', err); if (!userStopRequested.value) {
console.error('发送消息失败:', err);
}
// 移除加载占位符 // 移除加载占位符
const lastMsg = messages.value[messages.value.length - 1]; const lastMsg = messages.value[messages.value.length - 1];
if (lastMsg?.content?.isLoading) { if (lastMsg?.content?.isLoading) {
@@ -577,6 +591,10 @@ export function useMessages(
} }
} finally { } finally {
isStreaming.value = false; isStreaming.value = false;
currentReader.value = null;
currentRequestController.value = null;
currentRunningSessionId.value = '';
userStopRequested.value = false;
activeSSECount.value--; activeSSECount.value--;
if (activeSSECount.value === 0) { if (activeSSECount.value === 0) {
isConvRunning.value = false; isConvRunning.value = false;
@@ -584,6 +602,33 @@ export function useMessages(
} }
} }
async function stopMessage() {
const sessionId = currentRunningSessionId.value || currSessionId.value;
if (!sessionId) {
return;
}
userStopRequested.value = true;
try {
await axios.post('/api/chat/stop', {
session_id: sessionId
});
} catch (err) {
console.error('停止会话失败:', err);
}
try {
await currentReader.value?.cancel();
} catch (err) {
// ignore reader cancel failures
}
currentReader.value = null;
currentRequestController.value?.abort();
currentRequestController.value = null;
isStreaming.value = false;
}
return { return {
messages, messages,
isStreaming, isStreaming,
@@ -592,6 +637,7 @@ export function useMessages(
currentSessionProject, currentSessionProject,
getSessionMessages, getSessionMessages,
sendMessage, sendMessage,
stopMessage,
toggleStreaming, toggleStreaming,
getAttachment getAttachment
}; };
@@ -9,7 +9,8 @@
"voice": "Voice Input", "voice": "Voice Input",
"recordingPrompt": "Recording, please speak...", "recordingPrompt": "Recording, please speak...",
"chatPrompt": "Let's chat!", "chatPrompt": "Let's chat!",
"dropToUpload": "Drop files to upload" "dropToUpload": "Drop files to upload",
"stopGenerating": "Stop generating"
}, },
"message": { "message": {
"user": "User", "user": "User",
@@ -9,7 +9,8 @@
"voice": "语音输入", "voice": "语音输入",
"recordingPrompt": "录音中,请说话...", "recordingPrompt": "录音中,请说话...",
"chatPrompt": "聊天吧!", "chatPrompt": "聊天吧!",
"dropToUpload": "松开鼠标上传文件" "dropToUpload": "松开鼠标上传文件",
"stopGenerating": "停止生成"
}, },
"message": { "message": {
"user": "用户", "user": "用户",
+57
View File
@@ -105,6 +105,28 @@ class MockErrProvider(MockProvider):
) )
class MockAbortableStreamProvider(MockProvider):
async def text_chat_stream(self, **kwargs):
abort_signal = kwargs.get("abort_signal")
yield LLMResponse(
role="assistant",
completion_text="partial ",
is_chunk=True,
)
if abort_signal and abort_signal.is_set():
yield LLMResponse(
role="assistant",
completion_text="partial ",
is_chunk=False,
)
return
yield LLMResponse(
role="assistant",
completion_text="partial final",
is_chunk=False,
)
class MockHooks(BaseAgentRunHooks): class MockHooks(BaseAgentRunHooks):
"""模拟钩子函数""" """模拟钩子函数"""
@@ -394,6 +416,41 @@ async def test_fallback_provider_used_when_primary_returns_err(
assert fallback_provider.call_count == 1 assert fallback_provider.call_count == 1
@pytest.mark.asyncio
async def test_stop_signal_returns_aborted_and_persists_partial_message(
runner, provider_request, mock_tool_executor, mock_hooks
):
provider = MockAbortableStreamProvider()
await runner.reset(
provider=provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=True,
)
step_iter = runner.step()
first_resp = await step_iter.__anext__()
assert first_resp.type == "streaming_delta"
runner.request_stop()
rest_responses = []
async for response in step_iter:
rest_responses.append(response)
assert any(resp.type == "aborted" for resp in rest_responses)
assert runner.was_aborted() is True
final_resp = runner.get_final_llm_resp()
assert final_resp is not None
assert final_resp.role == "assistant"
assert final_resp.completion_text == "partial "
assert runner.run_context.messages[-1].role == "assistant"
if __name__ == "__main__": if __name__ == "__main__":
# 运行测试 # 运行测试
pytest.main([__file__, "-v"]) pytest.main([__file__, "-v"])