diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml new file mode 100644 index 000000000..e4bc932ce --- /dev/null +++ b/.github/workflows/code-format.yml @@ -0,0 +1,34 @@ +name: Code Format Check + +on: + pull_request: + branches: [ master ] + push: + branches: [ master ] + +jobs: + format-check: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install UV + run: pip install uv + + - name: Install dependencies + run: uv sync + + - name: Check code formatting with ruff + run: | + uv run ruff format --check . + + - name: Check code style with ruff + run: | + uv run ruff check . \ No newline at end of file diff --git a/.github/workflows/dashboard_ci.yml b/.github/workflows/dashboard_ci.yml index 3c96297e8..9299dd6f5 100644 --- a/.github/workflows/dashboard_ci.yml +++ b/.github/workflows/dashboard_ci.yml @@ -37,6 +37,7 @@ jobs: !dist/**/*.md - name: Create GitHub Release + if: github.event_name == 'push' uses: ncipollo/release-action@v1 with: tag: release-${{ github.sha }} diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index 5b9f14d32..cd1fcd97b 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -124,15 +124,17 @@ def build_plug_list(plugins_dir: Path) -> list: if metadata and all( k in metadata for k in ["name", "desc", "version", "author", "repo"] ): - result.append({ - "name": str(metadata.get("name", "")), - "desc": str(metadata.get("desc", "")), - "version": str(metadata.get("version", "")), - "author": str(metadata.get("author", "")), - "repo": str(metadata.get("repo", "")), - "status": PluginStatus.INSTALLED, - "local_path": str(plugin_dir), - }) + result.append( + { + "name": str(metadata.get("name", "")), + "desc": str(metadata.get("desc", "")), + "version": str(metadata.get("version", "")), + "author": str(metadata.get("author", "")), + "repo": str(metadata.get("repo", "")), + "status": PluginStatus.INSTALLED, + "local_path": str(plugin_dir), + } + ) # 获取在线插件列表 online_plugins = [] @@ -142,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list: resp.raise_for_status() data = resp.json() for plugin_id, plugin_info in data.items(): - online_plugins.append({ - "name": str(plugin_id), - "desc": str(plugin_info.get("desc", "")), - "version": str(plugin_info.get("version", "")), - "author": str(plugin_info.get("author", "")), - "repo": str(plugin_info.get("repo", "")), - "status": PluginStatus.NOT_INSTALLED, - "local_path": None, - }) + online_plugins.append( + { + "name": str(plugin_id), + "desc": str(plugin_info.get("desc", "")), + "version": str(plugin_info.get("version", "")), + "author": str(plugin_info.get("author", "")), + "repo": str(plugin_info.get("repo", "")), + "status": PluginStatus.NOT_INSTALLED, + "local_path": None, + } + ) except Exception as e: click.echo(f"获取在线插件列表失败: {e}", err=True) diff --git a/astrbot/core/agent/response.py b/astrbot/core/agent/response.py index 3f683a233..8eb1854f6 100644 --- a/astrbot/core/agent/response.py +++ b/astrbot/core/agent/response.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import typing as T from astrbot.core.message.message_event_result import MessageChain + class AgentResponseData(T.TypedDict): chain: MessageChain diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index 58ea2ca43..a0febf8c9 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -14,4 +14,5 @@ class ContextWrapper(Generic[TContext]): context: TContext event: AstrMessageEvent + NoContext = ContextWrapper[None] diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index 796a7b336..901cdc4ed 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -53,7 +53,7 @@ async def do_migration_v4( await migration_webchat_data(db_helper, platform_id_map) # 执行偏好设置迁移 - await migration_preferences(db_helper,platform_id_map) + await migration_preferences(db_helper, platform_id_map) # 执行平台统计表迁移 await migration_platform_table(db_helper, platform_id_map) diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index dda2cbcaf..6a661bd3d 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -5,6 +5,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path _VT = TypeVar("_VT") + class SharedPreferences: def __init__(self, path=None): if path is None: @@ -42,4 +43,5 @@ class SharedPreferences: self._data.clear() self._save_preferences() + sp = SharedPreferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index e7e734abd..ad86c51f3 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -4,6 +4,7 @@ from astrbot.core.db.po import Platform, Stats from typing import Tuple, List, Dict, Any from dataclasses import dataclass + @dataclass class Conversation: """LLM 对话存储 @@ -76,7 +77,7 @@ PRAGMA encoding = 'UTF-8'; """ -class SQLiteDatabase(): +class SQLiteDatabase: def __init__(self, db_path: str) -> None: super().__init__() self.db_path = db_path diff --git a/astrbot/core/db/vec_db/faiss_impl/__init__.py b/astrbot/core/db/vec_db/faiss_impl/__init__.py index 11fc79d60..41f60466c 100644 --- a/astrbot/core/db/vec_db/faiss_impl/__init__.py +++ b/astrbot/core/db/vec_db/faiss_impl/__init__.py @@ -1,3 +1,3 @@ from .vec_db import FaissVecDB -__all__ = ["FaissVecDB"] \ No newline at end of file +__all__ = ["FaissVecDB"] diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index bc23922ef..7c2ae1c01 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -113,7 +113,8 @@ class FaissVecDB(BaseVecDB): reranked_results, key=lambda x: x.relevance_score, reverse=True ) top_k_results = [ - top_k_results[reranked_result.index] for reranked_result in reranked_results + top_k_results[reranked_result.index] + for reranked_result in reranked_results ] return top_k_results diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 02e87e6d0..27b47cbe3 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -77,7 +77,7 @@ async def call_event_hook( Returns: bool: 如果事件被终止,返回 True - # """ + #""" handlers = star_handlers_registry.get_handlers_by_event_type( hook_type, plugins_name=event.plugins_name ) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 75ea317ad..c38dfddef 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -24,7 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric from .astrbot_message import AstrBotMessage, Group from .platform_metadata import PlatformMetadata -from .message_session import MessageSession, MessageSesion # noqa +from .message_session import MessageSession, MessageSesion # noqa class AstrMessageEvent(abc.ABC): diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index e1368cf0b..1308c482a 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -68,23 +68,23 @@ class PlatformManager: ) case "lark": from .sources.lark.lark_adapter import ( - LarkPlatformAdapter, - ) # noqa: F401 + LarkPlatformAdapter, # noqa: F401 + ) case "dingtalk": from .sources.dingtalk.dingtalk_adapter import ( DingtalkPlatformAdapter, # noqa: F401 ) case "telegram": from .sources.telegram.tg_adapter import ( - TelegramPlatformAdapter, - ) # noqa: F401 + TelegramPlatformAdapter, # noqa: F401 + ) case "wecom": from .sources.wecom.wecom_adapter import ( - WecomPlatformAdapter, - ) # noqa: F401 + WecomPlatformAdapter, # noqa: F401 + ) case "weixin_official_account": from .sources.weixin_official_account.weixin_offacc_adapter import ( - WeixinOfficialAccountPlatformAdapter, # noqa + WeixinOfficialAccountPlatformAdapter, # noqa: F401 ) case "discord": from .sources.discord.discord_platform_adapter import ( @@ -94,8 +94,8 @@ class PlatformManager: from .sources.slack.slack_adapter import SlackAdapter # noqa: F401 case "satori": from .sources.satori.satori_adapter import ( - SatoriPlatformAdapter, - ) # noqa: F401 + SatoriPlatformAdapter, # noqa: F401 + ) except (ImportError, ModuleNotFoundError) as e: logger.error( f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 9265fff63..7d3702666 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -321,7 +321,9 @@ class AiocqhttpAdapter(Platform): user_id=int(m["data"]["qq"]), no_cache=False, ) - nickname = at_info.get("nick", "") or at_info.get("nickname", "") + nickname = at_info.get("nick", "") or at_info.get( + "nickname", "" + ) is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"} abm.message.append( diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index bded9bfd0..1e6ddd49f 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -54,9 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent): logger.debug(f"send image: {ret}") except Exception as e: - logger.error(f"钉钉图片处理失败: {e}") - logger.warning(f"跳过图片发送: {image_path}") + logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送") continue + async def send(self, message: MessageChain): await self.send_with_client(self.client, message) await super().send(message) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index 2e3fd89b3..78894491f 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot): await self.on_ready_once_callback() except Exception as e: logger.error( - f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True) + f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True + ) def _create_message_data(self, message: discord.Message) -> dict: """从 discord.Message 创建数据字典""" @@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot): message_data = self._create_message_data(message) await self.on_message_received(message_data) - def _extract_interaction_content(self, interaction: discord.Interaction) -> str: """从交互中提取内容""" interaction_type = interaction.type diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index dbeda38ad..07e712161 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent): self.url = url self.disabled = disabled + class DiscordReference(BaseMessageComponent): """Discord引用组件""" + type: str = "discord_reference" + def __init__(self, message_id: str, channel_id: str): self.message_id = message_id self.channel_id = channel_id @@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent): self.components = components or [] self.timeout = timeout - def to_discord_view(self) -> discord.ui.View: """转换为Discord View对象""" view = discord.ui.View(timeout=self.timeout) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index e61a12fcc..2c8d055fc 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent): # 解析消息链为 Discord 所需的对象 try: - content, files, view, embeds, reference_message_id = await self._parse_to_discord(message) + ( + content, + files, + view, + embeds, + reference_message_id, + ) = await self._parse_to_discord(message) except Exception as e: logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True) return @@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent): if await asyncio.to_thread(path.exists): file_bytes = await asyncio.to_thread(path.read_bytes) files.append( - discord.File(BytesIO(file_bytes), - filename=i.name) + discord.File(BytesIO(file_bytes), filename=i.name) ) else: logger.warning( diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index f6284542c..2096237ce 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -94,10 +94,15 @@ class QQOfficialMessageEvent(AstrMessageEvent): plain_text, image_base64, image_path, - record_file_path + record_file_path, ) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer) - if not plain_text and not image_base64 and not image_path and not record_file_path: + if ( + not plain_text + and not image_base64 + and not image_path + and not record_file_path + ): return payload = { @@ -118,7 +123,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - if record_file_path: # group record msg + if record_file_path: # group record msg media = await self.upload_group_and_c2c_record( record_file_path, 3, group_openid=source.group_openid ) @@ -134,9 +139,9 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - if record_file_path: # c2c record + if record_file_path: # c2c record media = await self.upload_group_and_c2c_record( - record_file_path, 3, openid = source.author.user_openid + record_file_path, 3, openid=source.author.user_openid ) payload["media"] = media payload["msg_type"] = 7 @@ -190,58 +195,55 @@ class QQOfficialMessageEvent(AstrMessageEvent): return await self.bot.api._http.request(route, json=payload) async def upload_group_and_c2c_record( - self, - file_source: str, - file_type: int, - srv_send_msg: bool = False, - **kwargs + self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs ) -> Optional[Media]: """ 上传媒体文件 """ # 构建基础payload - payload = { - "file_type": file_type, - "srv_send_msg": srv_send_msg - } - + payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} + # 处理文件数据 if os.path.exists(file_source): # 读取本地文件 - async with aiofiles.open(file_source, 'rb') as f: + async with aiofiles.open(file_source, "rb") as f: file_content = await f.read() # use base64 encode - payload["file_data"] = base64.b64encode(file_content).decode('utf-8') + payload["file_data"] = base64.b64encode(file_content).decode("utf-8") else: # 使用URL payload["url"] = file_source - + # 添加接收者信息和确定路由 if "openid" in kwargs: payload["openid"] = kwargs["openid"] route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) elif "group_openid" in kwargs: - payload["group_openid"] =kwargs["group_openid"] - route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"]) + payload["group_openid"] = kwargs["group_openid"] + route = Route( + "POST", + "/v2/groups/{group_openid}/files", + group_openid=kwargs["group_openid"], + ) else: return None - + try: # 使用底层HTTP请求 result = await self.bot.api._http.request(route, json=payload) - + if result: return Media( file_uuid=result.get("file_uuid"), file_info=result.get("file_info"), ttl=result.get("ttl", 0), - file_id=result.get("id", "") + file_id=result.get("id", ""), ) except Exception as e: logger.error(f"上传请求错误: {e}") - + return None - + async def post_c2c_message( self, openid: str, @@ -286,19 +288,23 @@ class QQOfficialMessageEvent(AstrMessageEvent): image_base64 = image_base64.removeprefix("base64://") elif isinstance(i, Record): if i.file: - record_wav_path = await i.convert_to_file_path() # wav 路径 + record_wav_path = await i.convert_to_file_path() # wav 路径 temp_dir = os.path.join(get_astrbot_data_path(), "temp") - record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk") + record_tecent_silk_path = os.path.join( + temp_dir, f"{uuid.uuid4()}.silk" + ) try: - duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path) + duration = await wav_to_tencent_silk( + record_wav_path, record_tecent_silk_path + ) if duration > 0: record_file_path = record_tecent_silk_path else: - record_file_path = None + record_file_path = None logger.error("转换音频格式时出错:音频时长不大于0") except Exception as e: logger.error(f"处理语音时出错: {e}") - record_file_path = None + record_file_path = None else: logger.debug(f"qq_official 忽略 {i.type}") return plain_text, image_base64, image_file_path, record_file_path diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 07dc1011b..7e75f3c20 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -308,7 +308,9 @@ class SlackAdapter(Platform): base64_content = base64.b64encode(content).decode("utf-8") return base64_content else: - logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}") + logger.error( + f"Failed to download slack file: {resp.status} {await resp.text()}" + ) raise Exception(f"下载文件失败: {resp.status}") async def run(self) -> Awaitable[Any]: diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 9acd61c85..86f9f9764 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent): "text": {"type": "mrkdwn", "text": "文件上传失败"}, } file_url = response["files"][0]["permalink"] - return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}} + return { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"文件: <{file_url}|{segment.name or '文件'}>", + }, + } else: return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}} diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 5b3a1d916..62e1998f7 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -66,7 +66,9 @@ class TelegramPlatformEvent(AstrMessageEvent): return chunks @classmethod - async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str): + async def send_with_client( + cls, client: ExtBot, message: MessageChain, user_name: str + ): image_path = None has_reply = False diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 96e172212..6c365cb3a 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -1,5 +1,6 @@ import asyncio + class WebChatQueueMgr: def __init__(self) -> None: self.queues = {} @@ -30,4 +31,5 @@ class WebChatQueueMgr: """Check if a queue exists for the given conversation ID""" return conversation_id in self.queues + webchat_queue_mgr = WebChatQueueMgr() diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 7d5984416..6b835ecb5 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform): def _extract_auth_key(self, data): """Helper method to extract auth_key from response data.""" if isinstance(data, dict): - auth_keys = data.get("authKeys") # 新接口 + auth_keys = data.get("authKeys") # 新接口 if isinstance(auth_keys, list) and auth_keys: return auth_keys[0] - elif isinstance(data, list) and data: # 旧接口 + elif isinstance(data, list) and data: # 旧接口 return data[0] return None @@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform): try: async with session.post(url, params=params, json=payload) as response: if response.status != 200: - logger.error(f"生成授权码失败: {response.status}, {await response.text()}") + logger.error( + f"生成授权码失败: {response.status}, {await response.text()}" + ) return response_data = await response.json() @@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform): if self.auth_key: logger.info("成功获取授权码") else: - logger.error(f"生成授权码成功但未找到授权码: {response_data}") + logger.error( + f"生成授权码成功但未找到授权码: {response_data}" + ) else: logger.error(f"生成授权码失败: {response_data}") except aiohttp.ClientConnectorError as e: diff --git a/astrbot/core/platform/sources/wecom/wecom_kf.py b/astrbot/core/platform/sources/wecom/wecom_kf.py index 316f6da33..118667975 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf.py @@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI): 注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。 :return: 接口调用结果 """ - data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid} + data = { + "token": token, + "cursor": cursor, + "limit": limit, + "open_kfid": open_kfid, + } return self._post("kf/sync_msg", data=data) def get_service_state(self, open_kfid, external_userid): @@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI): } return self._post("kf/service_state/get", data=data) - def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""): + def trans_service_state( + self, open_kfid, external_userid, service_state, servicer_userid="" + ): """ 变更会话状态 @@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI): """ return self._get("kf/customer/get_upgrade_service_config") - def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None): + def upgrade_service( + self, open_kfid, external_userid, service_type, member=None, groupchat=None + ): """ 为客户升级为专员或客户群服务 @@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI): data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time} return self._post("kf/get_corp_statistic", data=data) - def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None): + def get_servicer_statistic( + self, start_time, end_time, open_kfid=None, servicer_userid=None + ): """ 获取「客户数据统计」接待人员明细数据 diff --git a/astrbot/core/platform/sources/wecom/wecom_kf_message.py b/astrbot/core/platform/sources/wecom/wecom_kf_message.py index 493d0405c..42fc20d65 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf_message.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf_message.py @@ -26,6 +26,7 @@ from optionaldict import optionaldict from wechatpy.client.api.base import BaseWeChatAPI + class WeChatKFMessage(BaseWeChatAPI): """ 发送微信客服消息 @@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI): msg={"msgtype": "news", "link": {"link": articles_data}}, ) - def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""): + def send_msgmenu( + self, user_id, open_kfid, head_content, menu_list, tail_content, msgid="" + ): return self.send( user_id, open_kfid, msgid, msg={ "msgtype": "msgmenu", - "msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content}, + "msgmenu": { + "head_content": head_content, + "list": menu_list, + "tail_content": tail_content, + }, }, ) - def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""): + def send_location( + self, user_id, open_kfid, name, address, latitude, longitude, msgid="" + ): return self.send( user_id, open_kfid, msgid, msg={ "msgtype": "location", - "msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude}, + "msgmenu": { + "name": name, + "address": address, + "latitude": latitude, + "longitude": longitude, + }, }, ) - def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""): + def send_miniprogram( + self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid="" + ): return self.send( user_id, open_kfid, msgid, msg={ "msgtype": "miniprogram", - "msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath}, + "msgmenu": { + "appid": appid, + "title": title, + "thumb_media_id": thumb_media_id, + "pagepath": pagepath, + }, }, ) diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index 04186ff9d..9ea1e5332 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform): self.wexin_event_workers[msg.id] = future await self.convert_message(msg, future) # I love shield so much! - result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s + result = await asyncio.wait_for( + asyncio.shield(future), 60 + ) # wait for 60s logger.debug(f"Got future result: {result}") self.wexin_event_workers.pop(msg.id, None) return result # xml. see weixin_offacc_event.py diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index 102812705..4077cc1ab 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): return logger.info(f"微信公众平台上传语音返回: {response}") - if active_send_mode: self.client.message.send_voice( message_obj.sender.user_id, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index aac03e7a8..8ece29a2b 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -297,6 +297,7 @@ class LLMResponse: ) return ret + @dataclass class RerankResult: index: int diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 19f62edfa..bbf526d4e 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -366,7 +366,10 @@ class ProviderManager: if not self.curr_provider_inst: self.curr_provider_inst = inst - elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]: + elif provider_metadata.provider_type in [ + ProviderType.EMBEDDING, + ProviderType.RERANK, + ]: inst = provider_metadata.cls_type( provider_config, self.provider_settings ) diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 893f27463..49c78239e 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -98,7 +98,7 @@ class ProviderFishAudioTTSAPI(TTSProvider): # FishAudio的reference_id通常是32位十六进制字符串 # 例如: 626bb6d3f3364c9cbc3aa6a67300a664 - pattern = r'^[a-fA-F0-9]{32}$' + pattern = r"^[a-fA-F0-9]{32}$" return bool(re.match(pattern, reference_id.strip())) async def _generate_request(self, text: str) -> dict: diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 9db43896c..e5400a757 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -7,6 +7,7 @@ from astrbot.core.config import AstrBotConfig from .custom_filter import CustomFilter from ..star_handler import StarHandlerMetadata + class GreedyStr(str): """标记指令完成其他参数接收后的所有剩余文本。""" @@ -159,7 +160,7 @@ class CommandFilter(HandlerFilter): break elif message_str.startswith(_full): # 命令名后面无论是空格还是直接连参数都可以 - message_str = message_str[len(_full):].lstrip() + message_str = message_str[len(_full) :].lstrip() ok = True break diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 67d253636..88d8ae64d 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter): + self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg) ) raise ValueError( - f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" - + tree + f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree ) # complete_command_names = [name + " " for name in complete_command_names] diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index 5c7303e8d..d1fdf77c8 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -84,7 +84,10 @@ class SessionPluginManager: session_config["disabled_plugins"] = disabled_plugins session_plugin_config[session_id] = session_config sp.put( - "session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id + "session_plugin_config", + session_plugin_config, + scope="umo", + scope_id=session_id, ) logger.info( diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 5fb1b1dfa..91e7ef0d7 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -791,11 +791,11 @@ class PluginManager: if star_metadata.star_cls is None: return - if '__del__' in star_metadata.star_cls_type.__dict__: + if "__del__" in star_metadata.star_cls_type.__dict__: asyncio.get_event_loop().run_in_executor( None, star_metadata.star_cls.__del__ ) - elif 'terminate' in star_metadata.star_cls_type.__dict__: + elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() async def turn_on_plugin(self, plugin_name: str): diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 42ed168ff..14bd1ac9b 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -30,8 +30,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.star.context import Context from astrbot.core.star.star import star_map from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import AiocqhttpMessageEvent -from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import AiocqhttpAdapter +from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, +) +from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, +) + class StarTools: """ @@ -77,7 +82,11 @@ class StarTools: @classmethod async def send_message_by_id( - cls, type: str, id: str, message_chain: MessageChain, platform: str = "aiocqhttp" + cls, + type: str, + id: str, + message_chain: MessageChain, + platform: str = "aiocqhttp", ): """ 根据 id(例如qq号, 群号等) 直接, 主动地发送消息 @@ -92,7 +101,9 @@ class StarTools: raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": - adapter = next((p for p in platforms if isinstance(p, AiocqhttpAdapter)), None) + adapter = next( + (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None + ) if adapter is None: raise ValueError("未找到适配器: AiocqhttpAdapter") await AiocqhttpMessageEvent.send_message( @@ -115,7 +126,7 @@ class StarTools: message_str: str, message_id: str = "", raw_message: object = None, - group_id: str = "" + group_id: str = "", ) -> AstrBotMessage: """ 创建一个AstrBot消息对象 @@ -152,7 +163,6 @@ class StarTools: @classmethod async def create_event( cls, abm: AstrBotMessage, platform: str = "aiocqhttp", is_wake: bool = True - ) -> None: """ 创建并提交事件到指定平台 @@ -167,7 +177,9 @@ class StarTools: raise ValueError("StarTools not initialized") platforms = cls._context.platform_manager.get_insts() if platform == "aiocqhttp": - adapter = next((p for p in platforms if isinstance(p, AiocqhttpAdapter)), None) + adapter = next( + (p for p in platforms if isinstance(p, AiocqhttpAdapter)), None + ) if adapter is None: raise ValueError("未找到适配器: AiocqhttpAdapter") event = AiocqhttpMessageEvent( @@ -277,7 +289,9 @@ class StarTools: if not plugin_name: raise ValueError("无法获取插件名称") - data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) + data_dir = Path( + os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name) + ) try: data_dir.mkdir(parents=True, exist_ok=True) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 23daeab6b..0983cf8d5 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,7 +1,6 @@ import typing import traceback import os -import copy from .route import Route, Response, RouteContext from astrbot.core.provider.entities import ProviderType from quart import request diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index 5ae98cc21..e47f9d77c 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -10,7 +10,9 @@ class LogRoute(Route): super().__init__(context) self.log_broker = log_broker self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"]) - self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"]) + self.app.add_url_rule( + "/api/log-history", view_func=self.log_history, methods=["GET"] + ) async def log(self): async def stream(): @@ -48,9 +50,15 @@ class LogRoute(Route): """获取日志历史""" try: logs = list(self.log_broker.log_cache) - return Response().ok(data={ - "logs": logs, - }).__dict__ + return ( + Response() + .ok( + data={ + "logs": logs, + } + ) + .__dict__ + ) except BaseException as e: logger.error(f"获取日志历史失败: {e}") return Response().error(f"获取日志历史失败: {e}").__dict__ diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index de3ba7c90..ec455ce3d 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,4 +1,3 @@ -from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig from dataclasses import dataclass from quart import Quart diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 79a601b25..1f33136ed 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -1,6 +1,5 @@ import traceback -import aiohttp from quart import request from astrbot.core import logger diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index aca0c42cd..f1d13fa69 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -35,7 +35,9 @@ class LongTermMemory: else False ) image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"] - image_caption_provider_id = cfg["provider_settings"]["default_image_caption_provider_id"] + image_caption_provider_id = cfg["provider_settings"][ + "default_image_caption_provider_id" + ] active_reply = cfg["provider_ltm_settings"]["active_reply"] enable_active_reply = active_reply.get("enable", False) ar_method = active_reply["method"] diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 015d8e403..c9ce6908c 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -5,7 +5,7 @@ import astrbot.api.star as star import astrbot.api.event.filter as filter from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api.provider import ProviderRequest -from astrbot.api import llm_tool, agent, logger, AstrBotConfig +from astrbot.api import llm_tool, logger, AstrBotConfig from astrbot.core.provider.func_tool_manager import FunctionToolManager from .engines import SearchResult from .engines.bing import Bing @@ -35,7 +35,9 @@ class Main(star.Star): if provider_settings: tavily_key = provider_settings.get("websearch_tavily_key") if isinstance(tavily_key, str): - logger.info("检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。") + logger.info( + "检测到旧版 websearch_tavily_key (字符串格式),自动迁移为列表格式并保存。" + ) if tavily_key: provider_settings["websearch_tavily_key"] = [tavily_key] else: