chore: auto ann fix by ruff (#4903)
* chore: auto fix by ruff * refactor: 统一修正返回类型注解为 None/bool 以匹配实现 * refactor: 将 _get_next_page 改为异步并移除多余的请求错误抛出 * refactor: 将 get_client 的返回类型改为 object * style: 为 LarkMessageEvent 的相关方法添加返回类型注解 None --------- Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
|
||||
|
||||
class LongTermMemory:
|
||||
def __init__(self, acm: AstrBotConfigManager, context: star.Context):
|
||||
def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
|
||||
self.acm = acm
|
||||
self.context = context
|
||||
self.session_chats = defaultdict(list)
|
||||
@@ -111,7 +111,7 @@ class LongTermMemory:
|
||||
|
||||
return False
|
||||
|
||||
async def handle_message(self, event: AstrMessageEvent):
|
||||
async def handle_message(self, event: AstrMessageEvent) -> None:
|
||||
"""仅支持群聊"""
|
||||
if event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
@@ -148,7 +148,7 @@ class LongTermMemory:
|
||||
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
|
||||
self.session_chats[event.unified_msg_origin].pop(0)
|
||||
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
"""当触发 LLM 请求前,调用此方法修改 req"""
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
@@ -171,7 +171,9 @@ class LongTermMemory:
|
||||
)
|
||||
req.system_prompt += chats_str
|
||||
|
||||
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
|
||||
async def after_req_llm(
|
||||
self, event: AstrMessageEvent, llm_resp: LLMResponse
|
||||
) -> None:
|
||||
if event.unified_msg_origin not in self.session_chats:
|
||||
return
|
||||
|
||||
|
||||
@@ -85,7 +85,9 @@ class Main(star.Star):
|
||||
logger.error(f"主动回复失败: {e}")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
async def decorate_llm_req(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest
|
||||
) -> None:
|
||||
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
@@ -94,7 +96,9 @@ class Main(star.Star):
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
async def record_llm_resp_to_ltm(
|
||||
self, event: AstrMessageEvent, resp: LLMResponse
|
||||
) -> None:
|
||||
"""在 LLM 响应后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
@@ -103,7 +107,7 @@ class Main(star.Star):
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.after_message_sent()
|
||||
async def after_message_sent(self, event: AstrMessageEvent):
|
||||
async def after_message_sent(self, event: AstrMessageEvent) -> None:
|
||||
"""消息发送后处理"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
|
||||
@@ -5,10 +5,10 @@ from astrbot.core.utils.io import download_dashboard
|
||||
|
||||
|
||||
class AdminCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
|
||||
"""授权管理员。op <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
@@ -21,7 +21,7 @@ class AdminCommands:
|
||||
self.context.get_config().save_config()
|
||||
event.set_result(MessageEventResult().message("授权成功。"))
|
||||
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
if not admin_id:
|
||||
event.set_result(
|
||||
@@ -39,7 +39,7 @@ class AdminCommands:
|
||||
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
|
||||
)
|
||||
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""添加白名单。wl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
@@ -53,7 +53,7 @@ class AdminCommands:
|
||||
cfg.save_config()
|
||||
event.set_result(MessageEventResult().message("添加白名单成功。"))
|
||||
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""删除白名单。dwl <sid>"""
|
||||
if not sid:
|
||||
event.set_result(
|
||||
@@ -70,7 +70,7 @@ class AdminCommands:
|
||||
except ValueError:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
async def update_dashboard(self, event: AstrMessageEvent) -> None:
|
||||
"""更新管理面板"""
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
|
||||
@@ -11,10 +11,10 @@ from .utils.rst_scene import RstScene
|
||||
|
||||
|
||||
class AlterCmdCommands(CommandParserMixin):
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str):
|
||||
async def update_reset_permission(self, scene_key: str, perm_type: str) -> None:
|
||||
"""更新reset命令在特定场景下的权限设置"""
|
||||
from astrbot.api import sp
|
||||
|
||||
@@ -26,7 +26,7 @@ class AlterCmdCommands(CommandParserMixin):
|
||||
alter_cmd_cfg["astrbot"] = plugin_cfg
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
async def alter_cmd(self, event: AstrMessageEvent) -> None:
|
||||
token = self.parse_commands(event.message_str)
|
||||
if token.len < 3:
|
||||
await event.send(
|
||||
|
||||
@@ -16,7 +16,7 @@ THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
|
||||
|
||||
|
||||
class ConversationCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def _get_current_persona_id(self, session_id):
|
||||
@@ -33,7 +33,7 @@ class ConversationCommands:
|
||||
return None
|
||||
return conv.persona_id
|
||||
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
async def reset(self, message: AstrMessageEvent) -> None:
|
||||
"""重置 LLM 会话"""
|
||||
umo = message.unified_msg_origin
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
@@ -98,7 +98,7 @@ class ConversationCommands:
|
||||
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||
"""查看对话记录"""
|
||||
if not self.context.get_using_provider(message.unified_msg_origin):
|
||||
message.set_result(
|
||||
@@ -141,7 +141,7 @@ class ConversationCommands:
|
||||
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||
"""查看对话列表"""
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
@@ -216,7 +216,7 @@ class ConversationCommands:
|
||||
message.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
return
|
||||
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
async def new_conv(self, message: AstrMessageEvent) -> None:
|
||||
"""创建新对话"""
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
|
||||
@@ -242,7 +242,7 @@ class ConversationCommands:
|
||||
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
|
||||
)
|
||||
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""创建新群聊对话"""
|
||||
if sid:
|
||||
session = str(
|
||||
@@ -273,7 +273,7 @@ class ConversationCommands:
|
||||
self,
|
||||
message: AstrMessageEvent,
|
||||
index: int | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""通过 /ls 前面的序号切换对话"""
|
||||
if not isinstance(index, int):
|
||||
message.set_result(
|
||||
@@ -308,7 +308,7 @@ class ConversationCommands:
|
||||
),
|
||||
)
|
||||
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""):
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None:
|
||||
"""重命名对话"""
|
||||
if not new_name:
|
||||
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
|
||||
@@ -319,7 +319,7 @@ class ConversationCommands:
|
||||
)
|
||||
message.set_result(MessageEventResult().message("重命名对话成功。"))
|
||||
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
async def del_conv(self, message: AstrMessageEvent) -> None:
|
||||
"""删除当前对话"""
|
||||
cfg = self.context.get_config(umo=message.unified_msg_origin)
|
||||
is_unique_session = cfg["platform_settings"]["unique_session"]
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
@@ -34,7 +34,7 @@ class HelpCommand:
|
||||
lines: list[str] = []
|
||||
hidden_commands = {"set", "unset", "websearch"}
|
||||
|
||||
def walk(items: list[dict], indent: int = 0):
|
||||
def walk(items: list[dict], indent: int = 0) -> None:
|
||||
for item in items:
|
||||
if not item.get("reserved") or not item.get("enabled"):
|
||||
continue
|
||||
@@ -62,7 +62,7 @@ class HelpCommand:
|
||||
walk(commands)
|
||||
return lines
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
async def help(self, event: AstrMessageEvent) -> None:
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
|
||||
@@ -3,10 +3,10 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
|
||||
|
||||
class LLMCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def llm(self, event: AstrMessageEvent):
|
||||
async def llm(self, event: AstrMessageEvent) -> None:
|
||||
"""开启/关闭 LLM"""
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
enable = cfg["provider_settings"].get("enable", True)
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PersonaCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
def _build_tree_output(
|
||||
@@ -50,7 +50,7 @@ class PersonaCommands:
|
||||
|
||||
return lines
|
||||
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
async def persona(self, message: AstrMessageEvent) -> None:
|
||||
l = message.message_str.split(" ") # noqa: E741
|
||||
umo = message.unified_msg_origin
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from astrbot.core.star.star_manager import PluginManager
|
||||
|
||||
|
||||
class PluginCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
async def plugin_ls(self, event: AstrMessageEvent) -> None:
|
||||
"""获取已经安装的插件列表。"""
|
||||
parts = ["已加载的插件:\n"]
|
||||
for plugin in self.context.get_all_stars():
|
||||
@@ -30,7 +30,7 @@ class PluginCommands:
|
||||
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
|
||||
)
|
||||
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""禁用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
|
||||
@@ -43,7 +43,7 @@ class PluginCommands:
|
||||
await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
|
||||
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""启用插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
|
||||
@@ -56,7 +56,7 @@ class PluginCommands:
|
||||
await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore
|
||||
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
|
||||
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
|
||||
"""安装插件"""
|
||||
if DEMO_MODE:
|
||||
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
|
||||
@@ -77,7 +77,7 @@ class PluginCommands:
|
||||
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
|
||||
return
|
||||
|
||||
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""获取插件帮助"""
|
||||
if not plugin_name:
|
||||
event.set_result(
|
||||
|
||||
@@ -8,7 +8,7 @@ from astrbot.core.provider.entities import ProviderType
|
||||
|
||||
|
||||
class ProviderCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
def _log_reachability_failure(
|
||||
@@ -17,7 +17,7 @@ class ProviderCommands:
|
||||
provider_capability_type: ProviderType | None,
|
||||
err_code: str,
|
||||
err_reason: str,
|
||||
):
|
||||
) -> None:
|
||||
"""记录不可达原因到日志。"""
|
||||
meta = provider.meta()
|
||||
logger.warning(
|
||||
@@ -49,7 +49,7 @@ class ProviderCommands:
|
||||
event: AstrMessageEvent,
|
||||
idx: str | int | None = None,
|
||||
idx2: int | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""查看或者切换 LLM Provider"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
@@ -228,7 +228,7 @@ class ProviderCommands:
|
||||
self,
|
||||
message: AstrMessageEvent,
|
||||
idx_or_name: int | str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""查看或者切换模型"""
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
@@ -293,7 +293,7 @@ class ProviderCommands:
|
||||
MessageEventResult().message(f"切换模型到 {prov.get_model()}。"),
|
||||
)
|
||||
|
||||
async def key(self, message: AstrMessageEvent, index: int | None = None):
|
||||
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
|
||||
prov = self.context.get_using_provider(message.unified_msg_origin)
|
||||
if not prov:
|
||||
message.set_result(
|
||||
|
||||
@@ -3,10 +3,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
|
||||
|
||||
class SetUnsetCommands:
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
|
||||
"""设置会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
@@ -19,7 +19,7 @@ class SetUnsetCommands:
|
||||
),
|
||||
)
|
||||
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
|
||||
"""移除会话变量"""
|
||||
uid = event.unified_msg_origin
|
||||
session_var = await sp.session_get(uid, "session_variables", {})
|
||||
|
||||
@@ -7,10 +7,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
class SIDCommand:
|
||||
"""会话ID命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
async def sid(self, event: AstrMessageEvent) -> None:
|
||||
"""获取消息来源信息"""
|
||||
sid = event.unified_msg_origin
|
||||
user_id = str(event.get_sender_id())
|
||||
|
||||
@@ -7,10 +7,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
class T2ICommand:
|
||||
"""文本转图片命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
async def t2i(self, event: AstrMessageEvent) -> None:
|
||||
"""开关文本转图片"""
|
||||
config = self.context.get_config(umo=event.unified_msg_origin)
|
||||
if config["t2i"]:
|
||||
|
||||
@@ -8,10 +8,10 @@ from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
class TTSCommand:
|
||||
"""文本转语音命令类"""
|
||||
|
||||
def __init__(self, context: star.Context):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
async def tts(self, event: AstrMessageEvent) -> None:
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
|
||||
@@ -35,84 +35,84 @@ class Main(star.Star):
|
||||
self.sid_c = SIDCommand(self.context)
|
||||
|
||||
@filter.command("help")
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
async def help(self, event: AstrMessageEvent) -> None:
|
||||
"""查看帮助"""
|
||||
await self.help_c.help(event)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("llm")
|
||||
async def llm(self, event: AstrMessageEvent):
|
||||
async def llm(self, event: AstrMessageEvent) -> None:
|
||||
"""开启/关闭 LLM"""
|
||||
await self.llm_c.llm(event)
|
||||
|
||||
@filter.command_group("plugin")
|
||||
def plugin(self):
|
||||
def plugin(self) -> None:
|
||||
"""插件管理"""
|
||||
|
||||
@plugin.command("ls")
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
async def plugin_ls(self, event: AstrMessageEvent) -> None:
|
||||
"""获取已经安装的插件列表。"""
|
||||
await self.plugin_c.plugin_ls(event)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@plugin.command("off")
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""禁用插件"""
|
||||
await self.plugin_c.plugin_off(event, plugin_name)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@plugin.command("on")
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""启用插件"""
|
||||
await self.plugin_c.plugin_on(event, plugin_name)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@plugin.command("get")
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
|
||||
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
|
||||
"""安装插件"""
|
||||
await self.plugin_c.plugin_get(event, plugin_repo)
|
||||
|
||||
@plugin.command("help")
|
||||
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
|
||||
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
|
||||
"""获取插件帮助"""
|
||||
await self.plugin_c.plugin_help(event, plugin_name)
|
||||
|
||||
@filter.command("t2i")
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
async def t2i(self, event: AstrMessageEvent) -> None:
|
||||
"""开关文本转图片"""
|
||||
await self.t2i_c.t2i(event)
|
||||
|
||||
@filter.command("tts")
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
async def tts(self, event: AstrMessageEvent) -> None:
|
||||
"""开关文本转语音(会话级别)"""
|
||||
await self.tts_c.tts(event)
|
||||
|
||||
@filter.command("sid")
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
async def sid(self, event: AstrMessageEvent) -> None:
|
||||
"""获取会话 ID 和 管理员 ID"""
|
||||
await self.sid_c.sid(event)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("op")
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
|
||||
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
|
||||
"""授权管理员。op <admin_id>"""
|
||||
await self.admin_c.op(event, admin_id)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("deop")
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str):
|
||||
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
|
||||
"""取消授权管理员。deop <admin_id>"""
|
||||
await self.admin_c.deop(event, admin_id)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("wl")
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = ""):
|
||||
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
|
||||
"""添加白名单。wl <sid>"""
|
||||
await self.admin_c.wl(event, sid)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dwl")
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str):
|
||||
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
|
||||
"""删除白名单。dwl <sid>"""
|
||||
await self.admin_c.dwl(event, sid)
|
||||
|
||||
@@ -123,12 +123,12 @@ class Main(star.Star):
|
||||
event: AstrMessageEvent,
|
||||
idx: str | int | None = None,
|
||||
idx2: int | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""查看或者切换 LLM Provider"""
|
||||
await self.provider_c.provider(event, idx, idx2)
|
||||
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
async def reset(self, message: AstrMessageEvent) -> None:
|
||||
"""重置 LLM 会话"""
|
||||
await self.conversation_c.reset(message)
|
||||
|
||||
@@ -138,74 +138,76 @@ class Main(star.Star):
|
||||
self,
|
||||
message: AstrMessageEvent,
|
||||
idx_or_name: int | str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""查看或者切换模型"""
|
||||
await self.provider_c.model_ls(message, idx_or_name)
|
||||
|
||||
@filter.command("history")
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1):
|
||||
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||
"""查看对话记录"""
|
||||
await self.conversation_c.his(message, page)
|
||||
|
||||
@filter.command("ls")
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1):
|
||||
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
|
||||
"""查看对话列表"""
|
||||
await self.conversation_c.convs(message, page)
|
||||
|
||||
@filter.command("new")
|
||||
async def new_conv(self, message: AstrMessageEvent):
|
||||
async def new_conv(self, message: AstrMessageEvent) -> None:
|
||||
"""创建新对话"""
|
||||
await self.conversation_c.new_conv(message)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("groupnew")
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
|
||||
async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None:
|
||||
"""创建新群聊对话"""
|
||||
await self.conversation_c.groupnew_conv(message, sid)
|
||||
|
||||
@filter.command("switch")
|
||||
async def switch_conv(self, message: AstrMessageEvent, index: int | None = None):
|
||||
async def switch_conv(
|
||||
self, message: AstrMessageEvent, index: int | None = None
|
||||
) -> None:
|
||||
"""通过 /ls 前面的序号切换对话"""
|
||||
await self.conversation_c.switch_conv(message, index)
|
||||
|
||||
@filter.command("rename")
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
|
||||
async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None:
|
||||
"""重命名对话"""
|
||||
await self.conversation_c.rename_conv(message, new_name)
|
||||
|
||||
@filter.command("del")
|
||||
async def del_conv(self, message: AstrMessageEvent):
|
||||
async def del_conv(self, message: AstrMessageEvent) -> None:
|
||||
"""删除当前对话"""
|
||||
await self.conversation_c.del_conv(message)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("key")
|
||||
async def key(self, message: AstrMessageEvent, index: int | None = None):
|
||||
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
|
||||
"""查看或者切换 Key"""
|
||||
await self.provider_c.key(message, index)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("persona")
|
||||
async def persona(self, message: AstrMessageEvent):
|
||||
async def persona(self, message: AstrMessageEvent) -> None:
|
||||
"""查看或者切换 Persona"""
|
||||
await self.persona_c.persona(message)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
async def update_dashboard(self, event: AstrMessageEvent) -> None:
|
||||
"""更新管理面板"""
|
||||
await self.admin_c.update_dashboard(event)
|
||||
|
||||
@filter.command("set")
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
|
||||
await self.setunset_c.set_variable(event, key, value)
|
||||
|
||||
@filter.command("unset")
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
|
||||
await self.setunset_c.unset_variable(event, key)
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("alter_cmd", alias={"alter"})
|
||||
async def alter_cmd(self, event: AstrMessageEvent):
|
||||
async def alter_cmd(self, event: AstrMessageEvent) -> None:
|
||||
"""修改命令权限"""
|
||||
await self.alter_cmd_c.alter_cmd(event)
|
||||
|
||||
@@ -17,11 +17,11 @@ from astrbot.core.utils.session_waiter import (
|
||||
class Main(Star):
|
||||
"""会话控制"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
def __init__(self, context: Context) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
|
||||
async def handle_session_control_agent(self, event: AstrMessageEvent):
|
||||
async def handle_session_control_agent(self, event: AstrMessageEvent) -> None:
|
||||
"""会话控制代理"""
|
||||
for session_filter in FILTERS:
|
||||
session_id = session_filter.filter(event)
|
||||
@@ -90,7 +90,7 @@ class Main(Star):
|
||||
async def empty_mention_waiter(
|
||||
controller: SessionController,
|
||||
event: AstrMessageEvent,
|
||||
):
|
||||
) -> None:
|
||||
event.message_obj.message.insert(
|
||||
0,
|
||||
Comp.At(qq=event.get_self_id(), name=event.get_self_id()),
|
||||
|
||||
@@ -49,7 +49,7 @@ class SearchEngine:
|
||||
def _set_selector(self, selector: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_next_page(self, query: str):
|
||||
async def _get_next_page(self, query: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _get_html(self, url: str, data: dict | None = None) -> str:
|
||||
|
||||
@@ -199,7 +199,7 @@ class Main(star.Star):
|
||||
return results
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None:
|
||||
"""网页搜索指令(已废弃)"""
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
@@ -246,7 +246,7 @@ class Main(star.Star):
|
||||
|
||||
return ret
|
||||
|
||||
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None):
|
||||
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None:
|
||||
if self.baidu_initialized:
|
||||
return
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
@@ -553,7 +553,7 @@ class Main(star.Star):
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
):
|
||||
) -> None:
|
||||
"""Get the session conversation for the given event."""
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
prov_settings = cfg.get("provider_settings", {})
|
||||
|
||||
@@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
|
||||
|
||||
|
||||
@click.group(name="conf")
|
||||
def conf():
|
||||
def conf() -> None:
|
||||
"""配置管理命令
|
||||
|
||||
支持的配置项:
|
||||
@@ -149,7 +149,7 @@ def conf():
|
||||
@conf.command(name="set")
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def set_config(key: str, value: str):
|
||||
def set_config(key: str, value: str) -> None:
|
||||
"""设置配置项的值"""
|
||||
if key not in CONFIG_VALIDATORS:
|
||||
raise click.ClickException(f"不支持的配置项: {key}")
|
||||
@@ -178,7 +178,7 @@ def set_config(key: str, value: str):
|
||||
|
||||
@conf.command(name="get")
|
||||
@click.argument("key", required=False)
|
||||
def get_config(key: str | None = None):
|
||||
def get_config(key: str | None = None) -> None:
|
||||
"""获取配置项的值,不提供key则显示所有可配置项"""
|
||||
config = _load_config()
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from ..utils import (
|
||||
|
||||
|
||||
@click.group()
|
||||
def plug():
|
||||
def plug() -> None:
|
||||
"""插件管理"""
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def _get_data_path() -> Path:
|
||||
return (base / "data").resolve()
|
||||
|
||||
|
||||
def display_plugins(plugins, title=None, color=None):
|
||||
def display_plugins(plugins, title=None, color=None) -> None:
|
||||
if title:
|
||||
click.echo(click.style(title, fg=color, bold=True))
|
||||
|
||||
@@ -45,7 +45,7 @@ def display_plugins(plugins, title=None, color=None):
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def new(name: str):
|
||||
def new(name: str) -> None:
|
||||
"""创建新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
@@ -100,7 +100,7 @@ def new(name: str):
|
||||
|
||||
@plug.command()
|
||||
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
|
||||
def list(all: bool):
|
||||
def list(all: bool) -> None:
|
||||
"""列出插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
@@ -141,7 +141,7 @@ def list(all: bool):
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
@click.option("--proxy", help="代理服务器地址")
|
||||
def install(name: str, proxy: str | None):
|
||||
def install(name: str, proxy: str | None) -> None:
|
||||
"""安装插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
@@ -164,7 +164,7 @@ def install(name: str, proxy: str | None):
|
||||
|
||||
@plug.command()
|
||||
@click.argument("name")
|
||||
def remove(name: str):
|
||||
def remove(name: str) -> None:
|
||||
"""卸载插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
@@ -187,7 +187,7 @@ def remove(name: str):
|
||||
@plug.command()
|
||||
@click.argument("name", required=False)
|
||||
@click.option("--proxy", help="Github代理地址")
|
||||
def update(name: str, proxy: str | None):
|
||||
def update(name: str, proxy: str | None) -> None:
|
||||
"""更新插件"""
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins"
|
||||
@@ -225,7 +225,7 @@ def update(name: str, proxy: str | None):
|
||||
|
||||
@plug.command()
|
||||
@click.argument("query")
|
||||
def search(query: str):
|
||||
def search(query: str) -> None:
|
||||
"""搜索插件"""
|
||||
base_path = _get_data_path()
|
||||
plugins = build_plug_list(base_path / "plugins")
|
||||
|
||||
@@ -10,7 +10,7 @@ from filelock import FileLock, Timeout
|
||||
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
async def run_astrbot(astrbot_root: Path) -> None:
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
@@ -19,7 +19,7 @@ class PluginStatus(str, Enum):
|
||||
NOT_PUBLISHED = "未发布"
|
||||
|
||||
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
|
||||
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
|
||||
"""从 Git 仓库下载代码并解压到指定路径"""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
|
||||
@@ -57,7 +57,9 @@ class TruncateByTurnsCompressor:
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
def __init__(
|
||||
self, truncate_turns: int = 1, compression_threshold: float = 0.82
|
||||
) -> None:
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
@@ -152,7 +154,7 @@ class LLMSummaryCompressor:
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -13,7 +13,7 @@ class ContextManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
|
||||
@@ -14,7 +14,7 @@ class HandoffTool(FunctionTool, Generic[TContext]):
|
||||
parameters: dict | None = None,
|
||||
tool_description: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
self.agent = agent
|
||||
|
||||
# Avoid passing duplicate `description` to the FunctionTool dataclass.
|
||||
|
||||
@@ -9,22 +9,22 @@ from .run_context import ContextWrapper, TContext
|
||||
|
||||
|
||||
class BaseAgentRunHooks(Generic[TContext]):
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
|
||||
async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ...
|
||||
async def on_tool_start(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
): ...
|
||||
) -> None: ...
|
||||
async def on_tool_end(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool: FunctionTool,
|
||||
tool_args: dict | None,
|
||||
tool_result: mcp.types.CallToolResult | None,
|
||||
): ...
|
||||
) -> None: ...
|
||||
async def on_agent_done(
|
||||
self,
|
||||
run_context: ContextWrapper[TContext],
|
||||
llm_response: LLMResponse,
|
||||
): ...
|
||||
) -> None: ...
|
||||
|
||||
@@ -108,7 +108,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# Initialize session and client objects
|
||||
self.session: mcp.ClientSession | None = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
@@ -126,7 +126,7 @@ class MCPClient:
|
||||
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
||||
self._reconnecting: bool = False # For logging and debugging
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
|
||||
"""Connect to MCP server
|
||||
|
||||
If `url` parameter exists:
|
||||
@@ -144,7 +144,7 @@ class MCPClient:
|
||||
|
||||
cfg = _prepare_config(mcp_server_config.copy())
|
||||
|
||||
def logging_callback(msg: str):
|
||||
def logging_callback(msg: str) -> None:
|
||||
# Handle MCP service error logs
|
||||
print(f"MCP Server {name} Error: {msg}")
|
||||
self.server_errlogs.append(msg)
|
||||
@@ -214,7 +214,7 @@ class MCPClient:
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
def callback(msg: str) -> None:
|
||||
# Handle MCP service error logs
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
@@ -343,7 +343,7 @@ class MCPClient:
|
||||
|
||||
return await _call_with_retry()
|
||||
|
||||
async def cleanup(self):
|
||||
async def cleanup(self) -> None:
|
||||
"""Clean up resources including old exit stacks from reconnections"""
|
||||
# Close current exit stack
|
||||
try:
|
||||
@@ -365,7 +365,7 @@ class MCPTool(FunctionTool, Generic[TContext]):
|
||||
|
||||
def __init__(
|
||||
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core import logger
|
||||
|
||||
|
||||
class CozeAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = None
|
||||
@@ -277,7 +277,7 @@ class CozeAPIClient:
|
||||
logger.error(f"获取Coze消息列表失败: {e!s}")
|
||||
raise Exception(f"获取Coze消息列表失败: {e!s}")
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""关闭会话"""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
@@ -288,7 +288,7 @@ if __name__ == "__main__":
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
async def test_coze_api_client():
|
||||
async def test_coze_api_client() -> None:
|
||||
api_key = os.getenv("COZE_API_KEY", "")
|
||||
bot_id = os.getenv("COZE_BOT_ID", "")
|
||||
client = CozeAPIClient(api_key=api_key)
|
||||
|
||||
@@ -67,7 +67,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
def has_rag_options(self) -> bool:
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
|
||||
|
||||
|
||||
class DifyAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = ClientSession(trust_env=True)
|
||||
@@ -155,7 +155,7 @@ class DifyAPIClient:
|
||||
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
await self.session.close()
|
||||
|
||||
async def get_chat_convs(self, user: str, limit: int = 20):
|
||||
|
||||
+10
-10
@@ -64,7 +64,7 @@ class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
with a task identifier while the real work continues asynchronously.
|
||||
"""
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
||||
|
||||
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
|
||||
@@ -88,7 +88,7 @@ class ToolSet:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self.tools) == 0
|
||||
|
||||
def add_tool(self, tool: FunctionTool):
|
||||
def add_tool(self, tool: FunctionTool) -> None:
|
||||
"""Add a tool to the set."""
|
||||
# 检查是否已存在同名工具
|
||||
for i, existing_tool in enumerate(self.tools):
|
||||
@@ -97,7 +97,7 @@ class ToolSet:
|
||||
return
|
||||
self.tools.append(tool)
|
||||
|
||||
def remove_tool(self, name: str):
|
||||
def remove_tool(self, name: str) -> None:
|
||||
"""Remove a tool by its name."""
|
||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||
|
||||
@@ -156,7 +156,7 @@ class ToolSet:
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
):
|
||||
) -> None:
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -176,7 +176,7 @@ class ToolSet:
|
||||
self.add_tool(_func)
|
||||
|
||||
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
|
||||
def remove_func(self, name: str):
|
||||
def remove_func(self, name: str) -> None:
|
||||
"""Remove a function tool by its name."""
|
||||
self.remove_tool(name)
|
||||
|
||||
@@ -325,22 +325,22 @@ class ToolSet:
|
||||
"""获取所有工具的名称列表"""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def merge(self, other: "ToolSet"):
|
||||
def merge(self, other: "ToolSet") -> None:
|
||||
"""Merge another ToolSet into this one."""
|
||||
for tool in other.tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.tools)
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.tools) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"ToolSet(tools={self.tools})"
|
||||
|
||||
@@ -12,7 +12,7 @@ from astrbot.core.star.star_handler import EventType
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
async def on_agent_done(self, run_context, llm_response) -> None:
|
||||
# 执行事件钩子
|
||||
if llm_response and llm_response.reasoning_content:
|
||||
# we will use this in result_decorate stage to inject reasoning content to chain
|
||||
@@ -31,7 +31,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
):
|
||||
) -> None:
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnUsingLLMToolEvent,
|
||||
@@ -45,7 +45,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
tool: FunctionTool[Any],
|
||||
tool_args: dict | None,
|
||||
tool_result: CallToolResult | None,
|
||||
):
|
||||
) -> None:
|
||||
run_context.context.event.clear_result()
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
|
||||
@@ -295,7 +295,7 @@ async def _run_agent_feeder(
|
||||
max_step: int,
|
||||
show_tool_use: bool,
|
||||
show_reasoning: bool,
|
||||
):
|
||||
) -> None:
|
||||
"""运行 Agent 并将文本输出分句放入队列"""
|
||||
buffer = ""
|
||||
try:
|
||||
@@ -352,7 +352,7 @@ async def _safe_tts_stream_wrapper(
|
||||
tts_provider: TTSProvider,
|
||||
text_queue: asyncio.Queue[str | None],
|
||||
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
||||
):
|
||||
) -> None:
|
||||
"""包装原生流式 TTS 确保异常处理和队列关闭"""
|
||||
try:
|
||||
await tts_provider.get_audio_stream(text_queue, audio_queue)
|
||||
@@ -366,7 +366,7 @@ async def _simulated_stream_tts(
|
||||
tts_provider: TTSProvider,
|
||||
text_queue: asyncio.Queue[str | None],
|
||||
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
|
||||
):
|
||||
) -> None:
|
||||
"""模拟流式 TTS 分句生成音频"""
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -57,7 +57,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
elif tool.is_background_task:
|
||||
task_id = uuid.uuid4().hex
|
||||
|
||||
async def _run_in_background():
|
||||
async def _run_in_background() -> None:
|
||||
try:
|
||||
await cls._execute_background(
|
||||
tool=tool,
|
||||
@@ -153,7 +153,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
run_context: ContextWrapper[AstrAgentContext],
|
||||
task_id: str,
|
||||
**tool_args,
|
||||
):
|
||||
) -> None:
|
||||
from astrbot.core.astr_main_agent import (
|
||||
MainAgentBuildConfig,
|
||||
_get_session_conv,
|
||||
|
||||
@@ -36,7 +36,7 @@ class AstrBotConfigManager:
|
||||
default_config: AstrBotConfig,
|
||||
ucr: UmopConfigRouter,
|
||||
sp: SharedPreferences,
|
||||
):
|
||||
) -> None:
|
||||
self.sp = sp
|
||||
self.ucr = ucr
|
||||
self.confs: dict[str, AstrBotConfig] = {}
|
||||
@@ -56,7 +56,7 @@ class AstrBotConfigManager:
|
||||
)
|
||||
return self.abconf_data
|
||||
|
||||
def _load_all_configs(self):
|
||||
def _load_all_configs(self) -> None:
|
||||
"""Load all configurations from the shared preferences."""
|
||||
abconf_data = self._get_abconf_data()
|
||||
self.abconf_data = abconf_data
|
||||
|
||||
@@ -59,7 +59,7 @@ class AstrBotExporter:
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
) -> None:
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
|
||||
@@ -110,7 +110,7 @@ class ImportPreCheckResult:
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
@@ -161,7 +161,7 @@ class AstrBotImporter:
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
) -> None:
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
|
||||
@@ -22,7 +22,7 @@ class ComputerBooter:
|
||||
"""
|
||||
...
|
||||
|
||||
async def download_file(self, remote_path: str, local_path: str):
|
||||
async def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
"""Download file from the computer."""
|
||||
...
|
||||
|
||||
|
||||
@@ -225,7 +225,7 @@ class LocalBooter(ComputerBooter):
|
||||
"LocalBooter does not support upload_file operation. Use shell instead."
|
||||
)
|
||||
|
||||
async def download_file(self, remote_path: str, local_path: str):
|
||||
async def download_file(self, remote_path: str, local_path: str) -> None:
|
||||
raise NotImplementedError(
|
||||
"LocalBooter does not support download_file operation. Use shell instead."
|
||||
)
|
||||
|
||||
@@ -100,7 +100,7 @@ class FileUploadTool(FunctionTool):
|
||||
self,
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
local_path: str,
|
||||
):
|
||||
) -> str | None:
|
||||
sb = await get_booter(
|
||||
context.context.context,
|
||||
context.context.event.unified_msg_origin,
|
||||
|
||||
@@ -33,7 +33,7 @@ class AstrBotConfig(dict):
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
default_config: dict = DEFAULT_CONFIG,
|
||||
schema: dict | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
|
||||
@@ -66,7 +66,7 @@ class AstrBotConfig(dict):
|
||||
"""将 Schema 转换成 Config"""
|
||||
conf = {}
|
||||
|
||||
def _parse_schema(schema: dict, conf: dict):
|
||||
def _parse_schema(schema: dict, conf: dict) -> None:
|
||||
for k, v in schema.items():
|
||||
if v["type"] not in DEFAULT_VALUE_MAP:
|
||||
raise TypeError(
|
||||
@@ -148,7 +148,7 @@ class AstrBotConfig(dict):
|
||||
|
||||
return has_new
|
||||
|
||||
def save_config(self, replace_config: dict | None = None):
|
||||
def save_config(self, replace_config: dict | None = None) -> None:
|
||||
"""将配置写入文件
|
||||
|
||||
如果传入 replace_config,则将配置替换为 replace_config
|
||||
@@ -164,14 +164,14 @@ class AstrBotConfig(dict):
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def __delattr__(self, key):
|
||||
def __delattr__(self, key) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
self.save_config()
|
||||
except KeyError:
|
||||
raise AttributeError(f"没有找到 Key: '{key}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
def __setattr__(self, key, value) -> None:
|
||||
self[key] = value
|
||||
|
||||
def check_exist(self) -> bool:
|
||||
|
||||
@@ -16,7 +16,7 @@ from astrbot.core.db.po import Conversation, ConversationV2
|
||||
class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
def __init__(self, db_helper: BaseDatabase) -> None:
|
||||
self.session_conversations: dict[str, str] = {}
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
@@ -106,7 +106,9 @@ class ConversationManager:
|
||||
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
|
||||
return conv.conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
async def switch_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str
|
||||
) -> None:
|
||||
"""切换会话的对话
|
||||
|
||||
Args:
|
||||
@@ -121,7 +123,7 @@ class ConversationManager:
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
conversation_id: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
Args:
|
||||
@@ -138,7 +140,7 @@ class ConversationManager:
|
||||
self.session_conversations.pop(unified_msg_origin, None)
|
||||
await sp.session_remove(unified_msg_origin, "sel_conv_id")
|
||||
|
||||
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
|
||||
async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None:
|
||||
"""删除会话的所有对话
|
||||
|
||||
Args:
|
||||
|
||||
@@ -24,7 +24,7 @@ class CronMessageEvent(AstrMessageEvent):
|
||||
sender_name: str = "Scheduler",
|
||||
extras: dict[str, Any] | None = None,
|
||||
message_type: MessageType = MessageType.FRIEND_MESSAGE,
|
||||
):
|
||||
) -> None:
|
||||
platform_meta = PlatformMetadata(
|
||||
name="cron",
|
||||
description="CronJob",
|
||||
@@ -53,13 +53,13 @@ class CronMessageEvent(AstrMessageEvent):
|
||||
if extras:
|
||||
self._extras.update(extras)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
if message is None:
|
||||
return
|
||||
await self.context_obj.send_message(self.session, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False) -> None:
|
||||
async for chain in generator:
|
||||
await self.send(chain)
|
||||
|
||||
|
||||
@@ -25,14 +25,14 @@ if TYPE_CHECKING:
|
||||
class CronJobManager:
|
||||
"""Central scheduler for BasicCronJob and ActiveAgentCronJob."""
|
||||
|
||||
def __init__(self, db: BaseDatabase):
|
||||
def __init__(self, db: BaseDatabase) -> None:
|
||||
self.db = db
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self._basic_handlers: dict[str, Callable[..., Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._started = False
|
||||
|
||||
async def start(self, ctx: "Context"):
|
||||
async def start(self, ctx: "Context") -> None:
|
||||
self.ctx: Context = ctx # star context
|
||||
async with self._lock:
|
||||
if self._started:
|
||||
@@ -41,14 +41,14 @@ class CronJobManager:
|
||||
self._started = True
|
||||
await self.sync_from_db()
|
||||
|
||||
async def shutdown(self):
|
||||
async def shutdown(self) -> None:
|
||||
async with self._lock:
|
||||
if not self._started:
|
||||
return
|
||||
self.scheduler.shutdown(wait=False)
|
||||
self._started = False
|
||||
|
||||
async def sync_from_db(self):
|
||||
async def sync_from_db(self) -> None:
|
||||
jobs = await self.db.list_cron_jobs()
|
||||
for job in jobs:
|
||||
if not job.enabled or not job.persistent:
|
||||
@@ -136,11 +136,11 @@ class CronJobManager:
|
||||
async def list_jobs(self, job_type: str | None = None) -> list[CronJob]:
|
||||
return await self.db.list_cron_jobs(job_type)
|
||||
|
||||
def _remove_scheduled(self, job_id: str):
|
||||
def _remove_scheduled(self, job_id: str) -> None:
|
||||
if self.scheduler.get_job(job_id):
|
||||
self.scheduler.remove_job(job_id)
|
||||
|
||||
def _schedule_job(self, job: CronJob):
|
||||
def _schedule_job(self, job: CronJob) -> None:
|
||||
if not self._started:
|
||||
self.scheduler.start()
|
||||
self._started = True
|
||||
@@ -188,7 +188,7 @@ class CronJobManager:
|
||||
aps_job = self.scheduler.get_job(job_id)
|
||||
return aps_job.next_run_time if aps_job else None
|
||||
|
||||
async def _run_job(self, job_id: str):
|
||||
async def _run_job(self, job_id: str) -> None:
|
||||
job = await self.db.get_cron_job(job_id)
|
||||
if not job or not job.enabled:
|
||||
return
|
||||
@@ -222,7 +222,7 @@ class CronJobManager:
|
||||
# one-shot: remove after execution regardless of success
|
||||
await self.delete_job(job_id)
|
||||
|
||||
async def _run_basic_job(self, job: CronJob):
|
||||
async def _run_basic_job(self, job: CronJob) -> None:
|
||||
handler = self._basic_handlers.get(job.job_id)
|
||||
if not handler:
|
||||
raise RuntimeError(f"Basic cron job handler not found for {job.job_id}")
|
||||
@@ -231,7 +231,7 @@ class CronJobManager:
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
async def _run_active_agent_job(self, job: CronJob, start_time: datetime):
|
||||
async def _run_active_agent_job(self, job: CronJob, start_time: datetime) -> None:
|
||||
payload = job.payload or {}
|
||||
session_str = payload.get("session")
|
||||
if not session_str:
|
||||
@@ -266,7 +266,7 @@ class CronJobManager:
|
||||
message: str,
|
||||
session_str: str,
|
||||
extras: dict,
|
||||
):
|
||||
) -> None:
|
||||
"""Woke the main agent to handle the cron job message."""
|
||||
from astrbot.core.astr_main_agent import (
|
||||
MainAgentBuildConfig,
|
||||
|
||||
@@ -43,7 +43,7 @@ class BaseDatabase(abc.ABC):
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库连接"""
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_platform_type(
|
||||
async def migration_conversation_table(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
) -> None:
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
)
|
||||
@@ -101,7 +101,7 @@ async def migration_conversation_table(
|
||||
async def migration_platform_table(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
) -> None:
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
)
|
||||
@@ -180,7 +180,7 @@ async def migration_platform_table(
|
||||
async def migration_webchat_data(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
) -> None:
|
||||
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
|
||||
db_helper_v3 = SQLiteV3DatabaseV3(
|
||||
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
|
||||
@@ -236,7 +236,7 @@ async def migration_webchat_data(
|
||||
async def migration_persona_data(
|
||||
db_helper: BaseDatabase,
|
||||
astrbot_config: AstrBotConfig,
|
||||
):
|
||||
) -> None:
|
||||
"""迁移 Persona 数据到新的表中。
|
||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||
"""
|
||||
@@ -279,7 +279,7 @@ async def migration_persona_data(
|
||||
async def migration_preferences(
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
):
|
||||
) -> None:
|
||||
# 1. global scope migration
|
||||
keys = [
|
||||
"inactivated_llm_tools",
|
||||
|
||||
@@ -3,7 +3,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
|
||||
|
||||
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
|
||||
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None:
|
||||
abconf_data = acm.abconf_data
|
||||
|
||||
if not isinstance(abconf_data, dict):
|
||||
|
||||
@@ -12,7 +12,7 @@ from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
async def migrate_token_usage(db_helper: BaseDatabase) -> None:
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
|
||||
@@ -17,7 +17,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
|
||||
|
||||
|
||||
async def migrate_webchat_session(db_helper: BaseDatabase):
|
||||
async def migrate_webchat_session(db_helper: BaseDatabase) -> None:
|
||||
"""Create PlatformSession records from platform_message_history.
|
||||
|
||||
This migration extracts all unique user_ids from platform_message_history
|
||||
|
||||
@@ -8,7 +8,7 @@ _VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path=None):
|
||||
def __init__(self, path=None) -> None:
|
||||
if path is None:
|
||||
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
|
||||
self.path = path
|
||||
@@ -23,7 +23,7 @@ class SharedPreferences:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
def _save_preferences(self) -> None:
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
@@ -31,16 +31,16 @@ class SharedPreferences:
|
||||
def get(self, key, default: _VT = None) -> _VT:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
def put(self, key, value) -> None:
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
|
||||
def remove(self, key):
|
||||
def remove(self, key) -> None:
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None) -> None:
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -144,7 +144,7 @@ class SQLiteDatabase:
|
||||
|
||||
conn.commit()
|
||||
|
||||
def insert_platform_metrics(self, metrics: dict):
|
||||
def insert_platform_metrics(self, metrics: dict) -> None:
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
@@ -153,7 +153,7 @@ class SQLiteDatabase:
|
||||
(k, v, int(time.time())),
|
||||
)
|
||||
|
||||
def insert_llm_metrics(self, metrics: dict):
|
||||
def insert_llm_metrics(self, metrics: dict) -> None:
|
||||
for k, v in metrics.items():
|
||||
self._exec_sql(
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class SQLiteDatabase:
|
||||
|
||||
return Conversation(*res)
|
||||
|
||||
def new_conversation(self, user_id: str, cid: str):
|
||||
def new_conversation(self, user_id: str, cid: str) -> None:
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
@@ -287,7 +287,7 @@ class SQLiteDatabase:
|
||||
)
|
||||
return conversations
|
||||
|
||||
def update_conversation(self, user_id: str, cid: str, history: str):
|
||||
def update_conversation(self, user_id: str, cid: str, history: str) -> None:
|
||||
"""更新对话,并且同时更新时间"""
|
||||
updated_at = int(time.time())
|
||||
self._exec_sql(
|
||||
@@ -297,7 +297,7 @@ class SQLiteDatabase:
|
||||
(history, updated_at, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
||||
def update_conversation_title(self, user_id: str, cid: str, title: str) -> None:
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
||||
@@ -305,7 +305,9 @@ class SQLiteDatabase:
|
||||
(title, user_id, cid),
|
||||
)
|
||||
|
||||
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
||||
def update_conversation_persona_id(
|
||||
self, user_id: str, cid: str, persona_id: str
|
||||
) -> None:
|
||||
self._exec_sql(
|
||||
"""
|
||||
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
||||
@@ -313,7 +315,7 @@ class SQLiteDatabase:
|
||||
(persona_id, user_id, cid),
|
||||
)
|
||||
|
||||
def delete_conversation(self, user_id: str, cid: str):
|
||||
def delete_conversation(self, user_id: str, cid: str) -> None:
|
||||
self._exec_sql(
|
||||
"""
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
|
||||
@@ -305,7 +305,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
await session.execute(query)
|
||||
return await self.get_conversation_by_id(cid)
|
||||
|
||||
async def delete_conversation(self, cid):
|
||||
async def delete_conversation(self, cid) -> None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -461,7 +461,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
platform_id,
|
||||
user_id,
|
||||
offset_sec=86400,
|
||||
):
|
||||
) -> None:
|
||||
"""Delete platform message history records newer than the specified offset."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
@@ -645,7 +645,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
await session.execute(query)
|
||||
return await self.get_persona_by_id(persona_id)
|
||||
|
||||
async def delete_persona(self, persona_id):
|
||||
async def delete_persona(self, persona_id) -> None:
|
||||
"""Delete a persona by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
@@ -903,7 +903,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def remove_preference(self, scope, scope_id, key):
|
||||
async def remove_preference(self, scope, scope_id, key) -> None:
|
||||
"""Remove a preference by scope ID and key."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
@@ -917,7 +917,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def clear_preferences(self, scope, scope_id):
|
||||
async def clear_preferences(self, scope, scope_id) -> None:
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
@@ -1195,7 +1195,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
def runner() -> None:
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
@@ -1218,7 +1218,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
def runner() -> None:
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
@@ -1253,7 +1253,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
|
||||
result = None
|
||||
|
||||
def runner():
|
||||
def runner() -> None:
|
||||
nonlocal result
|
||||
result = asyncio.run(_inner())
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class Result:
|
||||
|
||||
|
||||
class BaseVecDB:
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""初始化向量数据库"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -33,7 +33,7 @@ class Document(BaseDocModel, table=True):
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
def __init__(self, db_path: str) -> None:
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.engine: AsyncEngine | None = None
|
||||
@@ -43,7 +43,7 @@ class DocumentStorage:
|
||||
"sqlite_init.sql",
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
await self.connect()
|
||||
async with self.engine.begin() as conn: # type: ignore
|
||||
@@ -80,7 +80,7 @@ class DocumentStorage:
|
||||
|
||||
await conn.commit()
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the SQLite database."""
|
||||
if self.engine is None:
|
||||
self.engine = create_async_engine(
|
||||
@@ -211,7 +211,7 @@ class DocumentStorage:
|
||||
await session.flush() # Flush to get all IDs
|
||||
return [doc.id for doc in documents] # type: ignore
|
||||
|
||||
async def delete_document_by_doc_id(self, doc_id: str):
|
||||
async def delete_document_by_doc_id(self, doc_id: str) -> None:
|
||||
"""Delete a document by its doc_id.
|
||||
|
||||
Args:
|
||||
@@ -249,7 +249,7 @@ class DocumentStorage:
|
||||
return self._document_to_dict(document)
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None:
|
||||
"""Update a document by its doc_id.
|
||||
|
||||
Args:
|
||||
@@ -269,7 +269,7 @@ class DocumentStorage:
|
||||
document.updated_at = datetime.now()
|
||||
session.add(document)
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
async def delete_documents(self, metadata_filters: dict) -> None:
|
||||
"""Delete documents by their metadata filters.
|
||||
|
||||
Args:
|
||||
@@ -384,7 +384,7 @@ class DocumentStorage:
|
||||
"updated_at": row[5],
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.engine:
|
||||
await self.engine.dispose()
|
||||
|
||||
@@ -10,7 +10,7 @@ import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str | None = None):
|
||||
def __init__(self, dimension: int, path: str | None = None) -> None:
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
@@ -20,7 +20,7 @@ class EmbeddingStorage:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
async def insert(self, vector: np.ndarray, id: int) -> None:
|
||||
"""插入向量
|
||||
|
||||
Args:
|
||||
@@ -38,7 +38,7 @@ class EmbeddingStorage:
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
await self.save_index()
|
||||
|
||||
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
|
||||
async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
|
||||
"""批量插入向量
|
||||
|
||||
Args:
|
||||
@@ -71,7 +71,7 @@ class EmbeddingStorage:
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def delete(self, ids: list[int]):
|
||||
async def delete(self, ids: list[int]) -> None:
|
||||
"""删除向量
|
||||
|
||||
Args:
|
||||
@@ -83,7 +83,7 @@ class EmbeddingStorage:
|
||||
self.index.remove_ids(id_array)
|
||||
await self.save_index()
|
||||
|
||||
async def save_index(self):
|
||||
async def save_index(self) -> None:
|
||||
"""保存索引
|
||||
|
||||
Args:
|
||||
|
||||
@@ -20,7 +20,7 @@ class FaissVecDB(BaseVecDB):
|
||||
index_store_path: str,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
rerank_provider: RerankProvider | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.doc_store_path = doc_store_path
|
||||
self.index_store_path = index_store_path
|
||||
self.embedding_provider = embedding_provider
|
||||
@@ -32,7 +32,7 @@ class FaissVecDB(BaseVecDB):
|
||||
self.embedding_provider = embedding_provider
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
await self.document_storage.initialize()
|
||||
|
||||
async def insert(
|
||||
@@ -165,7 +165,7 @@ class FaissVecDB(BaseVecDB):
|
||||
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: str):
|
||||
async def delete(self, doc_id: str) -> None:
|
||||
"""删除一条文档块(chunk)"""
|
||||
# 获得对应的 int id
|
||||
result = await self.document_storage.get_document_by_doc_id(doc_id)
|
||||
@@ -177,7 +177,7 @@ class FaissVecDB(BaseVecDB):
|
||||
await self.document_storage.delete_document_by_doc_id(doc_id)
|
||||
await self.embedding_storage.delete([int_id])
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self, metadata_filter: dict | None = None) -> int:
|
||||
@@ -192,7 +192,7 @@ class FaissVecDB(BaseVecDB):
|
||||
)
|
||||
return count
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
async def delete_documents(self, metadata_filters: dict) -> None:
|
||||
"""根据元数据过滤器删除文档"""
|
||||
docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters,
|
||||
|
||||
@@ -28,13 +28,13 @@ class EventBus:
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
) -> None:
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
async def dispatch(self):
|
||||
async def dispatch(self) -> None:
|
||||
while True:
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
@@ -47,7 +47,7 @@ class EventBus:
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -9,12 +9,12 @@ from urllib.parse import unquote, urlparse
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
def __init__(self, default_timeout: float = 300) -> None:
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
async def _cleanup_expired_tokens(self) -> None:
|
||||
"""清理过期的令牌"""
|
||||
now = time.time()
|
||||
expired_tokens = [
|
||||
|
||||
@@ -17,13 +17,13 @@ from astrbot.dashboard.server import AstrBotDashboard
|
||||
class InitialLoader:
|
||||
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
|
||||
|
||||
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
||||
def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None:
|
||||
self.db = db
|
||||
self.logger = logger
|
||||
self.log_broker = log_broker
|
||||
self.webui_dir: str | None = None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
|
||||
try:
|
||||
|
||||
@@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker):
|
||||
按照固定的字符数分块,并支持块之间的重叠。
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
|
||||
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
|
||||
"""初始化分块器
|
||||
|
||||
Args:
|
||||
|
||||
@@ -11,7 +11,7 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
length_function: Callable[[str], int] = len,
|
||||
is_separator_regex: bool = False,
|
||||
separators: list[str] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""初始化递归字符文本分割器
|
||||
|
||||
Args:
|
||||
|
||||
@@ -253,7 +253,7 @@ class KBSQLiteDatabase:
|
||||
"knowledge_base": row[1],
|
||||
}
|
||||
|
||||
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
|
||||
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
|
||||
"""删除单个文档及其相关数据"""
|
||||
# 在知识库表中删除
|
||||
async with self.get_db() as session, session.begin():
|
||||
|
||||
@@ -31,7 +31,7 @@ from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
|
||||
class RateLimiter:
|
||||
"""一个简单的速率限制器"""
|
||||
|
||||
def __init__(self, max_rpm: int):
|
||||
def __init__(self, max_rpm: int) -> None:
|
||||
self.max_per_minute = max_rpm
|
||||
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
|
||||
self.last_call_time = 0
|
||||
@@ -116,7 +116,7 @@ class KBHelper:
|
||||
provider_manager: ProviderManager,
|
||||
kb_root_dir: str,
|
||||
chunker: BaseChunker,
|
||||
):
|
||||
) -> None:
|
||||
self.kb_db = kb_db
|
||||
self.kb = kb
|
||||
self.prov_mgr = provider_manager
|
||||
@@ -130,7 +130,7 @@ class KBHelper:
|
||||
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
await self._ensure_vec_db()
|
||||
|
||||
async def get_ep(self) -> EmbeddingProvider:
|
||||
@@ -174,7 +174,7 @@ class KBHelper:
|
||||
self.vec_db = vec_db
|
||||
return vec_db
|
||||
|
||||
async def delete_vec_db(self):
|
||||
async def delete_vec_db(self) -> None:
|
||||
"""删除知识库的向量数据库和所有相关文件"""
|
||||
import shutil
|
||||
|
||||
@@ -182,7 +182,7 @@ class KBHelper:
|
||||
if self.kb_dir.exists():
|
||||
shutil.rmtree(self.kb_dir)
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
|
||||
@@ -293,7 +293,7 @@ class KBHelper:
|
||||
await progress_callback("chunking", 100, 100)
|
||||
|
||||
# 阶段3: 生成向量(带进度回调)
|
||||
async def embedding_progress_callback(current, total):
|
||||
async def embedding_progress_callback(current, total) -> None:
|
||||
if progress_callback:
|
||||
await progress_callback("embedding", current, total)
|
||||
|
||||
@@ -360,7 +360,7 @@ class KBHelper:
|
||||
doc = await self.kb_db.get_document_by_id(doc_id)
|
||||
return doc
|
||||
|
||||
async def delete_document(self, doc_id: str):
|
||||
async def delete_document(self, doc_id: str) -> None:
|
||||
"""删除单个文档及其相关数据"""
|
||||
await self.kb_db.delete_document_by_id(
|
||||
doc_id=doc_id,
|
||||
@@ -372,7 +372,7 @@ class KBHelper:
|
||||
)
|
||||
await self.refresh_kb()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str, doc_id: str):
|
||||
async def delete_chunk(self, chunk_id: str, doc_id: str) -> None:
|
||||
"""删除单个文本块及其相关数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
await vec_db.delete(chunk_id)
|
||||
@@ -383,7 +383,7 @@ class KBHelper:
|
||||
await self.refresh_kb()
|
||||
await self.refresh_document(doc_id)
|
||||
|
||||
async def refresh_kb(self):
|
||||
async def refresh_kb(self) -> None:
|
||||
if self.kb:
|
||||
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
|
||||
if kb:
|
||||
|
||||
@@ -26,14 +26,14 @@ class KnowledgeBaseManager:
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
) -> None:
|
||||
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
self.provider_manager = provider_manager
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
self.kb_insts: dict[str, KBHelper] = {}
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""初始化知识库模块"""
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
@@ -58,13 +58,13 @@ class KnowledgeBaseManager:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
async def _init_kb_database(self) -> None:
|
||||
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
|
||||
|
||||
async def load_kbs(self):
|
||||
async def load_kbs(self) -> None:
|
||||
"""加载所有知识库实例"""
|
||||
kb_records = await self.kb_db.list_kbs()
|
||||
for record in kb_records:
|
||||
@@ -275,7 +275,7 @@ class KnowledgeBaseManager:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
"""终止所有知识库实例,关闭数据库连接"""
|
||||
for kb_id, kb_helper in self.kb_insts.items():
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,7 @@ import aiohttp
|
||||
class URLExtractor:
|
||||
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
|
||||
|
||||
def __init__(self, tavily_keys: list[str]):
|
||||
def __init__(self, tavily_keys: list[str]) -> None:
|
||||
"""
|
||||
初始化 URL 提取器
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class RetrievalManager:
|
||||
sparse_retriever: SparseRetriever,
|
||||
rank_fusion: RankFusion,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
):
|
||||
) -> None:
|
||||
"""初始化检索管理器
|
||||
|
||||
Args:
|
||||
|
||||
@@ -31,7 +31,7 @@ class RankFusion:
|
||||
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
|
||||
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None:
|
||||
"""初始化结果融合器
|
||||
|
||||
Args:
|
||||
|
||||
@@ -34,7 +34,7 @@ class SparseRetriever:
|
||||
- 使用 BM25 算法计算相关度
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
def __init__(self, kb_db: KBSQLiteDatabase) -> None:
|
||||
"""初始化稀疏检索器
|
||||
|
||||
Args:
|
||||
|
||||
+15
-15
@@ -91,7 +91,7 @@ class LogBroker:
|
||||
发布-订阅模式
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||
self.subscribers: list[Queue] = [] # 订阅者列表
|
||||
|
||||
@@ -106,7 +106,7 @@ class LogBroker:
|
||||
self.subscribers.append(q)
|
||||
return q
|
||||
|
||||
def unregister(self, q: Queue):
|
||||
def unregister(self, q: Queue) -> None:
|
||||
"""取消订阅
|
||||
|
||||
Args:
|
||||
@@ -115,7 +115,7 @@ class LogBroker:
|
||||
"""
|
||||
self.subscribers.remove(q)
|
||||
|
||||
def publish(self, log_entry: dict):
|
||||
def publish(self, log_entry: dict) -> None:
|
||||
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
|
||||
|
||||
Args:
|
||||
@@ -137,11 +137,11 @@ class LogQueueHandler(logging.Handler):
|
||||
继承自 logging.Handler
|
||||
"""
|
||||
|
||||
def __init__(self, log_broker: LogBroker):
|
||||
def __init__(self, log_broker: LogBroker) -> None:
|
||||
super().__init__()
|
||||
self.log_broker = log_broker
|
||||
|
||||
def emit(self, record):
|
||||
def emit(self, record) -> None:
|
||||
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
|
||||
这个方法会在每次日志记录时被调用
|
||||
|
||||
@@ -201,7 +201,7 @@ class LogManager:
|
||||
class PluginFilter(logging.Filter):
|
||||
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
|
||||
|
||||
def filter(self, record):
|
||||
def filter(self, record) -> bool:
|
||||
record.plugin_tag = (
|
||||
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
|
||||
)
|
||||
@@ -213,7 +213,7 @@ class LogManager:
|
||||
"""
|
||||
|
||||
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||
def filter(self, record):
|
||||
def filter(self, record) -> bool:
|
||||
dirname = os.path.dirname(record.pathname)
|
||||
record.filename = (
|
||||
os.path.basename(dirname)
|
||||
@@ -226,14 +226,14 @@ class LogManager:
|
||||
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
|
||||
|
||||
# 添加短日志级别名称
|
||||
def filter(self, record):
|
||||
def filter(self, record) -> bool:
|
||||
record.short_levelname = get_short_level_name(record.levelname)
|
||||
return True
|
||||
|
||||
class AstrBotVersionTagFilter(logging.Filter):
|
||||
"""在 WARNING 及以上级别日志后追加当前 AstrBot 版本号。"""
|
||||
|
||||
def filter(self, record):
|
||||
def filter(self, record) -> bool:
|
||||
if record.levelno >= logging.WARNING:
|
||||
record.astrbot_version_tag = f" [v{VERSION}]"
|
||||
else:
|
||||
@@ -251,7 +251,7 @@ class LogManager:
|
||||
return logger
|
||||
|
||||
@classmethod
|
||||
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
||||
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
|
||||
"""设置队列处理器, 用于将日志消息发送到 LogBroker
|
||||
|
||||
Args:
|
||||
@@ -301,7 +301,7 @@ class LogManager:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _remove_file_handlers(cls, logger: logging.Logger):
|
||||
def _remove_file_handlers(cls, logger: logging.Logger) -> None:
|
||||
for handler in cls._get_file_handlers(logger):
|
||||
logger.removeHandler(handler)
|
||||
try:
|
||||
@@ -310,7 +310,7 @@ class LogManager:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _remove_trace_file_handlers(cls, logger: logging.Logger):
|
||||
def _remove_trace_file_handlers(cls, logger: logging.Logger) -> None:
|
||||
for handler in cls._get_trace_file_handlers(logger):
|
||||
logger.removeHandler(handler)
|
||||
try:
|
||||
@@ -326,7 +326,7 @@ class LogManager:
|
||||
max_mb: int | None = None,
|
||||
backup_count: int = 3,
|
||||
trace: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
|
||||
max_bytes = 0
|
||||
if max_mb and max_mb > 0:
|
||||
@@ -365,7 +365,7 @@ class LogManager:
|
||||
logger: logging.Logger,
|
||||
config: dict | None,
|
||||
override_level: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""根据配置设置日志级别和文件日志。
|
||||
|
||||
Args:
|
||||
@@ -413,7 +413,7 @@ class LogManager:
|
||||
cls._add_file_handler(logger, file_path, max_mb=max_mb)
|
||||
|
||||
@classmethod
|
||||
def configure_trace_logger(cls, config: dict | None):
|
||||
def configure_trace_logger(cls, config: dict | None) -> None:
|
||||
"""为 trace 事件配置独立的文件日志,不向控制台输出。"""
|
||||
if not config:
|
||||
return
|
||||
|
||||
@@ -66,7 +66,7 @@ class ComponentType(str, Enum):
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def toDict(self):
|
||||
@@ -89,7 +89,7 @@ class Plain(BaseMessageComponent):
|
||||
text: str
|
||||
convert: bool | None = True
|
||||
|
||||
def __init__(self, text: str, convert: bool = True, **_):
|
||||
def __init__(self, text: str, convert: bool = True, **_) -> None:
|
||||
super().__init__(text=text, convert=convert, **_)
|
||||
|
||||
def toDict(self):
|
||||
@@ -103,7 +103,7 @@ class Face(BaseMessageComponent):
|
||||
type = ComponentType.Face
|
||||
id: int
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class Record(BaseMessageComponent):
|
||||
# 额外
|
||||
path: str | None
|
||||
|
||||
def __init__(self, file: str | None, **_):
|
||||
def __init__(self, file: str | None, **_) -> None:
|
||||
for k in _:
|
||||
if k == "url":
|
||||
pass
|
||||
@@ -221,7 +221,7 @@ class Video(BaseMessageComponent):
|
||||
# 额外
|
||||
path: str | None = ""
|
||||
|
||||
def __init__(self, file: str, **_):
|
||||
def __init__(self, file: str, **_) -> None:
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -255,7 +255,7 @@ class Video(BaseMessageComponent):
|
||||
return os.path.abspath(url)
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
|
||||
async def register_to_file_service(self):
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""将视频注册到文件服务。
|
||||
|
||||
Returns:
|
||||
@@ -303,7 +303,7 @@ class At(BaseMessageComponent):
|
||||
qq: int | str # 此处str为all时代表所有人
|
||||
name: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
def toDict(self):
|
||||
@@ -316,28 +316,28 @@ class At(BaseMessageComponent):
|
||||
class AtAll(At):
|
||||
qq: str = "all"
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class RPS(BaseMessageComponent): # TODO
|
||||
type = ComponentType.RPS
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Dice(BaseMessageComponent): # TODO
|
||||
type = ComponentType.Dice
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
class Shake(BaseMessageComponent): # TODO
|
||||
type = ComponentType.Shake
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -348,7 +348,7 @@ class Share(BaseMessageComponent):
|
||||
content: str | None = ""
|
||||
image: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -357,7 +357,7 @@ class Contact(BaseMessageComponent): # TODO
|
||||
_type: str # type 字段冲突
|
||||
id: int | None = 0
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -368,7 +368,7 @@ class Location(BaseMessageComponent): # TODO
|
||||
title: str | None = ""
|
||||
content: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -382,7 +382,7 @@ class Music(BaseMessageComponent):
|
||||
content: str | None = ""
|
||||
image: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
# for k in _.keys():
|
||||
# if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]:
|
||||
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
|
||||
@@ -402,7 +402,7 @@ class Image(BaseMessageComponent):
|
||||
path: str | None = ""
|
||||
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
|
||||
|
||||
def __init__(self, file: str | None, **_):
|
||||
def __init__(self, file: str | None, **_) -> None:
|
||||
super().__init__(file=file, **_)
|
||||
|
||||
@staticmethod
|
||||
@@ -525,7 +525,7 @@ class Reply(BaseMessageComponent):
|
||||
seq: int | None = 0
|
||||
"""deprecated"""
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -534,7 +534,7 @@ class Poke(BaseMessageComponent):
|
||||
id: int | None = 0
|
||||
qq: int | None = 0
|
||||
|
||||
def __init__(self, type: str, **_):
|
||||
def __init__(self, type: str, **_) -> None:
|
||||
type = f"Poke:{type}"
|
||||
super().__init__(type=type, **_)
|
||||
|
||||
@@ -543,7 +543,7 @@ class Forward(BaseMessageComponent):
|
||||
type = ComponentType.Forward
|
||||
id: str
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ class Node(BaseMessageComponent):
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
def __init__(self, content: list[BaseMessageComponent], **_):
|
||||
def __init__(self, content: list[BaseMessageComponent], **_) -> None:
|
||||
if isinstance(content, Node):
|
||||
# back
|
||||
content = [content]
|
||||
@@ -605,7 +605,7 @@ class Nodes(BaseMessageComponent):
|
||||
type = ComponentType.Nodes
|
||||
nodes: list[Node]
|
||||
|
||||
def __init__(self, nodes: list[Node], **_):
|
||||
def __init__(self, nodes: list[Node], **_) -> None:
|
||||
super().__init__(nodes=nodes, **_)
|
||||
|
||||
def toDict(self):
|
||||
@@ -631,7 +631,7 @@ class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: dict
|
||||
|
||||
def __init__(self, data: str | dict, **_):
|
||||
def __init__(self, data: str | dict, **_) -> None:
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
super().__init__(data=data, **_)
|
||||
@@ -650,7 +650,7 @@ class File(BaseMessageComponent):
|
||||
file_: str | None = "" # 本地路径
|
||||
url: str | None = "" # url
|
||||
|
||||
def __init__(self, name: str, file: str = "", url: str = ""):
|
||||
def __init__(self, name: str, file: str = "", url: str = "") -> None:
|
||||
"""文件消息段。"""
|
||||
super().__init__(name=name, file_=file, url=url)
|
||||
|
||||
@@ -686,7 +686,7 @@ class File(BaseMessageComponent):
|
||||
return ""
|
||||
|
||||
@file.setter
|
||||
def file(self, value: str):
|
||||
def file(self, value: str) -> None:
|
||||
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
|
||||
Args:
|
||||
@@ -721,7 +721,7 @@ class File(BaseMessageComponent):
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
async def _download_file(self) -> None:
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("Download failed: No URL provided in File component.")
|
||||
@@ -736,7 +736,7 @@ class File(BaseMessageComponent):
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
async def register_to_file_service(self):
|
||||
async def register_to_file_service(self) -> str:
|
||||
"""将文件注册到文件服务。
|
||||
|
||||
Returns:
|
||||
@@ -786,7 +786,7 @@ class WechatEmoji(BaseMessageComponent):
|
||||
md5_len: int | None = 0
|
||||
cdnurl: str | None = ""
|
||||
|
||||
def __init__(self, **_):
|
||||
def __init__(self, **_) -> None:
|
||||
super().__init__(**_)
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ DEFAULT_PERSONALITY = Personality(
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
|
||||
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None:
|
||||
self.db = db_helper
|
||||
self.acm = acm
|
||||
default_ps = acm.default_conf.get("provider_settings", {})
|
||||
@@ -29,7 +29,7 @@ class PersonaManager:
|
||||
self.selected_default_persona_v3: Personality | None = None
|
||||
self.persona_v3_config: list[dict] = []
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
self.personas = await self.get_all_personas()
|
||||
self.get_v3_persona_data()
|
||||
logger.info(f"已加载 {len(self.personas)} 个人格。")
|
||||
@@ -58,7 +58,7 @@ class PersonaManager:
|
||||
except Exception:
|
||||
return DEFAULT_PERSONALITY
|
||||
|
||||
async def delete_persona(self, persona_id: str):
|
||||
async def delete_persona(self, persona_id: str) -> None:
|
||||
"""删除指定 persona"""
|
||||
if not await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
|
||||
@@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
当前只会检查文本的。
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
config = ctx.astrbot_config["content_safety"]
|
||||
self.strategy_selector = StrategySelector(config)
|
||||
|
||||
|
||||
@@ -336,7 +336,7 @@ class InternalAgentSubStage(Stage):
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
) -> None:
|
||||
if (
|
||||
not req
|
||||
or not req.conversation
|
||||
|
||||
@@ -19,7 +19,7 @@ class RateLimitStage(Stage):
|
||||
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# 存储每个会话的请求时间队列
|
||||
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
|
||||
# 为每个会话设置一个锁,避免并发冲突
|
||||
|
||||
@@ -35,7 +35,7 @@ class RespondStage(Stage):
|
||||
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||
@@ -91,7 +91,7 @@ class RespondStage(Stage):
|
||||
# random
|
||||
return random.uniform(self.interval[0], self.interval[1])
|
||||
|
||||
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
||||
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool:
|
||||
"""检查消息链是否为空
|
||||
|
||||
Args:
|
||||
|
||||
@@ -20,7 +20,7 @@ from ..stage import Stage, register_stage, registered_stages
|
||||
|
||||
@register_stage
|
||||
class ResultDecorateStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"]
|
||||
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||
|
||||
@@ -15,21 +15,21 @@ from .stage import registered_stages
|
||||
class PipelineScheduler:
|
||||
"""管道调度器,负责调度各个阶段的执行"""
|
||||
|
||||
def __init__(self, context: PipelineContext):
|
||||
def __init__(self, context: PipelineContext) -> None:
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__name__),
|
||||
) # 按照顺序排序
|
||||
self.ctx = context # 上下文对象
|
||||
self.stages = [] # 存储阶段实例
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""初始化管道调度器时, 初始化所有阶段"""
|
||||
for stage_cls in registered_stages:
|
||||
stage_instance = stage_cls() # 创建实例
|
||||
await stage_instance.initialize(self.ctx)
|
||||
self.stages.append(stage_instance)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0) -> None:
|
||||
"""依次执行各个阶段
|
||||
|
||||
Args:
|
||||
@@ -72,7 +72,7 @@ class PipelineScheduler:
|
||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
||||
break
|
||||
|
||||
async def execute(self, event: AstrMessageEvent):
|
||||
async def execute(self, event: AstrMessageEvent) -> None:
|
||||
"""执行 pipeline
|
||||
|
||||
Args:
|
||||
|
||||
@@ -38,7 +38,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
):
|
||||
) -> None:
|
||||
self.message_str = message_str
|
||||
"""纯文本的消息"""
|
||||
self.message_obj = message_obj
|
||||
@@ -91,7 +91,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
return str(self.session)
|
||||
|
||||
@unified_msg_origin.setter
|
||||
def unified_msg_origin(self, value: str):
|
||||
def unified_msg_origin(self, value: str) -> None:
|
||||
"""设置统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
|
||||
self.new_session = MessageSession.from_str(value)
|
||||
self.session = self.new_session
|
||||
@@ -102,7 +102,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
return self.session.session_id
|
||||
|
||||
@session_id.setter
|
||||
def session_id(self, value: str):
|
||||
def session_id(self, value: str) -> None:
|
||||
"""设置用户的会话 ID。可以直接使用下面的 unified_msg_origin"""
|
||||
self.session.session_id = value
|
||||
|
||||
@@ -191,7 +191,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
|
||||
def set_extra(self, key, value):
|
||||
def set_extra(self, key, value) -> None:
|
||||
"""设置额外的信息。"""
|
||||
self._extras[key] = value
|
||||
|
||||
@@ -201,7 +201,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
return self._extras
|
||||
return self._extras.get(key, default)
|
||||
|
||||
def clear_extra(self):
|
||||
def clear_extra(self) -> None:
|
||||
"""清除额外的信息。"""
|
||||
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
|
||||
self._extras.clear()
|
||||
@@ -234,7 +234,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
generator: AsyncGenerator[MessageChain, None],
|
||||
use_fallback: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp。
|
||||
@@ -244,13 +244,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
async def _pre_send(self) -> None:
|
||||
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
async def _post_send(self):
|
||||
async def _post_send(self) -> None:
|
||||
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
|
||||
|
||||
def set_result(self, result: MessageEventResult | str):
|
||||
def set_result(self, result: MessageEventResult | str) -> None:
|
||||
"""设置消息事件的结果。
|
||||
|
||||
Note:
|
||||
@@ -279,14 +279,14 @@ class AstrMessageEvent(abc.ABC):
|
||||
result.chain = []
|
||||
self._result = result
|
||||
|
||||
def stop_event(self):
|
||||
def stop_event(self) -> None:
|
||||
"""终止事件传播。"""
|
||||
if self._result is None:
|
||||
self.set_result(MessageEventResult().stop_event())
|
||||
else:
|
||||
self._result.stop_event()
|
||||
|
||||
def continue_event(self):
|
||||
def continue_event(self) -> None:
|
||||
"""继续事件传播。"""
|
||||
if self._result is None:
|
||||
self.set_result(MessageEventResult().continue_event())
|
||||
@@ -299,7 +299,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
return False # 默认是继续传播
|
||||
return self._result.is_stopped()
|
||||
|
||||
def should_call_llm(self, call_llm: bool):
|
||||
def should_call_llm(self, call_llm: bool) -> None:
|
||||
"""是否在此消息事件中禁止默认的 LLM 请求。
|
||||
|
||||
只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。
|
||||
@@ -310,7 +310,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""获取消息事件的结果。"""
|
||||
return self._result
|
||||
|
||||
def clear_result(self):
|
||||
def clear_result(self) -> None:
|
||||
"""清除消息事件的结果。"""
|
||||
self._result = None
|
||||
|
||||
@@ -404,7 +404,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
"""平台适配器"""
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
"""发送消息到消息平台。
|
||||
|
||||
Args:
|
||||
@@ -423,7 +423,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def react(self, emoji: str):
|
||||
async def react(self, emoji: str) -> None:
|
||||
"""对消息添加表情回应。
|
||||
|
||||
默认实现为发送一条包含该表情的消息。
|
||||
|
||||
@@ -11,7 +11,7 @@ class MessageMember:
|
||||
user_id: str # 发送者id
|
||||
nickname: str | None = None
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# 使用 f-string 来构建返回的字符串表示形式
|
||||
return (
|
||||
f"User ID: {self.user_id},"
|
||||
@@ -34,7 +34,7 @@ class Group:
|
||||
members: list[MessageMember] | None = None
|
||||
"""所有群成员"""
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# 使用 f-string 来构建返回的字符串表示形式
|
||||
return (
|
||||
f"Group ID: {self.group_id}\n"
|
||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str | None):
|
||||
def group_id(self, value: str | None) -> None:
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
|
||||
@@ -13,7 +13,7 @@ from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
|
||||
class PlatformManager:
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
|
||||
self.platform_insts: list[Platform] = []
|
||||
"""加载的 Platform 的实例"""
|
||||
|
||||
@@ -38,7 +38,7 @@ class PlatformManager:
|
||||
sanitized = platform_id.replace(":", "_").replace("!", "_")
|
||||
return sanitized, sanitized != platform_id
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
"""初始化所有平台适配器"""
|
||||
for platform in self.platforms_config:
|
||||
try:
|
||||
@@ -58,7 +58,7 @@ class PlatformManager:
|
||||
),
|
||||
)
|
||||
|
||||
async def load_platform(self, platform_config: dict):
|
||||
async def load_platform(self, platform_config: dict) -> None:
|
||||
"""实例化一个平台"""
|
||||
# 动态导入
|
||||
try:
|
||||
@@ -176,7 +176,9 @@ class PlatformManager:
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
|
||||
async def _task_wrapper(
|
||||
self, task: asyncio.Task, platform: Platform | None = None
|
||||
) -> None:
|
||||
# 设置平台状态为运行中
|
||||
if platform:
|
||||
platform.status = PlatformStatus.RUNNING
|
||||
@@ -198,7 +200,7 @@ class PlatformManager:
|
||||
if platform:
|
||||
platform.record_error(error_msg, tb_str)
|
||||
|
||||
async def reload(self, platform_config: dict):
|
||||
async def reload(self, platform_config: dict) -> None:
|
||||
await self.terminate_platform(platform_config["id"])
|
||||
if platform_config["enable"]:
|
||||
await self.load_platform(platform_config)
|
||||
@@ -209,7 +211,7 @@ class PlatformManager:
|
||||
if key not in config_ids:
|
||||
await self.terminate_platform(key)
|
||||
|
||||
async def terminate_platform(self, platform_id: str):
|
||||
async def terminate_platform(self, platform_id: str) -> None:
|
||||
if platform_id in self._inst_map:
|
||||
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
|
||||
|
||||
@@ -231,7 +233,7 @@ class PlatformManager:
|
||||
if getattr(inst, "terminate", None):
|
||||
await inst.terminate()
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
for inst in self.platform_insts:
|
||||
if getattr(inst, "terminate", None):
|
||||
await inst.terminate()
|
||||
|
||||
@@ -15,7 +15,7 @@ class MessageSession:
|
||||
session_id: str
|
||||
platform_id: str = field(init=False)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@@ -34,7 +34,7 @@ class PlatformError:
|
||||
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, config: dict, event_queue: Queue):
|
||||
def __init__(self, config: dict, event_queue: Queue) -> None:
|
||||
super().__init__()
|
||||
# 平台配置
|
||||
self.config = config
|
||||
@@ -53,7 +53,7 @@ class Platform(abc.ABC):
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: PlatformStatus):
|
||||
def status(self, value: PlatformStatus) -> None:
|
||||
"""设置平台运行状态"""
|
||||
self._status = value
|
||||
if value == PlatformStatus.RUNNING and self._started_at is None:
|
||||
@@ -69,12 +69,12 @@ class Platform(abc.ABC):
|
||||
"""获取最近的错误"""
|
||||
return self._errors[-1] if self._errors else None
|
||||
|
||||
def record_error(self, message: str, traceback_str: str | None = None):
|
||||
def record_error(self, message: str, traceback_str: str | None = None) -> None:
|
||||
"""记录一个错误"""
|
||||
self._errors.append(PlatformError(message=message, traceback=traceback_str))
|
||||
self._status = PlatformStatus.ERROR
|
||||
|
||||
def clear_errors(self):
|
||||
def clear_errors(self) -> None:
|
||||
"""清除错误记录"""
|
||||
self._errors.clear()
|
||||
if self._status == PlatformStatus.ERROR:
|
||||
@@ -121,7 +121,7 @@ class Platform(abc.ABC):
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
"""终止一个平台的运行实例。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -140,11 +140,11 @@ class Platform(abc.ABC):
|
||||
"""
|
||||
await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name)
|
||||
|
||||
def commit_event(self, event: AstrMessageEvent):
|
||||
def commit_event(self, event: AstrMessageEvent) -> None:
|
||||
"""提交一个事件到事件队列。"""
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
def get_client(self):
|
||||
def get_client(self) -> object:
|
||||
"""获取平台的客户端对象。"""
|
||||
|
||||
async def webhook_callback(self, request: Any) -> Any:
|
||||
|
||||
@@ -26,7 +26,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
platform_meta,
|
||||
session_id,
|
||||
bot: CQHttp,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@@ -72,7 +72,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
is_group: bool,
|
||||
session_id: str | None,
|
||||
messages: list[dict],
|
||||
):
|
||||
) -> None:
|
||||
# session_id 必须是纯数字字符串
|
||||
session_id_int = (
|
||||
int(session_id) if session_id and session_id.isdigit() else None
|
||||
@@ -97,7 +97,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
event: Event | None = None,
|
||||
is_group: bool = False,
|
||||
session_id: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""发送消息至 QQ 协议端(aiocqhttp)。
|
||||
|
||||
Args:
|
||||
@@ -143,7 +143,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
await cls._dispatch_send(bot, event, is_group, session_id, messages)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
"""发送消息"""
|
||||
event = getattr(self.message_obj, "raw_message", None)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class AiocqhttpAdapter(Platform):
|
||||
)
|
||||
|
||||
@self.bot.on_request()
|
||||
async def request(event: Event):
|
||||
async def request(event: Event) -> None:
|
||||
try:
|
||||
abm = await self.convert_message(event)
|
||||
if not abm:
|
||||
@@ -72,7 +72,7 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
@self.bot.on_notice()
|
||||
async def notice(event: Event):
|
||||
async def notice(event: Event) -> None:
|
||||
try:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
@@ -82,7 +82,7 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
@self.bot.on_message("group")
|
||||
async def group(event: Event):
|
||||
async def group(event: Event) -> None:
|
||||
try:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
@@ -92,7 +92,7 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
@self.bot.on_message("private")
|
||||
async def private(event: Event):
|
||||
async def private(event: Event) -> None:
|
||||
try:
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
@@ -102,14 +102,14 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
@self.bot.on_websocket_connection
|
||||
def on_websocket_connection(_):
|
||||
def on_websocket_connection(_) -> None:
|
||||
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
|
||||
|
||||
async def send_by_session(
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
is_group = session.message_type == MessageType.GROUP_MESSAGE
|
||||
if is_group:
|
||||
session_id = session.session_id.split("_")[-1]
|
||||
@@ -435,17 +435,17 @@ class AiocqhttpAdapter(Platform):
|
||||
self.shutdown_event = asyncio.Event()
|
||||
return coro
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
self.shutdown_event.set()
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
async def shutdown_trigger_placeholder(self) -> None:
|
||||
await self.shutdown_event.wait()
|
||||
logger.info("aiocqhttp 适配器已被关闭")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
async def handle_msg(self, message: AstrBotMessage) -> None:
|
||||
message_event = AiocqhttpMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast
|
||||
from typing import NoReturn, cast
|
||||
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -90,7 +90,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
@@ -104,7 +104,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
async def create_message_card(
|
||||
self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage
|
||||
):
|
||||
) -> bool | None:
|
||||
if not self.card_template_id:
|
||||
return False
|
||||
|
||||
@@ -122,7 +122,9 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(f"创建钉钉卡片失败: {e}")
|
||||
return False
|
||||
|
||||
async def send_card_message(self, message_id: str, content: str, is_final: bool):
|
||||
async def send_card_message(
|
||||
self, message_id: str, content: str, is_final: bool
|
||||
) -> None:
|
||||
if message_id not in self.card_instance_id_dict:
|
||||
return
|
||||
|
||||
@@ -276,7 +278,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
return ""
|
||||
return (await resp.json())["data"]["accessToken"]
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
async def handle_msg(self, abm: AstrBotMessage) -> None:
|
||||
event = DingtalkMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
@@ -288,10 +290,10 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
# await self.client_.start()
|
||||
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
|
||||
def start_client(loop: asyncio.AbstractEventLoop):
|
||||
def start_client(loop: asyncio.AbstractEventLoop) -> None:
|
||||
try:
|
||||
self._shutdown_event = threading.Event()
|
||||
task = loop.create_task(self.client_.start())
|
||||
@@ -307,8 +309,8 @@ class DingtalkPlatformAdapter(Platform):
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, start_client, loop)
|
||||
|
||||
async def terminate(self):
|
||||
def monkey_patch_close():
|
||||
async def terminate(self) -> None:
|
||||
def monkey_patch_close() -> NoReturn:
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
|
||||
if self.client_.websocket is not None:
|
||||
|
||||
@@ -17,7 +17,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
session_id,
|
||||
client: dingtalk_stream.ChatbotHandler,
|
||||
adapter: "Any" = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
self.adapter = adapter
|
||||
@@ -26,7 +26,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
self,
|
||||
client: dingtalk_stream.ChatbotHandler,
|
||||
message: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
|
||||
ats = []
|
||||
# fixes: #4218
|
||||
@@ -80,7 +80,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
|
||||
continue
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ else:
|
||||
class DiscordBotClient(discord.Bot):
|
||||
"""Discord客户端封装"""
|
||||
|
||||
def __init__(self, token: str, proxy: str | None = None):
|
||||
def __init__(self, token: str, proxy: str | None = None) -> None:
|
||||
self.token = token
|
||||
self.proxy = proxy
|
||||
|
||||
@@ -32,7 +32,7 @@ class DiscordBotClient(discord.Bot):
|
||||
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
async def on_ready(self):
|
||||
async def on_ready(self) -> None:
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
if self.user is None:
|
||||
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
|
||||
@@ -93,7 +93,7 @@ class DiscordBotClient(discord.Bot):
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
return
|
||||
@@ -130,12 +130,12 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
return str(interaction_data)
|
||||
|
||||
async def start_polling(self):
|
||||
async def start_polling(self) -> None:
|
||||
"""开始轮询消息,这是个阻塞方法"""
|
||||
await self.start(self.token)
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""关闭客户端"""
|
||||
if not self.is_closed():
|
||||
await super().close()
|
||||
|
||||
@@ -19,7 +19,7 @@ class DiscordEmbed(BaseMessageComponent):
|
||||
image: str | None = None,
|
||||
footer: str | None = None,
|
||||
fields: list[dict] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.color = color
|
||||
@@ -71,7 +71,7 @@ class DiscordButton(BaseMessageComponent):
|
||||
emoji: str | None = None,
|
||||
url: str | None = None,
|
||||
disabled: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
self.label = label
|
||||
self.custom_id = custom_id
|
||||
self.style = style
|
||||
@@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent):
|
||||
|
||||
type: str = "discord_reference"
|
||||
|
||||
def __init__(self, message_id: str, channel_id: str):
|
||||
def __init__(self, message_id: str, channel_id: str) -> None:
|
||||
self.message_id = message_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
@@ -99,7 +99,7 @@ class DiscordView(BaseMessageComponent):
|
||||
self,
|
||||
components: list[BaseMessageComponent] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
"""通过会话发送消息"""
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
@@ -122,11 +122,11 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
"""主要运行逻辑"""
|
||||
|
||||
# 初始化回调函数
|
||||
async def on_received(message_data):
|
||||
async def on_received(message_data) -> None:
|
||||
logger.debug(f"[Discord] 收到消息: {message_data}")
|
||||
if self.client_self_id is None:
|
||||
self.client_self_id = message_data.get("bot_id")
|
||||
@@ -143,7 +143,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
self.client = DiscordBotClient(token, proxy)
|
||||
self.client.on_message_received = on_received
|
||||
|
||||
async def callback():
|
||||
async def callback() -> None:
|
||||
if self.enable_command_register:
|
||||
await self._collect_and_register_commands()
|
||||
if self.activity_name:
|
||||
@@ -251,7 +251,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
# 由于 on_interaction 已被禁用,我们只处理普通消息
|
||||
return self._convert_message_to_abm(data)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
|
||||
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None:
|
||||
"""处理消息"""
|
||||
message_event = DiscordPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
@@ -323,7 +323,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
self.commit_event(message_event)
|
||||
|
||||
@override
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
"""终止适配器"""
|
||||
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
|
||||
self.shutdown_event.set()
|
||||
@@ -358,11 +358,11 @@ class DiscordPlatformAdapter(Platform):
|
||||
logger.warning(f"[Discord] 客户端关闭异常: {e}")
|
||||
logger.info("[Discord] 适配器已终止。")
|
||||
|
||||
def register_handler(self, handler_info):
|
||||
def register_handler(self, handler_info) -> None:
|
||||
"""注册处理器信息"""
|
||||
self.registered_handlers.append(handler_info)
|
||||
|
||||
async def _collect_and_register_commands(self):
|
||||
async def _collect_and_register_commands(self) -> None:
|
||||
"""收集所有指令并注册到Discord"""
|
||||
logger.info("[Discord] 开始收集并注册斜杠指令...")
|
||||
registered_commands = []
|
||||
@@ -420,7 +420,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
|
||||
async def dynamic_callback(
|
||||
ctx: discord.ApplicationContext, params: str | None = None
|
||||
):
|
||||
) -> None:
|
||||
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
|
||||
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
|
||||
logger.debug(f"[Discord] 回调函数参数: {ctx}")
|
||||
|
||||
@@ -28,7 +28,7 @@ from .components import DiscordEmbed, DiscordView
|
||||
class DiscordViewComponent(BaseMessageComponent):
|
||||
type: str = "discord_view"
|
||||
|
||||
def __init__(self, view: discord.ui.View):
|
||||
def __init__(self, view: discord.ui.View) -> None:
|
||||
self.view = view
|
||||
|
||||
|
||||
@@ -41,12 +41,12 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
session_id: str,
|
||||
client: DiscordBotClient,
|
||||
interaction_followup_webhook: discord.Webhook | None = None,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
self.interaction_followup_webhook = interaction_followup_webhook
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
"""发送消息到Discord平台"""
|
||||
# 解析消息链为 Discord 所需的对象
|
||||
try:
|
||||
@@ -267,7 +267,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
content = content[:2000]
|
||||
return content, files, view, embeds, reference_message_id
|
||||
|
||||
async def react(self, emoji: str):
|
||||
async def react(self, emoji: str) -> None:
|
||||
"""对原消息添加反应"""
|
||||
try:
|
||||
if hasattr(self.message_obj, "raw_message") and hasattr(
|
||||
|
||||
@@ -53,10 +53,10 @@ class LarkPlatformAdapter(Platform):
|
||||
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
|
||||
|
||||
# 初始化 WebSocket 长连接相关配置
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
await self.convert_msg(event)
|
||||
|
||||
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
asyncio.create_task(on_msg_event_recv(event))
|
||||
|
||||
self.event_handler = (
|
||||
@@ -91,7 +91,7 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
self.event_id_timestamps: dict[str, float] = {}
|
||||
|
||||
def _clean_expired_events(self):
|
||||
def _clean_expired_events(self) -> None:
|
||||
"""清理超过 30 分钟的事件记录"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
@@ -121,7 +121,7 @@ class LarkPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
id_type = "chat_id"
|
||||
receive_id = session.session_id
|
||||
@@ -149,7 +149,7 @@ class LarkPlatformAdapter(Platform):
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
if event.event is None:
|
||||
logger.debug("[Lark] 收到空事件(event.event is None)")
|
||||
return
|
||||
@@ -299,7 +299,7 @@ class LarkPlatformAdapter(Platform):
|
||||
logger.debug(abm)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
async def handle_msg(self, abm: AstrBotMessage) -> None:
|
||||
event = LarkMessageEvent(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
@@ -310,7 +310,7 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
|
||||
async def handle_webhook_event(self, event_data: dict):
|
||||
async def handle_webhook_event(self, event_data: dict) -> None:
|
||||
"""处理 Webhook 事件
|
||||
|
||||
Args:
|
||||
@@ -332,7 +332,7 @@ class LarkPlatformAdapter(Platform):
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
if self.connection_mode == "webhook":
|
||||
# Webhook 模式
|
||||
if self.webhook_server is None:
|
||||
@@ -355,7 +355,7 @@ class LarkPlatformAdapter(Platform):
|
||||
|
||||
return await self.webhook_server.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
if self.connection_mode == "socket":
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已关闭")
|
||||
|
||||
@@ -38,7 +38,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
platform_meta,
|
||||
session_id,
|
||||
bot: lark.Client,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@@ -274,7 +274,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""通用的消息链发送方法
|
||||
|
||||
Args:
|
||||
@@ -342,7 +342,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message,
|
||||
@@ -358,7 +358,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""发送文件消息
|
||||
|
||||
Args:
|
||||
@@ -392,7 +392,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""发送音频消息
|
||||
|
||||
Args:
|
||||
@@ -465,7 +465,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""发送视频消息
|
||||
|
||||
Args:
|
||||
@@ -531,7 +531,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
async def react(self, emoji: str) -> None:
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
return
|
||||
|
||||
@@ -21,7 +21,7 @@ from astrbot.api import logger
|
||||
class AESCipher:
|
||||
"""AES 加密/解密工具类"""
|
||||
|
||||
def __init__(self, key: str):
|
||||
def __init__(self, key: str) -> None:
|
||||
self.bs = AES.block_size
|
||||
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
|
||||
|
||||
@@ -52,7 +52,7 @@ class LarkWebhookServer:
|
||||
仅支持统一 Webhook 模式
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict, event_queue: asyncio.Queue):
|
||||
def __init__(self, config: dict, event_queue: asyncio.Queue) -> None:
|
||||
"""初始化 Webhook 服务器
|
||||
|
||||
Args:
|
||||
@@ -197,7 +197,7 @@ class LarkWebhookServer:
|
||||
|
||||
return {}
|
||||
|
||||
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
|
||||
def set_callback(self, callback: Callable[[dict], Awaitable[None]]) -> None:
|
||||
"""设置事件回调函数
|
||||
|
||||
Args:
|
||||
|
||||
@@ -121,7 +121,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
if not self.instance_url or not self.access_token:
|
||||
logger.error("[Misskey] 配置不完整,无法启动")
|
||||
return
|
||||
@@ -150,7 +150,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
|
||||
await self._start_websocket_connection()
|
||||
|
||||
def _register_event_handlers(self, streaming):
|
||||
def _register_event_handlers(self, streaming) -> None:
|
||||
"""注册事件处理器"""
|
||||
streaming.add_message_handler("notification", self._handle_notification)
|
||||
streaming.add_message_handler("main:notification", self._handle_notification)
|
||||
@@ -194,7 +194,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
message: AstrBotMessage,
|
||||
poll: dict[str, Any],
|
||||
message_parts: list[str],
|
||||
):
|
||||
) -> None:
|
||||
"""处理投票数据,将其添加到消息中"""
|
||||
try:
|
||||
if not isinstance(message.raw_message, dict):
|
||||
@@ -233,7 +233,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
|
||||
return fields
|
||||
|
||||
async def _start_websocket_connection(self):
|
||||
async def _start_websocket_connection(self) -> None:
|
||||
backoff_delay = 1.0
|
||||
max_backoff = 300.0
|
||||
backoff_multiplier = 1.5
|
||||
@@ -281,7 +281,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
await asyncio.sleep(sleep_time)
|
||||
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
||||
|
||||
async def _handle_notification(self, data: dict[str, Any]):
|
||||
async def _handle_notification(self, data: dict[str, Any]) -> None:
|
||||
try:
|
||||
notification_type = data.get("type")
|
||||
logger.debug(
|
||||
@@ -305,7 +305,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理通知失败: {e}")
|
||||
|
||||
async def _handle_chat_message(self, data: dict[str, Any]):
|
||||
async def _handle_chat_message(self, data: dict[str, Any]) -> None:
|
||||
try:
|
||||
sender_id = str(
|
||||
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""),
|
||||
@@ -340,7 +340,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
||||
|
||||
async def _debug_handler(self, data: dict[str, Any]):
|
||||
async def _debug_handler(self, data: dict[str, Any]) -> None:
|
||||
event_type = data.get("type", "unknown")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}",
|
||||
@@ -754,7 +754,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
)
|
||||
return message
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
self._running = False
|
||||
if self.api:
|
||||
await self.api.close()
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import random
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import Any, NoReturn
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -43,7 +43,7 @@ class WebSocketError(APIError):
|
||||
|
||||
|
||||
class StreamingClient:
|
||||
def __init__(self, instance_url: str, access_token: str):
|
||||
def __init__(self, instance_url: str, access_token: str) -> None:
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self.websocket: Any | None = None
|
||||
@@ -90,7 +90,7 @@ class StreamingClient:
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
async def disconnect(self) -> None:
|
||||
self._running = False
|
||||
if self.websocket:
|
||||
await self.websocket.close()
|
||||
@@ -116,7 +116,7 @@ class StreamingClient:
|
||||
self.channels[channel_id] = channel_type
|
||||
return channel_id
|
||||
|
||||
async def unsubscribe_channel(self, channel_id: str):
|
||||
async def unsubscribe_channel(self, channel_id: str) -> None:
|
||||
if (
|
||||
not self.is_connected
|
||||
or not self.websocket
|
||||
@@ -136,10 +136,10 @@ class StreamingClient:
|
||||
self,
|
||||
event_type: str,
|
||||
handler: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
) -> None:
|
||||
self.message_handlers[event_type] = handler
|
||||
|
||||
async def listen(self):
|
||||
async def listen(self) -> None:
|
||||
if not self.is_connected or not self.websocket:
|
||||
raise WebSocketError("WebSocket 未连接")
|
||||
|
||||
@@ -187,7 +187,7 @@ class StreamingClient:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle_message(self, data: dict[str, Any]):
|
||||
async def _handle_message(self, data: dict[str, Any]) -> None:
|
||||
message_type = data.get("type")
|
||||
body = data.get("body", {})
|
||||
|
||||
@@ -334,7 +334,7 @@ class MisskeyAPI:
|
||||
download_timeout: int = 15,
|
||||
chunk_size: int = 64 * 1024,
|
||||
max_download_bytes: int | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
@@ -375,7 +375,7 @@ class MisskeyAPI:
|
||||
self._session = aiohttp.ClientSession(headers=headers)
|
||||
return self._session
|
||||
|
||||
def _handle_response_status(self, status: int, endpoint: str):
|
||||
def _handle_response_status(self, status: int, endpoint: str) -> NoReturn:
|
||||
"""处理 HTTP 响应状态码"""
|
||||
if status == 400:
|
||||
logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})")
|
||||
@@ -449,7 +449,6 @@ class MisskeyAPI:
|
||||
)
|
||||
|
||||
self._handle_response_status(response.status, endpoint)
|
||||
raise APIConnectionError(f"Request failed for {endpoint}")
|
||||
|
||||
@retry_async(
|
||||
max_retries=API_MAX_RETRIES,
|
||||
|
||||
@@ -26,7 +26,7 @@ class MisskeyPlatformEvent(AstrMessageEvent):
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@@ -40,7 +40,7 @@ class MisskeyPlatformEvent(AstrMessageEvent):
|
||||
|
||||
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
"""发送消息,使用适配器的完整上传和发送逻辑"""
|
||||
try:
|
||||
logger.debug(
|
||||
|
||||
@@ -403,7 +403,7 @@ def cache_user_info(
|
||||
raw_data: dict[str, Any],
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""缓存用户信息"""
|
||||
if is_chat:
|
||||
user_cache_data = {
|
||||
@@ -429,7 +429,7 @@ def cache_room_info(
|
||||
user_cache: dict[str, Any],
|
||||
raw_data: dict[str, Any],
|
||||
client_self_id: str,
|
||||
):
|
||||
) -> None:
|
||||
"""缓存房间信息"""
|
||||
room_data = raw_data.get("toRoom")
|
||||
room_id = raw_data.get("toRoomId")
|
||||
|
||||
@@ -32,12 +32,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
bot: Client,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
self.send_buffer = None
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain) -> None:
|
||||
self.send_buffer = message
|
||||
await self._post_send()
|
||||
|
||||
|
||||
@@ -35,11 +35,13 @@ for handler in logging.root.handlers[:]:
|
||||
|
||||
# QQ 机器人官方框架
|
||||
class botClient(Client):
|
||||
def set_platform(self, platform: QQOfficialPlatformAdapter):
|
||||
def set_platform(self, platform: QQOfficialPlatformAdapter) -> None:
|
||||
self.platform = platform
|
||||
|
||||
# 收到群消息
|
||||
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
||||
async def on_group_at_message_create(
|
||||
self, message: botpy.message.GroupMessage
|
||||
) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
@@ -49,7 +51,7 @@ class botClient(Client):
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
async def on_at_message_create(self, message: botpy.message.Message):
|
||||
async def on_at_message_create(self, message: botpy.message.Message) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
@@ -59,7 +61,9 @@ class botClient(Client):
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
||||
async def on_direct_message_create(
|
||||
self, message: botpy.message.DirectMessage
|
||||
) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.FRIEND_MESSAGE,
|
||||
@@ -68,7 +72,7 @@ class botClient(Client):
|
||||
self._commit(abm)
|
||||
|
||||
# 收到 C2C 消息
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.FRIEND_MESSAGE,
|
||||
@@ -76,7 +80,7 @@ class botClient(Client):
|
||||
abm.session_id = abm.sender.user_id
|
||||
self._commit(abm)
|
||||
|
||||
def _commit(self, abm: AstrBotMessage):
|
||||
def _commit(self, abm: AstrBotMessage) -> None:
|
||||
self.platform.commit_event(
|
||||
QQOfficialMessageEvent(
|
||||
abm.message_str,
|
||||
@@ -128,7 +132,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
@@ -222,6 +226,6 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
def get_client(self) -> botClient:
|
||||
return self.client
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
await self.client.close()
|
||||
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
|
||||
|
||||
@@ -26,11 +26,13 @@ for handler in logging.root.handlers[:]:
|
||||
|
||||
# QQ 机器人官方框架
|
||||
class botClient(Client):
|
||||
def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"):
|
||||
def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None:
|
||||
self.platform = platform
|
||||
|
||||
# 收到群消息
|
||||
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
|
||||
async def on_group_at_message_create(
|
||||
self, message: botpy.message.GroupMessage
|
||||
) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
@@ -40,7 +42,7 @@ class botClient(Client):
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
async def on_at_message_create(self, message: botpy.message.Message):
|
||||
async def on_at_message_create(self, message: botpy.message.Message) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
@@ -50,7 +52,9 @@ class botClient(Client):
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
|
||||
async def on_direct_message_create(
|
||||
self, message: botpy.message.DirectMessage
|
||||
) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.FRIEND_MESSAGE,
|
||||
@@ -59,7 +63,7 @@ class botClient(Client):
|
||||
self._commit(abm)
|
||||
|
||||
# 收到 C2C 消息
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None:
|
||||
abm = QQOfficialPlatformAdapter._parse_from_qqofficial(
|
||||
message,
|
||||
MessageType.FRIEND_MESSAGE,
|
||||
@@ -67,7 +71,7 @@ class botClient(Client):
|
||||
abm.session_id = abm.sender.user_id
|
||||
self._commit(abm)
|
||||
|
||||
def _commit(self, abm: AstrBotMessage):
|
||||
def _commit(self, abm: AstrBotMessage) -> None:
|
||||
self.platform.commit_event(
|
||||
QQOfficialWebhookMessageEvent(
|
||||
abm.message_str,
|
||||
@@ -110,7 +114,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
@@ -121,7 +125,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
support_proactive_message=False,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async def run(self) -> None:
|
||||
self.webhook_helper = QQOfficialWebhook(
|
||||
self.config,
|
||||
self._event_queue,
|
||||
@@ -149,7 +153,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
# 复用 webhook_helper 的回调处理逻辑
|
||||
return await self.webhook_helper.handle_callback(request)
|
||||
|
||||
async def terminate(self):
|
||||
async def terminate(self) -> None:
|
||||
if self.webhook_helper:
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
await self.client.close()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user