Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b4498a554 | |||
| 5e5207da95 | |||
| def8b730b7 | |||
| 22a109c2ae | |||
| 6416707e35 | |||
| 4658998b85 | |||
| d233fb8b1e | |||
| fc2a67188f | |||
| d69592aaa8 | |||
| f3397f6f08 | |||
| be92e4f395 |
@@ -264,8 +264,9 @@ pre-commit install
|
||||
|
||||
<div align="center">
|
||||
|
||||
_陪伴与能力从来不应该是对立面。我们希望创造的是一个既能理解情绪、给予陪伴,也能可靠完成工作的机器人。_
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
陪伴与能力从来不应该是对立面。我们希望创造的是一个既能理解情绪、给予陪伴,也能可靠完成工作的机器人。
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ class Main(star.Star):
|
||||
"fetch_url",
|
||||
"web_search_tavily",
|
||||
"tavily_extract_web_page",
|
||||
"web_search_bocha",
|
||||
]
|
||||
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
@@ -30,6 +31,9 @@ class Main(star.Star):
|
||||
self.tavily_key_index = 0
|
||||
self.tavily_key_lock = asyncio.Lock()
|
||||
|
||||
self.bocha_key_index = 0
|
||||
self.bocha_key_lock = asyncio.Lock()
|
||||
|
||||
# 将 str 类型的 key 迁移至 list[str],并保存
|
||||
cfg = self.context.get_config()
|
||||
provider_settings = cfg.get("provider_settings")
|
||||
@@ -45,6 +49,14 @@ class Main(star.Star):
|
||||
provider_settings["websearch_tavily_key"] = []
|
||||
cfg.save_config()
|
||||
|
||||
bocha_key = provider_settings.get("websearch_bocha_key")
|
||||
if isinstance(bocha_key, str):
|
||||
if bocha_key:
|
||||
provider_settings["websearch_bocha_key"] = [bocha_key]
|
||||
else:
|
||||
provider_settings["websearch_bocha_key"] = []
|
||||
cfg.save_config()
|
||||
|
||||
self.bing_search = Bing()
|
||||
self.sogo_search = Sogo()
|
||||
self.baidu_initialized = False
|
||||
@@ -341,7 +353,7 @@ class Main(star.Star):
|
||||
}
|
||||
)
|
||||
if result.favicon:
|
||||
sp.temorary_cache["_ws_favicon"][result.url] = result.favicon
|
||||
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
|
||||
# ret = "\n".join(ret_ls)
|
||||
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
|
||||
return ret
|
||||
@@ -382,6 +394,160 @@ class Main(star.Star):
|
||||
return "Error: Tavily web searcher does not return any results."
|
||||
return ret
|
||||
|
||||
async def _get_bocha_key(self, cfg: AstrBotConfig) -> str:
|
||||
"""并发安全的从列表中获取并轮换BoCha API密钥。"""
|
||||
bocha_keys = cfg.get("provider_settings", {}).get("websearch_bocha_key", [])
|
||||
if not bocha_keys:
|
||||
raise ValueError("错误:BoCha API密钥未在AstrBot中配置。")
|
||||
|
||||
async with self.bocha_key_lock:
|
||||
key = bocha_keys[self.bocha_key_index]
|
||||
self.bocha_key_index = (self.bocha_key_index + 1) % len(bocha_keys)
|
||||
return key
|
||||
|
||||
async def _web_search_bocha(
|
||||
self,
|
||||
cfg: AstrBotConfig,
|
||||
payload: dict,
|
||||
) -> list[SearchResult]:
|
||||
"""使用 BoCha 搜索引擎进行搜索"""
|
||||
bocha_key = await self._get_bocha_key(cfg)
|
||||
url = "https://api.bochaai.com/v1/web-search"
|
||||
header = {
|
||||
"Authorization": f"Bearer {bocha_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=header,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
reason = await response.text()
|
||||
raise Exception(
|
||||
f"BoCha web search failed: {reason}, status: {response.status}",
|
||||
)
|
||||
data = await response.json()
|
||||
data = data["data"]["webPages"]["value"]
|
||||
results = []
|
||||
for item in data:
|
||||
result = SearchResult(
|
||||
title=item.get("name"),
|
||||
url=item.get("url"),
|
||||
snippet=item.get("snippet"),
|
||||
favicon=item.get("siteIcon"),
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
@llm_tool("web_search_bocha")
|
||||
async def search_from_bocha(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
query: str,
|
||||
freshness: str = "noLimit",
|
||||
summary: bool = False,
|
||||
include: str = "",
|
||||
exclude: str = "",
|
||||
count: int = 10,
|
||||
) -> str:
|
||||
"""
|
||||
A web search tool based on Bocha Search API, used to retrieve web pages
|
||||
related to the user's query.
|
||||
|
||||
Args:
|
||||
query (string): Required. User's search query.
|
||||
|
||||
freshness (string): Optional. Specifies the time range of the search.
|
||||
Supported values:
|
||||
- "noLimit": No time limit (default, recommended).
|
||||
- "oneDay": Within one day.
|
||||
- "oneWeek": Within one week.
|
||||
- "oneMonth": Within one month.
|
||||
- "oneYear": Within one year.
|
||||
- "YYYY-MM-DD..YYYY-MM-DD": Search within a specific date range.
|
||||
Example: "2025-01-01..2025-04-06".
|
||||
- "YYYY-MM-DD": Search on a specific date.
|
||||
Example: "2025-04-06".
|
||||
It is recommended to use "noLimit", as the search algorithm will
|
||||
automatically optimize time relevance. Manually restricting the
|
||||
time range may result in no search results.
|
||||
|
||||
summary (boolean): Optional. Whether to include a text summary
|
||||
for each search result.
|
||||
- True: Include summary.
|
||||
- False: Do not include summary (default).
|
||||
|
||||
include (string): Optional. Specifies the domains to include in
|
||||
the search. Multiple domains can be separated by "|" or ",".
|
||||
A maximum of 100 domains is allowed.
|
||||
Examples:
|
||||
- "qq.com"
|
||||
- "qq.com|m.163.com"
|
||||
|
||||
exclude (string): Optional. Specifies the domains to exclude from
|
||||
the search. Multiple domains can be separated by "|" or ",".
|
||||
A maximum of 100 domains is allowed.
|
||||
Examples:
|
||||
- "qq.com"
|
||||
- "qq.com|m.163.com"
|
||||
|
||||
count (number): Optional. Number of search results to return.
|
||||
- Range: 1–50
|
||||
- Default: 10
|
||||
The actual number of returned results may be less than the
|
||||
specified count.
|
||||
"""
|
||||
logger.info(f"web_searcher - search_from_bocha: {query}")
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
|
||||
if not cfg.get("provider_settings", {}).get("websearch_bocha_key", []):
|
||||
raise ValueError("Error: BoCha API key is not configured in AstrBot.")
|
||||
|
||||
# build payload
|
||||
payload = {
|
||||
"query": query,
|
||||
"count": count,
|
||||
}
|
||||
|
||||
# freshness:时间范围
|
||||
if freshness:
|
||||
payload["freshness"] = freshness
|
||||
|
||||
# 是否返回摘要
|
||||
payload["summary"] = summary
|
||||
|
||||
# include:限制搜索域
|
||||
if include:
|
||||
payload["include"] = include
|
||||
|
||||
# exclude:排除搜索域
|
||||
if exclude:
|
||||
payload["exclude"] = exclude
|
||||
|
||||
results = await self._web_search_bocha(cfg, payload)
|
||||
if not results:
|
||||
return "Error: BoCha web searcher does not return any results."
|
||||
|
||||
ret_ls = []
|
||||
ref_uuid = str(uuid.uuid4())[:4]
|
||||
for idx, result in enumerate(results, 1):
|
||||
index = f"{ref_uuid}.{idx}"
|
||||
ret_ls.append(
|
||||
{
|
||||
"title": f"{result.title}",
|
||||
"url": f"{result.url}",
|
||||
"snippet": f"{result.snippet}",
|
||||
"index": index,
|
||||
}
|
||||
)
|
||||
if result.favicon:
|
||||
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
|
||||
# ret = "\n".join(ret_ls)
|
||||
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
|
||||
return ret
|
||||
|
||||
@filter.on_llm_request(priority=-10000)
|
||||
async def edit_web_search_tools(
|
||||
self,
|
||||
@@ -419,6 +585,7 @@ class Main(star.Star):
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
tool_set.remove_tool("web_search_bocha")
|
||||
elif provider == "tavily":
|
||||
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
|
||||
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
|
||||
@@ -429,6 +596,7 @@ class Main(star.Star):
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
tool_set.remove_tool("web_search_bocha")
|
||||
elif provider == "baidu_ai_search":
|
||||
try:
|
||||
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
|
||||
@@ -440,5 +608,15 @@ class Main(star.Star):
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
tool_set.remove_tool("web_search_bocha")
|
||||
except Exception as e:
|
||||
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
|
||||
elif provider == "bocha":
|
||||
web_search_bocha = func_tool_mgr.get_func("web_search_bocha")
|
||||
if web_search_bocha:
|
||||
tool_set.add_tool(web_search_bocha)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.14.4"
|
||||
__version__ = "4.14.6"
|
||||
|
||||
@@ -254,6 +254,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
parts.append(TextPart(text=llm_resp.completion_text))
|
||||
if len(parts) == 0:
|
||||
logger.warning(
|
||||
"LLM returned empty assistant message with no tool calls."
|
||||
)
|
||||
self.run_context.messages.append(Message(role="assistant", content=parts))
|
||||
|
||||
# call the on_agent_done hook
|
||||
@@ -309,6 +313,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
parts.append(TextPart(text=llm_resp.completion_text))
|
||||
if len(parts) == 0:
|
||||
parts = None
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
|
||||
@@ -246,8 +246,18 @@ class ToolSet:
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
# Avoid side effects by not modifying the original schema
|
||||
origin_type = schema.get("type")
|
||||
target_type = origin_type
|
||||
|
||||
# Compatibility fix: Gemini API expects 'type' to be a string (enum),
|
||||
# but standard JSON Schema (MCP) allows lists (e.g. ["string", "null"]).
|
||||
# We fallback to the first non-null type.
|
||||
if isinstance(origin_type, list):
|
||||
target_type = next((t for t in origin_type if t != "null"), "string")
|
||||
|
||||
if target_type in supported_types:
|
||||
result["type"] = target_type
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"],
|
||||
set(),
|
||||
|
||||
@@ -59,7 +59,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
platform_name = run_context.context.event.get_platform_name()
|
||||
if (
|
||||
platform_name == "webchat"
|
||||
and tool.name == "web_search_tavily"
|
||||
and tool.name in ["web_search_tavily", "web_search_bocha"]
|
||||
and len(run_context.messages) > 0
|
||||
and tool_result
|
||||
and len(tool_result.content)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.14.4"
|
||||
VERSION = "4.14.6"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -74,6 +74,7 @@ DEFAULT_CONFIG = {
|
||||
"web_search": False,
|
||||
"websearch_provider": "default",
|
||||
"websearch_tavily_key": [],
|
||||
"websearch_bocha_key": [],
|
||||
"websearch_baidu_app_builder_key": "",
|
||||
"web_search_link": False,
|
||||
"display_reasoning_text": False,
|
||||
@@ -2563,7 +2564,7 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.websearch_provider": {
|
||||
"description": "网页搜索提供商",
|
||||
"type": "string",
|
||||
"options": ["default", "tavily", "baidu_ai_search"],
|
||||
"options": ["default", "tavily", "baidu_ai_search", "bocha"],
|
||||
"condition": {
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
@@ -2578,6 +2579,16 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_bocha_key": {
|
||||
"description": "BoCha API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可添加多个 Key 进行轮询。",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "bocha",
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_baidu_app_builder_key": {
|
||||
"description": "百度千帆智能云 APP Builder API Key",
|
||||
"type": "string",
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""使用此功能应该先 pip install baidu-aip"""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from aip import AipContentCensor
|
||||
|
||||
from . import ContentSafetyStrategy
|
||||
@@ -23,7 +25,8 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
||||
count = len(res["data"])
|
||||
parts = [f"百度审核服务发现 {count} 处违规:\n"]
|
||||
for i in res["data"]:
|
||||
parts.append(f"{i['msg']};\n")
|
||||
# 百度 AIP 返回结构是动态 dict;类型检查时 i 可能被推断为序列,转成 dict 后用 get 取字段
|
||||
parts.append(f"{cast(dict[str, Any], i).get('msg', '')};\n")
|
||||
parts.append("\n判断结果:" + res["conclusion"])
|
||||
info = "".join(parts)
|
||||
return False, info
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
@@ -13,7 +13,7 @@ class MessageSession:
|
||||
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
platform_id: str | None = None
|
||||
platform_id: str = field(init=False)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
@@ -444,9 +444,20 @@ class DiscordPlatformAdapter(Platform):
|
||||
logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}")
|
||||
|
||||
# 2. 构建 AstrBotMessage
|
||||
channel = ctx.channel
|
||||
abm = AstrBotMessage()
|
||||
abm.type = self._get_message_type(ctx.channel, ctx.guild_id)
|
||||
abm.group_id = self._get_channel_id(ctx.channel)
|
||||
if channel is not None:
|
||||
abm.type = self._get_message_type(channel, ctx.guild_id)
|
||||
abm.group_id = self._get_channel_id(channel)
|
||||
else:
|
||||
# 防守式兜底:channel 取不到时,仍能根据 guild_id/channel_id 推断会话信息
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if ctx.guild_id is not None
|
||||
else MessageType.FRIEND_MESSAGE
|
||||
)
|
||||
abm.group_id = str(ctx.channel_id)
|
||||
|
||||
abm.message_str = message_str_for_filter
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(ctx.author.id),
|
||||
|
||||
@@ -3,13 +3,10 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
|
||||
@@ -125,44 +122,23 @@ class LarkPlatformAdapter(Platform):
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
id_type = "chat_id"
|
||||
if "%" in session.session_id:
|
||||
session.session_id = session.session_id.split("%")[1]
|
||||
receive_id = session.session_id
|
||||
if "%" in receive_id:
|
||||
receive_id = receive_id.split("%")[1]
|
||||
else:
|
||||
id_type = "open_id"
|
||||
receive_id = session.session_id
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(session.session_id)
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build(),
|
||||
)
|
||||
.build()
|
||||
# 复用 LarkMessageEvent 中的通用发送逻辑
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message_chain,
|
||||
self.lark_api,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=id_type,
|
||||
)
|
||||
|
||||
response = await self.lark_api.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
|
||||
@@ -6,6 +6,8 @@ from io import BytesIO
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateFileRequest,
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
@@ -17,10 +19,15 @@ from lark_oapi.api.im.v1 import (
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Plain
|
||||
from astrbot.api.message_components import At, File, Plain, Record, Video
|
||||
from astrbot.api.message_components import Image as AstrBotImage
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.media_utils import (
|
||||
convert_audio_to_opus,
|
||||
convert_video_format,
|
||||
get_media_duration,
|
||||
)
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -35,6 +42,144 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _send_im_message(
|
||||
lark_client: lark.Client,
|
||||
*,
|
||||
content: str,
|
||||
msg_type: str,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
) -> bool:
|
||||
"""发送飞书 IM 消息的通用辅助函数
|
||||
|
||||
Args:
|
||||
lark_client: 飞书客户端
|
||||
content: 消息内容(JSON字符串)
|
||||
msg_type: 消息类型(post/file/audio/media等)
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
return False
|
||||
|
||||
if reply_message_id:
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(reply_message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.reply_in_thread(False)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.message.areply(request)
|
||||
else:
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
)
|
||||
|
||||
if receive_id_type is None or receive_id is None:
|
||||
logger.error(
|
||||
"[Lark] 主动发送消息时,receive_id 和 receive_id_type 不能为空",
|
||||
)
|
||||
return False
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.content(content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"[Lark] 发送飞书消息失败({response.code}): {response.msg}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _upload_lark_file(
|
||||
lark_client: lark.Client,
|
||||
*,
|
||||
path: str,
|
||||
file_type: str,
|
||||
duration: int | None = None,
|
||||
) -> str | None:
|
||||
"""上传文件到飞书的通用辅助函数
|
||||
|
||||
Args:
|
||||
lark_client: 飞书客户端
|
||||
path: 文件路径
|
||||
file_type: 文件类型(stream/opus/mp4等)
|
||||
duration: 媒体时长(毫秒),可选
|
||||
|
||||
Returns:
|
||||
成功返回file_key,失败返回None
|
||||
"""
|
||||
if not path or not os.path.exists(path):
|
||||
logger.error(f"[Lark] 文件不存在: {path}")
|
||||
return None
|
||||
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法上传文件")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, "rb") as file_obj:
|
||||
body_builder = (
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(os.path.basename(path))
|
||||
.file(file_obj)
|
||||
)
|
||||
if duration is not None:
|
||||
body_builder.duration(duration)
|
||||
|
||||
request = (
|
||||
CreateFileRequest.builder()
|
||||
.request_body(body_builder.build())
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.file.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(
|
||||
f"[Lark] 无法上传文件({response.code}): {response.msg}"
|
||||
)
|
||||
return None
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
|
||||
return None
|
||||
|
||||
file_key = response.data.file_key
|
||||
logger.debug(f"[Lark] 文件上传成功: {file_key}")
|
||||
return file_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开或上传文件: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list:
|
||||
ret = []
|
||||
@@ -103,6 +248,18 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
ret.append(_stage)
|
||||
ret.append([{"tag": "img", "image_key": image_key}])
|
||||
_stage.clear()
|
||||
elif isinstance(comp, File):
|
||||
# 文件将通过 _send_file_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到文件组件,将单独发送")
|
||||
continue
|
||||
elif isinstance(comp, Record):
|
||||
# 音频将通过 _send_audio_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到音频组件,将单独发送")
|
||||
continue
|
||||
elif isinstance(comp, Video):
|
||||
# 视频将通过 _send_media_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到视频组件,将单独发送")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
|
||||
|
||||
@@ -110,40 +267,270 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
ret.append(_stage)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
@staticmethod
|
||||
async def send_message_chain(
|
||||
message_chain: MessageChain,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""通用的消息链发送方法
|
||||
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.reply_in_thread(False)
|
||||
.build(),
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||
Args:
|
||||
message_chain: 要发送的消息链
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型,如 'open_id', 'chat_id'(用于主动发送)
|
||||
"""
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
return
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
# 分离文件、音频、视频组件和其他组件
|
||||
file_components: list[File] = []
|
||||
audio_components: list[Record] = []
|
||||
media_components: list[Video] = []
|
||||
other_components = []
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, File):
|
||||
file_components.append(comp)
|
||||
elif isinstance(comp, Record):
|
||||
audio_components.append(comp)
|
||||
elif isinstance(comp, Video):
|
||||
media_components.append(comp)
|
||||
else:
|
||||
other_components.append(comp)
|
||||
|
||||
# 先发送非文件内容(如果有)
|
||||
if other_components:
|
||||
temp_chain = MessageChain()
|
||||
temp_chain.chain = other_components
|
||||
res = await LarkMessageEvent._convert_to_lark(temp_chain, lark_client)
|
||||
|
||||
if res: # 只在有内容时发送
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps(wrapped),
|
||||
msg_type="post",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
# 发送附件
|
||||
for file_comp in file_components:
|
||||
await LarkMessageEvent._send_file_message(
|
||||
file_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
for audio_comp in audio_components:
|
||||
await LarkMessageEvent._send_audio_message(
|
||||
audio_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
for media_comp in media_components:
|
||||
await LarkMessageEvent._send_media_message(
|
||||
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message,
|
||||
self.bot,
|
||||
reply_message_id=self.message_obj.message_id,
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
async def _send_file_message(
|
||||
file_comp: File,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送文件消息
|
||||
|
||||
Args:
|
||||
file_comp: 文件组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
file_path = file_comp.file or ""
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client, path=file_path, file_type="stream"
|
||||
)
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
content = json.dumps({"file_key": file_key})
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=content,
|
||||
msg_type="file",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_audio_message(
|
||||
audio_comp: Record,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送音频消息
|
||||
|
||||
Args:
|
||||
audio_comp: 音频组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
# 获取音频文件路径
|
||||
try:
|
||||
original_audio_path = await audio_comp.convert_to_file_path()
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法获取音频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_audio_path or not os.path.exists(original_audio_path):
|
||||
logger.error(f"[Lark] 音频文件不存在: {original_audio_path}")
|
||||
return
|
||||
|
||||
# 转换为opus格式
|
||||
converted_audio_path = None
|
||||
try:
|
||||
audio_path = await convert_audio_to_opus(original_audio_path)
|
||||
# 如果转换后路径与原路径不同,说明生成了新文件
|
||||
if audio_path != original_audio_path:
|
||||
converted_audio_path = audio_path
|
||||
else:
|
||||
audio_path = original_audio_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 音频格式转换失败,将尝试直接上传: {e}")
|
||||
# 如果转换失败,继续尝试直接上传原始文件
|
||||
audio_path = original_audio_path
|
||||
|
||||
# 获取音频时长
|
||||
duration = await get_media_duration(audio_path)
|
||||
|
||||
# 上传音频文件
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client,
|
||||
path=audio_path,
|
||||
file_type="opus",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
# 清理转换后的临时音频文件
|
||||
if converted_audio_path and os.path.exists(converted_audio_path):
|
||||
try:
|
||||
os.remove(converted_audio_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Lark] 删除转换后的音频文件失败: {e}")
|
||||
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps({"file_key": file_key}),
|
||||
msg_type="audio",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_media_message(
|
||||
media_comp: Video,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送视频消息
|
||||
|
||||
Args:
|
||||
media_comp: 视频组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
# 获取视频文件路径
|
||||
try:
|
||||
original_video_path = await media_comp.convert_to_file_path()
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法获取视频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_video_path or not os.path.exists(original_video_path):
|
||||
logger.error(f"[Lark] 视频文件不存在: {original_video_path}")
|
||||
return
|
||||
|
||||
# 转换为mp4格式
|
||||
converted_video_path = None
|
||||
try:
|
||||
video_path = await convert_video_format(original_video_path, "mp4")
|
||||
# 如果转换后路径与原路径不同,说明生成了新文件
|
||||
if video_path != original_video_path:
|
||||
converted_video_path = video_path
|
||||
else:
|
||||
video_path = original_video_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 视频格式转换失败,将尝试直接上传: {e}")
|
||||
# 如果转换失败,继续尝试直接上传原始文件
|
||||
video_path = original_video_path
|
||||
|
||||
# 获取视频时长
|
||||
duration = await get_media_duration(video_path)
|
||||
|
||||
# 上传视频文件
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client,
|
||||
path=video_path,
|
||||
file_type="mp4",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
# 清理转换后的临时视频文件
|
||||
if converted_video_path and os.path.exists(converted_video_path):
|
||||
try:
|
||||
os.remove(converted_video_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Lark] 删除转换后的视频文件失败: {e}")
|
||||
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps({"file_key": file_key}),
|
||||
msg_type="media",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
|
||||
@@ -29,43 +29,11 @@ class QueueListener:
|
||||
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
|
||||
self.webchat_queue_mgr = webchat_queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, conversation_id: str):
|
||||
"""Listen to a specific conversation queue"""
|
||||
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message from conversation {conversation_id}: {e}",
|
||||
)
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""Monitor for new conversation queues and start listeners"""
|
||||
monitored_conversations = set()
|
||||
|
||||
while True:
|
||||
# Check for new conversations
|
||||
current_conversations = set(self.webchat_queue_mgr.queues.keys())
|
||||
new_conversations = current_conversations - monitored_conversations
|
||||
|
||||
# Start listeners for new conversations
|
||||
for conversation_id in new_conversations:
|
||||
task = asyncio.create_task(self.listen_to_queue(conversation_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_conversations.add(conversation_id)
|
||||
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||
|
||||
# Clean up monitored conversations that no longer exist
|
||||
removed_conversations = monitored_conversations - current_conversations
|
||||
monitored_conversations -= removed_conversations
|
||||
|
||||
await asyncio.sleep(1) # Check for new conversations every second
|
||||
"""Register callback and keep adapter task alive."""
|
||||
self.webchat_queue_mgr.set_listener(self.callback)
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
@register_platform_adapter("webchat", "webchat")
|
||||
|
||||
@@ -26,8 +26,12 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
session_id: str,
|
||||
streaming: bool = False,
|
||||
) -> str | None:
|
||||
cid = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
request_id = str(message_id)
|
||||
conversation_id = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
if not message:
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -124,9 +128,13 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
reasoning_content = ""
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
message_id = self.message_obj.message_id
|
||||
request_id = str(message_id)
|
||||
conversation_id = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
async for chain in generator:
|
||||
# 处理音频流(Live Mode)
|
||||
if chain.type == "audio_chunk":
|
||||
|
||||
@@ -1,35 +1,147 @@
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class WebChatQueueMgr:
|
||||
def __init__(self) -> None:
|
||||
self.queues = {}
|
||||
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
"""Conversation ID to asyncio.Queue mapping"""
|
||||
self.back_queues = {}
|
||||
"""Conversation ID to asyncio.Queue mapping for responses"""
|
||||
self.back_queues: dict[str, asyncio.Queue] = {}
|
||||
"""Request ID to asyncio.Queue mapping for responses"""
|
||||
self._conversation_back_requests: dict[str, set[str]] = {}
|
||||
self._request_conversation: dict[str, str] = {}
|
||||
self._queue_close_events: dict[str, asyncio.Event] = {}
|
||||
self._listener_tasks: dict[str, asyncio.Task] = {}
|
||||
self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.back_queue_maxsize = back_queue_maxsize
|
||||
|
||||
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||
"""Get or create a queue for the given conversation ID"""
|
||||
if conversation_id not in self.queues:
|
||||
self.queues[conversation_id] = asyncio.Queue()
|
||||
self.queues[conversation_id] = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._queue_close_events[conversation_id] = asyncio.Event()
|
||||
self._start_listener_if_needed(conversation_id)
|
||||
return self.queues[conversation_id]
|
||||
|
||||
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||
"""Get or create a back queue for the given conversation ID"""
|
||||
if conversation_id not in self.back_queues:
|
||||
self.back_queues[conversation_id] = asyncio.Queue()
|
||||
return self.back_queues[conversation_id]
|
||||
def get_or_create_back_queue(
|
||||
self,
|
||||
request_id: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Get or create a back queue for the given request ID"""
|
||||
if request_id not in self.back_queues:
|
||||
self.back_queues[request_id] = asyncio.Queue(
|
||||
maxsize=self.back_queue_maxsize
|
||||
)
|
||||
if conversation_id:
|
||||
self._request_conversation[request_id] = conversation_id
|
||||
if conversation_id not in self._conversation_back_requests:
|
||||
self._conversation_back_requests[conversation_id] = set()
|
||||
self._conversation_back_requests[conversation_id].add(request_id)
|
||||
return self.back_queues[request_id]
|
||||
|
||||
def remove_back_queue(self, request_id: str):
|
||||
"""Remove back queue for the given request ID"""
|
||||
self.back_queues.pop(request_id, None)
|
||||
conversation_id = self._request_conversation.pop(request_id, None)
|
||||
if conversation_id:
|
||||
request_ids = self._conversation_back_requests.get(conversation_id)
|
||||
if request_ids is not None:
|
||||
request_ids.discard(request_id)
|
||||
if not request_ids:
|
||||
self._conversation_back_requests.pop(conversation_id, None)
|
||||
|
||||
def remove_queues(self, conversation_id: str):
|
||||
"""Remove queues for the given conversation ID"""
|
||||
if conversation_id in self.queues:
|
||||
del self.queues[conversation_id]
|
||||
if conversation_id in self.back_queues:
|
||||
del self.back_queues[conversation_id]
|
||||
for request_id in list(
|
||||
self._conversation_back_requests.get(conversation_id, set())
|
||||
):
|
||||
self.remove_back_queue(request_id)
|
||||
self._conversation_back_requests.pop(conversation_id, None)
|
||||
self.remove_queue(conversation_id)
|
||||
|
||||
def remove_queue(self, conversation_id: str):
|
||||
"""Remove input queue and listener for the given conversation ID"""
|
||||
self.queues.pop(conversation_id, None)
|
||||
|
||||
close_event = self._queue_close_events.pop(conversation_id, None)
|
||||
if close_event is not None:
|
||||
close_event.set()
|
||||
|
||||
task = self._listener_tasks.pop(conversation_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def has_queue(self, conversation_id: str) -> bool:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
def set_listener(
|
||||
self,
|
||||
callback: Callable[[tuple], Awaitable[None]],
|
||||
):
|
||||
self._listener_callback = callback
|
||||
for conversation_id in list(self.queues.keys()):
|
||||
self._start_listener_if_needed(conversation_id)
|
||||
|
||||
def _start_listener_if_needed(self, conversation_id: str):
|
||||
if self._listener_callback is None:
|
||||
return
|
||||
if conversation_id in self._listener_tasks:
|
||||
task = self._listener_tasks[conversation_id]
|
||||
if not task.done():
|
||||
return
|
||||
queue = self.queues.get(conversation_id)
|
||||
close_event = self._queue_close_events.get(conversation_id)
|
||||
if queue is None or close_event is None:
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._listen_to_queue(conversation_id, queue, close_event),
|
||||
name=f"webchat_listener_{conversation_id}",
|
||||
)
|
||||
self._listener_tasks[conversation_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self._listener_tasks.pop(conversation_id, None)
|
||||
)
|
||||
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||
|
||||
async def _listen_to_queue(
|
||||
self,
|
||||
conversation_id: str,
|
||||
queue: asyncio.Queue,
|
||||
close_event: asyncio.Event,
|
||||
):
|
||||
while True:
|
||||
get_task = asyncio.create_task(queue.get())
|
||||
close_task = asyncio.create_task(close_event.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, close_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if close_task in done:
|
||||
break
|
||||
data = get_task.result()
|
||||
if self._listener_callback is None:
|
||||
continue
|
||||
try:
|
||||
await self._listener_callback(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message from conversation {conversation_id}: {e}"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
if not get_task.done():
|
||||
get_task.cancel()
|
||||
if not close_task.done():
|
||||
close_task.cancel()
|
||||
|
||||
|
||||
webchat_queue_mgr = WebChatQueueMgr()
|
||||
|
||||
@@ -51,44 +51,13 @@ class WecomAIQueueListener:
|
||||
) -> None:
|
||||
self.queue_mgr = queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, session_id: str):
|
||||
"""监听特定会话的队列"""
|
||||
queue = self.queue_mgr.get_or_create_queue(session_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""监控新会话队列并启动监听器"""
|
||||
monitored_sessions = set()
|
||||
|
||||
"""注册监听回调并定期清理过期响应。"""
|
||||
self.queue_mgr.set_listener(self.callback)
|
||||
while True:
|
||||
# 检查新会话
|
||||
current_sessions = set(self.queue_mgr.queues.keys())
|
||||
new_sessions = current_sessions - monitored_sessions
|
||||
|
||||
# 为新会话启动监听器
|
||||
for session_id in new_sessions:
|
||||
task = asyncio.create_task(self.listen_to_queue(session_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_sessions.add(session_id)
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
# 清理已不存在的会话
|
||||
removed_sessions = monitored_sessions - current_sessions
|
||||
monitored_sessions -= removed_sessions
|
||||
|
||||
# 清理过期的待处理响应
|
||||
self.queue_mgr.cleanup_expired_responses()
|
||||
|
||||
await asyncio.sleep(1) # 每秒检查一次新会话
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -212,7 +181,12 @@ class WecomAIBotAdapter(Platform):
|
||||
# wechat server is requesting for updates of a stream
|
||||
stream_id = message_data["stream"]["id"]
|
||||
if not self.queue_mgr.has_back_queue(stream_id):
|
||||
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
if self.queue_mgr.is_stream_finished(stream_id):
|
||||
logger.debug(
|
||||
f"Stream already finished, returning end message: {stream_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
|
||||
# 返回结束标志,告诉微信服务器流已结束
|
||||
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
@@ -243,10 +217,10 @@ class WecomAIBotAdapter(Platform):
|
||||
latest_plain_content = msg["data"] or ""
|
||||
elif msg["type"] == "image":
|
||||
image_base64.append(msg["image_data"])
|
||||
elif msg["type"] == "end":
|
||||
elif msg["type"] in {"end", "complete"}:
|
||||
# stream end
|
||||
finish = True
|
||||
self.queue_mgr.remove_queues(stream_id)
|
||||
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -12,7 +13,7 @@ from astrbot.api import logger
|
||||
class WecomAIQueueMgr:
|
||||
"""企业微信智能机器人队列管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
|
||||
|
||||
@@ -21,6 +22,13 @@ class WecomAIQueueMgr:
|
||||
|
||||
self.pending_responses: dict[str, dict[str, Any]] = {}
|
||||
"""待处理的响应缓存,用于流式响应"""
|
||||
self.completed_streams: dict[str, float] = {}
|
||||
"""已结束的 stream 缓存,用于兼容平台后续重复轮询"""
|
||||
self._queue_close_events: dict[str, asyncio.Event] = {}
|
||||
self._listener_tasks: dict[str, asyncio.Task] = {}
|
||||
self._listener_callback: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.back_queue_maxsize = back_queue_maxsize
|
||||
|
||||
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输入队列
|
||||
@@ -33,7 +41,9 @@ class WecomAIQueueMgr:
|
||||
|
||||
"""
|
||||
if session_id not in self.queues:
|
||||
self.queues[session_id] = asyncio.Queue()
|
||||
self.queues[session_id] = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._queue_close_events[session_id] = asyncio.Event()
|
||||
self._start_listener_if_needed(session_id)
|
||||
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
|
||||
return self.queues[session_id]
|
||||
|
||||
@@ -48,20 +58,21 @@ class WecomAIQueueMgr:
|
||||
|
||||
"""
|
||||
if session_id not in self.back_queues:
|
||||
self.back_queues[session_id] = asyncio.Queue()
|
||||
self.back_queues[session_id] = asyncio.Queue(
|
||||
maxsize=self.back_queue_maxsize
|
||||
)
|
||||
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
|
||||
return self.back_queues[session_id]
|
||||
|
||||
def remove_queues(self, session_id: str):
|
||||
def remove_queues(self, session_id: str, mark_finished: bool = False):
|
||||
"""移除指定会话的所有队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
mark_finished: 是否标记为已正常结束
|
||||
|
||||
"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
self.remove_queue(session_id)
|
||||
|
||||
if session_id in self.back_queues:
|
||||
del self.back_queues[session_id]
|
||||
@@ -70,6 +81,23 @@ class WecomAIQueueMgr:
|
||||
if session_id in self.pending_responses:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||
if mark_finished:
|
||||
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||
|
||||
def remove_queue(self, session_id: str):
|
||||
"""仅移除输入队列和对应监听任务"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
|
||||
close_event = self._queue_close_events.pop(session_id, None)
|
||||
if close_event is not None:
|
||||
close_event.set()
|
||||
|
||||
task = self._listener_tasks.pop(session_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def has_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的队列
|
||||
@@ -121,6 +149,20 @@ class WecomAIQueueMgr:
|
||||
"""
|
||||
return self.pending_responses.get(session_id)
|
||||
|
||||
def is_stream_finished(
|
||||
self,
|
||||
session_id: str,
|
||||
max_age_seconds: int = 60,
|
||||
) -> bool:
|
||||
"""判断 stream 是否在短期内已结束"""
|
||||
finished_at = self.completed_streams.get(session_id)
|
||||
if finished_at is None:
|
||||
return False
|
||||
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def cleanup_expired_responses(self, max_age_seconds: int = 300):
|
||||
"""清理过期的待处理响应
|
||||
|
||||
@@ -136,8 +178,75 @@ class WecomAIQueueMgr:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
|
||||
self.remove_queues(session_id)
|
||||
logger.debug(f"[WecomAI] 清理过期响应及队列: {session_id}")
|
||||
expired_finished = [
|
||||
session_id
|
||||
for session_id, finished_at in self.completed_streams.items()
|
||||
if current_time - finished_at > 60
|
||||
]
|
||||
for session_id in expired_finished:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
|
||||
def set_listener(
|
||||
self,
|
||||
callback: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
self._listener_callback = callback
|
||||
for session_id in list(self.queues.keys()):
|
||||
self._start_listener_if_needed(session_id)
|
||||
|
||||
def _start_listener_if_needed(self, session_id: str):
|
||||
if self._listener_callback is None:
|
||||
return
|
||||
if session_id in self._listener_tasks:
|
||||
task = self._listener_tasks[session_id]
|
||||
if not task.done():
|
||||
return
|
||||
queue = self.queues.get(session_id)
|
||||
close_event = self._queue_close_events.get(session_id)
|
||||
if queue is None or close_event is None:
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._listen_to_queue(session_id, queue, close_event),
|
||||
name=f"wecomai_listener_{session_id}",
|
||||
)
|
||||
self._listener_tasks[session_id] = task
|
||||
task.add_done_callback(lambda _: self._listener_tasks.pop(session_id, None))
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
async def _listen_to_queue(
|
||||
self,
|
||||
session_id: str,
|
||||
queue: asyncio.Queue,
|
||||
close_event: asyncio.Event,
|
||||
):
|
||||
while True:
|
||||
get_task = asyncio.create_task(queue.get())
|
||||
close_task = asyncio.create_task(close_event.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, close_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if close_task in done:
|
||||
break
|
||||
data = get_task.result()
|
||||
if self._listener_callback is None:
|
||||
continue
|
||||
try:
|
||||
await self._listener_callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
if not get_task.done():
|
||||
get_task.cancel()
|
||||
if not close_task.done():
|
||||
close_task.cancel()
|
||||
|
||||
def get_stats(self) -> dict[str, int]:
|
||||
"""获取队列统计信息
|
||||
|
||||
@@ -63,7 +63,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""媒体文件处理工具
|
||||
|
||||
提供音视频格式转换、时长获取等功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
async def get_media_duration(file_path: str) -> int | None:
|
||||
"""使用ffprobe获取媒体文件时长
|
||||
|
||||
Args:
|
||||
file_path: 媒体文件路径
|
||||
|
||||
Returns:
|
||||
时长(毫秒),如果获取失败返回None
|
||||
"""
|
||||
try:
|
||||
# 使用ffprobe获取时长
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
file_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0 and stdout:
|
||||
duration_seconds = float(stdout.decode().strip())
|
||||
duration_ms = int(duration_seconds * 1000)
|
||||
logger.debug(f"[Media Utils] 获取媒体时长: {duration_ms}ms")
|
||||
return duration_ms
|
||||
else:
|
||||
logger.warning(f"[Media Utils] 无法获取媒体文件时长: {file_path}")
|
||||
return None
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(
|
||||
"[Media Utils] ffprobe未安装或不在PATH中,无法获取媒体时长。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Media Utils] 获取媒体时长时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) -> str:
|
||||
"""使用ffmpeg将音频转换为opus格式
|
||||
|
||||
Args:
|
||||
audio_path: 原始音频文件路径
|
||||
output_path: 输出文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
转换后的opus文件路径
|
||||
|
||||
Raises:
|
||||
Exception: 转换失败时抛出异常
|
||||
"""
|
||||
# 如果已经是opus格式,直接返回
|
||||
if audio_path.lower().endswith(".opus"):
|
||||
return audio_path
|
||||
|
||||
# 生成输出文件路径
|
||||
if output_path is None:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus")
|
||||
|
||||
try:
|
||||
# 使用ffmpeg转换为opus格式
|
||||
# -y: 覆盖输出文件
|
||||
# -i: 输入文件
|
||||
# -acodec libopus: 使用opus编码器
|
||||
# -ac 1: 单声道
|
||||
# -ar 16000: 采样率16kHz
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
audio_path,
|
||||
"-acodec",
|
||||
"libopus",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
"16000",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
f"[Media Utils] 已清理失败的opus输出文件: {output_path}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(f"[Media Utils] 清理失败的opus输出文件时出错: {e}")
|
||||
|
||||
error_msg = stderr.decode() if stderr else "未知错误"
|
||||
logger.error(f"[Media Utils] ffmpeg转换音频失败: {error_msg}")
|
||||
raise Exception(f"ffmpeg conversion failed: {error_msg}")
|
||||
|
||||
logger.debug(f"[Media Utils] 音频转换成功: {audio_path} -> {output_path}")
|
||||
return output_path
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换音频格式。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
raise Exception("ffmpeg not found")
|
||||
except Exception as e:
|
||||
logger.error(f"[Media Utils] 转换音频格式时出错: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def convert_video_format(
|
||||
video_path: str, output_format: str = "mp4", output_path: str | None = None
|
||||
) -> str:
|
||||
"""使用ffmpeg转换视频格式
|
||||
|
||||
Args:
|
||||
video_path: 原始视频文件路径
|
||||
output_format: 目标格式,默认mp4
|
||||
output_path: 输出文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
转换后的视频文件路径
|
||||
|
||||
Raises:
|
||||
Exception: 转换失败时抛出异常
|
||||
"""
|
||||
# 如果已经是目标格式,直接返回
|
||||
if video_path.lower().endswith(f".{output_format}"):
|
||||
return video_path
|
||||
|
||||
# 生成输出文件路径
|
||||
if output_path is None:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}")
|
||||
|
||||
try:
|
||||
# 使用ffmpeg转换视频格式
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
video_path,
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-c:a",
|
||||
"aac",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
f"[Media Utils] 已清理失败的{output_format}输出文件: {output_path}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"[Media Utils] 清理失败的{output_format}输出文件时出错: {e}"
|
||||
)
|
||||
|
||||
error_msg = stderr.decode() if stderr else "未知错误"
|
||||
logger.error(f"[Media Utils] ffmpeg转换视频失败: {error_msg}")
|
||||
raise Exception(f"ffmpeg conversion failed: {error_msg}")
|
||||
|
||||
logger.debug(f"[Media Utils] 视频转换成功: {video_path} -> {output_path}")
|
||||
return output_path
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换视频格式。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
raise Exception("ffmpeg not found")
|
||||
except Exception as e:
|
||||
logger.error(f"[Media Utils] 转换视频格式时出错: {e}")
|
||||
raise
|
||||
@@ -23,7 +23,7 @@ class SharedPreferences:
|
||||
)
|
||||
self.path = json_storage_path
|
||||
self.db_helper = db_helper
|
||||
self.temorary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
self.temporary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
"""automatically clear per 24 hours. Might be helpful in some cases XD"""
|
||||
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
@@ -37,7 +37,7 @@ class SharedPreferences:
|
||||
self._scheduler.start()
|
||||
|
||||
def _clear_temporary_cache(self):
|
||||
self.temorary_cache.clear()
|
||||
self.temporary_cache.clear()
|
||||
|
||||
async def get_async(
|
||||
self,
|
||||
|
||||
@@ -238,6 +238,7 @@ class ChatRoute(Route):
|
||||
Returns:
|
||||
包含 used 列表的字典,记录被引用的搜索结果
|
||||
"""
|
||||
supported = ["web_search_tavily", "web_search_bocha"]
|
||||
# 从 accumulated_parts 中找到所有 web_search_tavily 的工具调用结果
|
||||
web_search_results = {}
|
||||
tool_call_parts = [
|
||||
@@ -248,7 +249,7 @@ class ChatRoute(Route):
|
||||
|
||||
for part in tool_call_parts:
|
||||
for tool_call in part["tool_calls"]:
|
||||
if tool_call.get("name") != "web_search_tavily" or not tool_call.get(
|
||||
if tool_call.get("name") not in supported or not tool_call.get(
|
||||
"result"
|
||||
):
|
||||
continue
|
||||
@@ -278,7 +279,7 @@ class ChatRoute(Route):
|
||||
if ref_index not in web_search_results:
|
||||
continue
|
||||
payload = {"index": ref_index, **web_search_results[ref_index]}
|
||||
if favicon := sp.temorary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
payload["favicon"] = favicon
|
||||
used_refs.append(payload)
|
||||
|
||||
@@ -353,12 +354,15 @@ class ChatRoute(Route):
|
||||
return Response().error("session_id is empty").__dict__
|
||||
|
||||
webchat_conv_id = session_id
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
||||
|
||||
# 构建用户消息段(包含 path 用于传递给 adapter)
|
||||
message_parts = await self._build_user_message_parts(message)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
message_id,
|
||||
webchat_conv_id,
|
||||
)
|
||||
|
||||
async def stream():
|
||||
client_disconnected = False
|
||||
@@ -531,6 +535,8 @@ class ChatRoute(Route):
|
||||
refs = {}
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
# 将消息放入会话特定的队列
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||
|
||||
@@ -23,7 +23,7 @@ class CronRoute(Route):
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
def _serialize_job(self, job):
|
||||
def _serialize_job(self, job) -> dict:
|
||||
data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__
|
||||
for k in ["created_at", "updated_at", "last_run_at", "next_run_time"]:
|
||||
if isinstance(data.get(k), datetime):
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
from quart import request
|
||||
@@ -75,7 +76,7 @@ class KnowledgeBaseRoute(Route):
|
||||
}
|
||||
|
||||
def _set_task_result(
|
||||
self, task_id: str, status: str, result: any = None, error: str | None = None
|
||||
self, task_id: str, status: str, result: Any = None, error: str | None = None
|
||||
) -> None:
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": status,
|
||||
|
||||
@@ -256,143 +256,148 @@ class LiveChatRoute(Route):
|
||||
await queue.put((session.username, cid, payload))
|
||||
|
||||
# 3. 等待响应并流式发送 TTS 音频
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
|
||||
|
||||
bot_text = ""
|
||||
audio_playing = False
|
||||
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(session, user_text, bot_text)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(
|
||||
session, user_text, bot_text
|
||||
)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
break
|
||||
break
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
|
||||
|
||||
@@ -2,14 +2,13 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from flask.json.provider import DefaultJSONProvider
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config as HyperConfig
|
||||
from psutil._common import addr as psutil_addr
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
|
||||
@@ -29,6 +28,11 @@ from .routes.session_management import SessionManagementRoute
|
||||
from .routes.subagent import SubAgentRoute
|
||||
from .routes.t2i import T2iRoute
|
||||
|
||||
|
||||
class _AddrWithPort(Protocol):
|
||||
port: int
|
||||
|
||||
|
||||
APP: Quart
|
||||
|
||||
|
||||
@@ -168,7 +172,7 @@ class AstrBotDashboard:
|
||||
"""获取占用端口的进程详细信息"""
|
||||
try:
|
||||
for conn in psutil.net_connections(kind="inet"):
|
||||
if cast(psutil_addr, conn.laddr).port == port:
|
||||
if cast(_AddrWithPort, conn.laddr).port == port:
|
||||
try:
|
||||
process = psutil.Process(conn.pid)
|
||||
# 获取详细信息
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
## What's Changed
|
||||
|
||||
### Fix
|
||||
- fix: `fix: messages[x] assistant content must contain at least one part` after tool calling ([#4928](https://github.com/AstrBotDevs/AstrBot/issues/4928)) after tool calls.
|
||||
- fix: TypeError when MCP schema type is a list ([#4867](https://github.com/AstrBotDevs/AstrBot/issues/4867))
|
||||
- fix: Fixed an issue that caused scheduled task execution failures with specific providers 修复特定提供商导致的定时任务执行失败的问题 ([#4872](https://github.com/AstrBotDevs/AstrBot/issues/4872))
|
||||
|
||||
|
||||
### Feature
|
||||
- feat: add bocha web search tool ([#4902](https://github.com/AstrBotDevs/AstrBot/issues/4902))
|
||||
- feat: systemd support ([#4880](https://github.com/AstrBotDevs/AstrBot/issues/4880))
|
||||
@@ -0,0 +1,10 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
- 修复一些原因导致 Tavily WebSearch、Bocha WebSearch 无法使用的问题
|
||||
|
||||
### xinzeng
|
||||
- 飞书支持 Bot 发送文件、图片和视频消息类型。
|
||||
|
||||
### 优化
|
||||
- 优化 WebChat 和 企业微信 AI 会话队列生命周期管理,减少内存泄漏,提高性能。
|
||||
@@ -108,6 +108,10 @@
|
||||
"description": "Tavily API Key",
|
||||
"hint": "Multiple keys can be added for rotation."
|
||||
},
|
||||
"websearch_bocha_key": {
|
||||
"description": "BoCha API Key",
|
||||
"hint": "Multiple keys can be added for rotation."
|
||||
},
|
||||
"websearch_baidu_app_builder_key": {
|
||||
"description": "Baidu Qianfan Smart Cloud APP Builder API Key",
|
||||
"hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"
|
||||
|
||||
@@ -111,6 +111,10 @@
|
||||
"description": "Tavily API Key",
|
||||
"hint": "可添加多个 Key 进行轮询。"
|
||||
},
|
||||
"websearch_bocha_key": {
|
||||
"description": "BoCha API Key",
|
||||
"hint": "可添加多个 Key 进行轮询。"
|
||||
},
|
||||
"websearch_baidu_app_builder_key": {
|
||||
"description": "百度千帆智能云 APP Builder API Key",
|
||||
"hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)"
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.14.4"
|
||||
version = "4.14.6"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
[Unit]
|
||||
Description=AstrBot Service
|
||||
After=network-online.target
|
||||
Wants=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=%h/.local/share/astrbot
|
||||
ExecStart=/usr/bin/sh -c '/usr/bin/astrbot run || { /usr/bin/astrbot init && /usr/bin/astrbot run; }'
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
Environment=PYTHONUNBUFFERED=1
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
Reference in New Issue
Block a user