Compare commits

...

4 Commits

15 changed files with 385 additions and 30 deletions
@@ -102,6 +102,30 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret))
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话正在运行的 Agent"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
umo = message.unified_msg_origin
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
stopped_count = active_event_registry.stop_all(umo, exclude=message)
else:
stopped_count = active_event_registry.request_agent_stop_all(
umo,
exclude=message,
)
if stopped_count > 0:
message.set_result(
MessageEventResult().message(
f"已请求停止 {stopped_count} 个运行中的任务。"
)
)
return
message.set_result(MessageEventResult().message("当前会话没有运行中的任务。"))
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
@@ -132,6 +132,11 @@ class Main(star.Star):
"""重置 LLM 会话"""
await self.conversation_c.reset(message)
@filter.command("stop")
async def stop(self, message: AstrMessageEvent) -> None:
"""停止当前会话中正在运行的 Agent"""
await self.conversation_c.stop(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("model")
async def model_ls(
@@ -137,6 +137,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.tool_executor = tool_executor
self.agent_hooks = agent_hooks
self.run_context = run_context
self._stop_requested = False
self._aborted = False
# These two are used for tool schema mode handling
# We now have two modes:
@@ -328,6 +330,14 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
),
)
if self._stop_requested:
llm_resp_result = LLMResponse(
role="assistant",
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
reasoning_content=llm_response.reasoning_content,
reasoning_signature=llm_response.reasoning_signature,
)
break
continue
llm_resp_result = llm_response
@@ -339,6 +349,48 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
break # got final response
if not llm_resp_result:
if self._stop_requested:
llm_resp_result = LLMResponse(role="assistant", completion_text="")
else:
return
if self._stop_requested:
logger.info("Agent execution was requested to stop by user.")
llm_resp = llm_resp_result
if llm_resp.role != "assistant":
llm_resp = LLMResponse(
role="assistant",
completion_text="[SYSTEM: User actively interrupted the response generation. Partial output before interruption is preserved.]",
)
self.final_llm_resp = llm_resp
self._aborted = True
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
if llm_resp.completion_text:
parts.append(TextPart(text=llm_resp.completion_text))
if parts:
self.run_context.messages.append(
Message(role="assistant", content=parts)
)
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
yield AgentResponse(
type="aborted",
data=AgentResponseData(chain=MessageChain(type="aborted")),
)
return
# 处理 LLM 响应
@@ -848,5 +900,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
def request_stop(self) -> None:
self._stop_requested = True
def was_aborted(self) -> bool:
return self._aborted
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp
+43 -1
View File
@@ -20,6 +20,10 @@ from astrbot.core.provider.provider import TTSProvider
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
def _should_stop_agent(astr_event) -> bool:
return astr_event.is_stopped() or bool(astr_event.get_extra("agent_stop_requested"))
async def run_agent(
agent_runner: AgentRunner,
max_step: int = 30,
@@ -48,10 +52,28 @@ async def run_agent(
)
)
stop_watcher = asyncio.create_task(
_watch_agent_stop_signal(agent_runner, astr_event),
)
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
if _should_stop_agent(astr_event):
agent_runner.request_stop()
if resp.type == "aborted":
if not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
astr_event.set_extra("agent_user_aborted", True)
astr_event.set_extra("agent_stop_requested", False)
return
if _should_stop_agent(astr_event):
continue
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
@@ -120,6 +142,12 @@ async def run_agent(
# display the reasoning content only when configured
continue
yield resp.data["chain"] # MessageChain
if not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
if agent_runner.done():
# send agent stats to webchat
if astr_event.get_platform_name() == "webchat":
@@ -133,6 +161,12 @@ async def run_agent(
break
except Exception as e:
if "stop_watcher" in locals() and not stop_watcher.done():
stop_watcher.cancel()
try:
await stop_watcher
except asyncio.CancelledError:
pass
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
@@ -155,6 +189,14 @@ async def run_agent(
return
async def _watch_agent_stop_signal(agent_runner: AgentRunner, astr_event) -> None:
while not agent_runner.done():
if _should_stop_agent(astr_event):
agent_runner.request_stop()
return
await asyncio.sleep(0.5)
async def run_live_agent(
agent_runner: AgentRunner,
tts_provider: TTSProvider | None = None,
@@ -247,13 +247,16 @@ class InternalAgentSubStage(Stage):
yield
# 保存历史记录
if not event.is_stopped() and agent_runner.done():
if agent_runner.done() and (
not event.is_stopped() or agent_runner.was_aborted()
):
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)
elif streaming_response and not stream_to_general:
@@ -308,13 +311,14 @@ class InternalAgentSubStage(Stage):
)
# 检查事件是否被停止,如果被停止则不保存历史记录
if not event.is_stopped():
if not event.is_stopped() or agent_runner.was_aborted():
await self._save_to_history(
event,
req,
final_resp,
agent_runner.run_context.messages,
agent_runner.stats,
user_aborted=agent_runner.was_aborted(),
)
asyncio.create_task(
@@ -340,16 +344,29 @@ class InternalAgentSubStage(Stage):
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
user_aborted: bool = False,
) -> None:
if (
not req
or not req.conversation
or not llm_response
or llm_response.role != "assistant"
):
if not req or not req.conversation:
return
if not llm_response.completion_text and not req.tool_calls_result:
if not llm_response and not user_aborted:
return
if llm_response and llm_response.role != "assistant":
if not user_aborted:
return
llm_response = LLMResponse(
role="assistant",
completion_text=llm_response.completion_text or "",
)
elif llm_response is None:
llm_response = LLMResponse(role="assistant", completion_text="")
if (
not llm_response.completion_text
and not req.tool_calls_result
and not user_aborted
):
logger.debug("LLM 响应为空,不保存记录。")
return
@@ -363,6 +380,14 @@ class InternalAgentSubStage(Stage):
continue
message_to_save.append(message.model_dump())
# if user_aborted:
# message_to_save.append(
# Message(
# role="assistant",
# content="[User aborted this request. Partial output before abort was preserved.]",
# ).model_dump()
# )
token_usage = None
if runner_stats:
# token_usage = runner_stats.token_usage.total
@@ -11,13 +11,13 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr
imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
os.makedirs(imgs_dir, exist_ok=True)
os.makedirs(attachments_dir, exist_ok=True)
@staticmethod
async def _send(
@@ -69,7 +69,7 @@ class WebChatMessageEvent(AstrMessageEvent):
elif isinstance(comp, Image):
# save image to local
filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(imgs_dir, filename)
path = os.path.join(attachments_dir, filename)
image_base64 = await comp.convert_to_base64()
with open(path, "wb") as f:
f.write(base64.b64decode(image_base64))
@@ -85,7 +85,7 @@ class WebChatMessageEvent(AstrMessageEvent):
elif isinstance(comp, Record):
# save record to local
filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(imgs_dir, filename)
path = os.path.join(attachments_dir, filename)
record_base64 = await comp.convert_to_base64()
with open(path, "wb") as f:
f.write(base64.b64decode(record_base64))
@@ -104,7 +104,7 @@ class WebChatMessageEvent(AstrMessageEvent):
original_name = comp.name or os.path.basename(file_path)
ext = os.path.splitext(original_name)[1] or ""
filename = f"{uuid.uuid4()!s}{ext}"
dest_path = os.path.join(imgs_dir, filename)
dest_path = os.path.join(attachments_dir, filename)
shutil.copy2(file_path, dest_path)
data = f"[FILE]{filename}"
await web_chat_back_queue.put(
@@ -46,5 +46,22 @@ class ActiveEventRegistry:
count += 1
return count
def request_agent_stop_all(
self,
umo: str,
exclude: AstrMessageEvent | None = None,
) -> int:
"""请求停止指定 UMO 的所有活跃事件中的 Agent 运行。
与 stop_all 不同,这里不会调用 event.stop_event()
因此不会中断事件传播,后续流程(如历史记录保存)仍可继续。
"""
count = 0
for event in list(self._events.get(umo, [])):
if event is not exclude:
event.set_extra("agent_stop_requested", True)
count += 1
return count
active_event_registry = ActiveEventRegistry()
+47 -9
View File
@@ -13,7 +13,9 @@ from quart import g, make_response, request, send_file
from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_type import MessageType
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.active_event_registry import active_event_registry
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .route import Response, Route, RouteContext
@@ -41,6 +43,7 @@ class ChatRoute(Route):
"/chat/new_session": ("GET", self.new_session),
"/chat/sessions": ("GET", self.get_sessions),
"/chat/get_session": ("GET", self.get_session),
"/chat/stop": ("POST", self.stop_session),
"/chat/delete_session": ("GET", self.delete_webchat_session),
"/chat/update_session_display_name": (
"POST",
@@ -212,8 +215,13 @@ class ChatRoute(Route):
filename: 存储的文件名
attach_type: 附件类型 (image, record, file, video)
"""
file_path = os.path.join(self.attachments_dir, os.path.basename(filename))
if not os.path.exists(file_path):
basename = os.path.basename(filename)
candidate_paths = [
os.path.join(self.attachments_dir, basename),
os.path.join(self.legacy_img_dir, basename),
]
file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
if not file_path:
return None
# guess mime type
@@ -466,13 +474,13 @@ class ChatRoute(Route):
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{
"type": "tool_call",
"tool_calls": [tool_calls[tc_id]],
}
)
tool_calls.pop(tc_id, None)
accumulated_parts.append(
{
"type": "tool_call",
"tool_calls": [tool_calls[tc_id]],
}
)
tool_calls.pop(tc_id, None)
elif chain_type == "reasoning":
accumulated_reasoning += result_text
elif streaming:
@@ -603,6 +611,36 @@ class ChatRoute(Route):
response.timeout = None # fix SSE auto disconnect issue
return response
async def stop_session(self):
"""Stop active agent runs for a session."""
post_data = await request.json
if post_data is None:
return Response().error("Missing JSON body").__dict__
session_id = post_data.get("session_id")
if not session_id:
return Response().error("Missing key: session_id").__dict__
username = g.get("username", "guest")
session = await self.db.get_platform_session_by_id(session_id)
if not session:
return Response().error(f"Session {session_id} not found").__dict__
if session.creator != username:
return Response().error("Permission denied").__dict__
message_type = (
MessageType.GROUP_MESSAGE.value
if session.is_group
else MessageType.FRIEND_MESSAGE.value
)
umo = (
f"{session.platform_id}:{message_type}:"
f"{session.platform_id}!{username}!{session_id}"
)
stopped_count = active_event_registry.request_agent_stop_all(umo)
return Response().ok(data={"stopped_count": stopped_count}).__dict__
async def delete_webchat_session(self):
"""Delete a Platform session and all its related data."""
session_id = request.args.get("session_id")
+11
View File
@@ -77,12 +77,14 @@
:stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles"
:disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming"
:isRecording="isRecording"
:session-id="currSessionId || null"
:current-session="getCurrentSession"
:replyTo="replyTo"
@send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming"
@removeImage="removeImage"
@removeAudio="removeAudio"
@@ -106,12 +108,14 @@
:stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles"
:disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming"
:isRecording="isRecording"
:session-id="currSessionId || null"
:current-session="getCurrentSession"
:replyTo="replyTo"
@send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming"
@removeImage="removeImage"
@removeAudio="removeAudio"
@@ -134,12 +138,14 @@
:stagedAudioUrl="stagedAudioUrl"
:stagedFiles="stagedNonImageFiles"
:disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming"
:isRecording="isRecording"
:session-id="currSessionId || null"
:current-session="getCurrentSession"
:replyTo="replyTo"
@send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming"
@removeImage="removeImage"
@removeAudio="removeAudio"
@@ -298,6 +304,7 @@ const {
currentSessionProject,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
stopMessage: stopMsg,
toggleStreaming
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
@@ -631,6 +638,10 @@ async function handleSendMessage() {
}
}
async function handleStopMessage() {
await stopMsg();
}
// 路由变化监听
watch(
() => route.path,
+25 -2
View File
@@ -94,8 +94,29 @@
{{ isRecording ? tm('voice.speaking') : tm('voice.startRecording') }}
</v-tooltip>
</v-btn>
<v-btn @click="$emit('send')" icon="mdi-send" variant="text" color="deep-purple"
:disabled="!canSend" class="send-btn" size="small" />
<v-btn
icon
v-if="isRunning"
@click="$emit('stop')"
variant="text"
class="send-btn"
size="small"
>
<v-icon icon="mdi-stop" variant="text" plain></v-icon>
<v-tooltip activator="parent" location="top">
{{ tm('input.stopGenerating') }}
</v-tooltip>
</v-btn>
<v-btn
v-else
@click="$emit('send')"
icon="mdi-send"
variant="text"
color="deep-purple"
:disabled="!canSend"
class="send-btn"
size="small"
/>
</div>
</div>
</div>
@@ -160,6 +181,7 @@ interface Props {
disabled: boolean;
enableStreaming: boolean;
isRecording: boolean;
isRunning: boolean;
sessionId?: string | null;
currentSession?: Session | null;
configId?: string | null;
@@ -177,6 +199,7 @@ const props = withDefaults(defineProps<Props>(), {
const emit = defineEmits<{
'update:prompt': [value: string];
send: [];
stop: [];
toggleStreaming: [];
removeImage: [index: number];
removeAudio: [];
@@ -23,12 +23,14 @@
:stagedImagesUrl="stagedImagesUrl"
:stagedAudioUrl="stagedAudioUrl"
:disabled="isStreaming"
:is-running="isStreaming || isConvRunning"
:enableStreaming="enableStreaming"
:isRecording="isRecording"
:session-id="currSessionId || null"
:current-session="getCurrentSession"
:config-id="configId"
@send="handleSendMessage"
@stop="handleStopMessage"
@toggleStreaming="toggleStreaming"
@removeImage="removeImage"
@removeAudio="removeAudio"
@@ -156,6 +158,7 @@ const {
enableStreaming,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
stopMessage: stopMsg,
toggleStreaming
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
@@ -236,6 +239,10 @@ async function handleSendMessage() {
}
}
async function handleStopMessage() {
await stopMsg();
}
onMounted(async () => {
// 独立模式在挂载时创建新会话
try {
+48 -2
View File
@@ -82,6 +82,10 @@ export function useMessages(
const activeSSECount = ref(0);
const enableStreaming = ref(true);
const attachmentCache = new Map<string, string>(); // attachment_id -> blob URL
const currentRequestController = ref<AbortController | null>(null);
const currentReader = ref<ReadableStreamDefaultReader<Uint8Array> | null>(null);
const currentRunningSessionId = ref('');
const userStopRequested = ref(false);
// 当前会话的项目信息
const currentSessionProject = ref<{ project_id: string; title: string; emoji: string } | null>(null);
@@ -289,6 +293,8 @@ export function useMessages(
if (activeSSECount.value === 1) {
isConvRunning.value = true;
}
userStopRequested.value = false;
currentRunningSessionId.value = currSessionId.value;
// 收集所有 attachment_id
const files = stagedFiles.map(f => f.attachment_id);
@@ -330,12 +336,15 @@ export function useMessages(
messageToSend = prompt;
}
const controller = new AbortController();
currentRequestController.value = controller;
const response = await fetch('/api/chat/send', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
signal: controller.signal,
body: JSON.stringify({
message: messageToSend,
session_id: currSessionId.value,
@@ -350,6 +359,7 @@ export function useMessages(
}
const reader = response.body!.getReader();
currentReader.value = reader;
const decoder = new TextDecoder();
let in_streaming = false;
let message_obj: MessageContent | null = null;
@@ -560,7 +570,9 @@ export function useMessages(
}
}
} catch (readError) {
console.error('SSE读取错误:', readError);
if (!userStopRequested.value) {
console.error('SSE读取错误:', readError);
}
break;
}
}
@@ -569,7 +581,9 @@ export function useMessages(
onSessionsUpdate();
} catch (err) {
console.error('发送消息失败:', err);
if (!userStopRequested.value) {
console.error('发送消息失败:', err);
}
// 移除加载占位符
const lastMsg = messages.value[messages.value.length - 1];
if (lastMsg?.content?.isLoading) {
@@ -577,6 +591,10 @@ export function useMessages(
}
} finally {
isStreaming.value = false;
currentReader.value = null;
currentRequestController.value = null;
currentRunningSessionId.value = '';
userStopRequested.value = false;
activeSSECount.value--;
if (activeSSECount.value === 0) {
isConvRunning.value = false;
@@ -584,6 +602,33 @@ export function useMessages(
}
}
async function stopMessage() {
const sessionId = currentRunningSessionId.value || currSessionId.value;
if (!sessionId) {
return;
}
userStopRequested.value = true;
try {
await axios.post('/api/chat/stop', {
session_id: sessionId
});
} catch (err) {
console.error('停止会话失败:', err);
}
try {
await currentReader.value?.cancel();
} catch (err) {
// ignore reader cancel failures
}
currentReader.value = null;
currentRequestController.value?.abort();
currentRequestController.value = null;
isStreaming.value = false;
}
return {
messages,
isStreaming,
@@ -592,6 +637,7 @@ export function useMessages(
currentSessionProject,
getSessionMessages,
sendMessage,
stopMessage,
toggleStreaming,
getAttachment
};
@@ -9,7 +9,8 @@
"voice": "Voice Input",
"recordingPrompt": "Recording, please speak...",
"chatPrompt": "Let's chat!",
"dropToUpload": "Drop files to upload"
"dropToUpload": "Drop files to upload",
"stopGenerating": "Stop generating"
},
"message": {
"user": "User",
@@ -9,7 +9,8 @@
"voice": "语音输入",
"recordingPrompt": "录音中,请说话...",
"chatPrompt": "聊天吧!",
"dropToUpload": "松开鼠标上传文件"
"dropToUpload": "松开鼠标上传文件",
"stopGenerating": "停止生成"
},
"message": {
"user": "用户",
+57
View File
@@ -105,6 +105,28 @@ class MockErrProvider(MockProvider):
)
class MockAbortableStreamProvider(MockProvider):
async def text_chat_stream(self, **kwargs):
abort_signal = kwargs.get("abort_signal")
yield LLMResponse(
role="assistant",
completion_text="partial ",
is_chunk=True,
)
if abort_signal and abort_signal.is_set():
yield LLMResponse(
role="assistant",
completion_text="partial ",
is_chunk=False,
)
return
yield LLMResponse(
role="assistant",
completion_text="partial final",
is_chunk=False,
)
class MockHooks(BaseAgentRunHooks):
"""模拟钩子函数"""
@@ -394,6 +416,41 @@ async def test_fallback_provider_used_when_primary_returns_err(
assert fallback_provider.call_count == 1
@pytest.mark.asyncio
async def test_stop_signal_returns_aborted_and_persists_partial_message(
runner, provider_request, mock_tool_executor, mock_hooks
):
provider = MockAbortableStreamProvider()
await runner.reset(
provider=provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=mock_tool_executor,
agent_hooks=mock_hooks,
streaming=True,
)
step_iter = runner.step()
first_resp = await step_iter.__anext__()
assert first_resp.type == "streaming_delta"
runner.request_stop()
rest_responses = []
async for response in step_iter:
rest_responses.append(response)
assert any(resp.type == "aborted" for resp in rest_responses)
assert runner.was_aborted() is True
final_resp = runner.get_final_llm_resp()
assert final_resp is not None
assert final_resp.role == "assistant"
assert final_resp.completion_text == "partial "
assert runner.run_context.messages[-1].role == "assistant"
if __name__ == "__main__":
# 运行测试
pytest.main([__file__, "-v"])