Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7cedf0d587 | |||
| aeb21f719e | |||
| 7c1dbecea5 | |||
| 05012af627 | |||
| 17b52ab5dd | |||
| 9449ff668b | |||
| c5a2827def | |||
| 701399c00c | |||
| eaee98d4b8 | |||
| 76c66000a7 | |||
| 4b365143c0 | |||
| 6e4e5011e2 | |||
| d853bfde84 | |||
| a0e856f80f | |||
| 8c94a0010c | |||
| a44fdaaec0 | |||
| 60105c76f5 | |||
| bcf87d3ce4 | |||
| 4d7c8c8453 | |||
| a064a9115f | |||
| 6ef99e1553 | |||
| c0dbe5cf65 | |||
| 3598c51eff | |||
| b5cdb8f650 | |||
| fc5b520f9b | |||
| 904f56b32f |
@@ -1,4 +1,4 @@
|
||||

|
||||

|
||||
|
||||
<div align="center">
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.10.0-alpha.1"
|
||||
__version__ = "4.10.2"
|
||||
|
||||
@@ -76,12 +76,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -165,7 +173,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
content=llm_resp.completion_text or "*No response*",
|
||||
),
|
||||
)
|
||||
try:
|
||||
@@ -230,6 +238,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -376,35 +403,33 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具没有返回值或者将结果直接发送给了用户*",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -426,6 +451,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
@@ -2,6 +2,7 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
@@ -24,8 +25,25 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
|
||||
@@ -209,12 +209,42 @@ async def call_local_llm_tool(
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
raise Exception(f"Tool execution ValueError: {e}") from e
|
||||
except TypeError as e:
|
||||
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
|
||||
try:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
# 跳过第一个参数(event 或 context)
|
||||
if params:
|
||||
params = params[1:]
|
||||
|
||||
param_strs = []
|
||||
for param in params:
|
||||
param_str = param.name
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# 获取类型注解的字符串表示
|
||||
if isinstance(param.annotation, type):
|
||||
type_str = param.annotation.__name__
|
||||
else:
|
||||
type_str = str(param.annotation)
|
||||
param_str += f": {type_str}"
|
||||
if param.default != inspect.Parameter.empty:
|
||||
param_str += f" = {param.default!r}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
handler_param_str = (
|
||||
", ".join(param_strs) if param_strs else "(no additional parameters)"
|
||||
)
|
||||
except Exception:
|
||||
handler_param_str = "(unable to inspect signature)"
|
||||
|
||||
raise Exception(
|
||||
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.10.0-alpha.1"
|
||||
VERSION = "4.10.2"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
|
||||
@@ -321,7 +321,12 @@ class InternalAgentSubStage(Stage):
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text or "*No response*",
|
||||
}
|
||||
)
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
|
||||
@@ -385,10 +385,25 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
|
||||
message_str += "".join(at_parts)
|
||||
elif t == "markdown":
|
||||
text = m["data"].get("markdown") or m["data"].get("content", "")
|
||||
abm.message.append(Plain(text=text))
|
||||
message_str += text
|
||||
else:
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
try:
|
||||
if t not in ComponentTypes:
|
||||
logger.warning(
|
||||
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
|
||||
)
|
||||
continue
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"消息段解析失败: type={t}, data={m['data']}. {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -14,6 +14,7 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import (
|
||||
AssistantMessageSegment,
|
||||
ContentPart,
|
||||
ToolCall,
|
||||
ToolCallMessageSegment,
|
||||
)
|
||||
@@ -92,6 +93,8 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
|
||||
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。"""
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
@@ -166,13 +169,23 @@ class ProviderRequest:
|
||||
|
||||
async def assemble_context(self) -> dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if self.prompt and self.prompt.strip():
|
||||
content_blocks.append({"type": "text", "text": self.prompt})
|
||||
elif self.image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if self.extra_user_content_parts:
|
||||
for part in self.extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -185,11 +198,21 @@ class ProviderRequest:
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
and not self.extra_user_content_parts
|
||||
and not self.image_urls
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的用户内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
@@ -133,6 +135,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||
@@ -144,6 +147,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的用户内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
|
||||
@@ -11,6 +11,7 @@ from anthropic.types.usage import Usage
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
@@ -296,13 +297,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -350,13 +354,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
):
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -388,48 +395,116 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for block in extra_user_content_parts:
|
||||
block_type = block.get("type")
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
if block_type == "text":
|
||||
# 文本直接添加
|
||||
content.append(block)
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
elif block_type == "image_url":
|
||||
# 转换 OpenAI 格式的图片为 Anthropic 格式
|
||||
image_url_data = block.get("image_url", {})
|
||||
if isinstance(image_url_data, dict):
|
||||
url = image_url_data.get("url", "")
|
||||
else:
|
||||
# 兼容直接传 URL 字符串的情况
|
||||
url = str(image_url_data)
|
||||
|
||||
if url and url.startswith("data:"):
|
||||
try:
|
||||
# 提取 MIME 类型和 base64 数据
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
base64_data = (
|
||||
url.split("base64,")[1] if "base64," in url else url
|
||||
)
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": base64_data,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换 image_url 到 Anthropic 格式失败: {e}")
|
||||
else:
|
||||
logger.warning(f"image_url 不是有效的 data URI: {url[:50]}...")
|
||||
|
||||
else:
|
||||
# 其他类型(如 audio_url)Anthropic 不支持,记录警告
|
||||
logger.debug(f"Anthropic 不支持的内容类型 '{block_type}',已忽略")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content) == 1
|
||||
and content[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
|
||||
@@ -13,6 +13,7 @@ from google.genai.errors import APIError
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
@@ -138,7 +139,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalities = ["TEXT"]
|
||||
|
||||
tool_list: list[types.Tool] | None = []
|
||||
model_name = payloads.get("model", self.get_model())
|
||||
model_name = cast(str, payloads.get("model", self.get_model()))
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
@@ -199,7 +200,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
# oper thinking config
|
||||
thinking_config = None
|
||||
if model_name.startswith("gemini-2.5"):
|
||||
if model_name in [
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-preview",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-lite-preview",
|
||||
"gemini-robotics-er-1.5-preview",
|
||||
"gemini-live-2.5-flash-preview-native-audio-09-2025",
|
||||
]:
|
||||
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
|
||||
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
@@ -208,7 +218,14 @@ class ProviderGoogleGenAI(Provider):
|
||||
thinking_config = types.ThinkingConfig(
|
||||
thinking_budget=thinking_budget,
|
||||
)
|
||||
elif model_name.startswith("gemini-3"):
|
||||
elif model_name in [
|
||||
"gemini-3-pro",
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3-flash-lite",
|
||||
"gemini-3-flash-lite-preview",
|
||||
]:
|
||||
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
|
||||
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
|
||||
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
@@ -664,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -716,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -781,13 +804,33 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -800,14 +843,25 @@ class ProviderGoogleGenAI(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -17,7 +17,7 @@ from openai.types.completion_usage import CompletionUsage
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
@@ -348,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -355,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -476,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
@@ -485,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -539,6 +544,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
@@ -549,6 +555,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -624,13 +631,29 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -643,14 +666,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core import db_helper, logger
|
||||
from astrbot.core.db.po import CommandConfig
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
@@ -90,6 +90,7 @@ async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescri
|
||||
async def rename_command(
|
||||
handler_full_name: str,
|
||||
new_fragment: str,
|
||||
aliases: list[str] | None = None,
|
||||
) -> CommandDescriptor:
|
||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||
if not descriptor:
|
||||
@@ -99,9 +100,24 @@ async def rename_command(
|
||||
if not new_fragment:
|
||||
raise ValueError("指令名不能为空。")
|
||||
|
||||
# 校验主指令名
|
||||
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
|
||||
if _is_command_in_use(handler_full_name, candidate_full):
|
||||
raise ValueError("新的指令名已被其他指令占用,请换一个名称。")
|
||||
raise ValueError(f"指令名 '{candidate_full}' 已被其他指令占用。")
|
||||
|
||||
# 校验别名
|
||||
if aliases:
|
||||
for alias in aliases:
|
||||
alias = alias.strip()
|
||||
if not alias:
|
||||
continue
|
||||
alias_full = _compose_command(descriptor.parent_signature, alias)
|
||||
if _is_command_in_use(handler_full_name, alias_full):
|
||||
raise ValueError(f"别名 '{alias_full}' 已被其他指令占用。")
|
||||
|
||||
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
||||
merged_extra = dict(existing_cfg.extra_data or {}) if existing_cfg else {}
|
||||
merged_extra["resolved_aliases"] = aliases or []
|
||||
|
||||
config = await db_helper.upsert_command_config(
|
||||
handler_full_name=handler_full_name,
|
||||
@@ -114,7 +130,7 @@ async def rename_command(
|
||||
conflict_key=descriptor.original_command,
|
||||
resolution_strategy="manual_rename",
|
||||
note=None,
|
||||
extra_data=None,
|
||||
extra_data=merged_extra,
|
||||
auto_managed=False,
|
||||
)
|
||||
_bind_descriptor_with_config(descriptor, config)
|
||||
@@ -192,12 +208,18 @@ def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
|
||||
"""收集指令,按需包含子指令。"""
|
||||
descriptors: list[CommandDescriptor] = []
|
||||
for handler in star_handlers_registry:
|
||||
desc = _build_descriptor(handler)
|
||||
if not desc:
|
||||
try:
|
||||
desc = _build_descriptor(handler)
|
||||
if not desc:
|
||||
continue
|
||||
if not include_sub_commands and desc.is_sub_command:
|
||||
continue
|
||||
descriptors.append(desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"解析指令处理函数 {handler.handler_full_name} 失败,跳过该指令。原因: {e!s}"
|
||||
)
|
||||
continue
|
||||
if not include_sub_commands and desc.is_sub_command:
|
||||
continue
|
||||
descriptors.append(desc)
|
||||
return descriptors
|
||||
|
||||
|
||||
@@ -357,14 +379,27 @@ def _apply_config_to_descriptor(
|
||||
new_fragment,
|
||||
)
|
||||
|
||||
extra = config.extra_data or {}
|
||||
resolved_aliases = extra.get("resolved_aliases")
|
||||
if isinstance(resolved_aliases, list):
|
||||
descriptor.aliases = [str(x) for x in resolved_aliases if str(x).strip()]
|
||||
|
||||
|
||||
def _apply_config_to_runtime(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
descriptor.handler.enabled = config.enabled
|
||||
if descriptor.filter_ref and descriptor.current_fragment:
|
||||
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
|
||||
if descriptor.filter_ref:
|
||||
if descriptor.current_fragment:
|
||||
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
|
||||
extra = config.extra_data or {}
|
||||
resolved_aliases = extra.get("resolved_aliases")
|
||||
if isinstance(resolved_aliases, list):
|
||||
_set_filter_aliases(
|
||||
descriptor.filter_ref,
|
||||
[str(x) for x in resolved_aliases if str(x).strip()],
|
||||
)
|
||||
|
||||
|
||||
def _bind_configs_to_descriptors(
|
||||
@@ -403,6 +438,18 @@ def _set_filter_fragment(
|
||||
filter_ref._cmpl_cmd_names = None
|
||||
|
||||
|
||||
def _set_filter_aliases(
|
||||
filter_ref: CommandFilter | CommandGroupFilter,
|
||||
aliases: list[str],
|
||||
) -> None:
|
||||
current_aliases = getattr(filter_ref, "alias", set())
|
||||
if set(aliases) == current_aliases:
|
||||
return
|
||||
setattr(filter_ref, "alias", set(aliases))
|
||||
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
||||
filter_ref._cmpl_cmd_names = None
|
||||
|
||||
|
||||
def _is_command_in_use(
|
||||
target_handler_full_name: str,
|
||||
candidate_full_command: str,
|
||||
|
||||
@@ -631,7 +631,11 @@ class PluginManager:
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
await sync_command_configs()
|
||||
try:
|
||||
await sync_command_configs()
|
||||
except Exception as e:
|
||||
logger.error(f"同步指令配置失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if not fail_rec:
|
||||
return True, None
|
||||
|
||||
@@ -436,7 +436,7 @@ class ChatRoute(Route):
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
tool_calls = {}
|
||||
# tool_calls = {}
|
||||
agent_stats = {}
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
|
||||
@@ -61,12 +61,13 @@ class CommandRoute(Route):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
new_name = data.get("new_name")
|
||||
aliases = data.get("aliases")
|
||||
|
||||
if not handler_full_name or not new_name:
|
||||
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
|
||||
|
||||
try:
|
||||
await rename_command_service(handler_full_name, new_name)
|
||||
await rename_command_service(handler_full_name, new_name, aliases=aliases)
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
|
||||
@@ -185,23 +185,30 @@ class ConfigRoute(Route):
|
||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
||||
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
|
||||
"/config/provider_sources/<provider_source_id>/models": (
|
||||
"/config/provider_sources/models": (
|
||||
"GET",
|
||||
self.get_provider_source_models,
|
||||
),
|
||||
"/config/provider_sources/<provider_source_id>/update": (
|
||||
"/config/provider_sources/update": (
|
||||
"POST",
|
||||
self.update_provider_source,
|
||||
),
|
||||
"/config/provider_sources/<provider_source_id>/delete": (
|
||||
"/config/provider_sources/delete": (
|
||||
"POST",
|
||||
self.delete_provider_source,
|
||||
),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def delete_provider_source(self, provider_source_id: str):
|
||||
async def delete_provider_source(self):
|
||||
"""删除 provider_source,并更新关联的 providers"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
provider_source_id = post_data.get("id")
|
||||
if not provider_source_id:
|
||||
return Response().error("缺少 provider_source_id").__dict__
|
||||
|
||||
provider_sources = self.config.get("provider_sources", [])
|
||||
target_idx = next(
|
||||
@@ -235,15 +242,16 @@ class ConfigRoute(Route):
|
||||
|
||||
return Response().ok(message="删除 provider source 成功").__dict__
|
||||
|
||||
async def update_provider_source(self, provider_source_id: str):
|
||||
async def update_provider_source(self):
|
||||
"""更新或新增 provider_source,并重载关联的 providers"""
|
||||
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
new_source_config = post_data.get("config") or post_data
|
||||
original_id = provider_source_id
|
||||
original_id = post_data.get("original_id")
|
||||
if not original_id:
|
||||
return Response().error("缺少 original_id").__dict__
|
||||
|
||||
if not isinstance(new_source_config, dict):
|
||||
return Response().error("缺少或错误的配置数据").__dict__
|
||||
@@ -684,11 +692,15 @@ class ConfigRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取嵌入维度失败: {e!s}").__dict__
|
||||
|
||||
async def get_provider_source_models(self, provider_source_id: str):
|
||||
async def get_provider_source_models(self):
|
||||
"""获取指定 provider_source 支持的模型列表
|
||||
|
||||
本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例
|
||||
"""
|
||||
provider_source_id = request.args.get("source_id")
|
||||
if not provider_source_id:
|
||||
return Response().error("缺少参数 source_id").__dict__
|
||||
|
||||
try:
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 该版本为 alpha.2 预览版本。
|
||||
> 2. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 3. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
|
||||
## alpha.1 -> alpha.2
|
||||
|
||||
- 修复:“对话数据”页对话轨迹详情显示异常的问题
|
||||
- 优化:当 Agent 达到最大步数时的处理。在达到最大步数后,会移除所有请求中的 tools 并告知模型根据上下文进行最终总结。
|
||||
- 优化:LLM tools 执行的错误处理,减少工具调用无限循环的问题。
|
||||
- 优化:ChatUI 打开模型选择菜单时,会重新获取提供商配置。
|
||||
- 优化:ChatUI 新建对话并发送消息后,对话列表页自动选中该对话。
|
||||
|
||||
## 4.10.0 变化
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
@@ -0,0 +1,40 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 2. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
> 3. **升级后请务必确保 WebUI 和 AstrBot Core 版本一致**,否则会产生预期之外的情况。(判断方法:日志中出现 `WebUI 版本已是最新。` 即为一致的版本,`检测到 WebUI 版本 (xxx) 与当前 AstrBot 版本 (xxx) 不符。` 即为不一致的版本。此版本的判断方法也可通查看 WebUI 右上角是否出现 Bot / Chat 的切换按钮控件来判断是否是新版本的 WebUI)。
|
||||
> 4. 如果有任何问题请提交 [Issue](https://github.com/AstrBotDevs/AstrBot/issues) 并附带 `v4.10.0` tag。
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
- 优化当 Agent 达到最大步数时的处理。在达到最大步数后,会移除所有请求中的 tools 并告知模型根据上下文进行最终总结。
|
||||
- 优化 LLM tools 执行的错误处理,减少工具调用无限循环的问题。
|
||||
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
|
||||
Merry Christmas!
|
||||
@@ -0,0 +1,46 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 2. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
> 3. **升级后请务必确保 WebUI 和 AstrBot Core 版本一致**,否则会产生预期之外的情况。(判断方法:日志中出现 `WebUI 版本已是最新。` 即为一致的版本,`检测到 WebUI 版本 (xxx) 与当前 AstrBot 版本 (xxx) 不符。` 即为不一致的版本。此版本的判断方法也可通查看 WebUI 右上角是否出现 Bot / Chat 的切换按钮控件来判断是否是新版本的 WebUI)。
|
||||
> 4. 如果有任何问题请提交 [Issue](https://github.com/AstrBotDevs/AstrBot/issues) 并附带 `v4.10.0` tag。
|
||||
|
||||
## 4.10.0 -> 4.10.1
|
||||
|
||||
- fix(core): 修复极少数情况下由于指令管理导致的 AstrBot 启动失败的问题
|
||||
- fix(core): 修复当提供商源带有斜杠(“/”)时,无法删除 / 更新提供商源的问题(报错 405)
|
||||
- perf(core): 优化 OneBot 适配器的消息段解析逻辑,修复部分情况下无法正确解析消息段的问题
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
- 优化当 Agent 达到最大步数时的处理。在达到最大步数后,会移除所有请求中的 tools 并告知模型根据上下文进行最终总结。
|
||||
- 优化 LLM tools 执行的错误处理,减少工具调用无限循环的问题。
|
||||
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
|
||||
Merry Christmas!
|
||||
@@ -0,0 +1,9 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
1. ‼️‼️ 修复了由 `psutil` 新版本导致的启动时报错的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
1. 插件指令管理支持管理别名。
|
||||
@@ -310,7 +310,7 @@ async function handleSelectConversation(sessionIds: string[]) {
|
||||
isLoadingMessages.value = true;
|
||||
|
||||
try {
|
||||
await getSessionMsg(sessionIds[0], router);
|
||||
await getSessionMsg(sessionIds[0]);
|
||||
} finally {
|
||||
isLoadingMessages.value = false;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
<template>
|
||||
<div class="input-area fade-in">
|
||||
<div class="input-container"
|
||||
style="width: 85%; max-width: 900px; margin: 0 auto; border: 1px solid #e0e0e0; border-radius: 24px; box-shadow: 0px 2px 2px rgba(0, 0, 0, 0.1);">
|
||||
:style="{
|
||||
width: '85%',
|
||||
maxWidth: '900px',
|
||||
margin: '0 auto',
|
||||
border: isDark ? 'none' : '1px solid #e0e0e0',
|
||||
borderRadius: '24px',
|
||||
boxShadow: isDark ? 'none' : '0px 2px 2px rgba(0, 0, 0, 0.1)',
|
||||
backgroundColor: isDark ? '#2d2d2d' : 'transparent'
|
||||
}">
|
||||
<!-- 引用预览区 -->
|
||||
<div class="reply-preview" v-if="props.replyTo">
|
||||
<div class="reply-content">
|
||||
@@ -86,6 +94,7 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, onBeforeUnmount } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { useCustomizerStore } from '@/stores/customizer';
|
||||
import ConfigSelector from './ConfigSelector.vue';
|
||||
import ProviderModelMenu from './ProviderModelMenu.vue';
|
||||
import type { Session } from '@/composables/useSessions';
|
||||
@@ -140,6 +149,7 @@ const emit = defineEmits<{
|
||||
}>();
|
||||
|
||||
const { tm } = useModuleI18n('features/chat');
|
||||
const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark');
|
||||
|
||||
const inputField = ref<HTMLTextAreaElement | null>(null);
|
||||
const imageInputRef = ref<HTMLInputElement | null>(null);
|
||||
@@ -261,7 +271,7 @@ defineExpose({
|
||||
<style scoped>
|
||||
.input-area {
|
||||
padding: 16px;
|
||||
background-color: var(--v-theme-surface);
|
||||
background-color: transparent;
|
||||
position: relative;
|
||||
border-top: 1px solid var(--v-theme-border);
|
||||
flex-shrink: 0;
|
||||
|
||||
@@ -35,7 +35,8 @@
|
||||
@update:selected="$emit('selectConversation', $event)">
|
||||
<v-list-item v-for="item in sessions" :key="item.session_id" :value="item.session_id"
|
||||
rounded="lg" class="conversation-item" active-color="secondary">
|
||||
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="conversation-title">
|
||||
<v-list-item-title v-if="!sidebarCollapsed || isMobile" class="conversation-title"
|
||||
:style="{ color: isDark ? '#ffffff' : '#000000' }">
|
||||
{{ item.display_name || tm('conversation.newConversation') }}
|
||||
</v-list-item-title>
|
||||
<!-- <v-list-item-subtitle v-if="!sidebarCollapsed || isMobile" class="timestamp">
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<template>
|
||||
<v-menu :close-on-content-click="false" location="top">
|
||||
<v-menu v-model="menuOpen" :close-on-content-click="false" location="top" @update:model-value="handleMenuToggle">
|
||||
<template v-slot:activator="{ props: menuProps }">
|
||||
<v-chip v-bind="menuProps" class="text-none provider-chip" variant="tonal" size="x-small">
|
||||
<v-icon start size="14">mdi-creation</v-icon>
|
||||
@@ -72,11 +72,13 @@ interface ProviderConfig {
|
||||
model: string;
|
||||
api_base?: string;
|
||||
model_metadata?: ModelMetadata;
|
||||
enable?: boolean;
|
||||
}
|
||||
|
||||
const providerConfigs = ref<ProviderConfig[]>([]);
|
||||
const selectedProviderId = ref('');
|
||||
const searchQuery = ref('');
|
||||
const menuOpen = ref(false);
|
||||
|
||||
const filteredProviders = computed(() => {
|
||||
if (!searchQuery.value) {
|
||||
@@ -107,7 +109,10 @@ function loadProviderConfigs() {
|
||||
params: { provider_type: 'chat_completion' }
|
||||
}).then(response => {
|
||||
if (response.data.status === 'ok') {
|
||||
providerConfigs.value = response.data.data || [];
|
||||
// 过滤掉 enable 为 false 的配置
|
||||
providerConfigs.value = (response.data.data || []).filter(
|
||||
(p: ProviderConfig) => p.enable !== false
|
||||
);
|
||||
}
|
||||
}).catch(error => {
|
||||
console.error('获取提供商列表失败:', error);
|
||||
@@ -140,6 +145,13 @@ function getCurrentSelection() {
|
||||
};
|
||||
}
|
||||
|
||||
function handleMenuToggle(isOpen: boolean) {
|
||||
if (isOpen) {
|
||||
// 每次打开菜单时重新获取数据
|
||||
loadProviderConfigs();
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadFromStorage();
|
||||
loadProviderConfigs();
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, ref, watch } from 'vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import type { CommandItem } from '../types';
|
||||
|
||||
const { tm } = useModuleI18n('features/command');
|
||||
|
||||
// Props
|
||||
defineProps<{
|
||||
const props = defineProps<{
|
||||
show: boolean;
|
||||
command: CommandItem | null;
|
||||
newName: string;
|
||||
aliases: string[];
|
||||
loading: boolean;
|
||||
}>();
|
||||
|
||||
@@ -16,8 +18,42 @@ defineProps<{
|
||||
const emit = defineEmits<{
|
||||
(e: 'update:show', value: boolean): void;
|
||||
(e: 'update:newName', value: string): void;
|
||||
(e: 'update:aliases', value: string[]): void;
|
||||
(e: 'confirm'): void;
|
||||
}>();
|
||||
|
||||
const addAlias = () => {
|
||||
emit('update:aliases', [...props.aliases, '']);
|
||||
};
|
||||
|
||||
const removeAlias = (index: number) => {
|
||||
const newAliases = [...props.aliases];
|
||||
newAliases.splice(index, 1);
|
||||
emit('update:aliases', newAliases);
|
||||
};
|
||||
|
||||
const updateAlias = (index: number, value: string) => {
|
||||
const newAliases = [...props.aliases];
|
||||
newAliases[index] = value;
|
||||
emit('update:aliases', newAliases);
|
||||
};
|
||||
|
||||
const hasAliases = computed(() => (props.aliases || []).some(a => (a ?? '').toString().trim()));
|
||||
const showAliasEditor = ref(false);
|
||||
const aliasEditorEverOpened = ref(false);
|
||||
|
||||
watch(
|
||||
() => props.show,
|
||||
(open) => {
|
||||
if (!open) return;
|
||||
// 如果已有别名则默认展开,否则默认收起
|
||||
showAliasEditor.value = hasAliases.value;
|
||||
},
|
||||
);
|
||||
|
||||
watch(showAliasEditor, (open) => {
|
||||
if (open) aliasEditorEverOpened.value = true;
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -32,7 +68,49 @@ const emit = defineEmits<{
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
autofocus
|
||||
class="mb-2"
|
||||
/>
|
||||
|
||||
<v-card variant="outlined" class="mt-2" elevation="0">
|
||||
<div
|
||||
class="d-flex align-center justify-space-between px-4 py-3"
|
||||
role="button"
|
||||
tabindex="0"
|
||||
@click="showAliasEditor = !showAliasEditor"
|
||||
@keydown.enter.prevent="showAliasEditor = !showAliasEditor"
|
||||
@keydown.space.prevent="showAliasEditor = !showAliasEditor"
|
||||
>
|
||||
<div class="text-subtitle-1">{{ tm('dialogs.rename.aliases') }}</div>
|
||||
<v-icon size="20">{{ showAliasEditor ? 'mdi-chevron-up' : 'mdi-chevron-down' }}</v-icon>
|
||||
</div>
|
||||
<v-divider v-if="showAliasEditor" />
|
||||
<v-slide-y-transition>
|
||||
<div v-if="aliasEditorEverOpened" v-show="showAliasEditor" class="px-4 py-3">
|
||||
<div v-for="(alias, index) in aliases" :key="index" class="d-flex align-center mb-2">
|
||||
<v-text-field
|
||||
:model-value="alias"
|
||||
@update:model-value="updateAlias(index, $event)"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1 mr-2"
|
||||
/>
|
||||
<v-btn icon="mdi-delete" variant="text" color="error" density="compact" @click="removeAlias(index)" />
|
||||
</div>
|
||||
<v-btn
|
||||
prepend-icon="mdi-plus"
|
||||
variant="outlined"
|
||||
color="primary"
|
||||
block
|
||||
size="small"
|
||||
class="mt-2"
|
||||
@click="addAlias"
|
||||
>
|
||||
{{ tm('dialogs.rename.addAlias') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-slide-y-transition>
|
||||
</v-card>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer />
|
||||
|
||||
@@ -14,6 +14,7 @@ export function useCommandActions(
|
||||
show: false,
|
||||
command: null,
|
||||
newName: '',
|
||||
aliases: [],
|
||||
loading: false
|
||||
});
|
||||
|
||||
@@ -53,6 +54,7 @@ export function useCommandActions(
|
||||
const openRenameDialog = (cmd: CommandItem) => {
|
||||
renameDialog.command = cmd;
|
||||
renameDialog.newName = cmd.current_fragment || '';
|
||||
renameDialog.aliases = [...(cmd.aliases || [])];
|
||||
renameDialog.show = true;
|
||||
};
|
||||
|
||||
@@ -66,7 +68,8 @@ export function useCommandActions(
|
||||
try {
|
||||
const res = await axios.post('/api/commands/rename', {
|
||||
handler_full_name: renameDialog.command.handler_full_name,
|
||||
new_name: renameDialog.newName.trim()
|
||||
new_name: renameDialog.newName.trim(),
|
||||
aliases: renameDialog.aliases.filter(a => a.trim())
|
||||
});
|
||||
if (res.data.status === 'ok') {
|
||||
toast(successMessage, 'success');
|
||||
|
||||
@@ -288,6 +288,8 @@ watch(viewMode, async (mode) => {
|
||||
@update:show="renameDialog.show = $event"
|
||||
:new-name="renameDialog.newName"
|
||||
@update:new-name="renameDialog.newName = $event"
|
||||
:aliases="renameDialog.aliases"
|
||||
@update:aliases="renameDialog.aliases = $event"
|
||||
:command="renameDialog.command"
|
||||
:loading="renameDialog.loading"
|
||||
@confirm="handleConfirmRename"
|
||||
|
||||
@@ -52,6 +52,7 @@ export interface RenameDialogState {
|
||||
show: boolean;
|
||||
command: CommandItem | null;
|
||||
newName: string;
|
||||
aliases: string[];
|
||||
loading: boolean;
|
||||
}
|
||||
|
||||
|
||||
@@ -148,3 +148,10 @@ const emitDeleteSource = (source) => emit('delete-provider-source', source)
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
<style>
|
||||
.v-theme--PurpleThemeDark .provider-source-list-item--active {
|
||||
background-color: #2d2d2d;
|
||||
border: none;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -172,7 +172,7 @@ export function useMessages(
|
||||
}
|
||||
}
|
||||
|
||||
async function getSessionMessages(sessionId: string, router: any) {
|
||||
async function getSessionMessages(sessionId: string) {
|
||||
if (!sessionId) return;
|
||||
|
||||
try {
|
||||
@@ -188,7 +188,7 @@ export function useMessages(
|
||||
|
||||
// 如果会话还在运行,3秒后重新获取消息
|
||||
setTimeout(() => {
|
||||
getSessionMessages(currSessionId.value, router);
|
||||
getSessionMessages(currSessionId.value);
|
||||
}, 3000);
|
||||
}
|
||||
|
||||
@@ -353,6 +353,10 @@ export function useMessages(
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
console.log('SSE stream completed');
|
||||
// 流式传输结束后,获取最终消息并重新渲染
|
||||
if (currSessionId.value) {
|
||||
await getSessionMessages(currSessionId.value);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -398,7 +398,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
if (!confirm(tm('providerSources.deleteConfirm', { id: source.id }))) return
|
||||
|
||||
try {
|
||||
await axios.post(`/api/config/provider_sources/${source.id}/delete`)
|
||||
await axios.post('/api/config/provider_sources/delete', { id: source.id })
|
||||
|
||||
providers.value = providers.value.filter((p) => p.provider_source_id !== source.id)
|
||||
providerSources.value = providerSources.value.filter((s) => s.id !== source.id)
|
||||
@@ -423,7 +423,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
savingSource.value = true
|
||||
const originalId = selectedProviderSourceOriginalId.value || selectedProviderSource.value.id
|
||||
try {
|
||||
const response = await axios.post(`/api/config/provider_sources/${originalId}/update`, {
|
||||
const response = await axios.post('/api/config/provider_sources/update', {
|
||||
config: editableProviderSource.value,
|
||||
original_id: originalId
|
||||
})
|
||||
@@ -478,7 +478,9 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
loadingModels.value = true
|
||||
try {
|
||||
const sourceId = editableProviderSource.value?.id || selectedProviderSource.value.id
|
||||
const response = await axios.get(`/api/config/provider_sources/${sourceId}/models`)
|
||||
const response = await axios.get('/api/config/provider_sources/models', {
|
||||
params: { source_id: sourceId }
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
const metadataMap = response.data.data.model_metadata || {}
|
||||
modelMetadata.value = metadataMap
|
||||
|
||||
@@ -41,7 +41,13 @@ export function useSessions(chatboxMode: boolean = false) {
|
||||
selectedSessions.value = [pendingSessionId.value];
|
||||
pendingSessionId.value = null;
|
||||
}
|
||||
} else if (!currSessionId.value && sessions.value.length > 0) {
|
||||
} else if (currSessionId.value) {
|
||||
// 如果当前有选中的会话,确保它在列表中并被选中
|
||||
const session = sessions.value.find(s => s.session_id === currSessionId.value);
|
||||
if (session) {
|
||||
selectedSessions.value = [currSessionId.value];
|
||||
}
|
||||
} else if (sessions.value.length > 0) {
|
||||
// 默认选择第一个会话
|
||||
const firstSession = sessions.value[0];
|
||||
selectedSessions.value = [firstSession.session_id];
|
||||
@@ -65,6 +71,10 @@ export function useSessions(chatboxMode: boolean = false) {
|
||||
router.push(`${basePath}/${sessionId}`);
|
||||
|
||||
await getSessions();
|
||||
|
||||
// 确保新创建的会话被选中高亮
|
||||
selectedSessions.value = [sessionId];
|
||||
|
||||
return sessionId;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
|
||||
@@ -45,6 +45,8 @@
|
||||
"rename": {
|
||||
"title": "Rename Command",
|
||||
"newName": "New command name",
|
||||
"aliases": "Manage aliases",
|
||||
"addAlias": "Add alias",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Confirm"
|
||||
},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"title": "Providers",
|
||||
"subtitle": "Manage model providers",
|
||||
"subtitle": "Can configure chat models in \"Chat Completion\". Additionally, \"Agent Runner\" includes integrations with third-party services like Dify, Coze, and Alibaba Bailian(DashScope).",
|
||||
"providers": {
|
||||
"title": "Service Providers",
|
||||
"settings": "Settings",
|
||||
|
||||
@@ -45,6 +45,8 @@
|
||||
"rename": {
|
||||
"title": "重命名指令",
|
||||
"newName": "新指令名",
|
||||
"aliases": "管理别名",
|
||||
"addAlias": "添加别名",
|
||||
"cancel": "取消",
|
||||
"confirm": "确认"
|
||||
},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"title": "模型提供商",
|
||||
"subtitle": "管理模型提供商",
|
||||
"subtitle": "可以在“对话”中配置对话模型。此外,“Agent 执行器”包含了 Dify、Coze、阿里云百炼应用等第三方服务的集成。",
|
||||
"providers": {
|
||||
"title": "模型提供商",
|
||||
"settings": "设置",
|
||||
|
||||
@@ -7,9 +7,11 @@ import VerticalHeaderVue from './vertical-header/VerticalHeader.vue';
|
||||
import MigrationDialog from '@/components/shared/MigrationDialog.vue';
|
||||
import Chat from '@/components/chat/Chat.vue';
|
||||
import { useCustomizerStore } from '@/stores/customizer';
|
||||
import { useRouterLoadingStore } from '@/stores/routerLoading';
|
||||
|
||||
const customizer = useCustomizerStore();
|
||||
const route = useRoute();
|
||||
const routerLoadingStore = useRouterLoadingStore();
|
||||
|
||||
// 计算是否在聊天页面(非全屏模式)
|
||||
const isChatPage = computed(() => {
|
||||
@@ -60,6 +62,16 @@ onMounted(() => {
|
||||
<v-app :theme="useCustomizerStore().uiTheme"
|
||||
:class="[customizer.fontTheme, customizer.mini_sidebar ? 'mini-sidebar' : '', customizer.inputBg ? 'inputWithbg' : '']"
|
||||
>
|
||||
<!-- 路由切换进度条 -->
|
||||
<v-progress-linear
|
||||
v-if="routerLoadingStore.isLoading"
|
||||
:model-value="routerLoadingStore.progress"
|
||||
color="primary"
|
||||
height="2"
|
||||
fixed
|
||||
top
|
||||
style="z-index: 9999; position: absolute; opacity: 0.3; "
|
||||
/>
|
||||
<VerticalHeaderVue />
|
||||
<VerticalSidebarVue v-if="showSidebar" />
|
||||
<v-main :style="{
|
||||
|
||||
@@ -3,6 +3,7 @@ import MainRoutes from './MainRoutes';
|
||||
import AuthRoutes from './AuthRoutes';
|
||||
import ChatBoxRoutes from './ChatBoxRoutes';
|
||||
import { useAuthStore } from '@/stores/auth';
|
||||
import { useRouterLoadingStore } from '@/stores/routerLoading';
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHashHistory(import.meta.env.BASE_URL),
|
||||
@@ -22,6 +23,11 @@ interface AuthStore {
|
||||
}
|
||||
|
||||
router.beforeEach(async (to, from, next) => {
|
||||
if (from.name && from.path !== to.path) {
|
||||
const loadingStore = useRouterLoadingStore();
|
||||
loadingStore.start();
|
||||
}
|
||||
|
||||
const publicPages = ['/auth/login'];
|
||||
const authRequired = !publicPages.includes(to.path);
|
||||
const auth: AuthStore = useAuthStore();
|
||||
@@ -40,3 +46,8 @@ router.beforeEach(async (to, from, next) => {
|
||||
next();
|
||||
}
|
||||
});
|
||||
|
||||
router.afterEach(() => {
|
||||
const loadingStore = useRouterLoadingStore();
|
||||
loadingStore.finish();
|
||||
});
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
import { defineStore } from 'pinia';
|
||||
import { ref } from 'vue';
|
||||
|
||||
export const useRouterLoadingStore = defineStore('routerLoading', () => {
|
||||
const isLoading = ref(false);
|
||||
const progress = ref(0);
|
||||
let progressInterval: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
function start() {
|
||||
isLoading.value = true;
|
||||
progress.value = 0;
|
||||
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
}
|
||||
|
||||
let currentProgress = 0;
|
||||
progressInterval = setInterval(() => {
|
||||
if (currentProgress < 80) {
|
||||
// 快速阶段:0-80%
|
||||
currentProgress += Math.random() * 20 + 10;
|
||||
if (currentProgress > 80) {
|
||||
currentProgress = 80;
|
||||
}
|
||||
} else if (currentProgress < 90) {
|
||||
// 缓慢阶段:80-90%
|
||||
currentProgress += Math.random() * 3 + 1;
|
||||
if (currentProgress > 90) {
|
||||
currentProgress = 90;
|
||||
}
|
||||
}
|
||||
progress.value = Math.min(currentProgress, 90);
|
||||
}, 50);
|
||||
}
|
||||
|
||||
function finish() {
|
||||
// 清理interval
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
progressInterval = null;
|
||||
}
|
||||
|
||||
// 快速完成到100%
|
||||
progress.value = 100;
|
||||
|
||||
// 延迟隐藏,让用户看到100%
|
||||
setTimeout(() => {
|
||||
isLoading.value = false;
|
||||
progress.value = 0;
|
||||
}, 300);
|
||||
}
|
||||
|
||||
return {
|
||||
isLoading,
|
||||
progress,
|
||||
start,
|
||||
finish
|
||||
};
|
||||
});
|
||||
|
||||
@@ -499,21 +499,23 @@ export default {
|
||||
// 将对话历史转换为 MessageList 组件期望的格式
|
||||
formattedMessages() {
|
||||
return this.conversationHistory.map(msg => {
|
||||
console.log('处理消息:', msg.role, msg.image_url, msg.audio_url);
|
||||
console.log('处理消息:', msg.role, msg.content);
|
||||
|
||||
// 将消息内容转换为 MessagePart[] 格式
|
||||
const messageParts = this.convertContentToMessageParts(msg.content);
|
||||
|
||||
if (msg.role === 'user') {
|
||||
return {
|
||||
content: {
|
||||
type: 'user',
|
||||
message: this.extractTextFromContent(msg.content),
|
||||
image_url: this.extractImagesFromContent(msg.content),
|
||||
message: messageParts
|
||||
}
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
content: {
|
||||
type: 'bot',
|
||||
message: this.extractTextFromContent(msg.content),
|
||||
embedded_images: this.extractImagesFromContent(msg.content),
|
||||
message: messageParts
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -990,7 +992,61 @@ export default {
|
||||
this.showMessage = true;
|
||||
},
|
||||
|
||||
// 从内容中提取文本
|
||||
// 将消息内容转换为 MessagePart[] 格式
|
||||
convertContentToMessageParts(content) {
|
||||
const parts = [];
|
||||
|
||||
if (typeof content === 'string') {
|
||||
// 纯文本内容
|
||||
if (content.trim()) {
|
||||
parts.push({
|
||||
type: 'plain',
|
||||
text: content
|
||||
});
|
||||
}
|
||||
} else if (Array.isArray(content)) {
|
||||
// 数组格式(OpenAI 格式)
|
||||
content.forEach(item => {
|
||||
if (item.type === 'text' && item.text) {
|
||||
parts.push({
|
||||
type: 'plain',
|
||||
text: item.text
|
||||
});
|
||||
} else if (item.type === 'image_url' && item.image_url?.url) {
|
||||
parts.push({
|
||||
type: 'image',
|
||||
embedded_url: item.image_url.url
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (typeof content === 'object' && content !== null) {
|
||||
// 对象格式,尝试提取文本和图片
|
||||
const textParts = [];
|
||||
for (const [key, value] of Object.entries(content)) {
|
||||
if (typeof value === 'string' && value.trim()) {
|
||||
textParts.push(value);
|
||||
}
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
parts.push({
|
||||
type: 'plain',
|
||||
text: textParts.join('\n')
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有提取到任何内容,添加一个空文本
|
||||
if (parts.length === 0) {
|
||||
parts.push({
|
||||
type: 'plain',
|
||||
text: ''
|
||||
});
|
||||
}
|
||||
|
||||
return parts;
|
||||
},
|
||||
|
||||
// 从内容中提取文本(保留用于其他用途)
|
||||
extractTextFromContent(content) {
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
@@ -1004,7 +1060,7 @@ export default {
|
||||
return '';
|
||||
},
|
||||
|
||||
// 从内容中提取图片URL
|
||||
// 从内容中提取图片URL(保留用于其他用途)
|
||||
extractImagesFromContent(content) {
|
||||
if (Array.isArray(content)) {
|
||||
return content.filter(item => item.type === 'image_url')
|
||||
|
||||
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
|
||||
|
||||
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
|
||||
req.image_urls,
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# 收集系统提醒信息
|
||||
system_parts = []
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
@@ -225,10 +229,17 @@ class ProcessLLMRequest:
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
|
||||
# 3. 将所有部分组合成文本并直接注入到当前消息中
|
||||
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||
# 确保引用内容被正确的标签包裹
|
||||
quoted_content = "\n".join(content_parts)
|
||||
# 确保所有内容都在<Quoted Message>标签内
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
|
||||
req.prompt = f"{quoted_text}\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
# 统一包裹所有系统提醒
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
|
||||
+2
-2
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.10.0-alpha.1"
|
||||
version = "4.10.2"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -34,7 +34,7 @@ dependencies = [
|
||||
"ormsgpack>=1.9.1",
|
||||
"pillow>=11.2.1",
|
||||
"pip>=25.1.1",
|
||||
"psutil>=5.8.0",
|
||||
"psutil>=5.8.0,<7.2.0",
|
||||
"py-cord>=2.6.1",
|
||||
"pydantic~=2.10.3",
|
||||
"pydub>=0.25.1",
|
||||
|
||||
+1
-1
@@ -27,7 +27,7 @@ openai>=1.78.0
|
||||
ormsgpack>=1.9.1
|
||||
pillow>=11.2.1
|
||||
pip>=25.1.1
|
||||
psutil>=5.8.0
|
||||
psutil>=5.8.0,<7.2.0
|
||||
py-cord>=2.6.1
|
||||
pydantic~=2.10.3
|
||||
pydub>=0.25.1
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.run_context import ContextWrapper
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
||||
from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
class MockProvider(Provider):
|
||||
"""模拟Provider用于测试"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__({}, {})
|
||||
self.call_count = 0
|
||||
self.should_call_tools = True
|
||||
self.max_calls_before_normal_response = 10
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return "test_key"
|
||||
|
||||
def set_key(self, key: str):
|
||||
pass
|
||||
|
||||
async def get_models(self) -> list[str]:
|
||||
return ["test_model"]
|
||||
|
||||
async def text_chat(self, **kwargs) -> LLMResponse:
|
||||
self.call_count += 1
|
||||
|
||||
# 检查工具是否被禁用
|
||||
func_tool = kwargs.get("func_tool")
|
||||
|
||||
# 如果工具被禁用或超过最大调用次数,返回正常响应
|
||||
if func_tool is None or self.call_count > self.max_calls_before_normal_response:
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="这是我的最终回答",
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
# 模拟工具调用响应
|
||||
if self.should_call_tools:
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="我需要使用工具来帮助您",
|
||||
tools_call_name=["test_tool"],
|
||||
tools_call_args=[{"query": "test"}],
|
||||
tools_call_ids=["call_123"],
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
# 默认返回正常响应
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text="这是我的最终回答",
|
||||
usage=TokenUsage(input_other=10, output=5),
|
||||
)
|
||||
|
||||
async def text_chat_stream(self, **kwargs):
|
||||
response = await self.text_chat(**kwargs)
|
||||
response.is_chunk = True
|
||||
yield response
|
||||
response.is_chunk = False
|
||||
yield response
|
||||
|
||||
|
||||
class MockToolExecutor:
|
||||
"""模拟工具执行器"""
|
||||
|
||||
@classmethod
|
||||
def execute(cls, tool, run_context, **tool_args):
|
||||
async def generator():
|
||||
# 模拟工具返回结果,使用正确的类型
|
||||
from mcp.types import CallToolResult, TextContent
|
||||
|
||||
result = CallToolResult(
|
||||
content=[TextContent(type="text", text="工具执行结果")]
|
||||
)
|
||||
yield result
|
||||
|
||||
return generator()
|
||||
|
||||
|
||||
class MockHooks(BaseAgentRunHooks):
|
||||
"""模拟钩子函数"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_begin_called = False
|
||||
self.agent_done_called = False
|
||||
self.tool_start_called = False
|
||||
self.tool_end_called = False
|
||||
|
||||
async def on_agent_begin(self, run_context):
|
||||
self.agent_begin_called = True
|
||||
|
||||
async def on_tool_start(self, run_context, tool, tool_args):
|
||||
self.tool_start_called = True
|
||||
|
||||
async def on_tool_end(self, run_context, tool, tool_args, tool_result):
|
||||
self.tool_end_called = True
|
||||
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
self.agent_done_called = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
return MockProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_executor():
|
||||
return MockToolExecutor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hooks():
|
||||
return MockHooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_set():
|
||||
"""创建测试用的工具集"""
|
||||
tool = FunctionTool(
|
||||
name="test_tool",
|
||||
description="测试工具",
|
||||
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
|
||||
handler=AsyncMock(),
|
||||
)
|
||||
return ToolSet(tools=[tool])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_request(tool_set):
|
||||
"""创建测试用的ProviderRequest"""
|
||||
return ProviderRequest(prompt="请帮我查询信息", func_tool=tool_set, contexts=[])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
"""创建ToolLoopAgentRunner实例"""
|
||||
return ToolLoopAgentRunner()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_step_limit_functionality(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试最大步数限制功能"""
|
||||
|
||||
# 设置模拟provider,让它总是返回工具调用
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = (
|
||||
100 # 设置一个很大的值,确保不会自然结束
|
||||
)
|
||||
|
||||
# 初始化runner
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# 设置较小的最大步数来测试限制功能
|
||||
max_steps = 3
|
||||
|
||||
# 收集所有响应
|
||||
responses = []
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
responses.append(response)
|
||||
|
||||
# 验证结果
|
||||
assert runner.done(), "代理应该在达到最大步数后完成"
|
||||
|
||||
# 验证工具被禁用(这是最重要的验证点)
|
||||
assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用"
|
||||
|
||||
# 验证有最终响应
|
||||
final_responses = [r for r in responses if r.type == "llm_result"]
|
||||
assert len(final_responses) > 0, "应该有最终的LLM响应"
|
||||
|
||||
# 验证最后一条消息是assistant的最终回答
|
||||
last_message = runner.run_context.messages[-1]
|
||||
assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_completion_without_max_step(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试正常完成(不触发最大步数限制)"""
|
||||
|
||||
# 设置模拟provider,让它在第2次调用时返回正常响应
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = 2
|
||||
|
||||
# 初始化runner
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# 设置足够大的最大步数
|
||||
max_steps = 10
|
||||
|
||||
# 收集所有响应
|
||||
responses = []
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
responses.append(response)
|
||||
|
||||
# 验证结果
|
||||
assert runner.done(), "代理应该正常完成"
|
||||
|
||||
# 验证没有触发最大步数限制 - 通过检查provider调用次数
|
||||
# mock_provider在第2次调用后返回正常响应,所以不应该达到max_steps(10)
|
||||
assert mock_provider.call_count < max_steps, (
|
||||
f"正常完成时调用次数({mock_provider.call_count})应该小于最大步数({max_steps})"
|
||||
)
|
||||
|
||||
# 验证没有最大步数警告消息(注意:实际注入的是user角色的消息)
|
||||
user_messages = [m for m in runner.run_context.messages if m.role == "user"]
|
||||
max_step_messages = [
|
||||
m for m in user_messages if "工具调用次数已达到上限" in m.content
|
||||
]
|
||||
assert len(max_step_messages) == 0, "正常完成时不应该有步数限制消息"
|
||||
|
||||
# 验证工具仍然可用(没有被禁用)
|
||||
assert runner.req.func_tool is not None, "正常完成时工具不应该被禁用"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_step_with_streaming(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试流式响应下的最大步数限制"""
|
||||
|
||||
# 设置模拟provider
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = 100
|
||||
|
||||
# 初始化runner,启用流式响应
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# 设置较小的最大步数
|
||||
max_steps = 2
|
||||
|
||||
# 收集所有响应
|
||||
responses = []
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
responses.append(response)
|
||||
|
||||
# 验证结果
|
||||
assert runner.done(), "代理应该在达到最大步数后完成"
|
||||
|
||||
# 验证有流式响应
|
||||
streaming_responses = [r for r in responses if r.type == "streaming_delta"]
|
||||
assert len(streaming_responses) > 0, "应该有流式响应"
|
||||
|
||||
# 验证工具被禁用
|
||||
assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用"
|
||||
|
||||
# 验证最后一条消息是assistant的最终回答
|
||||
last_message = runner.run_context.messages[-1]
|
||||
assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hooks_called_with_max_step(
|
||||
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
|
||||
):
|
||||
"""测试达到最大步数时钩子函数是否被正确调用"""
|
||||
|
||||
# 设置模拟provider
|
||||
mock_provider.should_call_tools = True
|
||||
mock_provider.max_calls_before_normal_response = 100
|
||||
|
||||
# 初始化runner
|
||||
await runner.reset(
|
||||
provider=mock_provider,
|
||||
request=provider_request,
|
||||
run_context=ContextWrapper(context=None),
|
||||
tool_executor=mock_tool_executor,
|
||||
agent_hooks=mock_hooks,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# 设置较小的最大步数
|
||||
max_steps = 2
|
||||
|
||||
# 执行步骤
|
||||
async for response in runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
# 验证钩子函数被调用
|
||||
assert mock_hooks.agent_begin_called, "on_agent_begin应该被调用"
|
||||
assert mock_hooks.agent_done_called, "on_agent_done应该被调用"
|
||||
assert mock_hooks.tool_start_called, "on_tool_start应该被调用"
|
||||
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user