chore: auto ann fix by ruff (#4903)

* chore: auto fix by ruff

* refactor: 统一修正返回类型注解为 None/bool 以匹配实现

* refactor: 将 _get_next_page 改为异步并移除多余的请求错误抛出

* refactor: 将 get_client 的返回类型改为 object

* style: 为 LarkMessageEvent 的相关方法添加返回类型注解 None

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
This commit is contained in:
Dt8333
2026-02-09 00:22:24 +08:00
committed by GitHub
parent e1b71540c7
commit 7dd95d8a59
183 changed files with 785 additions and 732 deletions
@@ -17,7 +17,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class LongTermMemory:
def __init__(self, acm: AstrBotConfigManager, context: star.Context):
def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
self.acm = acm
self.context = context
self.session_chats = defaultdict(list)
@@ -111,7 +111,7 @@ class LongTermMemory:
return False
async def handle_message(self, event: AstrMessageEvent):
async def handle_message(self, event: AstrMessageEvent) -> None:
"""仅支持群聊"""
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
@@ -148,7 +148,7 @@ class LongTermMemory:
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
"""当触发 LLM 请求前,调用此方法修改 req"""
if event.unified_msg_origin not in self.session_chats:
return
@@ -171,7 +171,9 @@ class LongTermMemory:
)
req.system_prompt += chats_str
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
async def after_req_llm(
self, event: AstrMessageEvent, llm_resp: LLMResponse
) -> None:
if event.unified_msg_origin not in self.session_chats:
return
+7 -3
View File
@@ -85,7 +85,9 @@ class Main(star.Star):
logger.error(f"主动回复失败: {e}")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
async def decorate_llm_req(
self, event: AstrMessageEvent, req: ProviderRequest
) -> None:
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
if self.ltm and self.ltm_enabled(event):
try:
@@ -94,7 +96,9 @@ class Main(star.Star):
logger.error(f"ltm: {e}")
@filter.on_llm_response()
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
async def record_llm_resp_to_ltm(
self, event: AstrMessageEvent, resp: LLMResponse
) -> None:
"""在 LLM 响应后记录对话"""
if self.ltm and self.ltm_enabled(event):
try:
@@ -103,7 +107,7 @@ class Main(star.Star):
logger.error(f"ltm: {e}")
@filter.after_message_sent()
async def after_message_sent(self, event: AstrMessageEvent):
async def after_message_sent(self, event: AstrMessageEvent) -> None:
"""消息发送后处理"""
if self.ltm and self.ltm_enabled(event):
try:
@@ -5,10 +5,10 @@ from astrbot.core.utils.io import download_dashboard
class AdminCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
if not admin_id:
event.set_result(
@@ -21,7 +21,7 @@ class AdminCommands:
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
async def deop(self, event: AstrMessageEvent, admin_id: str = ""):
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""取消授权管理员。deop <admin_id>"""
if not admin_id:
event.set_result(
@@ -39,7 +39,7 @@ class AdminCommands:
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
)
async def wl(self, event: AstrMessageEvent, sid: str = ""):
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
if not sid:
event.set_result(
@@ -53,7 +53,7 @@ class AdminCommands:
cfg.save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
async def dwl(self, event: AstrMessageEvent, sid: str = ""):
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""删除白名单。dwl <sid>"""
if not sid:
event.set_result(
@@ -70,7 +70,7 @@ class AdminCommands:
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
async def update_dashboard(self, event: AstrMessageEvent):
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await event.send(MessageChain().message("正在尝试更新管理面板..."))
await download_dashboard(version=f"v{VERSION}", latest=False)
@@ -11,10 +11,10 @@ from .utils.rst_scene import RstScene
class AlterCmdCommands(CommandParserMixin):
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def update_reset_permission(self, scene_key: str, perm_type: str):
async def update_reset_permission(self, scene_key: str, perm_type: str) -> None:
"""更新reset命令在特定场景下的权限设置"""
from astrbot.api import sp
@@ -26,7 +26,7 @@ class AlterCmdCommands(CommandParserMixin):
alter_cmd_cfg["astrbot"] = plugin_cfg
await sp.global_put("alter_cmd", alter_cmd_cfg)
async def alter_cmd(self, event: AstrMessageEvent):
async def alter_cmd(self, event: AstrMessageEvent) -> None:
token = self.parse_commands(event.message_str)
if token.len < 3:
await event.send(
@@ -16,7 +16,7 @@ THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
class ConversationCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def _get_current_persona_id(self, session_id):
@@ -33,7 +33,7 @@ class ConversationCommands:
return None
return conv.persona_id
async def reset(self, message: AstrMessageEvent):
async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
umo = message.unified_msg_origin
cfg = self.context.get_config(umo=message.unified_msg_origin)
@@ -98,7 +98,7 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret))
async def his(self, message: AstrMessageEvent, page: int = 1):
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
message.set_result(
@@ -141,7 +141,7 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret).use_t2i(False))
async def convs(self, message: AstrMessageEvent, page: int = 1):
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
@@ -216,7 +216,7 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret).use_t2i(False))
return
async def new_conv(self, message: AstrMessageEvent):
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
@@ -242,7 +242,7 @@ class ConversationCommands:
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
)
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None:
"""创建新群聊对话"""
if sid:
session = str(
@@ -273,7 +273,7 @@ class ConversationCommands:
self,
message: AstrMessageEvent,
index: int | None = None,
):
) -> None:
"""通过 /ls 前面的序号切换对话"""
if not isinstance(index, int):
message.set_result(
@@ -308,7 +308,7 @@ class ConversationCommands:
),
)
async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""):
async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None:
"""重命名对话"""
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
@@ -319,7 +319,7 @@ class ConversationCommands:
)
message.set_result(MessageEventResult().message("重命名对话成功。"))
async def del_conv(self, message: AstrMessageEvent):
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
is_unique_session = cfg["platform_settings"]["unique_session"]
@@ -8,7 +8,7 @@ from astrbot.core.utils.io import get_dashboard_version
class HelpCommand:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def _query_astrbot_notice(self):
@@ -34,7 +34,7 @@ class HelpCommand:
lines: list[str] = []
hidden_commands = {"set", "unset", "websearch"}
def walk(items: list[dict], indent: int = 0):
def walk(items: list[dict], indent: int = 0) -> None:
for item in items:
if not item.get("reserved") or not item.get("enabled"):
continue
@@ -62,7 +62,7 @@ class HelpCommand:
walk(commands)
return lines
async def help(self, event: AstrMessageEvent):
async def help(self, event: AstrMessageEvent) -> None:
"""查看帮助"""
notice = ""
try:
@@ -3,10 +3,10 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
class LLMCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def llm(self, event: AstrMessageEvent):
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
cfg = self.context.get_config(umo=event.unified_msg_origin)
enable = cfg["provider_settings"].get("enable", True)
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
class PersonaCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
def _build_tree_output(
@@ -50,7 +50,7 @@ class PersonaCommands:
return lines
async def persona(self, message: AstrMessageEvent):
async def persona(self, message: AstrMessageEvent) -> None:
l = message.message_str.split(" ") # noqa: E741
umo = message.unified_msg_origin
@@ -8,10 +8,10 @@ from astrbot.core.star.star_manager import PluginManager
class PluginCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def plugin_ls(self, event: AstrMessageEvent):
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
parts = ["已加载的插件:\n"]
for plugin in self.context.get_all_stars():
@@ -30,7 +30,7 @@ class PluginCommands:
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
)
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
@@ -43,7 +43,7 @@ class PluginCommands:
await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
@@ -56,7 +56,7 @@ class PluginCommands:
await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
@@ -77,7 +77,7 @@ class PluginCommands:
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
return
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
if not plugin_name:
event.set_result(
@@ -8,7 +8,7 @@ from astrbot.core.provider.entities import ProviderType
class ProviderCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
def _log_reachability_failure(
@@ -17,7 +17,7 @@ class ProviderCommands:
provider_capability_type: ProviderType | None,
err_code: str,
err_reason: str,
):
) -> None:
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
@@ -49,7 +49,7 @@ class ProviderCommands:
event: AstrMessageEvent,
idx: str | int | None = None,
idx2: int | None = None,
):
) -> None:
"""查看或者切换 LLM Provider"""
umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
@@ -228,7 +228,7 @@ class ProviderCommands:
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
):
) -> None:
"""查看或者切换模型"""
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
@@ -293,7 +293,7 @@ class ProviderCommands:
MessageEventResult().message(f"切换模型到 {prov.get_model()}"),
)
async def key(self, message: AstrMessageEvent, index: int | None = None):
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
message.set_result(
@@ -3,10 +3,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
class SetUnsetCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
"""设置会话变量"""
uid = event.unified_msg_origin
session_var = await sp.session_get(uid, "session_variables", {})
@@ -19,7 +19,7 @@ class SetUnsetCommands:
),
)
async def unset_variable(self, event: AstrMessageEvent, key: str):
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
"""移除会话变量"""
uid = event.unified_msg_origin
session_var = await sp.session_get(uid, "session_variables", {})
@@ -7,10 +7,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
class SIDCommand:
"""会话ID命令类"""
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def sid(self, event: AstrMessageEvent):
async def sid(self, event: AstrMessageEvent) -> None:
"""获取消息来源信息"""
sid = event.unified_msg_origin
user_id = str(event.get_sender_id())
@@ -7,10 +7,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
class T2ICommand:
"""文本转图片命令类"""
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def t2i(self, event: AstrMessageEvent):
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
config = self.context.get_config(umo=event.unified_msg_origin)
if config["t2i"]:
@@ -8,10 +8,10 @@ from astrbot.core.star.session_llm_manager import SessionServiceManager
class TTSCommand:
"""文本转语音命令类"""
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def tts(self, event: AstrMessageEvent):
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
+33 -31
View File
@@ -35,84 +35,84 @@ class Main(star.Star):
self.sid_c = SIDCommand(self.context)
@filter.command("help")
async def help(self, event: AstrMessageEvent):
async def help(self, event: AstrMessageEvent) -> None:
"""查看帮助"""
await self.help_c.help(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("llm")
async def llm(self, event: AstrMessageEvent):
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
await self.llm_c.llm(event)
@filter.command_group("plugin")
def plugin(self):
def plugin(self) -> None:
"""插件管理"""
@plugin.command("ls")
async def plugin_ls(self, event: AstrMessageEvent):
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
await self.plugin_c.plugin_ls(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("off")
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
await self.plugin_c.plugin_off(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("on")
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
await self.plugin_c.plugin_on(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("get")
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
await self.plugin_c.plugin_get(event, plugin_repo)
@plugin.command("help")
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
await self.plugin_c.plugin_help(event, plugin_name)
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
await self.t2i_c.t2i(event)
@filter.command("tts")
async def tts(self, event: AstrMessageEvent):
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
await self.tts_c.tts(event)
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):
async def sid(self, event: AstrMessageEvent) -> None:
"""获取会话 ID 和 管理员 ID"""
await self.sid_c.sid(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
await self.admin_c.op(event, admin_id)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str):
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
"""取消授权管理员。deop <admin_id>"""
await self.admin_c.deop(event, admin_id)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str = ""):
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
await self.admin_c.wl(event, sid)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str):
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
"""删除白名单。dwl <sid>"""
await self.admin_c.dwl(event, sid)
@@ -123,12 +123,12 @@ class Main(star.Star):
event: AstrMessageEvent,
idx: str | int | None = None,
idx2: int | None = None,
):
) -> None:
"""查看或者切换 LLM Provider"""
await self.provider_c.provider(event, idx, idx2)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
await self.conversation_c.reset(message)
@@ -138,74 +138,76 @@ class Main(star.Star):
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
):
) -> None:
"""查看或者切换模型"""
await self.provider_c.model_ls(message, idx_or_name)
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
await self.conversation_c.his(message, page)
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1):
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
await self.conversation_c.convs(message, page)
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent):
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
await self.conversation_c.new_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("groupnew")
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None:
"""创建新群聊对话"""
await self.conversation_c.groupnew_conv(message, sid)
@filter.command("switch")
async def switch_conv(self, message: AstrMessageEvent, index: int | None = None):
async def switch_conv(
self, message: AstrMessageEvent, index: int | None = None
) -> None:
"""通过 /ls 前面的序号切换对话"""
await self.conversation_c.switch_conv(message, index)
@filter.command("rename")
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None:
"""重命名对话"""
await self.conversation_c.rename_conv(message, new_name)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent):
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
await self.conversation_c.del_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int | None = None):
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
"""查看或者切换 Key"""
await self.provider_c.key(message, index)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
async def persona(self, message: AstrMessageEvent) -> None:
"""查看或者切换 Persona"""
await self.persona_c.persona(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent):
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await self.admin_c.update_dashboard(event)
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
await self.setunset_c.set_variable(event, key, value)
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
await self.setunset_c.unset_variable(event, key)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent):
async def alter_cmd(self, event: AstrMessageEvent) -> None:
"""修改命令权限"""
await self.alter_cmd_c.alter_cmd(event)
@@ -17,11 +17,11 @@ from astrbot.core.utils.session_waiter import (
class Main(Star):
"""会话控制"""
def __init__(self, context: Context):
def __init__(self, context: Context) -> None:
super().__init__(context)
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
async def handle_session_control_agent(self, event: AstrMessageEvent):
async def handle_session_control_agent(self, event: AstrMessageEvent) -> None:
"""会话控制代理"""
for session_filter in FILTERS:
session_id = session_filter.filter(event)
@@ -90,7 +90,7 @@ class Main(Star):
async def empty_mention_waiter(
controller: SessionController,
event: AstrMessageEvent,
):
) -> None:
event.message_obj.message.insert(
0,
Comp.At(qq=event.get_self_id(), name=event.get_self_id()),
@@ -49,7 +49,7 @@ class SearchEngine:
def _set_selector(self, selector: str) -> str:
raise NotImplementedError
def _get_next_page(self, query: str):
async def _get_next_page(self, query: str) -> str:
raise NotImplementedError
async def _get_html(self, url: str, data: dict | None = None) -> str:
+3 -3
View File
@@ -199,7 +199,7 @@ class Main(star.Star):
return results
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None:
"""网页搜索指令(已废弃)"""
event.set_result(
MessageEventResult().message(
@@ -246,7 +246,7 @@ class Main(star.Star):
return ret
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None):
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None:
if self.baidu_initialized:
return
cfg = self.context.get_config(umo=umo)
@@ -553,7 +553,7 @@ class Main(star.Star):
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
) -> None:
"""Get the session conversation for the given event."""
cfg = self.context.get_config(umo=event.unified_msg_origin)
prov_settings = cfg.get("provider_settings", {})
+3 -3
View File
@@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
@click.group(name="conf")
def conf():
def conf() -> None:
"""配置管理命令
支持的配置项:
@@ -149,7 +149,7 @@ def conf():
@conf.command(name="set")
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str):
def set_config(key: str, value: str) -> None:
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"不支持的配置项: {key}")
@@ -178,7 +178,7 @@ def set_config(key: str, value: str):
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None):
def get_config(key: str | None = None) -> None:
"""获取配置项的值,不提供key则显示所有可配置项"""
config = _load_config()
+8 -8
View File
@@ -15,7 +15,7 @@ from ..utils import (
@click.group()
def plug():
def plug() -> None:
"""插件管理"""
@@ -28,7 +28,7 @@ def _get_data_path() -> Path:
return (base / "data").resolve()
def display_plugins(plugins, title=None, color=None):
def display_plugins(plugins, title=None, color=None) -> None:
if title:
click.echo(click.style(title, fg=color, bold=True))
@@ -45,7 +45,7 @@ def display_plugins(plugins, title=None, color=None):
@plug.command()
@click.argument("name")
def new(name: str):
def new(name: str) -> None:
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
@@ -100,7 +100,7 @@ def new(name: str):
@plug.command()
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
def list(all: bool):
def list(all: bool) -> None:
"""列出插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
@@ -141,7 +141,7 @@ def list(all: bool):
@plug.command()
@click.argument("name")
@click.option("--proxy", help="代理服务器地址")
def install(name: str, proxy: str | None):
def install(name: str, proxy: str | None) -> None:
"""安装插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
@@ -164,7 +164,7 @@ def install(name: str, proxy: str | None):
@plug.command()
@click.argument("name")
def remove(name: str):
def remove(name: str) -> None:
"""卸载插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
@@ -187,7 +187,7 @@ def remove(name: str):
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="Github代理地址")
def update(name: str, proxy: str | None):
def update(name: str, proxy: str | None) -> None:
"""更新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
@@ -225,7 +225,7 @@ def update(name: str, proxy: str | None):
@plug.command()
@click.argument("query")
def search(query: str):
def search(query: str) -> None:
"""搜索插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
+1 -1
View File
@@ -10,7 +10,7 @@ from filelock import FileLock, Timeout
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
async def run_astrbot(astrbot_root: Path):
async def run_astrbot(astrbot_root: Path) -> None:
"""运行 AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
+1 -1
View File
@@ -19,7 +19,7 @@ class PluginStatus(str, Enum):
NOT_PUBLISHED = "未发布"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
"""从 Git 仓库下载代码并解压到指定路径"""
temp_dir = Path(tempfile.mkdtemp())
try:
+4 -2
View File
@@ -57,7 +57,9 @@ class TruncateByTurnsCompressor:
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
def __init__(
self, truncate_turns: int = 1, compression_threshold: float = 0.82
) -> None:
"""Initialize the truncate by turns compressor.
Args:
@@ -152,7 +154,7 @@ class LLMSummaryCompressor:
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
) -> None:
"""Initialize the LLM summary compressor.
Args:
+1 -1
View File
@@ -13,7 +13,7 @@ class ContextManager:
def __init__(
self,
config: ContextConfig,
):
) -> None:
"""Initialize the context manager.
There are two strategies to handle context limit reached:
+1 -1
View File
@@ -14,7 +14,7 @@ class HandoffTool(FunctionTool, Generic[TContext]):
parameters: dict | None = None,
tool_description: str | None = None,
**kwargs,
):
) -> None:
self.agent = agent
# Avoid passing duplicate `description` to the FunctionTool dataclass.
+4 -4
View File
@@ -9,22 +9,22 @@ from .run_context import ContextWrapper, TContext
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ...
async def on_tool_start(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
): ...
) -> None: ...
async def on_tool_end(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
tool_result: mcp.types.CallToolResult | None,
): ...
) -> None: ...
async def on_agent_done(
self,
run_context: ContextWrapper[TContext],
llm_response: LLMResponse,
): ...
) -> None: ...
+6 -6
View File
@@ -108,7 +108,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
class MCPClient:
def __init__(self):
def __init__(self) -> None:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
@@ -126,7 +126,7 @@ class MCPClient:
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
async def connect_to_server(self, mcp_server_config: dict, name: str):
async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
"""Connect to MCP server
If `url` parameter exists:
@@ -144,7 +144,7 @@ class MCPClient:
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
def logging_callback(msg: str) -> None:
# Handle MCP service error logs
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
@@ -214,7 +214,7 @@ class MCPClient:
**cfg,
)
def callback(msg: str):
def callback(msg: str) -> None:
# Handle MCP service error logs
self.server_errlogs.append(msg)
@@ -343,7 +343,7 @@ class MCPClient:
return await _call_with_retry()
async def cleanup(self):
async def cleanup(self) -> None:
"""Clean up resources including old exit stacks from reconnections"""
# Close current exit stack
try:
@@ -365,7 +365,7 @@ class MCPTool(FunctionTool, Generic[TContext]):
def __init__(
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
):
) -> None:
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
@@ -10,7 +10,7 @@ from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None:
self.api_key = api_key
self.api_base = api_base
self.session = None
@@ -277,7 +277,7 @@ class CozeAPIClient:
logger.error(f"获取Coze消息列表失败: {e!s}")
raise Exception(f"获取Coze消息列表失败: {e!s}")
async def close(self):
async def close(self) -> None:
"""关闭会话"""
if self.session:
await self.session.close()
@@ -288,7 +288,7 @@ if __name__ == "__main__":
import asyncio
import os
async def test_coze_api_client():
async def test_coze_api_client() -> None:
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
@@ -67,7 +67,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
def has_rag_options(self) -> bool:
"""判断是否有 RAG 选项
Returns:
@@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None:
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession(trust_env=True)
@@ -155,7 +155,7 @@ class DifyAPIClient:
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
return await resp.json() # {"id": "xxx", ...}
async def close(self):
async def close(self) -> None:
await self.session.close()
async def get_chat_convs(self, user: str, limit: int = 20):
+10 -10
View File
@@ -64,7 +64,7 @@ class FunctionTool(ToolSchema, Generic[TContext]):
with a task identifier while the real work continues asynchronously.
"""
def __repr__(self):
def __repr__(self) -> str:
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
@@ -88,7 +88,7 @@ class ToolSet:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
@@ -97,7 +97,7 @@ class ToolSet:
return
self.tools.append(tool)
def remove_tool(self, name: str):
def remove_tool(self, name: str) -> None:
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
@@ -156,7 +156,7 @@ class ToolSet:
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
) -> None:
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -176,7 +176,7 @@ class ToolSet:
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
def remove_func(self, name: str) -> None:
"""Remove a function tool by its name."""
self.remove_tool(name)
@@ -325,22 +325,22 @@ class ToolSet:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def merge(self, other: "ToolSet"):
def merge(self, other: "ToolSet") -> None:
"""Merge another ToolSet into this one."""
for tool in other.tools:
self.add_tool(tool)
def __len__(self):
def __len__(self) -> int:
return len(self.tools)
def __bool__(self):
def __bool__(self) -> bool:
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def __repr__(self):
def __repr__(self) -> str:
return f"ToolSet(tools={self.tools})"
def __str__(self):
def __str__(self) -> str:
return f"ToolSet(tools={self.tools})"
+3 -3
View File
@@ -12,7 +12,7 @@ from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
async def on_agent_done(self, run_context, llm_response) -> None:
# 执行事件钩子
if llm_response and llm_response.reasoning_content:
# we will use this in result_decorate stage to inject reasoning content to chain
@@ -31,7 +31,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
):
) -> None:
await call_event_hook(
run_context.context.event,
EventType.OnUsingLLMToolEvent,
@@ -45,7 +45,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
) -> None:
run_context.context.event.clear_result()
await call_event_hook(
run_context.context.event,
+3 -3
View File
@@ -295,7 +295,7 @@ async def _run_agent_feeder(
max_step: int,
show_tool_use: bool,
show_reasoning: bool,
):
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
buffer = ""
try:
@@ -352,7 +352,7 @@ async def _safe_tts_stream_wrapper(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
):
) -> None:
"""包装原生流式 TTS 确保异常处理和队列关闭"""
try:
await tts_provider.get_audio_stream(text_queue, audio_queue)
@@ -366,7 +366,7 @@ async def _simulated_stream_tts(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
):
) -> None:
"""模拟流式 TTS 分句生成音频"""
try:
while True:
+2 -2
View File
@@ -57,7 +57,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
elif tool.is_background_task:
task_id = uuid.uuid4().hex
async def _run_in_background():
async def _run_in_background() -> None:
try:
await cls._execute_background(
tool=tool,
@@ -153,7 +153,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
task_id: str,
**tool_args,
):
) -> None:
from astrbot.core.astr_main_agent import (
MainAgentBuildConfig,
_get_session_conv,
+2 -2
View File
@@ -36,7 +36,7 @@ class AstrBotConfigManager:
default_config: AstrBotConfig,
ucr: UmopConfigRouter,
sp: SharedPreferences,
):
) -> None:
self.sp = sp
self.ucr = ucr
self.confs: dict[str, AstrBotConfig] = {}
@@ -56,7 +56,7 @@ class AstrBotConfigManager:
)
return self.abconf_data
def _load_all_configs(self):
def _load_all_configs(self) -> None:
"""Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data()
self.abconf_data = abconf_data
+1 -1
View File
@@ -59,7 +59,7 @@ class AstrBotExporter:
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
):
) -> None:
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
+2 -2
View File
@@ -110,7 +110,7 @@ class ImportPreCheckResult:
class ImportResult:
"""导入结果"""
def __init__(self):
def __init__(self) -> None:
self.success = True
self.imported_tables: dict[str, int] = {}
self.imported_files: dict[str, int] = {}
@@ -161,7 +161,7 @@ class AstrBotImporter:
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
kb_root_dir: str = KB_PATH,
):
) -> None:
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
+1 -1
View File
@@ -22,7 +22,7 @@ class ComputerBooter:
"""
...
async def download_file(self, remote_path: str, local_path: str):
async def download_file(self, remote_path: str, local_path: str) -> None:
"""Download file from the computer."""
...
+1 -1
View File
@@ -225,7 +225,7 @@ class LocalBooter(ComputerBooter):
"LocalBooter does not support upload_file operation. Use shell instead."
)
async def download_file(self, remote_path: str, local_path: str):
async def download_file(self, remote_path: str, local_path: str) -> None:
raise NotImplementedError(
"LocalBooter does not support download_file operation. Use shell instead."
)
+1 -1
View File
@@ -100,7 +100,7 @@ class FileUploadTool(FunctionTool):
self,
context: ContextWrapper[AstrAgentContext],
local_path: str,
):
) -> str | None:
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
+5 -5
View File
@@ -33,7 +33,7 @@ class AstrBotConfig(dict):
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict | None = None,
):
) -> None:
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
@@ -66,7 +66,7 @@ class AstrBotConfig(dict):
"""将 Schema 转换成 Config"""
conf = {}
def _parse_schema(schema: dict, conf: dict):
def _parse_schema(schema: dict, conf: dict) -> None:
for k, v in schema.items():
if v["type"] not in DEFAULT_VALUE_MAP:
raise TypeError(
@@ -148,7 +148,7 @@ class AstrBotConfig(dict):
return has_new
def save_config(self, replace_config: dict | None = None):
def save_config(self, replace_config: dict | None = None) -> None:
"""将配置写入文件
如果传入 replace_config则将配置替换为 replace_config
@@ -164,14 +164,14 @@ class AstrBotConfig(dict):
except KeyError:
return None
def __delattr__(self, key):
def __delattr__(self, key) -> None:
try:
del self[key]
self.save_config()
except KeyError:
raise AttributeError(f"没有找到 Key: '{key}'")
def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
self[key] = value
def check_exist(self) -> bool:
+6 -4
View File
@@ -16,7 +16,7 @@ from astrbot.core.db.po import Conversation, ConversationV2
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
def __init__(self, db_helper: BaseDatabase) -> None:
self.session_conversations: dict[str, str] = {}
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
@@ -106,7 +106,9 @@ class ConversationManager:
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
async def switch_conversation(
self, unified_msg_origin: str, conversation_id: str
) -> None:
"""切换会话的对话
Args:
@@ -121,7 +123,7 @@ class ConversationManager:
self,
unified_msg_origin: str,
conversation_id: str | None = None,
):
) -> None:
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
@@ -138,7 +140,7 @@ class ConversationManager:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None:
"""删除会话的所有对话
Args:
+3 -3
View File
@@ -24,7 +24,7 @@ class CronMessageEvent(AstrMessageEvent):
sender_name: str = "Scheduler",
extras: dict[str, Any] | None = None,
message_type: MessageType = MessageType.FRIEND_MESSAGE,
):
) -> None:
platform_meta = PlatformMetadata(
name="cron",
description="CronJob",
@@ -53,13 +53,13 @@ class CronMessageEvent(AstrMessageEvent):
if extras:
self._extras.update(extras)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
if message is None:
return
await self.context_obj.send_message(self.session, message)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
async def send_streaming(self, generator, use_fallback: bool = False) -> None:
async for chain in generator:
await self.send(chain)
+10 -10
View File
@@ -25,14 +25,14 @@ if TYPE_CHECKING:
class CronJobManager:
"""Central scheduler for BasicCronJob and ActiveAgentCronJob."""
def __init__(self, db: BaseDatabase):
def __init__(self, db: BaseDatabase) -> None:
self.db = db
self.scheduler = AsyncIOScheduler()
self._basic_handlers: dict[str, Callable[..., Any]] = {}
self._lock = asyncio.Lock()
self._started = False
async def start(self, ctx: "Context"):
async def start(self, ctx: "Context") -> None:
self.ctx: Context = ctx # star context
async with self._lock:
if self._started:
@@ -41,14 +41,14 @@ class CronJobManager:
self._started = True
await self.sync_from_db()
async def shutdown(self):
async def shutdown(self) -> None:
async with self._lock:
if not self._started:
return
self.scheduler.shutdown(wait=False)
self._started = False
async def sync_from_db(self):
async def sync_from_db(self) -> None:
jobs = await self.db.list_cron_jobs()
for job in jobs:
if not job.enabled or not job.persistent:
@@ -136,11 +136,11 @@ class CronJobManager:
async def list_jobs(self, job_type: str | None = None) -> list[CronJob]:
return await self.db.list_cron_jobs(job_type)
def _remove_scheduled(self, job_id: str):
def _remove_scheduled(self, job_id: str) -> None:
if self.scheduler.get_job(job_id):
self.scheduler.remove_job(job_id)
def _schedule_job(self, job: CronJob):
def _schedule_job(self, job: CronJob) -> None:
if not self._started:
self.scheduler.start()
self._started = True
@@ -188,7 +188,7 @@ class CronJobManager:
aps_job = self.scheduler.get_job(job_id)
return aps_job.next_run_time if aps_job else None
async def _run_job(self, job_id: str):
async def _run_job(self, job_id: str) -> None:
job = await self.db.get_cron_job(job_id)
if not job or not job.enabled:
return
@@ -222,7 +222,7 @@ class CronJobManager:
# one-shot: remove after execution regardless of success
await self.delete_job(job_id)
async def _run_basic_job(self, job: CronJob):
async def _run_basic_job(self, job: CronJob) -> None:
handler = self._basic_handlers.get(job.job_id)
if not handler:
raise RuntimeError(f"Basic cron job handler not found for {job.job_id}")
@@ -231,7 +231,7 @@ class CronJobManager:
if asyncio.iscoroutine(result):
await result
async def _run_active_agent_job(self, job: CronJob, start_time: datetime):
async def _run_active_agent_job(self, job: CronJob, start_time: datetime) -> None:
payload = job.payload or {}
session_str = payload.get("session")
if not session_str:
@@ -266,7 +266,7 @@ class CronJobManager:
message: str,
session_str: str,
extras: dict,
):
) -> None:
"""Woke the main agent to handle the cron job message."""
from astrbot.core.astr_main_agent import (
MainAgentBuildConfig,
+1 -1
View File
@@ -43,7 +43,7 @@ class BaseDatabase(abc.ABC):
expire_on_commit=False,
)
async def initialize(self):
async def initialize(self) -> None:
"""初始化数据库连接"""
@asynccontextmanager
+5 -5
View File
@@ -43,7 +43,7 @@ def get_platform_type(
async def migration_conversation_table(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
)
@@ -101,7 +101,7 @@ async def migration_conversation_table(
async def migration_platform_table(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
)
@@ -180,7 +180,7 @@ async def migration_platform_table(
async def migration_webchat_data(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
@@ -236,7 +236,7 @@ async def migration_webchat_data(
async def migration_persona_data(
db_helper: BaseDatabase,
astrbot_config: AstrBotConfig,
):
) -> None:
"""迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 新的 Persona 数据存储在 persona 表中
"""
@@ -279,7 +279,7 @@ async def migration_persona_data(
async def migration_preferences(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
# 1. global scope migration
keys = [
"inactivated_llm_tools",
+1 -1
View File
@@ -3,7 +3,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.umop_config_router import UmopConfigRouter
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None:
abconf_data = acm.abconf_data
if not isinstance(abconf_data, dict):
@@ -12,7 +12,7 @@ from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
async def migrate_token_usage(db_helper: BaseDatabase) -> None:
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
@@ -17,7 +17,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
async def migrate_webchat_session(db_helper: BaseDatabase):
async def migrate_webchat_session(db_helper: BaseDatabase) -> None:
"""Create PlatformSession records from platform_message_history.
This migration extracts all unique user_ids from platform_message_history
@@ -8,7 +8,7 @@ _VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
def __init__(self, path=None) -> None:
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
@@ -23,7 +23,7 @@ class SharedPreferences:
os.remove(self.path)
return {}
def _save_preferences(self):
def _save_preferences(self) -> None:
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
@@ -31,16 +31,16 @@ class SharedPreferences:
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
def put(self, key, value) -> None:
self._data[key] = value
self._save_preferences()
def remove(self, key):
def remove(self, key) -> None:
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
def clear(self) -> None:
self._data.clear()
self._save_preferences()
+10 -8
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple | None = None):
def _exec_sql(self, sql: str, params: tuple | None = None) -> None:
conn = self.conn
try:
c = self.conn.cursor()
@@ -144,7 +144,7 @@ class SQLiteDatabase:
conn.commit()
def insert_platform_metrics(self, metrics: dict):
def insert_platform_metrics(self, metrics: dict) -> None:
for k, v in metrics.items():
self._exec_sql(
"""
@@ -153,7 +153,7 @@ class SQLiteDatabase:
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
def insert_llm_metrics(self, metrics: dict) -> None:
for k, v in metrics.items():
self._exec_sql(
"""
@@ -249,7 +249,7 @@ class SQLiteDatabase:
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
def new_conversation(self, user_id: str, cid: str) -> None:
history = "[]"
updated_at = int(time.time())
created_at = updated_at
@@ -287,7 +287,7 @@ class SQLiteDatabase:
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
def update_conversation(self, user_id: str, cid: str, history: str) -> None:
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
@@ -297,7 +297,7 @@ class SQLiteDatabase:
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
def update_conversation_title(self, user_id: str, cid: str, title: str) -> None:
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
@@ -305,7 +305,9 @@ class SQLiteDatabase:
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
def update_conversation_persona_id(
self, user_id: str, cid: str, persona_id: str
) -> None:
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
@@ -313,7 +315,7 @@ class SQLiteDatabase:
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
def delete_conversation(self, user_id: str, cid: str) -> None:
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
+8 -8
View File
@@ -305,7 +305,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async def delete_conversation(self, cid) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -461,7 +461,7 @@ class SQLiteDatabase(BaseDatabase):
platform_id,
user_id,
offset_sec=86400,
):
) -> None:
"""Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
@@ -645,7 +645,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
async def delete_persona(self, persona_id) -> None:
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
@@ -903,7 +903,7 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key):
async def remove_preference(self, scope, scope_id, key) -> None:
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
@@ -917,7 +917,7 @@ class SQLiteDatabase(BaseDatabase):
)
await session.commit()
async def clear_preferences(self, scope, scope_id):
async def clear_preferences(self, scope, scope_id) -> None:
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
@@ -1195,7 +1195,7 @@ class SQLiteDatabase(BaseDatabase):
result = None
def runner():
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
@@ -1218,7 +1218,7 @@ class SQLiteDatabase(BaseDatabase):
result = None
def runner():
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
@@ -1253,7 +1253,7 @@ class SQLiteDatabase(BaseDatabase):
result = None
def runner():
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
+1 -1
View File
@@ -9,7 +9,7 @@ class Result:
class BaseVecDB:
async def initialize(self):
async def initialize(self) -> None:
"""初始化向量数据库"""
@abc.abstractmethod
@@ -33,7 +33,7 @@ class Document(BaseDocModel, table=True):
class DocumentStorage:
def __init__(self, db_path: str):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.engine: AsyncEngine | None = None
@@ -43,7 +43,7 @@ class DocumentStorage:
"sqlite_init.sql",
)
async def initialize(self):
async def initialize(self) -> None:
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
await self.connect()
async with self.engine.begin() as conn: # type: ignore
@@ -80,7 +80,7 @@ class DocumentStorage:
await conn.commit()
async def connect(self):
async def connect(self) -> None:
"""Connect to the SQLite database."""
if self.engine is None:
self.engine = create_async_engine(
@@ -211,7 +211,7 @@ class DocumentStorage:
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
async def delete_document_by_doc_id(self, doc_id: str) -> None:
"""Delete a document by its doc_id.
Args:
@@ -249,7 +249,7 @@ class DocumentStorage:
return self._document_to_dict(document)
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None:
"""Update a document by its doc_id.
Args:
@@ -269,7 +269,7 @@ class DocumentStorage:
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
async def delete_documents(self, metadata_filters: dict) -> None:
"""Delete documents by their metadata filters.
Args:
@@ -384,7 +384,7 @@ class DocumentStorage:
"updated_at": row[5],
}
async def close(self):
async def close(self) -> None:
"""Close the connection to the SQLite database."""
if self.engine:
await self.engine.dispose()
@@ -10,7 +10,7 @@ import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str | None = None):
def __init__(self, dimension: int, path: str | None = None) -> None:
self.dimension = dimension
self.path = path
self.index = None
@@ -20,7 +20,7 @@ class EmbeddingStorage:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
async def insert(self, vector: np.ndarray, id: int):
async def insert(self, vector: np.ndarray, id: int) -> None:
"""插入向量
Args:
@@ -38,7 +38,7 @@ class EmbeddingStorage:
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
await self.save_index()
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
"""批量插入向量
Args:
@@ -71,7 +71,7 @@ class EmbeddingStorage:
distances, indices = self.index.search(vector, k)
return distances, indices
async def delete(self, ids: list[int]):
async def delete(self, ids: list[int]) -> None:
"""删除向量
Args:
@@ -83,7 +83,7 @@ class EmbeddingStorage:
self.index.remove_ids(id_array)
await self.save_index()
async def save_index(self):
async def save_index(self) -> None:
"""保存索引
Args:
+5 -5
View File
@@ -20,7 +20,7 @@ class FaissVecDB(BaseVecDB):
index_store_path: str,
embedding_provider: EmbeddingProvider,
rerank_provider: RerankProvider | None = None,
):
) -> None:
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
@@ -32,7 +32,7 @@ class FaissVecDB(BaseVecDB):
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
async def initialize(self):
async def initialize(self) -> None:
await self.document_storage.initialize()
async def insert(
@@ -165,7 +165,7 @@ class FaissVecDB(BaseVecDB):
return top_k_results
async def delete(self, doc_id: str):
async def delete(self, doc_id: str) -> None:
"""删除一条文档块(chunk"""
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
@@ -177,7 +177,7 @@ class FaissVecDB(BaseVecDB):
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
async def close(self):
async def close(self) -> None:
await self.document_storage.close()
async def count_documents(self, metadata_filter: dict | None = None) -> int:
@@ -192,7 +192,7 @@ class FaissVecDB(BaseVecDB):
)
return count
async def delete_documents(self, metadata_filters: dict):
async def delete_documents(self, metadata_filters: dict) -> None:
"""根据元数据过滤器删除文档"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters,
+3 -3
View File
@@ -28,13 +28,13 @@ class EventBus:
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager,
):
) -> None:
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
async def dispatch(self):
async def dispatch(self) -> None:
while True:
event: AstrMessageEvent = await self.event_queue.get()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
@@ -47,7 +47,7 @@ class EventBus:
continue
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
"""用于记录事件信息
Args:
+2 -2
View File
@@ -9,12 +9,12 @@ from urllib.parse import unquote, urlparse
class FileTokenService:
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
def __init__(self, default_timeout: float = 300):
def __init__(self, default_timeout: float = 300) -> None:
self.lock = asyncio.Lock()
self.staged_files = {} # token: (file_path, expire_time)
self.default_timeout = default_timeout
async def _cleanup_expired_tokens(self):
async def _cleanup_expired_tokens(self) -> None:
"""清理过期的令牌"""
now = time.time()
expired_tokens = [
+2 -2
View File
@@ -17,13 +17,13 @@ from astrbot.dashboard.server import AstrBotDashboard
class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None:
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
async def start(self) -> None:
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
try:
@@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker):
按照固定的字符数分块,并支持块之间的重叠
"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
"""初始化分块器
Args:
@@ -11,7 +11,7 @@ class RecursiveCharacterChunker(BaseChunker):
length_function: Callable[[str], int] = len,
is_separator_regex: bool = False,
separators: list[str] | None = None,
):
) -> None:
"""初始化递归字符文本分割器
Args:
+1 -1
View File
@@ -253,7 +253,7 @@ class KBSQLiteDatabase:
"knowledge_base": row[1],
}
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session, session.begin():
+9 -9
View File
@@ -31,7 +31,7 @@ from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
class RateLimiter:
"""一个简单的速率限制器"""
def __init__(self, max_rpm: int):
def __init__(self, max_rpm: int) -> None:
self.max_per_minute = max_rpm
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
self.last_call_time = 0
@@ -116,7 +116,7 @@ class KBHelper:
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
):
) -> None:
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
@@ -130,7 +130,7 @@ class KBHelper:
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
async def initialize(self) -> None:
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
@@ -174,7 +174,7 @@ class KBHelper:
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
async def delete_vec_db(self) -> None:
"""删除知识库的向量数据库和所有相关文件"""
import shutil
@@ -182,7 +182,7 @@ class KBHelper:
if self.kb_dir.exists():
shutil.rmtree(self.kb_dir)
async def terminate(self):
async def terminate(self) -> None:
if self.vec_db:
await self.vec_db.close()
@@ -293,7 +293,7 @@ class KBHelper:
await progress_callback("chunking", 100, 100)
# 阶段3: 生成向量(带进度回调)
async def embedding_progress_callback(current, total):
async def embedding_progress_callback(current, total) -> None:
if progress_callback:
await progress_callback("embedding", current, total)
@@ -360,7 +360,7 @@ class KBHelper:
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def delete_document(self, doc_id: str):
async def delete_document(self, doc_id: str) -> None:
"""删除单个文档及其相关数据"""
await self.kb_db.delete_document_by_id(
doc_id=doc_id,
@@ -372,7 +372,7 @@ class KBHelper:
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str, doc_id: str):
async def delete_chunk(self, chunk_id: str, doc_id: str) -> None:
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
@@ -383,7 +383,7 @@ class KBHelper:
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
async def refresh_kb(self) -> None:
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
+5 -5
View File
@@ -26,14 +26,14 @@ class KnowledgeBaseManager:
def __init__(
self,
provider_manager: ProviderManager,
):
) -> None:
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
async def initialize(self) -> None:
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
@@ -58,13 +58,13 @@ class KnowledgeBaseManager:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
async def _init_kb_database(self) -> None:
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
async def load_kbs(self) -> None:
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
@@ -275,7 +275,7 @@ class KnowledgeBaseManager:
return "\n".join(lines)
async def terminate(self):
async def terminate(self) -> None:
"""终止所有知识库实例,关闭数据库连接"""
for kb_id, kb_helper in self.kb_insts.items():
try:
@@ -6,7 +6,7 @@ import aiohttp
class URLExtractor:
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
def __init__(self, tavily_keys: list[str]):
def __init__(self, tavily_keys: list[str]) -> None:
"""
初始化 URL 提取器
@@ -44,7 +44,7 @@ class RetrievalManager:
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBSQLiteDatabase,
):
) -> None:
"""初始化检索管理器
Args:
@@ -31,7 +31,7 @@ class RankFusion:
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None:
"""初始化结果融合器
Args:
@@ -34,7 +34,7 @@ class SparseRetriever:
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBSQLiteDatabase):
def __init__(self, kb_db: KBSQLiteDatabase) -> None:
"""初始化稀疏检索器
Args:
+15 -15
View File
@@ -91,7 +91,7 @@ class LogBroker:
发布-订阅模式
"""
def __init__(self):
def __init__(self) -> None:
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: list[Queue] = [] # 订阅者列表
@@ -106,7 +106,7 @@ class LogBroker:
self.subscribers.append(q)
return q
def unregister(self, q: Queue):
def unregister(self, q: Queue) -> None:
"""取消订阅
Args:
@@ -115,7 +115,7 @@ class LogBroker:
"""
self.subscribers.remove(q)
def publish(self, log_entry: dict):
def publish(self, log_entry: dict) -> None:
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
@@ -137,11 +137,11 @@ class LogQueueHandler(logging.Handler):
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker):
def __init__(self, log_broker: LogBroker) -> None:
super().__init__()
self.log_broker = log_broker
def emit(self, record):
def emit(self, record) -> None:
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
@@ -201,7 +201,7 @@ class LogManager:
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record):
def filter(self, record) -> bool:
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
)
@@ -213,7 +213,7 @@ class LogManager:
"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
def filter(self, record) -> bool:
dirname = os.path.dirname(record.pathname)
record.filename = (
os.path.basename(dirname)
@@ -226,14 +226,14 @@ class LogManager:
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record):
def filter(self, record) -> bool:
record.short_levelname = get_short_level_name(record.levelname)
return True
class AstrBotVersionTagFilter(logging.Filter):
"""在 WARNING 及以上级别日志后追加当前 AstrBot 版本号。"""
def filter(self, record):
def filter(self, record) -> bool:
if record.levelno >= logging.WARNING:
record.astrbot_version_tag = f" [v{VERSION}]"
else:
@@ -251,7 +251,7 @@ class LogManager:
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
@@ -301,7 +301,7 @@ class LogManager:
]
@classmethod
def _remove_file_handlers(cls, logger: logging.Logger):
def _remove_file_handlers(cls, logger: logging.Logger) -> None:
for handler in cls._get_file_handlers(logger):
logger.removeHandler(handler)
try:
@@ -310,7 +310,7 @@ class LogManager:
pass
@classmethod
def _remove_trace_file_handlers(cls, logger: logging.Logger):
def _remove_trace_file_handlers(cls, logger: logging.Logger) -> None:
for handler in cls._get_trace_file_handlers(logger):
logger.removeHandler(handler)
try:
@@ -326,7 +326,7 @@ class LogManager:
max_mb: int | None = None,
backup_count: int = 3,
trace: bool = False,
):
) -> None:
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
max_bytes = 0
if max_mb and max_mb > 0:
@@ -365,7 +365,7 @@ class LogManager:
logger: logging.Logger,
config: dict | None,
override_level: str | None = None,
):
) -> None:
"""根据配置设置日志级别和文件日志。
Args:
@@ -413,7 +413,7 @@ class LogManager:
cls._add_file_handler(logger, file_path, max_mb=max_mb)
@classmethod
def configure_trace_logger(cls, config: dict | None):
def configure_trace_logger(cls, config: dict | None) -> None:
"""为 trace 事件配置独立的文件日志,不向控制台输出。"""
if not config:
return
+27 -27
View File
@@ -66,7 +66,7 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel):
type: ComponentType
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def toDict(self):
@@ -89,7 +89,7 @@ class Plain(BaseMessageComponent):
text: str
convert: bool | None = True
def __init__(self, text: str, convert: bool = True, **_):
def __init__(self, text: str, convert: bool = True, **_) -> None:
super().__init__(text=text, convert=convert, **_)
def toDict(self):
@@ -103,7 +103,7 @@ class Face(BaseMessageComponent):
type = ComponentType.Face
id: int
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -118,7 +118,7 @@ class Record(BaseMessageComponent):
# 额外
path: str | None
def __init__(self, file: str | None, **_):
def __init__(self, file: str | None, **_) -> None:
for k in _:
if k == "url":
pass
@@ -221,7 +221,7 @@ class Video(BaseMessageComponent):
# 额外
path: str | None = ""
def __init__(self, file: str, **_):
def __init__(self, file: str, **_) -> None:
super().__init__(file=file, **_)
@staticmethod
@@ -255,7 +255,7 @@ class Video(BaseMessageComponent):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
async def register_to_file_service(self) -> str:
"""将视频注册到文件服务。
Returns:
@@ -303,7 +303,7 @@ class At(BaseMessageComponent):
qq: int | str # 此处str为all时代表所有人
name: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
def toDict(self):
@@ -316,28 +316,28 @@ class At(BaseMessageComponent):
class AtAll(At):
qq: str = "all"
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
class RPS(BaseMessageComponent): # TODO
type = ComponentType.RPS
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type = ComponentType.Dice
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type = ComponentType.Shake
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -348,7 +348,7 @@ class Share(BaseMessageComponent):
content: str | None = ""
image: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -357,7 +357,7 @@ class Contact(BaseMessageComponent): # TODO
_type: str # type 字段冲突
id: int | None = 0
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -368,7 +368,7 @@ class Location(BaseMessageComponent): # TODO
title: str | None = ""
content: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -382,7 +382,7 @@ class Music(BaseMessageComponent):
content: str | None = ""
image: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
# for k in _.keys():
# if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
@@ -402,7 +402,7 @@ class Image(BaseMessageComponent):
path: str | None = ""
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: str | None, **_):
def __init__(self, file: str | None, **_) -> None:
super().__init__(file=file, **_)
@staticmethod
@@ -525,7 +525,7 @@ class Reply(BaseMessageComponent):
seq: int | None = 0
"""deprecated"""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -534,7 +534,7 @@ class Poke(BaseMessageComponent):
id: int | None = 0
qq: int | None = 0
def __init__(self, type: str, **_):
def __init__(self, type: str, **_) -> None:
type = f"Poke:{type}"
super().__init__(type=type, **_)
@@ -543,7 +543,7 @@ class Forward(BaseMessageComponent):
type = ComponentType.Forward
id: str
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -558,7 +558,7 @@ class Node(BaseMessageComponent):
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
def __init__(self, content: list[BaseMessageComponent], **_):
def __init__(self, content: list[BaseMessageComponent], **_) -> None:
if isinstance(content, Node):
# back
content = [content]
@@ -605,7 +605,7 @@ class Nodes(BaseMessageComponent):
type = ComponentType.Nodes
nodes: list[Node]
def __init__(self, nodes: list[Node], **_):
def __init__(self, nodes: list[Node], **_) -> None:
super().__init__(nodes=nodes, **_)
def toDict(self):
@@ -631,7 +631,7 @@ class Json(BaseMessageComponent):
type = ComponentType.Json
data: dict
def __init__(self, data: str | dict, **_):
def __init__(self, data: str | dict, **_) -> None:
if isinstance(data, str):
data = json.loads(data)
super().__init__(data=data, **_)
@@ -650,7 +650,7 @@ class File(BaseMessageComponent):
file_: str | None = "" # 本地路径
url: str | None = "" # url
def __init__(self, name: str, file: str = "", url: str = ""):
def __init__(self, name: str, file: str = "", url: str = "") -> None:
"""文件消息段。"""
super().__init__(name=name, file_=file, url=url)
@@ -686,7 +686,7 @@ class File(BaseMessageComponent):
return ""
@file.setter
def file(self, value: str):
def file(self, value: str) -> None:
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
@@ -721,7 +721,7 @@ class File(BaseMessageComponent):
return ""
async def _download_file(self):
async def _download_file(self) -> None:
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
@@ -736,7 +736,7 @@ class File(BaseMessageComponent):
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self):
async def register_to_file_service(self) -> str:
"""将文件注册到文件服务。
Returns:
@@ -786,7 +786,7 @@ class WechatEmoji(BaseMessageComponent):
md5_len: int | None = 0
cdnurl: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
+3 -3
View File
@@ -17,7 +17,7 @@ DEFAULT_PERSONALITY = Personality(
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None:
self.db = db_helper
self.acm = acm
default_ps = acm.default_conf.get("provider_settings", {})
@@ -29,7 +29,7 @@ class PersonaManager:
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
async def initialize(self) -> None:
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
@@ -58,7 +58,7 @@ class PersonaManager:
except Exception:
return DEFAULT_PERSONALITY
async def delete_persona(self, persona_id: str):
async def delete_persona(self, persona_id: str) -> None:
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
@@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage):
当前只会检查文本的
"""
async def initialize(self, ctx: PipelineContext):
async def initialize(self, ctx: PipelineContext) -> None:
config = ctx.astrbot_config["content_safety"]
self.strategy_selector = StrategySelector(config)
@@ -336,7 +336,7 @@ class InternalAgentSubStage(Stage):
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
) -> None:
if (
not req
or not req.conversation
@@ -19,7 +19,7 @@ class RateLimitStage(Stage):
如果触发限流 stall 流水线直到下一个时间窗口来临时自动唤醒
"""
def __init__(self):
def __init__(self) -> None:
# 存储每个会话的请求时间队列
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
# 为每个会话设置一个锁,避免并发冲突
+2 -2
View File
@@ -35,7 +35,7 @@ class RespondStage(Stage):
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
}
async def initialize(self, ctx: PipelineContext):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.platform_settings: dict = self.config.get("platform_settings", {})
@@ -91,7 +91,7 @@ class RespondStage(Stage):
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool:
"""检查消息链是否为空
Args:
@@ -20,7 +20,7 @@ from ..stage import Stage, register_stage, registered_stages
@register_stage
class ResultDecorateStage(Stage):
async def initialize(self, ctx: PipelineContext):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"]
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
+4 -4
View File
@@ -15,21 +15,21 @@ from .stage import registered_stages
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext):
def __init__(self, context: PipelineContext) -> None:
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序
self.ctx = context # 上下文对象
self.stages = [] # 存储阶段实例
async def initialize(self):
async def initialize(self) -> None:
"""初始化管道调度器时, 初始化所有阶段"""
for stage_cls in registered_stages:
stage_instance = stage_cls() # 创建实例
await stage_instance.initialize(self.ctx)
self.stages.append(stage_instance)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
async def _process_stages(self, event: AstrMessageEvent, from_stage=0) -> None:
"""依次执行各个阶段
Args:
@@ -72,7 +72,7 @@ class PipelineScheduler:
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent):
async def execute(self, event: AstrMessageEvent) -> None:
"""执行 pipeline
Args:
+15 -15
View File
@@ -38,7 +38,7 @@ class AstrMessageEvent(abc.ABC):
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
):
) -> None:
self.message_str = message_str
"""纯文本的消息"""
self.message_obj = message_obj
@@ -91,7 +91,7 @@ class AstrMessageEvent(abc.ABC):
return str(self.session)
@unified_msg_origin.setter
def unified_msg_origin(self, value: str):
def unified_msg_origin(self, value: str) -> None:
"""设置统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self.new_session = MessageSession.from_str(value)
self.session = self.new_session
@@ -102,7 +102,7 @@ class AstrMessageEvent(abc.ABC):
return self.session.session_id
@session_id.setter
def session_id(self, value: str):
def session_id(self, value: str) -> None:
"""设置用户的会话 ID。可以直接使用下面的 unified_msg_origin"""
self.session.session_id = value
@@ -191,7 +191,7 @@ class AstrMessageEvent(abc.ABC):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value):
def set_extra(self, key, value) -> None:
"""设置额外的信息。"""
self._extras[key] = value
@@ -201,7 +201,7 @@ class AstrMessageEvent(abc.ABC):
return self._extras
return self._extras.get(key, default)
def clear_extra(self):
def clear_extra(self) -> None:
"""清除额外的信息。"""
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
self._extras.clear()
@@ -234,7 +234,7 @@ class AstrMessageEvent(abc.ABC):
self,
generator: AsyncGenerator[MessageChain, None],
use_fallback: bool = False,
):
) -> None:
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊
Fallback仅支持 aiocqhttp
@@ -244,13 +244,13 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def _pre_send(self):
async def _pre_send(self) -> None:
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
async def _post_send(self):
async def _post_send(self) -> None:
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
def set_result(self, result: MessageEventResult | str):
def set_result(self, result: MessageEventResult | str) -> None:
"""设置消息事件的结果。
Note:
@@ -279,14 +279,14 @@ class AstrMessageEvent(abc.ABC):
result.chain = []
self._result = result
def stop_event(self):
def stop_event(self) -> None:
"""终止事件传播。"""
if self._result is None:
self.set_result(MessageEventResult().stop_event())
else:
self._result.stop_event()
def continue_event(self):
def continue_event(self) -> None:
"""继续事件传播。"""
if self._result is None:
self.set_result(MessageEventResult().continue_event())
@@ -299,7 +299,7 @@ class AstrMessageEvent(abc.ABC):
return False # 默认是继续传播
return self._result.is_stopped()
def should_call_llm(self, call_llm: bool):
def should_call_llm(self, call_llm: bool) -> None:
"""是否在此消息事件中禁止默认的 LLM 请求。
只会阻止 AstrBot 默认的 LLM 请求链路不会阻止插件中的 LLM 请求
@@ -310,7 +310,7 @@ class AstrMessageEvent(abc.ABC):
"""获取消息事件的结果。"""
return self._result
def clear_result(self):
def clear_result(self) -> None:
"""清除消息事件的结果。"""
self._result = None
@@ -404,7 +404,7 @@ class AstrMessageEvent(abc.ABC):
"""平台适配器"""
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息到消息平台。
Args:
@@ -423,7 +423,7 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def react(self, emoji: str):
async def react(self, emoji: str) -> None:
"""对消息添加表情回应。
默认实现为发送一条包含该表情的消息
+3 -3
View File
@@ -11,7 +11,7 @@ class MessageMember:
user_id: str # 发送者id
nickname: str | None = None
def __str__(self):
def __str__(self) -> str:
# 使用 f-string 来构建返回的字符串表示形式
return (
f"User ID: {self.user_id},"
@@ -34,7 +34,7 @@ class Group:
members: list[MessageMember] | None = None
"""所有群成员"""
def __str__(self):
def __str__(self) -> str:
# 使用 f-string 来构建返回的字符串表示形式
return (
f"Group ID: {self.group_id}\n"
@@ -78,7 +78,7 @@ class AstrBotMessage:
return ""
@group_id.setter
def group_id(self, value: str | None):
def group_id(self, value: str | None) -> None:
"""设置 group_id"""
if value:
if self.group:
+9 -7
View File
@@ -13,7 +13,7 @@ from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue):
def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
@@ -38,7 +38,7 @@ class PlatformManager:
sanitized = platform_id.replace(":", "_").replace("!", "_")
return sanitized, sanitized != platform_id
async def initialize(self):
async def initialize(self) -> None:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
@@ -58,7 +58,7 @@ class PlatformManager:
),
)
async def load_platform(self, platform_config: dict):
async def load_platform(self, platform_config: dict) -> None:
"""实例化一个平台"""
# 动态导入
try:
@@ -176,7 +176,9 @@ class PlatformManager:
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
async def _task_wrapper(
self, task: asyncio.Task, platform: Platform | None = None
) -> None:
# 设置平台状态为运行中
if platform:
platform.status = PlatformStatus.RUNNING
@@ -198,7 +200,7 @@ class PlatformManager:
if platform:
platform.record_error(error_msg, tb_str)
async def reload(self, platform_config: dict):
async def reload(self, platform_config: dict) -> None:
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
await self.load_platform(platform_config)
@@ -209,7 +211,7 @@ class PlatformManager:
if key not in config_ids:
await self.terminate_platform(key)
async def terminate_platform(self, platform_id: str):
async def terminate_platform(self, platform_id: str) -> None:
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
@@ -231,7 +233,7 @@ class PlatformManager:
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self):
async def terminate(self) -> None:
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
+1 -1
View File
@@ -15,7 +15,7 @@ class MessageSession:
session_id: str
platform_id: str = field(init=False)
def __str__(self):
def __str__(self) -> str:
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
def __post_init__(self):
+7 -7
View File
@@ -34,7 +34,7 @@ class PlatformError:
class Platform(abc.ABC):
def __init__(self, config: dict, event_queue: Queue):
def __init__(self, config: dict, event_queue: Queue) -> None:
super().__init__()
# 平台配置
self.config = config
@@ -53,7 +53,7 @@ class Platform(abc.ABC):
return self._status
@status.setter
def status(self, value: PlatformStatus):
def status(self, value: PlatformStatus) -> None:
"""设置平台运行状态"""
self._status = value
if value == PlatformStatus.RUNNING and self._started_at is None:
@@ -69,12 +69,12 @@ class Platform(abc.ABC):
"""获取最近的错误"""
return self._errors[-1] if self._errors else None
def record_error(self, message: str, traceback_str: str | None = None):
def record_error(self, message: str, traceback_str: str | None = None) -> None:
"""记录一个错误"""
self._errors.append(PlatformError(message=message, traceback=traceback_str))
self._status = PlatformStatus.ERROR
def clear_errors(self):
def clear_errors(self) -> None:
"""清除错误记录"""
self._errors.clear()
if self._status == PlatformStatus.ERROR:
@@ -121,7 +121,7 @@ class Platform(abc.ABC):
"""得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError
async def terminate(self):
async def terminate(self) -> None:
"""终止一个平台的运行实例。"""
@abc.abstractmethod
@@ -140,11 +140,11 @@ class Platform(abc.ABC):
"""
await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name)
def commit_event(self, event: AstrMessageEvent):
def commit_event(self, event: AstrMessageEvent) -> None:
"""提交一个事件到事件队列。"""
self._event_queue.put_nowait(event)
def get_client(self):
def get_client(self) -> object:
"""获取平台的客户端对象。"""
async def webhook_callback(self, request: Any) -> Any:
@@ -26,7 +26,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
platform_meta,
session_id,
bot: CQHttp,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@@ -72,7 +72,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
is_group: bool,
session_id: str | None,
messages: list[dict],
):
) -> None:
# session_id 必须是纯数字字符串
session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
@@ -97,7 +97,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
event: Event | None = None,
is_group: bool = False,
session_id: str | None = None,
):
) -> None:
"""发送消息至 QQ 协议端(aiocqhttp)。
Args:
@@ -143,7 +143,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await cls._dispatch_send(bot, event, is_group, session_id, messages)
await asyncio.sleep(0.5)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息"""
event = getattr(self.message_obj, "raw_message", None)
@@ -61,7 +61,7 @@ class AiocqhttpAdapter(Platform):
)
@self.bot.on_request()
async def request(event: Event):
async def request(event: Event) -> None:
try:
abm = await self.convert_message(event)
if not abm:
@@ -72,7 +72,7 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_notice()
async def notice(event: Event):
async def notice(event: Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
@@ -82,7 +82,7 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_message("group")
async def group(event: Event):
async def group(event: Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
@@ -92,7 +92,7 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_message("private")
async def private(event: Event):
async def private(event: Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
@@ -102,14 +102,14 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_websocket_connection
def on_websocket_connection(_):
def on_websocket_connection(_) -> None:
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
is_group = session.message_type == MessageType.GROUP_MESSAGE
if is_group:
session_id = session.session_id.split("_")[-1]
@@ -435,17 +435,17 @@ class AiocqhttpAdapter(Platform):
self.shutdown_event = asyncio.Event()
return coro
async def terminate(self):
async def terminate(self) -> None:
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self):
async def shutdown_trigger_placeholder(self) -> None:
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
async def handle_msg(self, message: AstrBotMessage) -> None:
message_event = AiocqhttpMessageEvent(
message_str=message.message_str,
message_obj=message,
@@ -2,7 +2,7 @@ import asyncio
import os
import threading
import uuid
from typing import cast
from typing import NoReturn, cast
import aiohttp
import dingtalk_stream
@@ -90,7 +90,7 @@ class DingtalkPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
@@ -104,7 +104,7 @@ class DingtalkPlatformAdapter(Platform):
async def create_message_card(
self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage
):
) -> bool | None:
if not self.card_template_id:
return False
@@ -122,7 +122,9 @@ class DingtalkPlatformAdapter(Platform):
logger.error(f"创建钉钉卡片失败: {e}")
return False
async def send_card_message(self, message_id: str, content: str, is_final: bool):
async def send_card_message(
self, message_id: str, content: str, is_final: bool
) -> None:
if message_id not in self.card_instance_id_dict:
return
@@ -276,7 +278,7 @@ class DingtalkPlatformAdapter(Platform):
return ""
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = DingtalkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
@@ -288,10 +290,10 @@ class DingtalkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def run(self):
async def run(self) -> None:
# await self.client_.start()
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
def start_client(loop: asyncio.AbstractEventLoop):
def start_client(loop: asyncio.AbstractEventLoop) -> None:
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
@@ -307,8 +309,8 @@ class DingtalkPlatformAdapter(Platform):
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self):
def monkey_patch_close():
async def terminate(self) -> None:
def monkey_patch_close() -> NoReturn:
raise KeyboardInterrupt("Graceful shutdown")
if self.client_.websocket is not None:
@@ -17,7 +17,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
session_id,
client: dingtalk_stream.ChatbotHandler,
adapter: "Any" = None,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.adapter = adapter
@@ -26,7 +26,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
self,
client: dingtalk_stream.ChatbotHandler,
message: MessageChain,
):
) -> None:
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
ats = []
# fixes: #4218
@@ -80,7 +80,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
continue
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
await self.send_with_client(self.client, message)
await super().send(message)
@@ -15,7 +15,7 @@ else:
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str | None = None):
def __init__(self, token: str, proxy: str | None = None) -> None:
self.token = token
self.proxy = proxy
@@ -32,7 +32,7 @@ class DiscordBotClient(discord.Bot):
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False
async def on_ready(self):
async def on_ready(self) -> None:
"""当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
@@ -93,7 +93,7 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
async def on_message(self, message: discord.Message):
async def on_message(self, message: discord.Message) -> None:
"""当接收到消息时触发"""
if message.author.bot:
return
@@ -130,12 +130,12 @@ class DiscordBotClient(discord.Bot):
return str(interaction_data)
async def start_polling(self):
async def start_polling(self) -> None:
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
@override
async def close(self):
async def close(self) -> None:
"""关闭客户端"""
if not self.is_closed():
await super().close()
@@ -19,7 +19,7 @@ class DiscordEmbed(BaseMessageComponent):
image: str | None = None,
footer: str | None = None,
fields: list[dict] | None = None,
):
) -> None:
self.title = title
self.description = description
self.color = color
@@ -71,7 +71,7 @@ class DiscordButton(BaseMessageComponent):
emoji: str | None = None,
url: str | None = None,
disabled: bool = False,
):
) -> None:
self.label = label
self.custom_id = custom_id
self.style = style
@@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent):
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
def __init__(self, message_id: str, channel_id: str) -> None:
self.message_id = message_id
self.channel_id = channel_id
@@ -99,7 +99,7 @@ class DiscordView(BaseMessageComponent):
self,
components: list[BaseMessageComponent] | None = None,
timeout: float | None = None,
):
) -> None:
self.components = components or []
self.timeout = timeout
@@ -60,7 +60,7 @@ class DiscordPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
"""通过会话发送消息"""
if self.client.user is None:
logger.error(
@@ -122,11 +122,11 @@ class DiscordPlatformAdapter(Platform):
)
@override
async def run(self):
async def run(self) -> None:
"""主要运行逻辑"""
# 初始化回调函数
async def on_received(message_data):
async def on_received(message_data) -> None:
logger.debug(f"[Discord] 收到消息: {message_data}")
if self.client_self_id is None:
self.client_self_id = message_data.get("bot_id")
@@ -143,7 +143,7 @@ class DiscordPlatformAdapter(Platform):
self.client = DiscordBotClient(token, proxy)
self.client.on_message_received = on_received
async def callback():
async def callback() -> None:
if self.enable_command_register:
await self._collect_and_register_commands()
if self.activity_name:
@@ -251,7 +251,7 @@ class DiscordPlatformAdapter(Platform):
# 由于 on_interaction 已被禁用,我们只处理普通消息
return self._convert_message_to_abm(data)
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None:
"""处理消息"""
message_event = DiscordPlatformEvent(
message_str=message.message_str,
@@ -323,7 +323,7 @@ class DiscordPlatformAdapter(Platform):
self.commit_event(message_event)
@override
async def terminate(self):
async def terminate(self) -> None:
"""终止适配器"""
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
self.shutdown_event.set()
@@ -358,11 +358,11 @@ class DiscordPlatformAdapter(Platform):
logger.warning(f"[Discord] 客户端关闭异常: {e}")
logger.info("[Discord] 适配器已终止。")
def register_handler(self, handler_info):
def register_handler(self, handler_info) -> None:
"""注册处理器信息"""
self.registered_handlers.append(handler_info)
async def _collect_and_register_commands(self):
async def _collect_and_register_commands(self) -> None:
"""收集所有指令并注册到Discord"""
logger.info("[Discord] 开始收集并注册斜杠指令...")
registered_commands = []
@@ -420,7 +420,7 @@ class DiscordPlatformAdapter(Platform):
async def dynamic_callback(
ctx: discord.ApplicationContext, params: str | None = None
):
) -> None:
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
logger.debug(f"[Discord] 回调函数参数: {ctx}")
@@ -28,7 +28,7 @@ from .components import DiscordEmbed, DiscordView
class DiscordViewComponent(BaseMessageComponent):
type: str = "discord_view"
def __init__(self, view: discord.ui.View):
def __init__(self, view: discord.ui.View) -> None:
self.view = view
@@ -41,12 +41,12 @@ class DiscordPlatformEvent(AstrMessageEvent):
session_id: str,
client: DiscordBotClient,
interaction_followup_webhook: discord.Webhook | None = None,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.interaction_followup_webhook = interaction_followup_webhook
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息到Discord平台"""
# 解析消息链为 Discord 所需的对象
try:
@@ -267,7 +267,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
content = content[:2000]
return content, files, view, embeds, reference_message_id
async def react(self, emoji: str):
async def react(self, emoji: str) -> None:
"""对原消息添加反应"""
try:
if hasattr(self.message_obj, "raw_message") and hasattr(
@@ -53,10 +53,10 @@ class LarkPlatformAdapter(Platform):
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
# 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None:
await self.convert_msg(event)
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None:
asyncio.create_task(on_msg_event_recv(event))
self.event_handler = (
@@ -91,7 +91,7 @@ class LarkPlatformAdapter(Platform):
self.event_id_timestamps: dict[str, float] = {}
def _clean_expired_events(self):
def _clean_expired_events(self) -> None:
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
expired_keys = [
@@ -121,7 +121,7 @@ class LarkPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
if session.message_type == MessageType.GROUP_MESSAGE:
id_type = "chat_id"
receive_id = session.session_id
@@ -149,7 +149,7 @@ class LarkPlatformAdapter(Platform):
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
@@ -299,7 +299,7 @@ class LarkPlatformAdapter(Platform):
logger.debug(abm)
await self.handle_msg(abm)
async def handle_msg(self, abm: AstrBotMessage):
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = LarkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
@@ -310,7 +310,7 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def handle_webhook_event(self, event_data: dict):
async def handle_webhook_event(self, event_data: dict) -> None:
"""处理 Webhook 事件
Args:
@@ -332,7 +332,7 @@ class LarkPlatformAdapter(Platform):
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
async def run(self):
async def run(self) -> None:
if self.connection_mode == "webhook":
# Webhook 模式
if self.webhook_server is None:
@@ -355,7 +355,7 @@ class LarkPlatformAdapter(Platform):
return await self.webhook_server.handle_callback(request)
async def terminate(self):
async def terminate(self) -> None:
if self.connection_mode == "socket":
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已关闭")
@@ -38,7 +38,7 @@ class LarkMessageEvent(AstrMessageEvent):
platform_meta,
session_id,
bot: lark.Client,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@@ -274,7 +274,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""通用的消息链发送方法
Args:
@@ -342,7 +342,7 @@ class LarkMessageEvent(AstrMessageEvent):
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
await LarkMessageEvent.send_message_chain(
message,
@@ -358,7 +358,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""发送文件消息
Args:
@@ -392,7 +392,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""发送音频消息
Args:
@@ -465,7 +465,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""发送视频消息
Args:
@@ -531,7 +531,7 @@ class LarkMessageEvent(AstrMessageEvent):
receive_id_type=receive_id_type,
)
async def react(self, emoji: str):
async def react(self, emoji: str) -> None:
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
+3 -3
View File
@@ -21,7 +21,7 @@ from astrbot.api import logger
class AESCipher:
"""AES 加密/解密工具类"""
def __init__(self, key: str):
def __init__(self, key: str) -> None:
self.bs = AES.block_size
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
@@ -52,7 +52,7 @@ class LarkWebhookServer:
仅支持统一 Webhook 模式
"""
def __init__(self, config: dict, event_queue: asyncio.Queue):
def __init__(self, config: dict, event_queue: asyncio.Queue) -> None:
"""初始化 Webhook 服务器
Args:
@@ -197,7 +197,7 @@ class LarkWebhookServer:
return {}
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
def set_callback(self, callback: Callable[[dict], Awaitable[None]]) -> None:
"""设置事件回调函数
Args:
@@ -121,7 +121,7 @@ class MisskeyPlatformAdapter(Platform):
support_streaming_message=False,
)
async def run(self):
async def run(self) -> None:
if not self.instance_url or not self.access_token:
logger.error("[Misskey] 配置不完整,无法启动")
return
@@ -150,7 +150,7 @@ class MisskeyPlatformAdapter(Platform):
await self._start_websocket_connection()
def _register_event_handlers(self, streaming):
def _register_event_handlers(self, streaming) -> None:
"""注册事件处理器"""
streaming.add_message_handler("notification", self._handle_notification)
streaming.add_message_handler("main:notification", self._handle_notification)
@@ -194,7 +194,7 @@ class MisskeyPlatformAdapter(Platform):
message: AstrBotMessage,
poll: dict[str, Any],
message_parts: list[str],
):
) -> None:
"""处理投票数据,将其添加到消息中"""
try:
if not isinstance(message.raw_message, dict):
@@ -233,7 +233,7 @@ class MisskeyPlatformAdapter(Platform):
return fields
async def _start_websocket_connection(self):
async def _start_websocket_connection(self) -> None:
backoff_delay = 1.0
max_backoff = 300.0
backoff_multiplier = 1.5
@@ -281,7 +281,7 @@ class MisskeyPlatformAdapter(Platform):
await asyncio.sleep(sleep_time)
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
async def _handle_notification(self, data: dict[str, Any]):
async def _handle_notification(self, data: dict[str, Any]) -> None:
try:
notification_type = data.get("type")
logger.debug(
@@ -305,7 +305,7 @@ class MisskeyPlatformAdapter(Platform):
except Exception as e:
logger.error(f"[Misskey] 处理通知失败: {e}")
async def _handle_chat_message(self, data: dict[str, Any]):
async def _handle_chat_message(self, data: dict[str, Any]) -> None:
try:
sender_id = str(
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""),
@@ -340,7 +340,7 @@ class MisskeyPlatformAdapter(Platform):
except Exception as e:
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
async def _debug_handler(self, data: dict[str, Any]):
async def _debug_handler(self, data: dict[str, Any]) -> None:
event_type = data.get("type", "unknown")
logger.debug(
f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}",
@@ -754,7 +754,7 @@ class MisskeyPlatformAdapter(Platform):
)
return message
async def terminate(self):
async def terminate(self) -> None:
self._running = False
if self.api:
await self.api.close()
@@ -3,7 +3,7 @@ import json
import random
import uuid
from collections.abc import Awaitable, Callable
from typing import Any
from typing import Any, NoReturn
try:
import aiohttp
@@ -43,7 +43,7 @@ class WebSocketError(APIError):
class StreamingClient:
def __init__(self, instance_url: str, access_token: str):
def __init__(self, instance_url: str, access_token: str) -> None:
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self.websocket: Any | None = None
@@ -90,7 +90,7 @@ class StreamingClient:
self.is_connected = False
return False
async def disconnect(self):
async def disconnect(self) -> None:
self._running = False
if self.websocket:
await self.websocket.close()
@@ -116,7 +116,7 @@ class StreamingClient:
self.channels[channel_id] = channel_type
return channel_id
async def unsubscribe_channel(self, channel_id: str):
async def unsubscribe_channel(self, channel_id: str) -> None:
if (
not self.is_connected
or not self.websocket
@@ -136,10 +136,10 @@ class StreamingClient:
self,
event_type: str,
handler: Callable[[dict], Awaitable[None]],
):
) -> None:
self.message_handlers[event_type] = handler
async def listen(self):
async def listen(self) -> None:
if not self.is_connected or not self.websocket:
raise WebSocketError("WebSocket 未连接")
@@ -187,7 +187,7 @@ class StreamingClient:
except Exception:
pass
async def _handle_message(self, data: dict[str, Any]):
async def _handle_message(self, data: dict[str, Any]) -> None:
message_type = data.get("type")
body = data.get("body", {})
@@ -334,7 +334,7 @@ class MisskeyAPI:
download_timeout: int = 15,
chunk_size: int = 64 * 1024,
max_download_bytes: int | None = None,
):
) -> None:
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self._session: aiohttp.ClientSession | None = None
@@ -375,7 +375,7 @@ class MisskeyAPI:
self._session = aiohttp.ClientSession(headers=headers)
return self._session
def _handle_response_status(self, status: int, endpoint: str):
def _handle_response_status(self, status: int, endpoint: str) -> NoReturn:
"""处理 HTTP 响应状态码"""
if status == 400:
logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})")
@@ -449,7 +449,6 @@ class MisskeyAPI:
)
self._handle_response_status(response.status, endpoint)
raise APIConnectionError(f"Request failed for {endpoint}")
@retry_async(
max_retries=API_MAX_RETRIES,
@@ -26,7 +26,7 @@ class MisskeyPlatformEvent(AstrMessageEvent):
platform_meta: PlatformMetadata,
session_id: str,
client,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@@ -40,7 +40,7 @@ class MisskeyPlatformEvent(AstrMessageEvent):
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息,使用适配器的完整上传和发送逻辑"""
try:
logger.debug(
@@ -403,7 +403,7 @@ def cache_user_info(
raw_data: dict[str, Any],
client_self_id: str,
is_chat: bool = False,
):
) -> None:
"""缓存用户信息"""
if is_chat:
user_cache_data = {
@@ -429,7 +429,7 @@ def cache_room_info(
user_cache: dict[str, Any],
raw_data: dict[str, Any],
client_self_id: str,
):
) -> None:
"""缓存房间信息"""
room_data = raw_data.get("toRoom")
room_id = raw_data.get("toRoomId")
@@ -32,12 +32,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
platform_meta: PlatformMetadata,
session_id: str,
bot: Client,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
self.send_buffer = message
await self._post_send()
@@ -35,11 +35,13 @@ for handler in logging.root.handlers[:]:
# QQ 机器人官方框架
class botClient(Client):
def set_platform(self, platform: QQOfficialPlatformAdapter):
def set_platform(self, platform: QQOfficialPlatformAdapter) -> None:
self.platform = platform
# 收到群消息
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
async def on_group_at_message_create(
self, message: botpy.message.GroupMessage
) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.GROUP_MESSAGE,
@@ -49,7 +51,7 @@ class botClient(Client):
self._commit(abm)
# 收到频道消息
async def on_at_message_create(self, message: botpy.message.Message):
async def on_at_message_create(self, message: botpy.message.Message) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.GROUP_MESSAGE,
@@ -59,7 +61,9 @@ class botClient(Client):
self._commit(abm)
# 收到私聊消息
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
async def on_direct_message_create(
self, message: botpy.message.DirectMessage
) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.FRIEND_MESSAGE,
@@ -68,7 +72,7 @@ class botClient(Client):
self._commit(abm)
# 收到 C2C 消息
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.FRIEND_MESSAGE,
@@ -76,7 +80,7 @@ class botClient(Client):
abm.session_id = abm.sender.user_id
self._commit(abm)
def _commit(self, abm: AstrBotMessage):
def _commit(self, abm: AstrBotMessage) -> None:
self.platform.commit_event(
QQOfficialMessageEvent(
abm.message_str,
@@ -128,7 +132,7 @@ class QQOfficialPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
@@ -222,6 +226,6 @@ class QQOfficialPlatformAdapter(Platform):
def get_client(self) -> botClient:
return self.client
async def terminate(self):
async def terminate(self) -> None:
await self.client.close()
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
@@ -26,11 +26,13 @@ for handler in logging.root.handlers[:]:
# QQ 机器人官方框架
class botClient(Client):
def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"):
def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None:
self.platform = platform
# 收到群消息
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
async def on_group_at_message_create(
self, message: botpy.message.GroupMessage
) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.GROUP_MESSAGE,
@@ -40,7 +42,7 @@ class botClient(Client):
self._commit(abm)
# 收到频道消息
async def on_at_message_create(self, message: botpy.message.Message):
async def on_at_message_create(self, message: botpy.message.Message) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.GROUP_MESSAGE,
@@ -50,7 +52,9 @@ class botClient(Client):
self._commit(abm)
# 收到私聊消息
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
async def on_direct_message_create(
self, message: botpy.message.DirectMessage
) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.FRIEND_MESSAGE,
@@ -59,7 +63,7 @@ class botClient(Client):
self._commit(abm)
# 收到 C2C 消息
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None:
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
message,
MessageType.FRIEND_MESSAGE,
@@ -67,7 +71,7 @@ class botClient(Client):
abm.session_id = abm.sender.user_id
self._commit(abm)
def _commit(self, abm: AstrBotMessage):
def _commit(self, abm: AstrBotMessage) -> None:
self.platform.commit_event(
QQOfficialWebhookMessageEvent(
abm.message_str,
@@ -110,7 +114,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
@@ -121,7 +125,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
support_proactive_message=False,
)
async def run(self):
async def run(self) -> None:
self.webhook_helper = QQOfficialWebhook(
self.config,
self._event_queue,
@@ -149,7 +153,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
# 复用 webhook_helper 的回调处理逻辑
return await self.webhook_helper.handle_callback(request)
async def terminate(self):
async def terminate(self) -> None:
if self.webhook_helper:
self.webhook_helper.shutdown_event.set()
await self.client.close()

Some files were not shown because too many files have changed in this diff Show More