diff --git a/.github/workflows/auto_release.yml b/.github/workflows/auto_release.yml index 146037529..e6f84e716 100644 --- a/.github/workflows/auto_release.yml +++ b/.github/workflows/auto_release.yml @@ -23,6 +23,33 @@ jobs: echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV echo ${{ github.ref_name }} > dist/assets/version zip -r dist.zip dist + + - name: Upload to Cloudflare R2 + env: + R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} + R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} + R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + R2_BUCKET_NAME: "astrbot" + R2_OBJECT_NAME: "astrbot-webui-latest.zip" + VERSION_TAG: ${{ github.ref_name }} + run: | + echo "Installing rclone..." + curl https://rclone.org/install.sh | sudo bash + + echo "Configuring rclone remote..." + mkdir -p ~/.config/rclone + cat < ~/.config/rclone/rclone.conf + [r2] + type = s3 + provider = Cloudflare + access_key_id = $R2_ACCESS_KEY_ID + secret_access_key = $R2_SECRET_ACCESS_KEY + endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com + EOF + + echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME" + rclone copy dashboard/dist.zip r2:$R2_BUCKET_NAME/$R2_OBJECT_NAME --progress + rclone copy dashboard/dist.zip r2:$R2_BUCKET_NAME/astrbot-webui-${VERSION_TAG}.zip --progress - name: Fetch Changelog run: | diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index bae1f37d4..dbc779c95 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -11,24 +11,42 @@ jobs: runs-on: ubuntu-latest steps: - - name: 拉取源码 + - name: Pull The Codes uses: actions/checkout@v3 with: - fetch-depth: 1 + fetch-depth: 0 # Must be 0 so we can fetch tags - - name: 设置 QEMU + - name: Get latest tag (only on manual trigger) + id: get-latest-tag + if: github.event_name == 'workflow_dispatch' + run: | + tag=$(git describe --tags --abbrev=0) + echo "latest_tag=$tag" >> $GITHUB_OUTPUT + + - name: Checkout to latest tag (only on manual trigger) + if: github.event_name == 'workflow_dispatch' + run: git checkout ${{ steps.get-latest-tag.outputs.latest_tag }} + + - name: Set QEMU uses: docker/setup-qemu-action@v3 - - name: 设置 Docker Buildx + - name: Set Docker Buildx uses: docker/setup-buildx-action@v3 - - name: 登录到 DockerHub + - name: Log in to DockerHub uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_PASSWORD }} - - name: 构建和推送 Docker hub + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: Soulter + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and Push Docker to DockerHub and Github GHCR uses: docker/build-push-action@v6 with: context: . @@ -36,8 +54,9 @@ jobs: push: true tags: | ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest - ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }} + ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }} + ghcr.io/soulter/astrbot:latest + ghcr.io/soulter/astrbot:${{ github.event_name == 'workflow_dispatch' && steps.get-latest-tag.outputs.latest_tag || github.ref_name }} - name: Post build notifications run: echo "Docker image has been built and pushed successfully" - diff --git a/README.md b/README.md index 2bc168358..5846ebebc 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,15 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用 ## ✨ 近期更新 -1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器! +
1. AstrBot 现已自带知识库能力 + + 📚 详见[文档](https://astrbot.app/use/knowledge-base.html) + + ![image](https://github.com/user-attachments/assets/28b639b0-bb5c-4958-8e94-92ae8cfd1ab4) + +
+ +2. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器! ## ✨ 主要功能 @@ -171,7 +179,6 @@ pre-commit install - Star 这个项目! - 在[爱发电](https://afdian.com/a/soulter)支持我! -- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~ ## ✨ Demo diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 1ee0fac7f..98794c9dd 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -99,6 +99,12 @@ class AstrBotConfig(dict): has_new |= self.check_config_integrity( value, conf[key], path + "." + key if path else key ) + for key in list(conf.keys()): + if key not in refer_conf: + path_ = path + "." + key if path else key + logger.info(f"检查到配置项 {path_} 不存在,将从当前配置中删除") + del conf[key] + has_new = True return has_new def save_config(self, replace_config: Dict = None): diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 19960269a..6af27337b 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -862,8 +862,48 @@ CONFIG_METADATA_2 = { "api_base": "https://openspeech.bytedance.com/api/v1/tts", "timeout": 20, }, + "OpenAI Embedding": { + "id": "openai_embedding", + "type": "openai_embedding", + "provider_type": "embedding", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "", + "embedding_model": "", + "embedding_dimensions": 1536, + "timeout": 20, + }, + "Gemini Embedding": { + "id": "gemini_embedding", + "type": "gemini_embedding", + "provider_type": "embedding", + "enable": True, + "embedding_api_key": "", + "embedding_api_base": "", + "embedding_model": "gemini-embedding-exp-03-07", + "embedding_dimensions": 768, + "timeout": 20, + }, }, "items": { + "embedding_dimensions": { + "description": "嵌入维度", + "type": "int", + "hint": "嵌入向量的维度。根据模型不同,可能需要调整,请参考具体模型的文档。此配置项请务必填写正确,否则将导致向量数据库无法正常工作。", + }, + "embedding_model": { + "description": "嵌入模型", + "type": "string", + "hint": "嵌入模型名称。", + }, + "embedding_api_key": { + "description": "API Key", + "type": "string", + }, + "embedding_api_base": { + "description": "API Base URL", + "type": "string", + }, "volcengine_cluster": { "type": "string", "description": "火山引擎集群", diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index c990e9d8d..275926ba9 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -32,6 +32,7 @@ class RespondStage(Stage): Comp.Node: lambda comp: bool(comp.content), # 转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 Comp.File: lambda comp: bool(comp.file_ or comp.url), + Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } async def initialize(self, ctx: PipelineContext): diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 7e7755501..0ad38819e 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -6,9 +6,8 @@ from typing import Optional import aiohttp import websockets - from astrbot import logger -from astrbot.api.message_components import Plain, Image +from astrbot.api.message_components import Plain, Image, At from astrbot.api.platform import Platform, PlatformMetadata from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astrbot_message import ( @@ -22,6 +21,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter from .wechatpadpro_message_event import WeChatPadProMessageEvent +try: + from .xml_data_parser import GeweDataParser +except ImportError as e: + logger.warning( + f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}" + ) + @register_platform_adapter("wechatpadpro", "WeChatPadPro 消息平台适配器") class WeChatPadProAdapter(Platform): @@ -59,6 +65,18 @@ class WeChatPadProAdapter(Platform): ) # 持久化文件路径 self.ws_handle_task = None + # 添加图片消息缓存,用于引用消息处理 + self.cached_images = {} + """缓存图片消息。key是NewMsgId (对应引用消息的svrid),value是图片的base64数据""" + # 设置缓存大小限制,避免内存占用过大 + self.max_image_cache = 50 + + # 添加文本消息缓存,用于引用消息处理 + self.cached_texts = {} + """缓存文本消息。key是NewMsgId (对应引用消息的svrid),value是消息文本内容""" + # 设置文本缓存大小限制 + self.max_text_cache = 100 + async def run(self) -> None: """ 启动平台适配器的运行实例。 @@ -102,7 +120,7 @@ class WeChatPadProAdapter(Platform): logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。") await self.terminate() return - + # 登录成功后,连接 WebSocket 接收消息 self.ws_handle_task = asyncio.create_task(self.connect_websocket()) @@ -161,27 +179,21 @@ class WeChatPadProAdapter(Platform): return True # login_state == 3 为离线状态 elif login_state == 3: - logger.info( - "WeChatPadPro 设备不在线。" - ) + logger.info("WeChatPadPro 设备不在线。") return False else: - logger.error( - f"未知的在线状态: {login_state:}" - ) + logger.error(f"未知的在线状态: {login_state:}") return False # Code == 300 为微信退出状态。 elif response.status == 200 and response_data.get("Code") == 300: - logger.info( - "WeChatPadPro 设备已退出。" - ) + logger.info("WeChatPadPro 设备已退出。") return False else: logger.error( f"检查在线状态失败: {response.status}, {response_data}" ) return False - + except aiohttp.ClientConnectorError as e: logger.error(f"连接到 WeChatPadPro 服务失败: {e}") return False @@ -364,7 +376,9 @@ class WeChatPadProAdapter(Platform): logger.error(f"处理 WebSocket 消息时发生错误: {e}") break except Exception as e: - logger.error(f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。") + logger.error( + f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。" + ) await asyncio.sleep(5) async def handle_websocket_message(self, message: str): @@ -439,7 +453,7 @@ class WeChatPadProAdapter(Platform): ): # 再根据消息类型处理消息内容 await self._process_message_content(abm, raw_message, msg_type, content) - + return abm return None @@ -457,6 +471,7 @@ class WeChatPadProAdapter(Platform): """ if from_user_name == "weixin": return False + at_me = False if "@chatroom" in from_user_name: abm.type = MessageType.GROUP_MESSAGE abm.group_id = from_user_name @@ -478,6 +493,14 @@ class WeChatPadProAdapter(Platform): abm.session_id = f"{from_user_name}_{to_user_name}" else: abm.session_id = from_user_name + + msg_source = raw_message.get("msg_source", "") + if self.wxid in msg_source: + at_me = True + if "在群聊中@了你" in raw_message.get("push_content", ""): + at_me = True + if at_me: + abm.message.insert(0, At(qq=abm.self_id, name="")) else: abm.type = MessageType.FRIEND_MESSAGE abm.group_id = "" @@ -575,6 +598,25 @@ class WeChatPadProAdapter(Platform): abm.message.append(Plain(abm.message_str)) else: # 私聊消息 abm.message.append(Plain(abm.message_str)) + + # 缓存文本消息,以便引用消息可以查找 + try: + # 获取msg_id作为缓存的key + new_msg_id = raw_message.get("new_msg_id") + if new_msg_id: + # 限制缓存大小 + if ( + len(self.cached_texts) >= self.max_text_cache + and self.cached_texts + ): + # 删除最早的一条缓存 + oldest_key = next(iter(self.cached_texts)) + self.cached_texts.pop(oldest_key) + + logger.debug(f"缓存文本消息,new_msg_id={new_msg_id}") + self.cached_texts[str(new_msg_id)] = content + except Exception as e: + logger.error(f"缓存文本消息失败: {e}") elif msg_type == 3: # 图片消息 from_user_name = raw_message.get("from_user_name", {}).get("str", "") @@ -588,15 +630,57 @@ class WeChatPadProAdapter(Platform): ) if image_bs64_data: abm.message.append(Image.fromBase64(image_bs64_data)) + # 缓存图片,以便引用消息可以查找 + try: + # 获取msg_id作为缓存的key + new_msg_id = raw_message.get("new_msg_id") + if new_msg_id: + # 限制缓存大小 + if ( + len(self.cached_images) >= self.max_image_cache + and self.cached_images + ): + # 删除最早的一条缓存 + oldest_key = next(iter(self.cached_images)) + self.cached_images.pop(oldest_key) + + logger.debug(f"缓存图片消息,new_msg_id={new_msg_id}") + self.cached_images[str(new_msg_id)] = image_bs64_data + except Exception as e: + logger.error(f"缓存图片消息失败: {e}") elif msg_type == 47: # 视频消息 (注意:表情消息也是 47,需要区分) - logger.warning("收到视频消息,待实现。") + data_parser = GeweDataParser( + content=content, + is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), + raw_message=raw_message, + ) + emoji_message = data_parser.parse_emoji() + if emoji_message is not None: + abm.message.append(emoji_message) elif msg_type == 50: # 语音/视频 logger.warning("收到语音/视频消息,待实现。") elif msg_type == 49: - # 引用消息 - logger.warning("收到引用消息,待实现。") + try: + parser = GeweDataParser( + content=content, + is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), + cached_texts=self.cached_texts, + cached_images=self.cached_images, + raw_message=raw_message, + downloader=self._download_raw_image, + ) + components = await parser.parse_mutil_49() + if components: + abm.message.extend(components) + abm.message_str = "\n".join( + c.text for c in components if isinstance(c, Plain) + ) + except Exception as e: + logger.warning(f"msg_type 49 处理失败: {e}") + abm.message.append(Plain("[XML 消息处理失败]")) + abm.message_str = "[XML 消息处理失败]" else: logger.warning(f"收到未处理的消息类型: {msg_type}。") diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 04bb02936..3c37be345 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -7,7 +7,7 @@ import aiohttp from PIL import Image as PILImage # 使用别名避免冲突 from astrbot import logger -from astrbot.core.message.components import Image, Plain # Import Image +from astrbot.core.message.components import Image, Plain, WechatEmoji # Import Image from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType @@ -38,6 +38,8 @@ class WeChatPadProMessageEvent(AstrMessageEvent): await self._send_text(session, comp.text) elif isinstance(comp, Image): await self._send_image(session, comp) + elif isinstance(comp, WechatEmoji): + await self._send_emoji(session, comp) await super().send(message) async def _send_image(self, session: aiohttp.ClientSession, comp: Image): @@ -73,12 +75,29 @@ class WeChatPadProMessageEvent(AstrMessageEvent): message_text = text payload = { "MsgItem": [ - {"MsgType": 1, "TextContent": message_text, "ToUserName": self.session_id} + { + "MsgType": 1, + "TextContent": message_text, + "ToUserName": self.session_id, + } ] } url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) + async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): + payload = { + "EmojiList": [ + { + "EmojiMd5": comp.md5, + "EmojiSize": comp.md5_len, + "ToUserName": self.session_id, + } + ] + } + url = f"{self.adapter.base_url}/message/SendEmojiMessage" + await self._post(session, url, payload) + @staticmethod def _validate_base64(b64: str) -> bytes: return base64.b64decode(b64, validate=True) diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py new file mode 100644 index 000000000..054ca1b48 --- /dev/null +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -0,0 +1,160 @@ +from defusedxml import ElementTree as eT +from astrbot.api import logger +from astrbot.api.message_components import ( + WechatEmoji as Emoji, + Plain, + Image, + BaseMessageComponent, +) + + +class GeweDataParser: + def __init__( + self, + content: str, + is_private_chat: bool = False, + cached_texts=None, + cached_images=None, + raw_message: dict = None, + downloader=None, + ): + self._xml = None + self.content = content + self.is_private_chat = is_private_chat + self.cached_texts = cached_texts or {} + self.cached_images = cached_images or {} + self.downloader = downloader + + raw_message = raw_message or {} + self.from_user_name = raw_message.get("from_user_name", {}).get("str", "") + self.to_user_name = raw_message.get("to_user_name", {}).get("str", "") + self.msg_id = raw_message.get("msg_id", "") + + def _format_to_xml(self): + if self._xml: + return self._xml + + try: + msg_str = self.content + if not self.is_private_chat: + parts = self.content.split(":\n", 1) + msg_str = parts[1] if len(parts) == 2 else self.content + + self._xml = eT.fromstring(msg_str) + return self._xml + except Exception as e: + logger.error(f"[XML解析失败] {e}") + raise + + async def parse_mutil_49(self) -> list[BaseMessageComponent] | None: + """ + 处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57) + """ + try: + appmsg_type = self._format_to_xml().findtext(".//appmsg/type") + if appmsg_type == "57": + return await self.parse_reply() + except Exception as e: + logger.warning(f"[parse_mutil_49] 解析失败: {e}") + return None + + async def parse_reply(self) -> list[BaseMessageComponent]: + """ + 处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49) + """ + components = [] + + try: + appmsg = self._format_to_xml().find("appmsg") + if appmsg is None: + return [Plain("[引用消息解析失败]")] + + refermsg = appmsg.find("refermsg") + if refermsg is None: + return [Plain("[引用消息解析失败]")] + + quote_type = int(refermsg.findtext("type", "0")) + nickname = refermsg.findtext("displayname", "未知发送者") + quote_content = refermsg.findtext("content", "") + svrid = refermsg.findtext("svrid") + + match quote_type: + case 1: # 文本引用 + quoted_text = self.cached_texts.get(str(svrid), quote_content) + components.append(Plain(f"[引用] {nickname}: {quoted_text}")) + + case 3: # 图片引用 + quoted_image_b64 = self.cached_images.get(str(svrid)) + if not quoted_image_b64: + try: + quote_xml = eT.fromstring(quote_content) + img = quote_xml.find("img") + cdn_url = ( + img.get("cdnbigimgurl") or img.get("cdnmidimgurl") + if img is not None + else None + ) + if cdn_url and self.downloader: + image_resp = await self.downloader( + self.from_user_name, self.to_user_name, self.msg_id + ) + quoted_image_b64 = ( + image_resp.get("Data", {}) + .get("Data", {}) + .get("Buffer") + ) + except Exception as e: + logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}") + + if quoted_image_b64: + components.extend( + [ + Image.fromBase64(quoted_image_b64), + Plain(f"[引用] {nickname}: [引用的图片]"), + ] + ) + else: + components.append( + Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]") + ) + + case 49: # 嵌套引用 + try: + nested_root = eT.fromstring(quote_content) + nested_title = nested_root.findtext(".//appmsg/title", "") + components.append(Plain(f"[引用] {nickname}: {nested_title}")) + except Exception as e: + logger.warning(f"[嵌套引用解析失败] err={e}") + components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]")) + + case _: # 其他未识别类型 + logger.info(f"[未知引用类型] quote_type={quote_type}") + components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]")) + + # 主消息标题 + title = appmsg.findtext("title", "") + if title: + components.append(Plain(title)) + + except Exception as e: + logger.error(f"[parse_reply] 总体解析失败: {e}") + return [Plain("[引用消息解析失败]")] + + return components + + def parse_emoji(self) -> Emoji | None: + """ + 处理 msg_type == 47 的表情消息(emoji) + """ + try: + emoji_element = self._format_to_xml().find(".//emoji") + if emoji_element is not None: + return Emoji( + md5=emoji_element.get("md5"), + md5_len=emoji_element.get("len"), + cdnurl=emoji_element.get("cdnurl"), + ) + except Exception as e: + logger.error(f"[parse_emoji] 解析失败: {e}") + + return None diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 6ad67da55..e01e46cf9 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -19,6 +19,7 @@ class ProviderType(enum.Enum): CHAT_COMPLETION = "chat_completion" SPEECH_TO_TEXT = "speech_to_text" TEXT_TO_SPEECH = "text_to_speech" + EMBEDDING = "embedding" @dataclass @@ -155,7 +156,9 @@ class ProviderRequest: if self.image_urls: user_content = { "role": "user", - "content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}], + "content": [ + {"type": "text", "text": self.prompt if self.prompt else "[图片]"} + ], } for image_url in self.image_urls: if image_url.startswith("http"): diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 78337ce95..edfd9f581 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -98,6 +98,8 @@ class ProviderManager: """加载的 Speech To Text Provider 的实例""" self.tts_provider_insts: List[TTSProvider] = [] """加载的 Text To Speech Provider 的实例""" + self.embedding_provider_insts: List[Provider] = [] + """加载的 Embedding Provider 的实例""" self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools @@ -211,6 +213,10 @@ class ProviderManager: from .sources.volcengine_tts import ( ProviderVolcengineTTS as ProviderVolcengineTTS, ) + case "openai_embedding": + from .sources.openai_embedding_source import ( + OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, + ) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。" @@ -290,6 +296,14 @@ class ProviderManager: if not self.curr_provider_inst: self.curr_provider_inst = inst + elif provider_metadata.provider_type == ProviderType.EMBEDDING: + inst = provider_metadata.cls_type( + provider_config, self.provider_settings + ) + if getattr(inst, "initialize", None): + await inst.initialize() + self.embedding_provider_insts.append(inst) + self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error(traceback.format_exc()) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 7019113c7..c285ebd42 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -192,6 +192,11 @@ class EmbeddingProvider(AbstractProvider): """获取文本的向量""" ... + @abc.abstractmethod + async def get_embeddings(self, text: list[str]) -> list[list[float]]: + """批量获取文本的向量""" + ... + @abc.abstractmethod def get_dim(self) -> int: """获取向量的维度""" diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py new file mode 100644 index 000000000..baccf52a2 --- /dev/null +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -0,0 +1,63 @@ +from google import genai +from google.genai import types +from google.genai.errors import APIError +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter +from ..entities import ProviderType + + +@register_provider_adapter( + "gemini_embedding", + "Google Gemini Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, +) +class GeminiEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + + api_key: str = provider_config.get("embedding_api_key") + api_base: str = provider_config.get("embedding_api_base", None) + timeout: int = int(provider_config.get("timeout", 20)) + + http_options = types.HttpOptions(timeout=timeout * 1000) + if api_base: + if api_base.endswith("/"): + api_base = api_base[:-1] + http_options.base_url = api_base + + self.client = genai.Client(api_key=api_key, http_options=http_options).aio + + self.model = provider_config.get( + "embedding_model", "gemini-embedding-exp-03-07" + ) + self.dimension = provider_config.get("embedding_dimensions", 768) + + async def get_embedding(self, text: str) -> list[float]: + """ + 获取文本的嵌入 + """ + try: + result = await self.client.models.embed_content( + model=self.model, contents=text + ) + return result.embeddings[0].values + except APIError as e: + raise Exception(f"Gemini Embedding API请求失败: {e.message}") + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + """ + 批量获取文本的嵌入 + """ + try: + result = await self.client.models.embed_content( + model=self.model, contents=texts + ) + return [embedding.values for embedding in result.embeddings] + except APIError as e: + raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") + + def get_dim(self) -> int: + """获取向量的维度""" + return self.dimension diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py new file mode 100644 index 000000000..f43152473 --- /dev/null +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -0,0 +1,43 @@ +from openai import AsyncOpenAI +from ..provider import EmbeddingProvider +from ..register import register_provider_adapter +from ..entities import ProviderType + + +@register_provider_adapter( + "openai_embedding", + "OpenAI API Embedding 提供商适配器", + provider_type=ProviderType.EMBEDDING, +) +class OpenAIEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.provider_config = provider_config + self.provider_settings = provider_settings + self.client = AsyncOpenAI( + api_key=provider_config.get("embedding_api_key"), + base_url=provider_config.get( + "embedding_api_base", "https://api.openai.com/v1" + ), + timeout=int(provider_config.get("timeout", 20)), + ) + self.model = provider_config.get("embedding_model", "text-embedding-3-small") + self.dimension = provider_config.get("embedding_dimensions", 1536) + + async def get_embedding(self, text: str) -> list[float]: + """ + 获取文本的嵌入 + """ + embedding = await self.client.embeddings.create(input=text, model=self.model) + return embedding.data[0].embedding + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + """ + 批量获取文本的嵌入 + """ + embeddings = await self.client.embeddings.create(input=texts, model=self.model) + return [item.embedding for item in embeddings.data] + + def get_dim(self) -> int: + """获取向量的维度""" + return self.dimension diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 996b8ae5e..880b0c72c 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -125,11 +125,8 @@ class Context: self.provider_manager.provider_insts.append(provider) def get_provider_by_id(self, provider_id: str) -> Provider: - """通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" - for provider in self.provider_manager.provider_insts: - if provider.meta().id == provider_id: - return provider - return None + """通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。""" + return self.provider_manager.inst_map.get(provider_id) def get_all_providers(self) -> List[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index ff07de712..13c93f226 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -166,7 +166,7 @@ class PluginManager: plugins.extend(_p) return plugins - def _check_plugin_dept_update(self, target_plugin: str = None): + async def _check_plugin_dept_update(self, target_plugin: str = None): """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -185,7 +185,7 @@ class PluginManager: pth = os.path.join(plugin_path, "requirements.txt") logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}") try: - pip_installer.install(requirements_path=pth) + await pip_installer.install(requirements_path=pth) except Exception as e: logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") @@ -407,7 +407,7 @@ class PluginManager: module = __import__(path, fromlist=[module_str]) except (ModuleNotFoundError, ImportError): # 尝试安装依赖 - self._check_plugin_dept_update(target_plugin=root_dir_name) + await self._check_plugin_dept_update(target_plugin=root_dir_name) module = __import__(path, fromlist=[module_str]) except Exception as e: logger.error(traceback.format_exc()) @@ -451,11 +451,11 @@ class PluginManager: metadata.repo = metadata_yaml.repo except Exception: pass - + metadata.config = plugin_config if path not in inactivated_plugins: # 只有没有禁用插件时才实例化插件类 if plugin_config: - metadata.config = plugin_config + # metadata.config = plugin_config try: metadata.star_cls = metadata.star_cls_type( context=self.context, config=plugin_config diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 0163b11b4..a7c04d3d9 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -1,5 +1,5 @@ import logging -from pip import main as pip_main +import asyncio logger = logging.getLogger("astrbot") @@ -9,7 +9,7 @@ class PipInstaller: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url - def install( + async def install( self, package_name: str = None, requirements_path: str = None, @@ -29,12 +29,29 @@ class PipInstaller: args.extend(self.pip_install_arg.split()) logger.info(f"Pip 包管理器: pip {' '.join(args)}") + try: + process = await asyncio.create_subprocess_exec( + "pip", *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) - result_code = pip_main(args) + assert process.stdout is not None + async for line in process.stdout: + logger.info(line.decode().strip()) - # 清除 pip.main 导致的多余的 logging handlers - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) + await process.wait() - if result_code != 0: - raise Exception(f"安装失败,错误码:{result_code}") + if process.returncode != 0: + raise Exception(f"安装失败,错误码:{process.returncode}") + except FileNotFoundError: + # 没有 pip + from pip import main as pip_main + result_code = await asyncio.to_thread(pip_main, args) + + # 清除 pip.main 导致的多余的 logging handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + if result_code != 0: + raise Exception(f"安装失败,错误码:{result_code}") diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 6c95aee33..f33929c5c 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -9,7 +9,7 @@ from astrbot.core.platform.register import platform_registry from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry from astrbot.core import logger -import asyncio # 用于并发执行获取供应商请求 +import asyncio def try_cast(value: str, type_: str): @@ -166,15 +166,18 @@ class ConfigRoute(Route): "/config/provider/delete": ("POST", self.post_delete_provider), "/config/llmtools": ("GET", self.get_llm_tools), "/config/provider/check_status": ("GET", self.check_all_providers_status), + "/config/provider/list": ("GET", self.get_provider_config_list), } self.register_routes() async def _test_single_provider(self, provider): """辅助函数:测试单个 provider 的可用性""" meta = provider.meta() - # 使用更简洁的回退逻辑获取provider_name - provider_name = provider.provider_config.get("name") or getattr(meta, 'id', 'Unknown Provider') - + provider_name = provider.provider_config.get("id", "Unknown Provider") + if not provider_name and meta: + provider_name = meta.id + elif not provider_name: + provider_name = "Unknown Provider" status_info = { "id": getattr(meta, 'id', 'Unknown ID'), "model": getattr(meta, 'model', 'Unknown Model'), @@ -186,7 +189,7 @@ class ConfigRoute(Route): logger.debug(f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})") try: logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") - response = await asyncio.wait_for(provider.text_chat(prompt="Ping"), timeout=20.0) #超时二十秒 + response = await asyncio.wait_for(provider.text_chat(prompt="Ping"), timeout=20.0) # 超时 20 秒 logger.debug(f"Received response from {status_info['name']}: {response}") # 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用 if response is not None: @@ -248,6 +251,17 @@ class ConfigRoute(Route): return Response().ok(await self._get_astrbot_config()).__dict__ return Response().ok(await self._get_plugin_config(plugin_name)).__dict__ + async def get_provider_config_list(self): + provider_type = request.args.get("provider_type", None) + if not provider_type: + return Response().error("缺少参数 provider_type").__dict__ + provider_list = [] + astrbot_config = self.core_lifecycle.astrbot_config + for provider in astrbot_config["provider"]: + if provider.get("provider_type", None) == provider_type: + provider_list.append(provider) + return Response().ok(provider_list).__dict__ + async def post_astrbot_configs(self): post_configs = await request.json try: diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index f99110530..a8cf34c95 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -23,6 +23,7 @@ class LogRoute(Route): **message, # see astrbot/core/log.py } yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + await asyncio.sleep(0.07) # 控制发送频率,避免过快 except asyncio.CancelledError: pass except BaseException as e: diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 36a582bee..3d3d0ca51 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -13,6 +13,9 @@ class StaticFileRoute(Route): "/extension", "/dashboard/default", "/alkaid", + "/alkaid/knowledge-base", + "/alkaid/long-term-memory", + "/alkaid/other", "/console", "/chat", "/settings", diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index 44adf2591..f88e9a208 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -91,7 +91,7 @@ class UpdateRoute(Route): # pip 更新依赖 logger.info("更新依赖中...") try: - pip_installer.install(requirements_path="requirements.txt") + await pip_installer.install(requirements_path="requirements.txt") except Exception as e: logger.error(f"更新依赖失败: {e}") @@ -140,7 +140,7 @@ class UpdateRoute(Route): if not package: return Response().error("缺少参数 package 或不合法。").__dict__ try: - pip_installer.install(package, mirror=mirror) + await pip_installer.install(package, mirror=mirror) return Response().ok(None, "安装成功。").__dict__ except Exception as e: logger.error(f"/api/update_pip: {traceback.format_exc()}") diff --git a/changelogs/v3.5.13.md b/changelogs/v3.5.13.md index 1a674f1f8..7ff838da4 100644 --- a/changelogs/v3.5.13.md +++ b/changelogs/v3.5.13.md @@ -1,6 +1,9 @@ # What's Changed -1. 新增:WebUI 支持暗夜模式 -2. 修复:修复 WebUI Chat 接口的未授权访问安全漏洞、插件 README 可能存在的 XSS 注入漏洞 -3. 优化:优化 Vec DB 在 indexing 过程时的数据库事务处理 -4. 修复:WebUI 下,插件市场的推荐卡片无法点击帮助文档的问题 +1. 新增:WebUI 支持暗夜模式。 +2. 修复:修复 WebUI Chat 接口的未授权访问安全漏洞、插件 README 可能存在的 XSS 注入漏洞。 +3. 优化:优化 Vec DB 在 indexing 过程时的数据库事务处理。 +4. 修复:WebUI 下,插件市场的推荐卡片无法点击帮助文档的问题。 +5. 新增:知识库。 +6. 新增:WebUI 提供商测试功能,一键检测可用性。 +7. 新增:WebUI 提供商分类功能,按能力分类提供商。 diff --git a/dashboard/src/components/shared/ConsoleDisplayer.vue b/dashboard/src/components/shared/ConsoleDisplayer.vue index e780d6569..a05c6eea7 100644 --- a/dashboard/src/components/shared/ConsoleDisplayer.vue +++ b/dashboard/src/components/shared/ConsoleDisplayer.vue @@ -5,7 +5,7 @@ import { useCommonStore } from '@/stores/common';