diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index bae1f37d4..dbc779c95 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -11,24 +11,42 @@ jobs: runs-on: ubuntu-latest steps: - - name: 拉取源码 + - name: Pull The Codes uses: actions/checkout@v3 with: - fetch-depth: 1 + fetch-depth: 0 # Must be 0 so we can fetch tags - - name: 设置 QEMU + - name: Get latest tag (only on manual trigger) + id: get-latest-tag + if: github.event_name == 'workflow_dispatch' + run: | + tag=$(git describe --tags --abbrev=0) + echo "latest_tag=$tag" >> $GITHUB_OUTPUT + + - name: Checkout to latest tag (only on manual trigger) + if: github.event_name == 'workflow_dispatch' + run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }} + + - name: Set QEMU uses: docker/setup-qemu-action@v3 - - name: 设置 Docker Buildx + - name: Set Docker Buildx uses: docker/setup-buildx-action@v3 - - name: 登录到 DockerHub + - name: Log in to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_PASSWORD }} - - name: 构建和推送 Docker hub + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: Soulter + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and Push Docker to DockerHub and Github GHCR uses: docker/build-push-action@v6 with: context: . @@ -36,8 +54,9 @@ jobs: push: true tags: | ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest - ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }} + ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }} + ghcr.io/soulter/astrbot:latest + ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }} - name: Post build notifications run: echo "Docker image has been built and pushed successfully" - diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 1ee0fac7f..98794c9dd 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -99,6 +99,12 @@ class AstrBotConfig(dict): has_new |= self.check_config_integrity( value, conf[key], path + "." + key if path else key ) + for key in list(conf.keys()): + if key not in refer_conf: + path_ = path + "." + key if path else key + logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除") + del conf[key] + has_new = True return has_new def save_config(self, replace_config: Dict = None): diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index c990e9d8d..275926ba9 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -32,6 +32,7 @@ class RespondStage(Stage): Comp.Node: lambda comp: bool(comp.content), # 转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 Comp.File: lambda comp: bool(comp.file_ or comp.url), + Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } async def initialize(self, ctx: PipelineContext): diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 8db301dc8..3cddeccce 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -1,14 +1,15 @@ import asyncio +import base64 import json import os import time from typing import Optional import aiohttp +import anyio import websockets - from astrbot import logger -from astrbot.api.message_components import Plain, Image, At +from astrbot.api.message_components import Plain, Image, At, Record from astrbot.api.platform import Platform, PlatformMetadata from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astrbot_message import ( @@ -22,6 +23,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter from .wechatpadpro_message_event import WeChatPadProMessageEvent +try: + from .xml_data_parser import GeweDataParser +except ImportError as e: + logger.warning( + f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}" + ) + @register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器") class WeChatPadProAdapter(Platform): @@ -59,6 +67,18 @@ class WeChatPadProAdapter(Platform): ) # 持久化文件路径 self.ws_handle_task = None + # 添加图片消息缓存,用于引用消息处理 + self.cached_images = {} + """缓存图片消息。key是NewMsgId (对应引用消息的svrid),value是图片的base64数据""" + # 设置缓存大小限制,避免内存占用过大 + self.max_image_cache = 50 + + # 添加文本消息缓存,用于引用消息处理 + self.cached_texts = {} + """缓存文本消息。key是NewMsgId (对应引用消息的svrid),value是消息文本内容""" + # 设置文本缓存大小限制 + self.max_text_cache = 100 + async def run(self) -> None: """ 启动平台适配器的运行实例。 @@ -102,7 +122,7 @@ class WeChatPadProAdapter(Platform): logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。") await self.terminate() return - + # 登录成功后,连接 WebSocket 接收消息 self.ws_handle_task = asyncio.create_task(self.connect_websocket()) @@ -161,27 +181,21 @@ class WeChatPadProAdapter(Platform): return True # login_state == 3 为离线状态 elif login_state == 3: - logger.info( - "WeChatPadPro 设备不在线。" - ) + logger.info("WeChatPadPro 设备不在线。") return False else: - logger.error( - f"未知的在线状态: {login_state:}" - ) + logger.error(f"未知的在线状态: {login_state:}") return False # Code == 300 为微信退出状态。 elif response.status == 200 and response_data.get("Code") == 300: - logger.info( - "WeChatPadPro 设备已退出。" - ) + logger.info("WeChatPadPro 设备已退出。") return False else: logger.error( f"检查在线状态失败: {response.status}, {response_data}" ) return False - + except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") return False @@ -364,7 +378,9 @@ class WeChatPadProAdapter(Platform): logger.error(f"处理 WebSocket 消息时发生错误: {e}") break except Exception as e: - logger.error(f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。") + logger.error( + f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。" + ) await asyncio.sleep(5) async def handle_websocket_message(self, message: str): @@ -439,7 +455,7 @@ class WeChatPadProAdapter(Platform): ): # 再根据消息类型处理消息内容 await self._process_message_content(abm, raw_message, msg_type, content) - + return abm return None @@ -457,6 +473,7 @@ class WeChatPadProAdapter(Platform): """ if from_user_name == "weixin": return False + at_me = False if "@chatroom" in from_user_name: abm.type = MessageType.GROUP_MESSAGE abm.group_id = from_user_name @@ -478,6 +495,14 @@ class WeChatPadProAdapter(Platform): abm.session_id = f"{from_user_name}_{to_user_name}" else: abm.session_id = from_user_name + + msg_source = raw_message.get("msg_source", "") + if self.wxid in msg_source: + at_me = True + if "在群聊中@了你" in raw_message.get("push_content", ""): + at_me = True + if at_me: + abm.message.insert(0, At(qq=abm.self_id, name="")) else: abm.type = MessageType.FRIEND_MESSAGE abm.group_id = "" @@ -558,6 +583,32 @@ class WeChatPadProAdapter(Platform): logger.error(f"下载图片时发生错误: {e}") return None + async def download_voice( + self, to_user_name: str, new_msg_id: str, bufid: str, length: int + ): + """下载原始音频。""" + url = f"{self.base_url}/message/GetMsgVoice" + params = {"key": self.auth_key} + payload = { + "Bufid": bufid, + "ToUserName": to_user_name, + "NewMsgId": new_msg_id, + "Length": length, + } + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, params=params, json=payload) as response: + if response.status == 200: + return await response.json() + logger.error(f"下载音频失败: {response.status}") + return None + except aiohttp.ClientConnectorError as e: + logger.error(f"连接到 WeChatPadPro 服务失败: {e}") + return None + except Exception as e: + logger.error(f"下载音频时发生错误: {e}") + return None + async def _process_message_content( self, abm: AstrBotMessage, raw_message: dict, msg_type: int, content: str ): @@ -571,28 +622,28 @@ class WeChatPadProAdapter(Platform): if len(parts) == 2: message_content = parts[1] abm.message_str = message_content - + # 检查是否@了机器人,参考 gewechat 的实现方式 # 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符(四分之一空格) at_me = False - + # 检查 msg_source 中是否包含机器人的 wxid # wechatpadpro 的格式: wxid # gewechat 的格式: msg_source = raw_message.get("msg_source", "") if f"{abm.self_id}" in msg_source or f"{abm.self_id}," in msg_source or f",{abm.self_id}" in msg_source: at_me = True - + # 也检查 push_content 中是否有@提示 push_content = raw_message.get("push_content", "") if "在群聊中@了你" in push_content: at_me = True - + if at_me: # 被@了,在消息开头插入At组件(参考gewechat的做法) bot_nickname = await self._get_group_member_nickname(abm.group_id, abm.self_id) abm.message.insert(0, At(qq=abm.self_id, name=bot_nickname or abm.self_id)) - + # 只有当消息内容不仅仅是@时才添加Plain组件 if "\u2005" in message_content: # 检查@之后是否还有其他内容 @@ -613,6 +664,25 @@ class WeChatPadProAdapter(Platform): abm.message.append(Plain(abm.message_str)) else: # 私聊消息 abm.message.append(Plain(abm.message_str)) + + # 缓存文本消息,以便引用消息可以查找 + try: + # 获取msg_id作为缓存的key + new_msg_id = raw_message.get("new_msg_id") + if new_msg_id: + # 限制缓存大小 + if ( + len(self.cached_texts) >= self.max_text_cache + and self.cached_texts + ): + # 删除最早的一条缓存 + oldest_key = next(iter(self.cached_texts)) + self.cached_texts.pop(oldest_key) + + logger.debug(f"缓存文本消息,new_msg_id={new_msg_id}") + self.cached_texts[str(new_msg_id)] = content + except Exception as e: + logger.error(f"缓存文本消息失败: {e}") elif msg_type == 3: # 图片消息 from_user_name = raw_message.get("from_user_name", {}).get("str", "") @@ -626,15 +696,87 @@ class WeChatPadProAdapter(Platform): ) if image_bs64_data: abm.message.append(Image.fromBase64(image_bs64_data)) + # 缓存图片,以便引用消息可以查找 + try: + # 获取msg_id作为缓存的key + new_msg_id = raw_message.get("new_msg_id") + if new_msg_id: + # 限制缓存大小 + if ( + len(self.cached_images) >= self.max_image_cache + and self.cached_images + ): + # 删除最早的一条缓存 + oldest_key = next(iter(self.cached_images)) + self.cached_images.pop(oldest_key) + + logger.debug(f"缓存图片消息,new_msg_id={new_msg_id}") + self.cached_images[str(new_msg_id)] = image_bs64_data + except Exception as e: + logger.error(f"缓存图片消息失败: {e}") elif msg_type == 47: # 视频消息 (注意:表情消息也是 47,需要区分) - logger.warning("收到视频消息,待实现。") + data_parser = GeweDataParser( + content=content, + is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), + raw_message=raw_message, + ) + emoji_message = data_parser.parse_emoji() + if emoji_message is not None: + abm.message.append(emoji_message) elif msg_type == 50: - # 语音/视频 logger.warning("收到语音/视频消息,待实现。") + elif msg_type == 34: + # 语音消息 + bufid = 0 + to_user_name = raw_message.get("to_user_name", {}).get("str", "") + new_msg_id = raw_message.get("new_msg_id") + data_parser = GeweDataParser( + content=content, + is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), + raw_message=raw_message, + ) + + voicemsg = data_parser._format_to_xml().find("voicemsg") + bufid = voicemsg.get("bufid") or "0" + length = int(voicemsg.get("length") or 0) + voice_resp = await self.download_voice( + to_user_name=to_user_name, + new_msg_id=new_msg_id, + bufid=bufid, + length=length, + ) + voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None) + if voice_bs64_data: + voice_bs64_data = base64.b64decode(voice_bs64_data) + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + file_path = os.path.join( + temp_dir, f"wechatpadpro_voice_{abm.message_id}.silk" + ) + + async with await anyio.open_file(file_path, "wb") as f: + await f.write(voice_bs64_data) + abm.message.append(Record(file=file_path, url=file_path)) elif msg_type == 49: - # 引用消息 - logger.warning("收到引用消息,待实现。") + try: + parser = GeweDataParser( + content=content, + is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), + cached_texts=self.cached_texts, + cached_images=self.cached_images, + raw_message=raw_message, + downloader=self._download_raw_image, + ) + components = await parser.parse_mutil_49() + if components: + abm.message.extend(components) + abm.message_str = "\n".join( + c.text for c in components if isinstance(c, Plain) + ) + except Exception as e: + logger.warning(f"msg_type 49 处理失败: {e}") + abm.message.append(Plain("[XML 消息处理失败]")) + abm.message_str = "[XML 消息处理失败]" else: logger.warning(f"收到未处理的消息类型: {msg_type}。") diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 04bb02936..ab836ad28 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -7,11 +7,17 @@ import aiohttp from PIL import Image as PILImage # 使用别名避免冲突 from astrbot import logger -from astrbot.core.message.components import Image, Plain # Import Image +from astrbot.core.message.components import ( + Image, + Plain, + WechatEmoji, + Record, +) # Import Image from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk_base64 if TYPE_CHECKING: from .wechatpadpro_adapter import WeChatPadProAdapter @@ -38,6 +44,10 @@ class WeChatPadProMessageEvent(AstrMessageEvent): await self._send_text(session, comp.text) elif isinstance(comp, Image): await self._send_image(session, comp) + elif isinstance(comp, WechatEmoji): + await self._send_emoji(session, comp) + elif isinstance(comp, Record): + await self._send_voice(session, comp) await super().send(message) async def _send_image(self, session: aiohttp.ClientSession, comp: Image): @@ -73,12 +83,42 @@ class WeChatPadProMessageEvent(AstrMessageEvent): message_text = text payload = { "MsgItem": [ - {"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id} + { + "MsgType": 1, + "TextContent": message_text, + "ToUserName": self.session_id, + } ] } url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) + async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): + payload = { + "EmojiList": [ + { + "EmojiMd5": comp.md5, + "EmojiSize": comp.md5_len, + "ToUserName": self.session_id, + } + ] + } + url = f"{self.adapter.base_url}/message/SendEmojiMessage" + await self._post(session, url, payload) + + async def _send_voice(self, session: aiohttp.ClientSession, comp: Record): + record_path = await comp.convert_to_file_path() + # 默认已经存在 data/temp 中 + b64, duration = await wav_to_tencent_silk_base64(record_path) + payload = { + "ToUserName": self.session_id, + "VoiceData": b64, + "VoiceFormat": 4, + "VoiceSecond": duration, + } + url = f"{self.adapter.base_url}/message/SendVoice" + await self._post(session, url, payload) + @staticmethod def _validate_base64(b64: str) -> bytes: return base64.b64decode(b64, validate=True) diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py new file mode 100644 index 000000000..054ca1b48 --- /dev/null +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -0,0 +1,160 @@ +from defusedxml import ElementTree as eT +from astrbot.api import logger +from astrbot.api.message_components import ( + WechatEmoji as Emoji, + Plain, + Image, + BaseMessageComponent, +) + + +class GeweDataParser: + def __init__( + self, + content: str, + is_private_chat: bool = False, + cached_texts=None, + cached_images=None, + raw_message: dict = None, + downloader=None, + ): + self._xml = None + self.content = content + self.is_private_chat = is_private_chat + self.cached_texts = cached_texts or {} + self.cached_images = cached_images or {} + self.downloader = downloader + + raw_message = raw_message or {} + self.from_user_name = raw_message.get("from_user_name", {}).get("str", "") + self.to_user_name = raw_message.get("to_user_name", {}).get("str", "") + self.msg_id = raw_message.get("msg_id", "") + + def _format_to_xml(self): + if self._xml: + return self._xml + + try: + msg_str = self.content + if not self.is_private_chat: + parts = self.content.split(":\n", 1) + msg_str = parts[1] if len(parts) == 2 else self.content + + self._xml = eT.fromstring(msg_str) + return self._xml + except Exception as e: + logger.error(f"[XML解析失败] {e}") + raise + + async def parse_mutil_49(self) -> list[BaseMessageComponent] | None: + """ + 处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57) + """ + try: + appmsg_type = self._format_to_xml().findtext(".//appmsg/type") + if appmsg_type == "57": + return await self.parse_reply() + except Exception as e: + logger.warning(f"[parse_mutil_49] 解析失败: {e}") + return None + + async def parse_reply(self) -> list[BaseMessageComponent]: + """ + 处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49) + """ + components = [] + + try: + appmsg = self._format_to_xml().find("appmsg") + if appmsg is None: + return [Plain("[引用消息解析失败]")] + + refermsg = appmsg.find("refermsg") + if refermsg is None: + return [Plain("[引用消息解析失败]")] + + quote_type = int(refermsg.findtext("type", "0")) + nickname = refermsg.findtext("displayname", "未知发送者") + quote_content = refermsg.findtext("content", "") + svrid = refermsg.findtext("svrid") + + match quote_type: + case 1: # 文本引用 + quoted_text = self.cached_texts.get(str(svrid), quote_content) + components.append(Plain(f"[引用] {nickname}: {quoted_text}")) + + case 3: # 图片引用 + quoted_image_b64 = self.cached_images.get(str(svrid)) + if not quoted_image_b64: + try: + quote_xml = eT.fromstring(quote_content) + img = quote_xml.find("img") + cdn_url = ( + img.get("cdnbigimgurl") or img.get("cdnmidimgurl") + if img is not None + else None + ) + if cdn_url and self.downloader: + image_resp = await self.downloader( + self.from_user_name, self.to_user_name, self.msg_id + ) + quoted_image_b64 = ( + image_resp.get("Data", {}) + .get("Data", {}) + .get("Buffer") + ) + except Exception as e: + logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}") + + if quoted_image_b64: + components.extend( + [ + Image.fromBase64(quoted_image_b64), + Plain(f"[引用] {nickname}: [引用的图片]"), + ] + ) + else: + components.append( + Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]") + ) + + case 49: # 嵌套引用 + try: + nested_root = eT.fromstring(quote_content) + nested_title = nested_root.findtext(".//appmsg/title", "") + components.append(Plain(f"[引用] {nickname}: {nested_title}")) + except Exception as e: + logger.warning(f"[嵌套引用解析失败] err={e}") + components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]")) + + case _: # 其他未识别类型 + logger.info(f"[未知引用类型] quote_type={quote_type}") + components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]")) + + # 主消息标题 + title = appmsg.findtext("title", "") + if title: + components.append(Plain(title)) + + except Exception as e: + logger.error(f"[parse_reply] 总体解析失败: {e}") + return [Plain("[引用消息解析失败]")] + + return components + + def parse_emoji(self) -> Emoji | None: + """ + 处理 msg_type == 47 的表情消息(emoji) + """ + try: + emoji_element = self._format_to_xml().find(".//emoji") + if emoji_element is not None: + return Emoji( + md5=emoji_element.get("md5"), + md5_len=emoji_element.get("len"), + cdnurl=emoji_element.get("cdnurl"), + ) + except Exception as e: + logger.error(f"[parse_emoji] 解析失败: {e}") + + return None diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 319515c52..c3ad45868 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -104,11 +104,13 @@ class ProviderAnthropic(ProviderOpenAIOfficial): session_id: str = None, image_urls: List[str] = [], func_tool: FuncCall = None, - contexts=[], + contexts=None, system_prompt=None, tool_calls_result: ToolCallsResult = None, **kwargs, ) -> LLMResponse: + if contexts is None: + contexts = [] if not prompt: prompt = "" diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 2c4930692..f719190a1 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -74,6 +74,8 @@ class ProviderDashscope(ProviderOpenAIOfficial): system_prompt: str = None, **kwargs, ) -> LLMResponse: + if contexts is None: + contexts = [] # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index ad0605f14..619aae13d 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -61,12 +61,14 @@ class ProviderDify(Provider): self, prompt: str, session_id: str = None, - image_urls: List[str] = [], + image_urls: List[str] = None, func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, **kwargs, ) -> LLMResponse: + if image_urls is None: + image_urls = [] result = "" conversation_id = self.conversation_ids.get(session_id, "") diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index 85994fd59..8648512d0 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -60,10 +60,12 @@ class LLMTunerModelLoader(Provider): session_id: str = None, image_urls: List[str] = None, func_tool: FuncCall = None, - contexts: List = [], + contexts: List = None, system_prompt: str = None, **kwargs, ) -> LLMResponse: + if contexts is None: + contexts = [] system_prompt = "" new_record = {"role": "user", "content": prompt} query_context = [*contexts, new_record] diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 2f7490317..e7e9d4a14 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -31,10 +31,12 @@ class ProviderZhipu(ProviderOpenAIOfficial): session_id: str = None, image_urls: List[str] = None, func_tool: FuncCall = None, - contexts=[], + contexts=None, system_prompt=None, **kwargs, ) -> LLMResponse: + if contexts is None: + contexts = [] new_record = await self.assemble_context(prompt, image_urls) context_query = [] diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 44e018dfc..13c93f226 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -451,11 +451,11 @@ class PluginManager: metadata.repo = metadata_yaml.repo except Exception: pass - + metadata.config = plugin_config if path not in inactivated_plugins: # 只有没有禁用插件时才实例化插件类 if plugin_config: - metadata.config = plugin_config + # metadata.config = plugin_config try: metadata.star_cls = metadata.star_cls_type( context=self.context, config=plugin_config diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 45f8b8a23..14cb5331a 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -18,7 +18,8 @@ class PluginUpdator(RepoZipUpdator): return self.plugin_store_path async def install(self, repo_url: str, proxy="") -> str: - repo_name = self.format_repo_name(repo_url) + _, repo_name, _ = self.parse_github_url(repo_url) + repo_name = self.format_name(repo_name) plugin_path = os.path.join(self.plugin_store_path, repo_name) await self.download_from_repo_url(plugin_path, repo_url, proxy) self.unzip_file(plugin_path + ".zip", plugin_path) @@ -54,7 +55,7 @@ class PluginUpdator(RepoZipUpdator): def unzip_file(self, zip_path: str, target_dir: str): os.makedirs(target_dir, exist_ok=True) update_dir = "" - logger.info(f"解压文件: {zip_path}") + logger.info(f"正在解压压缩包: {zip_path}") with zipfile.ZipFile(zip_path, "r") as z: update_dir = z.namelist()[0] z.extractall(target_dir) diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index f7b2eb5a4..00886bbf7 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -1,5 +1,10 @@ +import base64 import wave +import os from io import BytesIO +import asyncio +import tempfile +from astrbot.core.utils.astrbot_path import get_astrbot_data_path async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: @@ -50,3 +55,46 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int: rate = wav.getframerate() duration = pilk.encode(wav_path, output_path, pcm_rate=rate, tencent=True) return duration + + +async def wav_to_tencent_silk_base64(wav_path: str) -> str: + """ + 将 WAV 文件转为 Silk,并返回 Base64 字符串。 + 默认采样率为 24000,输出临时文件为 temp/output.silk。 + + 参数: + - wav_path: 输入 .wav 文件路径(需为 PCM 16bit) + + 返回: + - Base64 编码的 Silk 字符串 + - duration: 音频时长(秒) + """ + try: + import pilk + except ImportError as e: + raise Exception("pysilk 模块未安装,请安装 pysilk") from e + + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + os.makedirs(temp_dir, exist_ok=True) + + with wave.open(wav_path, "rb") as wav: + rate = wav.getframerate() + + with tempfile.NamedTemporaryFile( + suffix=".silk", delete=False, dir=temp_dir + ) as tmp_file: + silk_path = tmp_file.name + + try: + duration = await asyncio.to_thread( + pilk.encode, wav_path, silk_path, pcm_rate=rate, tencent=True + ) + + with open(silk_path, "rb") as f: + silk_bytes = await asyncio.to_thread(f.read) + silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") + + return silk_b64, duration # 已是秒 + finally: + if os.path.exists(silk_path): + os.remove(silk_path) diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 137c7444a..2d2b7b834 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -1,5 +1,6 @@ import aiohttp import os +import re import zipfile import shutil @@ -119,28 +120,60 @@ class RepoZipUpdator: ) async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""): - repo_namespace = repo_url.split("/")[-2:] - author = repo_namespace[0] - repo = repo_namespace[1] + author, repo, branch = self.parse_github_url(repo_url) logger.info(f"正在下载更新 {repo} ...") - release_url = f"https://api.github.com/repos/{author}/{repo}/releases" - releases = await self.fetch_release_info(url=release_url) - if not releases: - # download from the default branch directly. - logger.info(f"正在从默认分支下载 {author}/{repo} ") + + if branch: + logger.info(f"正在从指定分支 {branch} 下载 {author}/{repo}") release_url = ( - f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" + f"https://github.com/{author}/{repo}/archive/refs/heads/{branch}.zip" ) else: - release_url = releases[0]["zipball_url"] + try: + release_url = f"https://api.github.com/repos/{author}/{repo}/releases" + releases = await self.fetch_release_info(url=release_url) + except Exception as e: + logger.warning( + f"获取 {author}/{repo} 的 GitHub Releases 失败: {e},将尝试下载默认分支" + ) + releases = [] + if not releases: + # 如果没有最新版本,下载默认分支 + logger.info(f"正在从默认分支下载 {author}/{repo}") + release_url = ( + f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" + ) + else: + release_url = releases[0]["zipball_url"] if proxy: release_url = f"{proxy}/{release_url}" - logger.info(f"使用代理下载: {release_url}") + logger.info( + f"检查到设置了镜像站,将使用镜像站下载 {author}/{repo} 仓库源码: {release_url}" + ) await download_file(release_url, target_path + ".zip") + def parse_github_url(self, url: str): + """使用正则表达式解析 GitHub 仓库 URL,支持 `.git` 后缀和 `tree/branch` 结构 + Returns: + tuple[str, str, str]: 返回作者名、仓库名和分支名 + Raises: + ValueError: 如果 URL 格式不正确 + """ + cleaned_url = url.rstrip("/") + pattern = r"^https://github\.com/([a-zA-Z0-9_-]+)/([a-zA-Z0-9_-]+)(\.git)?(?:/tree/([a-zA-Z0-9_-]+))?$" + match = re.match(pattern, cleaned_url) + + if match: + author = match.group(1) + repo = match.group(2) + branch = match.group(4) + return author, repo, branch + else: + raise ValueError("无效的 GitHub URL") + def unzip_file(self, zip_path: str, target_dir: str): """ 解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir @@ -174,16 +207,5 @@ class RepoZipUpdator: f"删除更新文件失败,可以手动删除 {zip_path} 和 {os.path.join(target_dir, update_dir)}" ) - def format_repo_name(self, repo_url: str) -> str: - if repo_url.endswith("/"): - repo_url = repo_url[:-1] - - repo_namespace = repo_url.split("/")[-2:] - repo = repo_namespace[1] - - repo = self.format_name(repo) - - return repo - def format_name(self, name: str) -> str: return name.replace("-", "_").lower() diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 4b25c977f..8b158caaa 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -174,14 +174,15 @@ class ConfigRoute(Route): """辅助函数:测试单个 provider 的可用性""" meta = provider.meta() provider_name = provider.provider_config.get("id", "Unknown Provider") + logger.debug(f"Got provider meta: {meta}") if not provider_name and meta: provider_name = meta.id elif not provider_name: provider_name = "Unknown Provider" status_info = { - "id": meta.id if meta else "Unknown ID", - "model": meta.model if meta else "Unknown Model", - "type": meta.type if meta else "Unknown Type", + "id": getattr(meta, 'id', 'Unknown ID'), + "model": getattr(meta, 'model', 'Unknown Model'), + "type": getattr(meta, 'type', 'Unknown Type'), "name": provider_name, "status": "unavailable", # 默认为不可用 "error": None, @@ -189,7 +190,7 @@ class ConfigRoute(Route): logger.debug(f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})") try: logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") - response = await asyncio.wait_for(provider.text_chat(prompt="Ping"), timeout=20.0) # 超时 20 秒 + response = await asyncio.wait_for(provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0) logger.debug(f"Received response from {status_info['name']}: {response}") # 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用 if response is not None: @@ -209,7 +210,7 @@ class ConfigRoute(Route): logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.") except asyncio.TimeoutError: - status_info["error"] = "Connection timed out after 10 seconds during test call." + status_info["error"] = "Connection timed out after 45 seconds during test call." logger.warning(f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.") except Exception as e: error_message = str(e) diff --git a/dashboard/src/config.ts b/dashboard/src/config.ts index 9f70123a7..f52812ad8 100644 --- a/dashboard/src/config.ts +++ b/dashboard/src/config.ts @@ -8,10 +8,10 @@ export type ConfigProps = { }; function checkUITheme() { + /* 检查localStorage有无记忆的主题选项,如有则使用,否则使用默认值 */ const theme = localStorage.getItem("uiTheme"); - console.log('memorized theme: ', theme); if (!theme || !(['PurpleTheme', 'PurpleThemeDark'].includes(theme))) { - localStorage.setItem("uiTheme", "PurpleTheme"); + localStorage.setItem("uiTheme", "PurpleTheme"); // todo: 这部分可以根据vuetify.ts的默认主题动态调整 return 'PurpleTheme'; } else return theme; } diff --git a/dashboard/src/layouts/full/FullLayout.vue b/dashboard/src/layouts/full/FullLayout.vue index 0380b350e..a0754cafd 100644 --- a/dashboard/src/layouts/full/FullLayout.vue +++ b/dashboard/src/layouts/full/FullLayout.vue @@ -2,14 +2,13 @@ import { RouterView } from 'vue-router'; import VerticalSidebarVue from './vertical-sidebar/VerticalSidebar.vue'; import VerticalHeaderVue from './vertical-header/VerticalHeader.vue'; -import { useCustomizerStore } from '../../stores/customizer'; +import { useCustomizerStore } from '@/stores/customizer'; const customizer = useCustomizerStore(); @@ -265,6 +267,7 @@ export default { loading_: false, upload_file: null, pluginMarketData: [], + showPluginFullName: false, loadingDialog: { show: false, title: "加载中...", @@ -283,8 +286,8 @@ export default { pluginMarketHeaders: [ { title: '名称', key: 'name', maxWidth: '200px' }, { title: '描述', key: 'desc', maxWidth: '250px' }, - { title: '作者', key: 'author', maxWidth: '70px' }, - { title: 'Star数', key: 'stars', maxWidth: '100px' }, + { title: '作者', key: 'author', maxWidth: '90px' }, + { title: 'Star数', key: 'stars', maxWidth: '80px' }, { title: '最近更新', key: 'updated_at', maxWidth: '100px' }, { title: '标签', key: 'tags', maxWidth: '100px' }, { title: '操作', key: 'actions', sortable: false } @@ -319,6 +322,7 @@ export default { this.loading_ = true this.commonStore.getPluginCollections().then((data) => { this.pluginMarketData = data; + this.trimExtensionName(); this.checkAlreadyInstalled(); this.checkUpdate(); this.loading_ = false @@ -367,11 +371,23 @@ export default { getExtensions() { axios.get('/api/plugin/get').then((res) => { this.extension_data = res.data; + this.trimExtensionName(); this.checkAlreadyInstalled(); this.checkUpdate() }); }, - + trimExtensionName() { + this.pluginMarketData.forEach(plugin => { + if (plugin.name) { + let name = plugin.name.trim().toLowerCase(); + if (name.startsWith("astrbot_plugin_")) { + plugin.trimmedName = name.substring(15); + } else if (name.startsWith("astrbot_") || name.startsWith("astrbot-")) { + plugin.trimmedName = name.substring(8); + } else plugin.trimmedName = plugin.name; + } + }); + }, checkUpdate() { // 创建在线插件的map const onlinePluginsMap = new Map(); diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue index 6936fd7dd..5c9649246 100644 --- a/dashboard/src/views/ExtensionPage.vue +++ b/dashboard/src/views/ExtensionPage.vue @@ -191,6 +191,17 @@ const updateExtension = async (extension_name) => { Object.assign(extension_data, res.data); onLoadingDialogResult(1, res.data.message); + setTimeout(async () => { + toast(`正在刷新插件列表...`, "info", 2000); + try { + await getExtensions(); + toast("插件列表已刷新!", "success"); + + } catch (error) { + const errorMsg = error.response?.data?.message || error.message || String(error); + toast(`刷新插件列表时发生错误: ${errorMsg}`, "error"); + } + }, 1000); } catch (err) { toast(err, "error"); } diff --git a/dashboard/src/views/PlatformPage.vue b/dashboard/src/views/PlatformPage.vue index 014526d60..8426eb048 100644 --- a/dashboard/src/views/PlatformPage.vue +++ b/dashboard/src/views/PlatformPage.vue @@ -110,14 +110,14 @@ :metadata="metadata['platform_group']?.metadata" metadataKey="platform" /> - - + + mdi-refresh 刷新 diff --git a/dashboard/src/views/alkaid/LongTermMemory.vue b/dashboard/src/views/alkaid/LongTermMemory.vue index e534c70d8..762c1f8ac 100644 --- a/dashboard/src/views/alkaid/LongTermMemory.vue +++ b/dashboard/src/views/alkaid/LongTermMemory.vue @@ -1,12 +1,12 @@ @@ -199,6 +292,16 @@ export default { isSearching: false, searchResults: [], hasSearched: false, + + // 添加边点击相关数据 + selectedEdge: null, + selectedEdgeFactId: null, + selectedEdgeFactData: null, + showFactDialog: false, + isLoadingFactData: false, + + // 改进元数据展示 + parsedMetadata: null, } }, mounted() { @@ -393,6 +496,83 @@ export default { this.ltmGetGraph(); }, + // 添加获取Fact详情的方法 + getFactDetails(factId) { + if (!factId) return; + + this.isLoadingFactData = true; + this.selectedEdgeFactData = null; + this.parsedMetadata = null; + + axios.get('/api/plug/alkaid/ltm/graph/fact', { + params: { fact_id: factId } + }) + .then(response => { + if (response.data.status === 'ok') { + this.selectedEdgeFactData = response.data.data; + // 解析元数据 + this.parsedMetadata = this.parseMetadata(this.selectedEdgeFactData.metadata); + this.showFactDialog = true; + } else { + this.$toast.error('获取记忆详情失败: ' + response.data.message); + } + }) + .catch(error => { + console.error('获取记忆详情失败:', error); + this.$toast.error('获取记忆详情失败: ' + (error.response?.data?.message || error.message)); + }) + .finally(() => { + this.isLoadingFactData = false; + }); + }, + + // 添加元数据解析方法 + parseMetadata(metadata) { + if (!metadata) return null; + + try { + // 如果是字符串,尝试解析JSON + if (typeof metadata === 'string') { + try { + return JSON.parse(metadata); + } catch (e) { + return { value: metadata }; // 如果无法解析为JSON,则作为单个值返回 + } + } + + // 如果已经是对象,直接返回 + if (typeof metadata === 'object') { + return metadata; + } + + return { value: String(metadata) }; + } catch (e) { + console.error('解析元数据出错:', e); + return { error: '无法解析元数据' }; + } + }, + + // 格式化元数据值 + formatMetadataValue(value) { + if (value === null || value === undefined) return '无'; + + if (typeof value === 'object') { + return JSON.stringify(value); + } + + return String(value); + }, + + // 格式化时间戳的辅助方法 + formatTime(timestamp) { + if (!timestamp) return '未知'; + try { + return new Date(timestamp).toLocaleString(); + } catch (e) { + return timestamp; + } + }, + initD3Graph() { const container = document.getElementById("graph-container"); if (!container) return; @@ -431,6 +611,8 @@ export default { if (!this.svg || !this.simulation) return; const g = this.g; g.selectAll("*").remove(); + + // 添加箭头定义 g.append("defs").append("marker") .attr("id", "arrowhead") .attr("viewBox", "0 -5 10 10") @@ -442,13 +624,22 @@ export default { .append("path") .attr("d", "M0,-5L10,0L0,5") .attr("fill", "#999"); + + // 预处理边数据,标识和处理重复边 + const linkGroups = this.identifyParallelLinks(this.links); + + // 使用路径替代直线来绘制边,以便支持曲线 const link = g.append("g") - .selectAll("line") + .selectAll("path") .data(this.links) - .join("line") + .join("path") .attr("stroke", d => d.color) .attr("stroke-width", 1.5) - .attr("marker-end", "url(#arrowhead)"); + .attr("fill", "none") + .attr("marker-end", "url(#arrowhead)") + .style("cursor", "pointer"); + + // 边标签需要相应调整位置 const edgeLabels = g.append("g") .selectAll("text") .data(this.links) @@ -457,7 +648,22 @@ export default { .attr("font-size", "8px") .attr("text-anchor", "middle") .attr("fill", "#666") - .attr("dy", -5); + .style("cursor", "pointer") + .on("click", (event, d) => { + event.stopPropagation(); + + // 检查边数据中是否有fact_id + const factId = d.originalData?.fact_id; + if (factId) { + this.selectedEdge = d; + this.selectedEdgeFactId = factId; + this.getFactDetails(factId); + } else { + this.$toast.info('该关系没有关联的记忆数据'); + } + }); + + // 节点绘制部分保持不变 const node = g.append("g") .selectAll("circle") .data(this.nodes) @@ -466,6 +672,7 @@ export default { .attr("fill", d => d.color) .style("cursor", "pointer") .call(this.dragBehavior()); + const nodeLabels = g.append("g") .selectAll("text") .data(this.nodes) @@ -475,27 +682,33 @@ export default { .attr("text-anchor", "middle") .attr("fill", "#333") .attr("dy", -12); + node.on("click", (event, d) => { event.stopPropagation(); this.selectedNode = d.originalData; }); + + // 给SVG添加全局点击事件,用于关闭气泡 this.svg.on("click", () => { this.selectedNode = null; }); + this.simulation .nodes(this.nodes) .on("tick", () => { - link - .attr("x1", d => d.source.x) - .attr("y1", d => d.source.y) - .attr("x2", d => d.target.x) - .attr("y2", d => d.target.y); + // 更新边的路径 + link.attr("d", d => this.generateLinkPath(d)); + + // 更新边标签位置 edgeLabels - .attr("x", d => (d.source.x + d.target.x) / 2) - .attr("y", d => (d.source.y + d.target.y) / 2); + .attr("x", d => this.getLinkLabelX(d)) + .attr("y", d => this.getLinkLabelY(d)); + + // 更新节点位置 node .attr("cx", d => d.x) .attr("cy", d => d.y); + nodeLabels .attr("x", d => d.x) .attr("y", d => d.y); @@ -506,6 +719,175 @@ export default { this.simulation.alpha(1).restart(); }, + + // 识别并标记平行边(连接相同两个节点的多条边) + identifyParallelLinks(links) { + // 创建一个映射来存储连接相同节点对的边 + const linkMap = new Map(); + + // 遍历所有边,按照起点和终点进行分组 + links.forEach(link => { + // 创建边的键,确保无论边的方向如何,同一对节点生成的键都相同 + const sourceId = typeof link.source === 'object' ? link.source.id : link.source; + const targetId = typeof link.target === 'object' ? link.target.id : link.target; + + const forwardKey = `${sourceId}-${targetId}`; + const reverseKey = `${targetId}-${sourceId}`; + + // 判断是从source到target的边还是反向边 + const isForwardLink = sourceId < targetId; + const key = isForwardLink ? forwardKey : reverseKey; + + // 使用方向信息 + if (!linkMap.has(key)) { + linkMap.set(key, []); + } + + // 存储边和其方向 + linkMap.get(key).push({ + link, + isForward: isForwardLink + }); + }); + + // 处理每一组平行边,为它们分配曲率 + linkMap.forEach((parallels, key) => { + if (parallels.length > 1) { + // 有多条平行边,分配不同曲率 + parallels.forEach((item, index) => { + // 根据边的数量计算适当的曲率 + const totalLinks = parallels.length; + // 基础曲率,可根据边数调整 + const baseCurvature = 0.45; + // 根据边的索引计算曲率:中间的边较直,两侧的边较弯 + let curvature; + + if (totalLinks % 2 === 1) { + // 奇数条边,中间的边直线,其他边弯曲 + const middleIndex = Math.floor(totalLinks / 2); + if (index === middleIndex) { + curvature = 0; // 中间的边为直线 + } else { + // 到中间边的距离决定曲率大小 + const distance = Math.abs(index - middleIndex); + const direction = index < middleIndex ? -1 : 1; + curvature = direction * baseCurvature * distance; + } + } else { + // 偶数条边,所有边都弯曲 + const middleIndex = totalLinks / 2 - 0.5; + const distance = Math.abs(index - middleIndex); + const direction = index < middleIndex ? -1 : 1; + curvature = direction * baseCurvature * distance; + } + + // 如果是反向边,翻转曲率方向 + if (!item.isForward) { + curvature = -curvature; + } + + // 存储曲率值到边对象 + item.link.curvature = curvature; + }); + } else { + // 只有一条边,不需要弯曲 + parallels[0].link.curvature = 0; + } + }); + + return linkMap; + }, + + // 根据曲率生成边的路径 + generateLinkPath(d) { + // 确保source和target是对象 + const source = typeof d.source === 'object' ? d.source : this.nodes.find(n => n.id === d.source); + const target = typeof d.target === 'object' ? d.target : this.nodes.find(n => n.id === d.target); + + if (!source || !target) return ''; + + // 如果是直线(无曲率) + if (!d.curvature || d.curvature === 0) { + return `M${source.x},${source.y}L${target.x},${target.y}`; + } + + // 计算曲线的控制点 + const dx = target.x - source.x; + const dy = target.y - source.y; + const dr = Math.sqrt(dx * dx + dy * dy); + + // 控制点偏移距离,由曲率决定 + const offset = dr * d.curvature; + + // 计算中点 + const midX = (source.x + target.x) / 2; + const midY = (source.y + target.y) / 2; + + // 计算垂直于连线的方向向量 + const nx = -dy / dr; + const ny = dx / dr; + + // 计算控制点坐标 + const cpx = midX + offset * nx; + const cpy = midY + offset * ny; + + // 创建二次贝塞尔曲线路径 + return `M${source.x},${source.y} Q${cpx},${cpy} ${target.x},${target.y}`; + }, + + // 新增方法:计算边标签的X坐标 + getLinkLabelX(d) { + const source = typeof d.source === 'object' ? d.source : this.nodes.find(n => n.id === d.source); + const target = typeof d.target === 'object' ? d.target : this.nodes.find(n => n.id === d.target); + + if (!source || !target) return 0; + + // 如果是直线 + if (!d.curvature || d.curvature === 0) { + return (source.x + target.x) / 2; + } + + // 计算曲线上的点 + const dx = target.x - source.x; + const dy = target.y - source.y; + const dr = Math.sqrt(dx * dx + dy * dy); + + // 中点 + const midX = (source.x + target.x) / 2; + + // 垂直向量 + const nx = -dy / dr; + + // 曲线路径上的点,使用曲率进行调整 + return midX + d.curvature * dr * nx * 0.5; + }, + + // 新增方法:计算边标签的Y坐标 + getLinkLabelY(d) { + const source = typeof d.source === 'object' ? d.source : this.nodes.find(n => n.id === d.source); + const target = typeof d.target === 'object' ? d.target : this.nodes.find(n => n.id === d.target); + + if (!source || !target) return 0; + + // 如果是直线 + if (!d.curvature || d.curvature === 0) { + return (source.y + target.y) / 2; + } + + // 计算曲线上的点 + const dx = target.x - source.x; + const dy = target.y - source.y; + const dr = Math.sqrt(dx * dx + dy * dy); + + // 中点 + const midY = (source.y + target.y) / 2; + + // 垂直向量 + const ny = dx / dr; + + // 曲线路径上的点,使用曲率进行调整 + return midY + d.curvature * dr * ny * 0.5; + }, dragBehavior() { return d3.drag() @@ -578,4 +960,43 @@ export default { background-color: #f2f6f9; } +/* 为连接线添加交互样式 */ +#graph-container line { + transition: stroke-width 0.2s; +} + +#graph-container line:hover { + stroke-width: 3px; + cursor: pointer; +} + +/* 添加美化详情卡片的样式 */ +.fact-detail-card :deep(.v-card-title) { + border-bottom-left-radius: 0; + border-bottom-right-radius: 0; +} + +.fact-detail-card :deep(.metadata-table) { + border-radius: 8px; + overflow: hidden; +} + +.fact-detail-card :deep(.v-table) { + background: transparent; +} + +.fact-detail-card :deep(.v-table th) { + color: var(--v-primary-base); + font-weight: bold; + background-color: rgba(var(--v-theme-primary), 0.05); +} + +.fact-detail-card :deep(pre) { + background-color: #f5f5f5; + padding: 8px; + border-radius: 4px; + max-height: 150px; + overflow: auto; + font-size: 12px; +} diff --git a/dashboard/src/views/dashboards/default/components/MessageStat.vue b/dashboard/src/views/dashboards/default/components/MessageStat.vue index 3f8bdfbed..2a7648613 100644 --- a/dashboard/src/views/dashboards/default/components/MessageStat.vue +++ b/dashboard/src/views/dashboards/default/components/MessageStat.vue @@ -69,6 +69,7 @@