diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 48592b20c..6fa1631ed 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -22,17 +22,19 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +import asyncio import base64 import json import os -import uuid -import asyncio import typing as T +import uuid from enum import Enum + from pydantic.v1 import BaseModel + from astrbot.core import logger -from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 class ComponentType(Enum): @@ -468,7 +470,7 @@ class Node(BaseMessageComponent): uin: T.Optional[str] = "0" # qq号 content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表 seq: T.Optional[T.Union[str, list]] = "" # 忽略 - time: T.Optional[int] = 0 # 忽略 + time: T.Optional[int] = 0 # 忽略 def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_): if isinstance(content, list): @@ -502,10 +504,11 @@ class Nodes(BaseMessageComponent): } for node in self.nodes: d = node.toDict() - d["data"]["uin"] = str(node.uin) # 转为字符串 + d["data"]["uin"] = str(node.uin) # 转为字符串 ret["messages"].append(d) return ret + class Xml(BaseMessageComponent): type: ComponentType = "Xml" data: str @@ -590,11 +593,13 @@ class File(BaseMessageComponent): try: loop = asyncio.get_event_loop() if loop.is_running(): - logger.warning(( - "不可以在异步上下文中同步等待下载! " - "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" - "请使用 await get_file() 代替直接获取 .file 字段" - )) + logger.warning( + ( + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段" + ) + ) return "" else: # 等待下载完成 @@ -620,7 +625,7 @@ class File(BaseMessageComponent): else: self.file_ = value - async def get_file(self, allow_return_url: bool=False) -> str: + async def get_file(self, allow_return_url: bool = False) -> str: """异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间 Args: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 1e71838be..b9139ec3a 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -1,33 +1,31 @@ +import asyncio import os import sys import uuid -import asyncio -import quart -import aiohttp +import quart +from requests import Response +from wechatpy.enterprise import WeChatClient, parse_message +from wechatpy.enterprise.crypto import WeChatCrypto +from wechatpy.enterprise.messages import ImageMessage, TextMessage, VoiceMessage +from wechatpy.exceptions import InvalidSignatureException +from wechatpy.messages import BaseMessage + +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Image, Plain, Record from astrbot.api.platform import ( - Platform, AstrBotMessage, MessageMember, - PlatformMetadata, MessageType, + Platform, + PlatformMetadata, + register_platform_adapter, ) -from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, Image, Record -from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.api.platform import register_platform_adapter from astrbot.core import logger -from requests import Response - -from wechatpy.enterprise.crypto import WeChatCrypto -from wechatpy.enterprise import WeChatClient -from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage -from wechatpy.messages import BaseMessage -from wechatpy.exceptions import InvalidSignatureException -from wechatpy.enterprise import parse_message -from .wecom_event import WecomPlatformEvent +from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from .wecom_event import WecomPlatformEvent from .wecom_kf import WeChatKF from .wecom_kf_message import WeChatKFMessage @@ -299,7 +297,7 @@ class WecomPlatformAdapter(Platform): external_userid = msg.get("external_userid", None) abm = AstrBotMessage() abm.raw_message = msg - abm.raw_message["_wechat_kf_flag"] = None # 方便处理 + abm.raw_message["_wechat_kf_flag"] = None # 方便处理 abm.self_id = msg["open_kfid"] abm.sender = MessageMember(external_userid, external_userid) abm.session_id = external_userid diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 516a9efbd..25d6adfe6 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -2,35 +2,37 @@ 插件的重载、启停、安装、卸载等操作。 """ -import inspect +import asyncio import functools +import inspect +import json +import logging import os import sys -import json import traceback -import yaml -import logging -import asyncio from types import ModuleType from typing import List + +import yaml + +from astrbot.core import logger, pip_installer, sp from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core import logger, sp, pip_installer -from .context import Context -from . import StarMetadata -from .updator import PluginUpdator -from astrbot.core.utils.io import remove_dir -from .star import star_registry, star_map -from .star_handler import star_handlers_registry from astrbot.core.provider.register import llm_tools from astrbot.core.utils.astrbot_path import ( - get_astrbot_plugin_path, get_astrbot_config_path, + get_astrbot_plugin_path, ) +from astrbot.core.utils.io import remove_dir -from .filter.permission import PermissionTypeFilter, PermissionType +from . import StarMetadata +from .context import Context +from .filter.permission import PermissionType, PermissionTypeFilter +from .star import star_map, star_registry +from .star_handler import star_handlers_registry +from .updator import PluginUpdator try: - from watchfiles import awatch, PythonFilter + from watchfiles import PythonFilter, awatch except ImportError: if os.getenv("ASTRBOT_RELOAD", "0") == "1": logger.warning("未安装 watchfiles,无法实现插件的热重载。") @@ -138,13 +140,11 @@ class PluginManager: if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists( os.path.join(path, d, d + ".py") ): - modules.append( - { - "pname": d, - "module": module_str, - "module_path": os.path.join(path, d, module_str), - } - ) + modules.append({ + "pname": d, + "module": module_str, + "module_path": os.path.join(path, d, module_str), + }) return modules def _get_plugin_modules(self) -> List[dict]: diff --git a/astrbot/core/utils/path_util.py b/astrbot/core/utils/path_util.py index 034577e49..0d8511f0c 100644 --- a/astrbot/core/utils/path_util.py +++ b/astrbot/core/utils/path_util.py @@ -1,70 +1,72 @@ import os + from astrbot.core import logger -def path_Mapping(mappings, srcPath: str)->str: - """路径映射处理函数。尝试支援 Windows 和 Linux 的路径映射。 - Args: - mappings: 映射规则列表 - srcPath: 原路径 - Returns: - str: 处理后的路径 - """ - for mapping in mappings: - rule = mapping.split(":") - if len(rule) == 2: - from_, to_ = mapping.split(":") - elif len(rule) > 4 or len(rule) == 1: - # 切割后大于4个项目,或者只有1个项目,那肯定是错误的,只能是2,3,4个项目 - logger.warning(f"路径映射规则错误: {mapping}") - continue + +def path_Mapping(mappings, srcPath: str) -> str: + """路径映射处理函数。尝试支援 Windows 和 Linux 的路径映射。 + Args: + mappings: 映射规则列表 + srcPath: 原路径 + Returns: + str: 处理后的路径 + """ + for mapping in mappings: + rule = mapping.split(":") + if len(rule) == 2: + from_, to_ = mapping.split(":") + elif len(rule) > 4 or len(rule) == 1: + # 切割后大于4个项目,或者只有1个项目,那肯定是错误的,只能是2,3,4个项目 + logger.warning(f"路径映射规则错误: {mapping}") + continue + else: + # rule.len == 3 or 4 + if os.path.exists(rule[0] + ":" + rule[1]): + # 前面两个项目合并路径存在,说明是本地Window路径。后面一个或两个项目组成的路径本地大概率无法解析,直接拼接 + from_ = rule[0] + ":" + rule[1] + if len(rule) == 3: + to_ = rule[2] + else: + to_ = rule[2] + ":" + rule[3] else: - # rule.len == 3 or 4 - if(os.path.exists(rule[0]+":"+rule[1])): - # 前面两个项目合并路径存在,说明是本地Window路径。后面一个或两个项目组成的路径本地大概率无法解析,直接拼接 - from_ = rule[0] + ":" + rule[1] - if len(rule) == 3: - to_ = rule[2] - else: - to_ = rule[2] + ":" + rule[3] + # 前面两个项目合并路径不存在,说明第一个项目是本地Linux路径,后面一个或两个项目直接拼接。 + from_ = rule[0] + if len(rule) == 3: + to_ = rule[1] + ":" + rule[2] else: - # 前面两个项目合并路径不存在,说明第一个项目是本地Linux路径,后面一个或两个项目直接拼接。 - from_ = rule[0] - if len(rule) == 3: - to_ = rule[1] + ":" + rule[2] - else: - # 这种情况下存在四个项目,说明规则也是错误的 - logger.warning(f"路径映射规则错误: {mapping}") - continue + # 这种情况下存在四个项目,说明规则也是错误的 + logger.warning(f"路径映射规则错误: {mapping}") + continue - from_ = from_.removesuffix("/") - from_ = from_.removesuffix("\\") - to_ = to_.removesuffix("/") - to_ = to_.removesuffix("\\") - # logger.debug(f"\t路径映射-规则(处理): {from_} -> {to_}") + from_ = from_.removesuffix("/") + from_ = from_.removesuffix("\\") + to_ = to_.removesuffix("/") + to_ = to_.removesuffix("\\") + # logger.debug(f"\t路径映射-规则(处理): {from_} -> {to_}") - url = srcPath.removeprefix("file://") - if url.startswith(from_): - srcPath = url.replace(from_, to_, 1) - if ":" in srcPath: - # Windows路径处理 - srcPath = srcPath.replace("/", "\\") - else: - has_replaced_processed = False - if srcPath.startswith("."): - # 相对路径处理。如果是相对路径,可能是Linux路径,也可能是Windows路径 - sign = srcPath[1] - # 处理两个点的情况 - if sign == ".": - sign = srcPath[2] - if sign == "/": - srcPath = srcPath.replace("\\", "/") - has_replaced_processed = True - elif sign == "\\": - srcPath = srcPath.replace("/", "\\") - has_replaced_processed = True - if has_replaced_processed == False: - # 如果不是相对路径或不能处理,默认按照Linux路径处理 + url = srcPath.removeprefix("file://") + if url.startswith(from_): + srcPath = url.replace(from_, to_, 1) + if ":" in srcPath: + # Windows路径处理 + srcPath = srcPath.replace("/", "\\") + else: + has_replaced_processed = False + if srcPath.startswith("."): + # 相对路径处理。如果是相对路径,可能是Linux路径,也可能是Windows路径 + sign = srcPath[1] + # 处理两个点的情况 + if sign == ".": + sign = srcPath[2] + if sign == "/": srcPath = srcPath.replace("\\", "/") - logger.info(f"路径映射: {url} -> {srcPath}") - return srcPath - return srcPath \ No newline at end of file + has_replaced_processed = True + elif sign == "\\": + srcPath = srcPath.replace("/", "\\") + has_replaced_processed = True + if not has_replaced_processed: + # 如果不是相对路径或不能处理,默认按照Linux路径处理 + srcPath = srcPath.replace("\\", "/") + logger.info(f"路径映射: {url} -> {srcPath}") + return srcPath + return srcPath diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 6dd093546..f8cd9c8f6 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -1,11 +1,14 @@ -import os import json -import aiohttp +import os import traceback -from .route import Route, Response, RouteContext + +import aiohttp from quart import request -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + from astrbot.core import logger +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + +from .route import Response, Route, RouteContext DEFAULT_MCP_CONFIG = {"mcpServers": {}} @@ -130,13 +133,11 @@ class ToolsRoute(Route): if self.save_mcp_config(config): # 动态初始化新MCP客户端 - await self.tool_mgr.mcp_service_queue.put( - { - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - } - ) + await self.tool_mgr.mcp_service_queue.put({ + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + }) return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ else: return Response().error("保存配置失败").__dict__ @@ -194,37 +195,29 @@ class ToolsRoute(Route): if active: # 如果要激活服务器或者配置已更改 if name in self.tool_mgr.mcp_client_dict or not only_update_active: - await self.tool_mgr.mcp_service_queue.put( - { - "type": "terminate", - "name": name, - } - ) - await self.tool_mgr.mcp_service_queue.put( - { - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - } - ) + await self.tool_mgr.mcp_service_queue.put({ + "type": "terminate", + "name": name, + }) + await self.tool_mgr.mcp_service_queue.put({ + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + }) else: # 客户端不存在,初始化 - await self.tool_mgr.mcp_service_queue.put( - { - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - } - ) + await self.tool_mgr.mcp_service_queue.put({ + "type": "init", + "name": name, + "cfg": config["mcpServers"][name], + }) else: # 如果要停用服务器 if name in self.tool_mgr.mcp_client_dict: - self.tool_mgr.mcp_service_queue.put_nowait( - { - "type": "terminate", - "name": name, - } - ) + self.tool_mgr.mcp_service_queue.put_nowait({ + "type": "terminate", + "name": name, + }) return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__ else: @@ -252,12 +245,10 @@ class ToolsRoute(Route): if self.save_mcp_config(config): # 关闭并删除MCP客户端 if name in self.tool_mgr.mcp_client_dict: - self.tool_mgr.mcp_service_queue.put_nowait( - { - "type": "terminate", - "name": name, - } - ) + self.tool_mgr.mcp_service_queue.put_nowait({ + "type": "terminate", + "name": name, + }) return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__ else: @@ -269,9 +260,11 @@ class ToolsRoute(Route): async def get_mcp_markets(self): page = request.args.get("page", 1, type=int) page_size = request.args.get("page_size", 10, type=int) - BASE_URL = "https://api.soulter.top/astrbot/mcpservers?page={}&page_size={}".format( - page, - page_size, + BASE_URL = ( + "https://api.soulter.top/astrbot/mcpservers?page={}&page_size={}".format( + page, + page_size, + ) ) try: async with aiohttp.ClientSession() as session: @@ -287,4 +280,4 @@ class ToolsRoute(Route): ) except Exception as _: logger.error(traceback.format_exc()) - return Response().error("获取市场数据失败").__dict__ \ No newline at end of file + return Response().error("获取市场数据失败").__dict__