chore: fix bunches of type checking errors (#3213)

* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
Dt8333
2025-12-09 14:13:47 +08:00
committed by GitHub
parent aa6d07afcc
commit f624971613
96 changed files with 1218 additions and 524 deletions
@@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
+7 -2
View File
@@ -1,4 +1,4 @@
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic
import jsonschema
@@ -7,6 +7,8 @@ from deprecated import deprecated
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
@@ -38,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None
handler: (
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
| None
) = None
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
+5 -1
View File
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Callable[
...,
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str,
*args,
**kwargs,
+4
View File
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
"""
config_path: str
default_config: dict
schema: dict | None
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
+1 -1
View File
@@ -197,7 +197,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
+2 -3
View File
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass
from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core.db.po import (
Attachment,
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.AsyncSessionLocal = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
continue
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
continue
if ":" in conv.user_id:
continue
platform_id = "webchat"
+6 -4
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple = None):
def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn
try:
c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close()
return Stats(platform, [], [])
return Stats(platform)
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> tuple:
def get_conversations(self, user_id: str) -> list[Conversation]:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
+16 -15
View File
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats" # type: ignore
__tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore
__tablename__: str = "conversations"
inner_conversation_id: int = Field(
inner_conversation_id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas" # type: ignore
__tablename__: str = "personas"
id: int | None = Field(
primary_key=True,
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore
__tablename__: str = "preferences"
id: int | None = Field(
default=None,
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages.
"""
__tablename__ = "platform_message_history" # type: ignore
__tablename__: str = "platform_message_history"
id: int | None = Field(
primary_key=True,
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
Each session can have multiple conversations (对话) associated with it.
"""
__tablename__ = "platform_sessions" # type: ignore
__tablename__: str = "platform_sessions"
inner_id: int | None = Field(
primary_key=True,
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments" # type: ignore
__tablename__: str = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True,
@@ -261,17 +262,17 @@ class Personality(TypedDict):
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
prompt: str
name: str
begin_dialogs: list[str]
mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
_begin_dialogs_processed: list[dict]
_mood_imitation_dialogs_processed: str
# ====
+4 -3
View File
@@ -3,6 +3,7 @@ import threading
import typing as T
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -489,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids)
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
@@ -505,7 +506,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = await session.execute(query)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
@@ -521,7 +522,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount
async def insert_persona(
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径
"""
if self.index is None:
return
faiss.write_index(self.index, self.path)
+6 -1
View File
@@ -27,7 +27,7 @@ class EventBus:
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
astrbot_config_mgr: AstrBotConfigManager,
):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
if not scheduler:
logger.error(
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
)
continue
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
+9 -3
View File
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel):
type: ComponentType
def __init__(self, **kwargs):
super().__init__(**kwargs)
def toDict(self):
data = {}
for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略
name: str | None = "" # qq昵称
uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = []
content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d)
return ret
async def to_dict(self):
async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
@@ -714,12 +717,15 @@ class File(BaseMessageComponent):
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
if self.file_:
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
if self.name:
+2 -2
View File
@@ -98,8 +98,8 @@ class PersonaManager:
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self,
event: AstrMessageEvent,
check_text: str | None = None,
) -> None | AsyncGenerator[None, None]:
) -> AsyncGenerator[None, None]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
+2 -1
View File
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler(
event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -91,6 +91,7 @@ async def call_event_hook(
)
for handler in handlers:
try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
)
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> AsyncGenerator[None, None]:
) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
)
+1 -1
View File
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
event.get_result() and not event.is_stopped()
) or not event.get_result():
async for _ in self.agent_sub_stage.process(event):
yield
+4 -2
View File
@@ -117,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
if (result := event.get_result()) is None:
return False
if self.only_llm_result and result.is_llm_result():
return False
if event.get_platform_name() in [
@@ -185,7 +187,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
result.chain[idx] = component
# 检查消息链是否为空
try:
+10 -6
View File
@@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -130,11 +131,13 @@ class ResultDecorateStage(Stage):
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
):
yield
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -151,7 +154,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
if (result := event.get_result()) is None or not result.chain:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
)
+5 -1
View File
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from . import STAGES_ORDER
from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None)
logger.debug("pipeline 执行完毕。")
+5 -3
View File
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
return self.message_obj.sender.nickname
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value):
"""设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
"""
self.call_llm = call_llm
def get_result(self) -> MessageEventResult:
def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。"""
return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self,
prompt: str,
func_tool_manager=None,
session_id: str = None,
session_id: str = "",
image_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str = "",
+2 -2
View File
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group: Group # 群组
group: Group | None # 群组
sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return ""
@group_id.setter
def group_id(self, value: str):
def group_id(self, value: str | None):
"""设置 group_id"""
if value:
if self.group:
+3 -3
View File
@@ -1,7 +1,7 @@
import abc
import uuid
from asyncio import Queue
from collections.abc import Awaitable
from collections.abc import Coroutine
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -100,7 +100,7 @@ class Platform(abc.ABC):
}
@abc.abstractmethod
def run(self) -> Awaitable[Any]:
def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError
@@ -116,7 +116,7 @@ class Platform(abc.ABC):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法
+1 -1
View File
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
id: str | None = None
id: str
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None
+1
View File
@@ -40,6 +40,7 @@ def register_platform_adapter(
pm = PlatformMetadata(
name=adapter_name,
description=desc,
id=adapter_name,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
session_id: str | None,
messages: list[dict],
):
# session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None
session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
)
if is_group and isinstance(session_id, int):
await bot.send_group_msg(group_id=session_id, message=messages)
elif not is_group and isinstance(session_id, int):
await bot.send_private_msg(user_id=session_id, message=messages)
if is_group and isinstance(session_id_int, int):
await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages)
else:
@@ -4,7 +4,7 @@ import logging
import time
import uuid
from collections.abc import Awaitable
from typing import Any
from typing import Any, cast
from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata(
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象
@param get_reply: 是否获取回复消息这个参数是为了防止多个回复嵌套
"""
assert event.sender is not None
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return None
raise ValueError(err)
# 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -2,6 +2,7 @@ import asyncio
import os
import threading
import uuid
from typing import cast
import aiohttp
import dingtalk_stream
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
outer_self = self
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage):
async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im)
await self.handle_msg(abm)
abm = await outer_self.convert_msg(im)
await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
self.client,
)
self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id:
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata(
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
abm.timestamp = int(message.create_at / 1000)
abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
nickname=message.sender_nick,
)
abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.message_id = message.message_id
abm.message_id = cast(str, message.message_id)
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
message_type: str = message.message_type
message_type: str = cast(str, message.message_type)
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
contents: list[dict] = rtc.rich_text_list
rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents:
plains = ""
if "text" in content:
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
message.robot_code,
cast(str, message.robot_code),
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
)
return None
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
)
return None
return ""
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
@@ -250,9 +256,11 @@ class DingtalkPlatformAdapter(Platform):
def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
if self.client_.websocket is not None:
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set()
def get_client(self):
return self.client
@@ -1,4 +1,5 @@
import asyncio
from typing import cast
import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
segment.text,
segment.text,
self.message_obj.raw_message,
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
"😄",
markdown_str,
self.message_obj.raw_message,
cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
)
logger.debug(f"send image: {ret}")
@@ -1,4 +1,5 @@
import sys
from collections.abc import Awaitable, Callable
import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
self.on_ready_once_callback = None
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False
@override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
return
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
is_mentioned = self.user in message.mentions
return {
"message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
if interaction.user is None:
raise ValueError("Interaction received without a valid user")
return {
"interaction": interaction,
"bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
@override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__(
self,
components: list[BaseMessageComponent] = None,
timeout: float = None,
components: list[BaseMessageComponent] | None = None,
timeout: float | None = None,
):
self.components = components or []
self.timeout = timeout
@@ -1,10 +1,10 @@
import asyncio
import re
import sys
from typing import Any
from typing import Any, cast
import discord
from discord.abc import Messageable
from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel
from astrbot import logger
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
) -> None:
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.client_self_id = None
self.client_self_id: str | None = None
self.registered_handlers = []
# 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True)
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain,
):
"""通过会话发送消息"""
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
)
return
# 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage()
if "_" in session.session_id:
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id),
nickname=self.client.user.display_name,
)
message_obj.self_id = self.client_self_id
message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id
message_obj.message = message_chain.chain
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata(
"discord",
"Discord 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
default_config_tmpl=self.config,
support_streaming_message=False,
)
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type(
self,
channel: Messageable,
channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None,
) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型"""
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str:
def _get_channel_id(
self, channel: Messageable | GuildChannel | PrivateChannel
) -> str:
"""根据 channel 对象获取ID"""
return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
message = data["message"]
content = message.content
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = message_chain
abm.raw_message = message
abm.self_id = self.client_self_id
abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id)
abm.message_id = str(message.id)
return abm
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook,
)
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
)
return
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
# 1. 优先处理斜杠指令
if is_slash_command:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
return
# 2. 处理普通消息(提及检测)
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
raw_message = message.raw_message
if not isinstance(raw_message, discord.Message):
logger.warning(
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
)
return
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False
# User Mention
if (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
if self.client.user in raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"):
if not is_mention and raw_message.role_mentions:
bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
if raw_message.guild:
try:
bot_member = message.raw_message.guild.get_member(
bot_member = raw_message.guild.get_member(
self.client.user.id,
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions)
mentioned_roles = set(raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
):
is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention:
# 如果是被@的消息,设置为唤醒状态
if is_mention:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id
abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id)
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info(
event_filter: Any,
handler_metadata: StarHandlerMetadata,
) -> tuple[str, str, CommandFilter] | None:
) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
# is_group = False
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator
from io import BytesIO
from pathlib import Path
from typing import cast
import discord
from discord.types.interactions import ComponentInteractionData
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel()
if not channel:
return
if not isinstance(channel, discord.abc.Messageable):
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
return
await channel.send(**kwargs)
except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None:
async def _get_channel(
self,
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord(
self,
message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
) -> tuple[
str,
list[discord.File],
discord.ui.View | None,
list[discord.Embed],
str | int | None,
]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = []
files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"add_reaction",
):
await self.message_obj.raw_message.add_reaction(emoji)
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
emoji
)
except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type
and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command
)
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component
and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.component
)
def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id"""
if self.is_button_interaction():
try:
return self.message_obj.raw_message.data.get("custom_id", "")
return cast(
ComponentInteractionData,
cast(discord.Interaction, self.message_obj.raw_message).data,
).get("custom_id", "")
except Exception:
pass
return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
):
return any(
mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions
for mention in cast(
discord.Message, self.message_obj.raw_message
).mentions
)
return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"clean_content",
):
return self.message_obj.raw_message.clean_content
return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str
@@ -3,9 +3,14 @@ import base64
import json
import re
import uuid
from typing import cast
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
import astrbot.api.message_components as Comp
from astrbot import logger
@@ -74,6 +79,10 @@ class LarkPlatformAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
@@ -114,14 +123,21 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata(
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
message = event.event.message
if message is None:
logger.debug("[Lark] 事件中没有消息体(message is None)")
return
abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000
abm.timestamp = cast(int, message.create_time) // 1000
abm.message = []
abm.type = (
MessageType.GROUP_MESSAGE
@@ -136,14 +152,28 @@ class LarkPlatformAdapter(Platform):
at_list = {}
if message.mentions:
for m in message.mentions:
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name:
abm.self_id = m.id.open_id
if m.id is None:
continue
# 飞书 open_id 可能是 None,这里做个防护
open_id = m.id.open_id if m.id.open_id else ""
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
content_json_b = json.loads(message.content)
if m.name == self.bot_name:
if m.id.open_id is not None:
abm.self_id = m.id.open_id
if message.content is None:
logger.warning("[Lark] 消息内容为空")
return
try:
content_json_b = json.loads(message.content)
except json.JSONDecodeError:
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if message.message_type == "text":
message_str_raw = content_json_b["text"] # 带有 @ 的消息
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分
@@ -168,27 +198,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls
elif message.message_type == "image":
content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
{
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
]
if message.message_type in ("post", "image"):
for comp in content_json_b:
if comp["tag"] == "at":
abm.message.append(at_list[comp["user_id"]])
elif comp["tag"] == "text" and comp["text"].strip():
if comp.get("tag") == "at":
user_id = comp.get("user_id")
if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img":
image_key = comp["image_key"]
elif comp.get("tag") == "img":
image_key = comp.get("image_key")
if not image_key:
continue
request = (
GetMessageResourceRequest.builder()
.message_id(message.message_id)
.message_id(cast(str, message.message_id))
.file_key(image_key)
.type("image")
.build()
)
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -196,6 +246,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message:
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
return
if (
event.event.sender is None
or event.event.sender.sender_id is None
or event.event.sender.sender_id.open_id is None
):
logger.error("[Lark] 消息发送者信息不完整")
return
abm.message_id = message.message_id
abm.raw_message = message
abm.sender = MessageMember(
@@ -235,5 +298,5 @@ class LarkPlatformAdapter(Platform):
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭")
def get_client(self) -> lark.Client:
def get_client(self) -> lark.ws.Client:
return self.client
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
from lark_oapi.api.im.v1 import (
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path
file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
file_path = comp.file
file_path = comp.file if comp.file else ""
if image_file is None:
image_file = open(file_path, "rb")
if not file_path:
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
request = (
CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
continue
response = await lark_client.im.v1.image.acreate(request)
if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
continue
if response.data is None:
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
continue
image_key = response.data.image_key
logger.debug(image_key)
ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build()
)
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
return
response = await self.bot.im.v1.message.areply(request)
if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
request = (
CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
@@ -1,7 +1,6 @@
import asyncio
import os
import random
from collections.abc import Awaitable
from typing import Any
import astrbot.api.message_components as Comp
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict):
message.raw_message = {}
message.raw_message["poll"] = poll
message.poll = poll
message.__setattr__("poll", poll)
except Exception:
pass
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self,
session: MessageSession,
message_chain: MessageChain,
) -> Awaitable[Any]:
) -> None:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
@@ -3,6 +3,7 @@ import base64
import os
import random
import uuid
from typing import cast
import aiofiles
import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload)
ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None
source = self.message_obj.raw_message
assert isinstance(
if not isinstance(
source,
(
botpy.message.Message,
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage,
botpy.message.C2CMessage,
),
)
):
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
return None
(
plain_text,
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
):
return None
payload = {
payload: dict = {
"content": plain_text,
"msg_id": self.message_obj.message_id,
}
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None
match type(source):
case botpy.message.GroupMessage:
match source:
case botpy.message.GroupMessage():
if not source.group_openid:
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
return None
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid,
**payload,
)
case botpy.message.C2CMessage:
case botpy.message.C2CMessage():
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload,
)
logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
case botpy.message.Message():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_message(
channel_id=source.channel_id,
**payload,
)
case botpy.message.DirectMessage:
case botpy.message.DirectMessage():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
case _:
pass
await super().send(self.send_buffer)
self.send_buffer = None
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type,
"srv_send_msg": False,
}
result = None
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
return await self.bot.api._http.request(route, json=payload)
if "group_openid" in kwargs:
result = await self.bot.api._http.request(route, json=payload)
elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"]
route = Route(
"POST",
"/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"],
)
return await self.bot.api._http.request(route, json=payload)
result = await self.bot.api._http.request(route, json=payload)
else:
raise ValueError("Invalid upload parameters")
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to upload image, response is not dict: {result}"
)
return Media(
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record(
self,
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload)
if result:
if not isinstance(result, dict):
logger.error(f"上传文件响应格式错误: {result}")
return None
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None,
media: message.Media | None = None,
msg_id: str | None = None,
msg_seq: str = 1,
msg_seq: int | None = 1,
event_id: str | None = None,
markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None,
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload)
result = await self.bot.api._http.request(route, json=payload)
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result)
@staticmethod
async def _parse_to_qqofficial(message: MessageChain):
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"):
image_base64 = i.file
else:
elif i.file:
image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
@@ -4,6 +4,7 @@ import asyncio
import logging
import os
import time
from typing import cast
import botpy
import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
self._commit(abm)
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"]
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
)
@staticmethod
def _parse_from_qqofficial(
message: botpy.message.Message | botpy.message.GroupMessage,
message: botpy.message.Message
| botpy.message.GroupMessage
| botpy.message.DirectMessage
| botpy.message.C2CMessage,
message_type: MessageType,
):
abm = AstrBotMessage()
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time())
abm.raw_message = message
abm.message_id = message.id
abm.tag = "qq_official"
# abm.tag = "qq_official"
msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message,
botpy.message.DirectMessage,
):
try:
if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id)
except BaseException as _:
else:
abm.self_id = ""
plain_content = message.content.replace(
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Any
from typing import Any, cast
import botpy
import botpy.message
@@ -36,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
self._commit(abm)
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
)
async def run(self):
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import cast
import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13:
# validation
signed = await self.webhook_validation(data)
signed = await self.webhook_validation(cast(dict, data))
print(signed)
return signed
@@ -4,9 +4,11 @@ import hmac
import json
import logging
from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient
@@ -66,7 +68,7 @@ class SlackWebhookClient:
"""
try:
# 获取请求体和头部
body = await req.get_data()
body = cast(bytes, await req.get_data())
event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature
@@ -139,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler
self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
async def _handle_events(
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
):
"""处理 Socket Mode 事件"""
try:
if self.socket_client is None:
raise RuntimeError("Socket client is not initialized")
# 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response)
@@ -3,8 +3,7 @@ import base64
import re
import time
import uuid
from collections.abc import Awaitable
from typing import Any
from typing import Any, cast
import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata(
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage()
abm.self_id = self.bot_self_id
abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息
user_id = event.get("user", "")
try:
user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"]
user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception:
user_name = user_id
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "")
try:
channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"]
is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im:
abm.type = MessageType.FRIEND_MESSAGE
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions:
try:
mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"]
user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get(
"name",
mention,
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
)
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:
async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -1,6 +1,7 @@
import asyncio
import re
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Iterable
from typing import cast
from slack_sdk.web.async_client import AsyncWebClient
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
if isinstance(segment, Image):
# upload file
url = segment.url or segment.file
if url.startswith("http"):
if url and url.startswith("http"):
return {
"type": "image",
"image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"},
}
image_url = response["files"][0]["url_private"]
image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}")
return {
"type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
file_url = cast(list, response["files"])[0]["permalink"]
return {
"type": "section",
"text": {
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
)
members = []
for member_id in members_response["members"]:
for member_id in cast(Iterable, members_response["members"]):
try:
user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"]
user_data = cast(dict, user_info["user"])
members.append(
MessageMember(
user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"]
channel_data = cast(dict, channel_info["channel"])
return Group(
group_id=channel_id,
group_name=channel_data.get("name", ""),
@@ -1,6 +1,7 @@
import asyncio
import os
import re
from typing import Any, cast
import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply,
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name,
}
if has_reply:
payload["reply_to_message_id"] = reply_message_id
payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id:
payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
md_text = telegramify_markdown.markdownify(
chunk,
max_line_length=None,
normalize_whitespace=False,
)
await client.send_message(
text=md_text,
parse_mode="MarkdownV2",
**payload,
**cast(Any, payload),
)
except Exception as e:
logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.",
)
await client.send_message(text=chunk, **payload)
await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await client.send_document(document=i.file, filename=i.name, **payload)
path = await i.get_file()
name = i.name or os.path.basename(path)
await client.send_document(
document=path, filename=name, **cast(Any, payload)
)
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload)
await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
path = await i.get_file()
name = i.name or os.path.basename(path)
await self.client.send_document(
document=i.file,
filename=i.name,
**payload,
document=path,
filename=name,
**cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
await self.client.send_voice(voice=path, **cast(Any, payload))
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
# delta 长度一般不会大于 4096,因此这里直接发送
try:
msg = await self.client.send_message(text=delta, **payload)
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
markdown_text = telegramify_markdown.markdownify(
delta,
max_line_length=None,
normalize_whitespace=False,
)
await self.client.edit_message_text(
@@ -2,7 +2,7 @@ import asyncio
import os
import time
import uuid
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from typing import Any
from astrbot import logger
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
@@ -101,9 +101,9 @@ class WebChatMessageEvent(AstrMessageEvent):
return data
async def send(self, message: MessageChain):
async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
@@ -4,6 +4,7 @@ import json
import os
import time
import traceback
from typing import cast
import aiohttp
import anyio
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
)
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid
self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
)
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str):
async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}")
try:
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage()
abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time")
abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180:
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type")
msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = ""
abm.message = []
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str,
to_user_name: str,
msg_id: int,
):
) -> dict | None:
"""下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key}
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id")
msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image(
from_user_name,
to_user_name,
msg_id,
)
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
)
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid,
length=length,
)
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -827,7 +843,8 @@ class WeChatPadProAdapter(Platform):
try:
if self.ws_handle_task:
self.ws_handle_task.cancel()
self._shutdown_event.set()
if self._shutdown_event is not None:
self._shutdown_event.set()
except Exception:
pass
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list(
self,
room_wx_id_list: list[str] = None,
user_names: list[str] = None,
room_wx_id_list: list[str] | None = None,
user_names: list[str] | None = None,
) -> dict | None:
"""获取联系人详情列表。"""
if room_wx_id_list is None:
@@ -2,7 +2,8 @@ import asyncio
import os
import sys
import uuid
from typing import Any
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
from requests import Response
@@ -40,7 +41,7 @@ else:
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule(
"/callback/command",
@@ -60,7 +61,7 @@ class WecomServer:
config["corpid"].strip(),
)
self.callback = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
@@ -114,7 +115,7 @@ class WecomServer:
logger.error("解密失败,签名异常,请检查配置。")
raise
else:
msg = parse_message(xml)
msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject
self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api
self.client.kf_message = self.wechat_kf_message_api
self.client.__setattr__("kf", self.wechat_kf_api)
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
self.client.API_BASE_URL = self.api_base_url
self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage()
if msg.type == "text":
assert isinstance(msg, TextMessage)
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
elif isinstance(msg, ImageMessage):
abm.message_str = "[图片]"
abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
else:
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype")
external_userid = msg.get("external_userid")
external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage()
abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf:
# 微信客服
kf_message_api = getattr(self.client, "kf_message", None)
if not kf_message_api:
if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。")
return
assert isinstance(kf_message_api, WeChatKFMessage)
user_id = self.get_sender_id()
for comp in message.chain:
if isinstance(comp, Plain):
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
@staticmethod
async def _send(
message_chain: MessageChain,
message_chain: MessageChain | None,
stream_id: str,
queue_mgr: WecomAIQueueMgr,
streaming: bool = False,
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data
async def send(self, message: MessageChain):
async def send(self, message: MessageChain | None):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(message)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计"""
@@ -1,7 +1,8 @@
import asyncio
import sys
import uuid
from typing import Any
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
from requests import Response
@@ -36,7 +37,7 @@ else:
class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key")
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
self.event_queue = event_queue
self.callback = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
raise
else:
msg = parse_message(xml)
if not msg:
logger.error("解析失败。msg为None。")
raise
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.config["secret"].strip(),
)
self.client.API_BASE_URL = self.api_base_url
self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id]
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60,
) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError:
pass
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
async def convert_message(
self,
msg,
future: asyncio.Future = None,
future: asyncio.Future | None = None,
) -> AstrBotMessage | None:
abm = AstrBotMessage()
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)]
abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]"
abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
@@ -306,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
else:
logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None)
if future:
future.set_result(None)
return
# 很不优雅 :(
abm.raw_message = {
@@ -1,5 +1,6 @@
import asyncio
import uuid
from typing import cast
from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
active_send_mode = cast(dict, message_obj.raw_message).get(
"active_send_mode", False
)
for comp in message.chain:
if isinstance(comp, Plain):
# Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = TextReply(
content=chunk,
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = ImageReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = VoiceReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
+3 -3
View File
@@ -4,7 +4,7 @@ import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
name: str,
func_args: list[dict],
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None:
"""添加函数调用工具
+126 -70
View File
@@ -1,5 +1,6 @@
import asyncio
import traceback
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +11,7 @@ from .entities import ProviderType
from .provider import (
EmbeddingProvider,
Provider,
Providers,
RerankProvider,
STTProvider,
TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map
@runtime_checkable
class HasInitialize(Protocol):
async def initialize(self) -> None: ...
class ProviderManager:
def __init__(
self,
@@ -48,7 +55,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
Providers,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -123,15 +130,13 @@ class ProviderManager:
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(
self,
provider_type: ProviderType,
umo=None,
) -> Provider | STTProvider | TTSProvider | None:
self, provider_type: ProviderType, umo=None
) -> Providers | None:
"""获取正在使用的提供商实例。
Args:
@@ -191,7 +196,6 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
@@ -210,15 +214,37 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
temp_provider = (
self.inst_map.get(selected_provider_id)
if isinstance(selected_provider_id, str)
else None
)
self.curr_provider_inst = (
temp_provider if isinstance(temp_provider, Provider) else None
)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
temp_stt = (
self.inst_map.get(selected_stt_provider_id)
if isinstance(selected_stt_provider_id, str)
else None
)
self.curr_stt_provider_inst = (
temp_stt if isinstance(temp_stt, STTProvider) else None
)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
temp_tts = (
self.inst_map.get(selected_tts_provider_id)
if isinstance(selected_tts_provider_id, str)
else None
)
self.curr_tts_provider_inst = (
temp_tts if isinstance(temp_tts, TTSProvider) else None
)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
@@ -358,73 +384,103 @@ class ProviderManager:
provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
match provider_metadata.provider_type:
case ProviderType.SPEECH_TO_TEXT:
# STT 任务
if not issubclass(cls_type, STTProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of STTProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
if isinstance(inst, HasInitialize):
await inst.initialize()
self.stt_provider_insts.append(inst)
if (
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
self.stt_provider_insts.append(inst)
if (
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
case ProviderType.TEXT_TO_SPEECH:
# TTS 任务
if not issubclass(cls_type, TTSProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of TTSProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.provider_settings.get("provider_id")
== provider_config["id"]
):
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
case ProviderType.CHAT_COMPLETION:
# 文本生成任务
if not issubclass(cls_type, Provider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of Provider"
)
inst = cls_type(
provider_config,
self.provider_settings,
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
self.tts_provider_insts.append(inst)
if self.provider_settings.get("provider_id") == provider_config["id"]:
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
case ProviderType.EMBEDDING:
if not issubclass(cls_type, EmbeddingProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.embedding_provider_insts.append(inst)
case ProviderType.RERANK:
if not issubclass(cls_type, RerankProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of RerankProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.rerank_provider_insts.append(inst)
case _:
# 未知供应商抛出异常,确保inst初始化
# Should be unreachable
raise Exception(
f"未知的提供商类型:{provider_metadata.provider_type}"
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = cls_type(
provider_config,
self.provider_settings,
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
+12 -1
View File
@@ -2,6 +2,7 @@ import abc
import asyncio
import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
...
if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0
self.timeout = Timeout(10.0)
self.retry_count = 3
self.client = None
self._client: AsyncClient | None = None
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout)
self._client = AsyncClient(timeout=self.timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
if self._client:
await self._client.aclose()
self._client = None
async def _sync_time(self):
try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1))
raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
self.client = None
self._client: AsyncClient | None = None
self.token = None
self.token_expire = 0
self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"),
}
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self):
self.client = AsyncClient(
self._client = AsyncClient(
headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
if self._client:
await self._client.aclose()
self._client = None
async def _refresh_token(self):
token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config)
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
def _parse_provider(
self, key_value: str, config: dict
) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["):
json_str = ""
try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
Returns:
重排序结果列表
"""
if not self.client:
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
return []
if not documents:
logger.warning("文档列表为空,返回空结果")
return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
kwargs = {
"model": model,
"text": text,
"messages": None,
"api_key": self.chosen_api_key,
"voice": self.voice or "Cherry",
"text": text,
}
if not self.voice:
logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path)
ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}",
}
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
async def _get_reference_id_by_character(self, character: str) -> str:
async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id
Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict:
async def _generate_request(self, text: str) -> ServeTTSRequest:
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip():
# 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
text = await response.aread()
body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
from typing import cast
from google import genai
from google.genai import types
from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.provider_config = provider_config
self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key")
api_base: str = provider_config.get("embedding_api_base")
api_key: str = provider_config["embedding_api_key"]
api_base: str = provider_config["embedding_api_base"]
timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
model=self.model,
contents=text,
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None
return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
try:
result = await self.client.models.embed_content(
model=self.model,
contents=texts,
contents=cast(types.ContentListUnion, text),
)
return [embedding.values for embedding in result.embeddings]
assert result.embeddings is not None
embeddings: list[list[float]] = []
for embedding in result.embeddings:
assert embedding.values is not None
embeddings.append(embedding.values)
return embeddings
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
@@ -4,6 +4,7 @@ import json
import logging
import random
from collections.abc import AsyncGenerator
from typing import cast
from google import genai
from google.genai import types
@@ -136,7 +137,7 @@ class ProviderGoogleGenAI(Provider):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"]
tool_list = []
tool_list: list[types.Tool] | None = []
model_name = self.get_model()
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False)
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"),
response_modalities=modalities,
tools=tool_list,
tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=(
types.ThinkingConfig(
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
content_cls: type[types.Content],
) -> None:
if contents and isinstance(contents[-1], content_cls):
assert contents[-1].parts is not None
contents[-1].parts.extend(part)
else:
contents.append(content_cls(parts=part))
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
)
result = await self.client.models.generate_content(
model=self.get_model(),
contents=conversation,
contents=cast(types.ContentListUnion, conversation),
config=config,
)
logger.debug(f"genai result: {result}")
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
)
result = await self.client.models.generate_content_stream(
model=self.get_model(),
contents=conversation,
contents=cast(types.ContentListUnion, conversation),
config=config,
)
break
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body)
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
"""进行流式请求"""
try:
async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
data = json.loads(message[6:])
if "extra_info" in data:
continue
audio = data.get("data", {}).get("audio")
audio: str | None = data.get("data", {}).get(
"audio"
)
if audio is not None:
yield audio
except json.JSONDecodeError:
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
if tools is None:
# 工具集未提供
# Should be unreachable
raise Exception("工具集未提供")
for tool in tools.func_list:
if (
tool_call.type == "function"
@@ -7,6 +7,7 @@ import asyncio
import os
import re
from datetime import datetime
from typing import cast
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("stt_model"))
self.set_model(provider_config["stt_model"])
self.model = None
self.is_emotion = provider_config.get("is_emotion", False)
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
loop = asyncio.get_event_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
lambda: self.model(audio_url, language="auto", use_itn=True),
lambda: cast(SenseVoiceSmall, self.model)(
audio_url, language="auto", use_itn=True
),
)
# res = self.model(audio_url, language="auto", use_itn=True)
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
}
if top_n is not None:
payload["top_n"] = top_n
assert self.client is not None
async with self.client.post(
f"{self.base_url}/v1/rerank",
json=payload,
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
async def _get_audio_format(self, file_path):
# 定义要检测的头部字节
@@ -1,6 +1,7 @@
import asyncio
import os
import uuid
from typing import cast
import whisper
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
self.model = None
async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
if not self.model:
raise RuntimeError("Whisper 模型未初始化")
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result["text"]
return cast(str, result["text"])
@@ -1,6 +1,11 @@
from typing import cast
from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client,
)
from xinference_client.client.restful.async_restful_client import (
AsyncRESTfulRerankModelHandle,
)
from astrbot import logger
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
False,
)
self.client = None
self.model = None
self.model: AsyncRESTfulRerankModelHandle | None = None
self.model_uid = None
async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
return
if self.model_uid:
self.model = await self.client.get_model(self.model_uid)
self.model = cast(
AsyncRESTfulRerankModelHandle,
await self.client.get_model(self.model_uid),
)
except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}")
+2 -2
View File
@@ -285,7 +285,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider | None:
def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args:
@@ -296,7 +296,7 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if prov and not isinstance(prov, Provider):
if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
+22 -5
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import docstring_parser
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
def get_handler_full_name(
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
handler: Callable[..., Awaitable[Any]],
handler: Callable[
...,
Awaitable[MessageEventResult | str | None]
| AsyncGenerator[MessageEventResult | str | None],
],
event_type: EventType,
dont_add=False,
**kwargs,
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
for (
sub_handle
) in parent_register_commandable.parent_group.sub_command_filters:
if isinstance(sub_handle, CommandGroupFilter):
continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md()
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
else:
# 裸指令
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
@@ -237,7 +248,7 @@ class RegisteringCommandable:
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Callable[..., Awaitable[Any]]):
def decorator(
awaitable: Callable[
...,
AsyncGenerator[MessageEventResult | str | None]
| Awaitable[MessageEventResult | str | None],
],
):
llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
+85 -4
View File
@@ -1,9 +1,9 @@
from __future__ import annotations
import enum
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter
from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
for handler in self._handlers:
print(handler.handler_full_name)
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAstrBotLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPlatformLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.AdapterMessageEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMRequestEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMResponseEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnDecoratingResultEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnCallingFuncToolEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAfterMessageSentEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: EventType,
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
def get_handlers_by_event_type(
self,
event_type: EventType,
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
OnAfterMessageSentEvent = enum.auto() # 发送消息后
H = TypeVar("H", bound=Callable[..., Any])
@dataclass
class StarHandlerMetadata:
class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
handler_module_path: str
"""Handler 所在的模块路径。"""
handler: Callable[..., Awaitable[Any]]
handler: H
"""Handler 的函数对象,应当是一个异步函数"""
event_filters: list[HandlerFilter]
+3 -3
View File
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
async def check_update(
self,
url: str,
current_version: str,
url: str | None,
current_version: str | None,
consider_prerelease: bool = True,
) -> ReleaseInfo:
) -> ReleaseInfo | None:
"""检查更新"""
return await super().check_update(
self.ASTRBOT_RELEASE_API,
+1 -1
View File
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
return False
def save_temp_img(img: Image.Image | str) -> str:
def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
# 获得文件创建时间,清除超过 12 小时的
try:
+17 -10
View File
@@ -20,16 +20,16 @@ class SessionController:
def __init__(self):
self.future = asyncio.Future()
self.current_event: asyncio.Event = None
self.current_event: asyncio.Event | None = None
"""当前正在等待的所用的异步事件"""
self.ts: float = None
self.ts: float | None = None
"""上次保持(keep)开始时的时间"""
self.timeout: float | int = None
self.timeout: float | int | None = None
"""上次保持(keep)开始时的超时时间"""
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
def stop(self, error: Exception = None):
def stop(self, error: Exception | None = None):
"""立即结束这个会话"""
if not self.future.done():
if error:
@@ -53,6 +53,8 @@ class SessionController:
self.stop()
return
else:
assert self.timeout is not None
assert self.ts is not None
left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout
if timeout <= 0:
@@ -69,7 +71,7 @@ class SessionController:
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout: int):
async def _holding(self, event: asyncio.Event, timeout: float):
"""等待事件结束或超时"""
try:
await asyncio.wait_for(event.wait(), timeout)
@@ -108,7 +110,9 @@ class SessionWaiter:
):
self.session_id = session_id
self.session_filter = session_filter
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
self.handler: (
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
) = None # 处理函数
self.session_controller = SessionController()
self.record_history_chains = record_history_chains
@@ -119,7 +123,7 @@ class SessionWaiter:
async def register_wait(
self,
handler: Callable[[str], Awaitable[Any]],
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
timeout: int = 30,
) -> Any:
"""等待外部输入并处理"""
@@ -137,7 +141,7 @@ class SessionWaiter:
finally:
self._cleanup()
def _cleanup(self, error: Exception = None):
def _cleanup(self, error: Exception | None = None):
"""清理会话"""
USER_SESSIONS.pop(self.session_id, None)
try:
@@ -161,6 +165,7 @@ class SessionWaiter:
)
try:
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
assert session.handler is not None
await session.handler(session.session_controller, event)
except Exception as e:
session.session_controller.stop(e)
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
:param record_history_chain: 是否自动记录历史消息链可以通过 controller.get_history_chains() 获取深拷贝
"""
def decorator(func: Callable[[str], Awaitable[Any]]):
def decorator(
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
):
@functools.wraps(func)
async def wrapper(
event: AstrMessageEvent,
session_filter: SessionFilter = None,
session_filter: SessionFilter | None = None,
*args,
**kwargs,
):
+32
View File
@@ -53,6 +53,38 @@ class SharedPreferences:
ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret
@overload
async def session_get(
self,
umo: str,
key: str,
default: _VT = None,
) -> _VT: ...
@overload
async def session_get(
self,
umo: None,
key: str,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: str,
key: None,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: None,
key: None,
default: Any = None,
) -> list[Preference]: ...
async def session_get(
self,
umo: str | None,
+2 -2
View File
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
class RenderStrategy(ABC):
@abstractmethod
def render(self, text: str, return_url: bool) -> str:
async def render(self, text: str, return_url: bool) -> str:
pass
@abstractmethod
def render_custom_template(
async def render_custom_template(
self,
tmpl_str: str,
tmpl_data: dict,
+25 -31
View File
@@ -20,7 +20,7 @@ class FontManager:
_font_cache = {}
@classmethod
def get_font(cls, size: int) -> ImageFont.FreeTypeFont:
def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
"""获取指定大小的字体,优先从缓存获取"""
if size in cls._font_cache:
return cls._font_cache[size]
@@ -66,23 +66,17 @@ class TextMeasurer:
"""测量文本尺寸的工具类"""
@staticmethod
def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
"""获取文本的尺寸"""
try:
# PIL 9.0.0 以上版本
return (
font.getbbox(text)[2:]
if hasattr(font, "getbbox")
else font.getsize(text)
)
except Exception:
# 兼容旧版本
return font.getsize(text)
# 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
left, top, right, bottom = font.getbbox("Hello world")
return int(right - left), int(bottom - top)
@staticmethod
def split_text_to_fit_width(
text: str, font: ImageFont.FreeTypeFont, max_width: int
) -> List[str]:
text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
) -> list[str]:
"""将文本拆分为多行,确保每行不超过指定宽度"""
lines = []
if not text:
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
# 倾斜变换,使用仿射变换实现斜体效果
# 变换矩阵: [1, 0.2, 0, 0, 1, 0]
italic_img = text_img.transform(
text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC
text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
)
# 粘贴到原图像
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
class CodeBlockElement(MarkdownElement):
"""代码块元素"""
def __init__(self, content: List[str]):
def __init__(self, content: list[str]):
super().__init__("\n".join(content))
def calculate_height(self, image_width: int, font_size: int) -> int:
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
def render(
self,
image: Image.Image,
draw: ImageDraw.Draw,
draw: ImageDraw.ImageDraw,
x: int,
y: int,
image_width: int,
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
if pasted_image.width > max_width:
ratio = max_width / pasted_image.width
new_size = (int(max_width), int(pasted_image.height * ratio))
pasted_image = pasted_image.resize(new_size, Image.LANCZOS)
pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
# 计算居中位置
paste_x = x + (image_width - pasted_image.width) // 2 - 10
@@ -705,7 +699,7 @@ class MarkdownParser:
"""Markdown解析器,将文本解析为元素"""
@staticmethod
async def parse(text: str) -> List[MarkdownElement]:
async def parse(text: str) -> list[MarkdownElement]:
elements = []
lines = text.split("\n")
@@ -847,7 +841,7 @@ class MarkdownRenderer:
self,
font_size: int = 26,
width: int = 800,
bg_color: Tuple[int, int, int] = (255, 255, 255),
bg_color: tuple[int, int, int] = (255, 255, 255),
):
self.font_size = font_size
self.width = width
+1 -1
View File
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input=input_path, output=output_path)
ff.convert(input_file=input_path, output_file=output_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
+6 -3
View File
@@ -60,9 +60,12 @@ class VersionComparator:
return -1
if isinstance(p1, str) and isinstance(p2, int):
return 1
if (isinstance(p1, int) and isinstance(p2, int)) or (
isinstance(p1, str) and isinstance(p2, str)
):
if isinstance(p1, int) and isinstance(p2, int):
if p1 > p2:
return 1
if p1 < p2:
return -1
if isinstance(p1, str) and isinstance(p2, str):
if p1 > p2:
return 1
if p1 < p2:
+14 -9
View File
@@ -4,7 +4,9 @@ import mimetypes
import os
import uuid
from contextlib import asynccontextmanager
from typing import cast
from quart import Response as QuartResponse
from quart import g, make_response, request, send_file
from astrbot.core import logger
@@ -424,16 +426,19 @@ class ChatRoute(Route):
sender_name=username,
)
response = await make_response(
stream(),
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
},
response = cast(
QuartResponse,
await make_response(
stream(),
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
},
),
)
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue]
response.timeout = None # fix SSE auto disconnect issue
return response
async def delete_webchat_session(self):
+6 -5
View File
@@ -3,6 +3,7 @@ import inspect
import os
import traceback
import uuid
from typing import Any
from quart import request
@@ -26,7 +27,7 @@ from astrbot.core.star.star import star_registry
from .route import Response, Route, RouteContext
def try_cast(value: str, type_: str):
def try_cast(value: Any, type_: str):
if type_ == "int":
try:
return int(value)
@@ -505,9 +506,9 @@ class ConfigRoute(Route):
if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
# 初始化
if getattr(inst, "initialize", None):
await inst.initialize()
init_fn = getattr(inst, "initialize", None)
if inspect.iscoroutinefunction(init_fn):
await init_fn()
# 获取嵌入向量维度
vec = await inst.get_embedding("echo")
@@ -777,7 +778,7 @@ class ConfigRoute(Route):
return {"metadata": CONFIG_METADATA_2, "config": config}
async def _get_plugin_config(self, plugin_name: str):
ret = {"metadata": None, "config": None}
ret: dict = {"metadata": None, "config": None}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
+13 -8
View File
@@ -1,6 +1,8 @@
import asyncio
import json
from typing import cast
from quart import Response as QuartResponse
from quart import make_response
from astrbot.core import LogBroker, logger
@@ -39,14 +41,17 @@ class LogRoute(Route):
if queue:
self.log_broker.unregister(queue)
response = await make_response(
stream(),
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
response = cast(
QuartResponse,
await make_response(
stream(),
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
),
)
response.timeout = None
return response
+4
View File
@@ -579,6 +579,10 @@ class PluginRoute(Route):
logger.warning(f"插件 {plugin_name} 不存在")
return Response().error(f"插件 {plugin_name} 不存在").__dict__
if not plugin_obj.root_dir_name:
logger.warning(f"插件 {plugin_name} 目录不存在")
return Response().error(f"插件 {plugin_name} 目录不存在").__dict__
plugin_dir = os.path.join(
self.plugin_manager.plugin_store_path,
plugin_obj.root_dir_name or "",
+2
View File
@@ -12,6 +12,8 @@ class RouteContext:
class Route:
routes: list | dict
def __init__(self, context: RouteContext):
self.app = context.app
self.config = context.config
+6 -3
View File
@@ -2,9 +2,12 @@ import asyncio
import logging
import os
import socket
from typing import cast
import jwt
import psutil
from flask.json.provider import DefaultJSONProvider
from psutil._common import addr as psutil_addr
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
@@ -21,7 +24,7 @@ from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute
from .routes.t2i import T2iRoute
APP: Quart = None
APP: Quart
class AstrBotDashboard:
@@ -48,7 +51,7 @@ class AstrBotDashboard:
self.app.config["MAX_CONTENT_LENGTH"] = (
128 * 1024 * 1024
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
self.app.json.sort_keys = False
cast(DefaultJSONProvider, self.app.json).sort_keys = False
self.app.before_request(self.auth_middleware)
# token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler)
@@ -147,7 +150,7 @@ class AstrBotDashboard:
"""获取占用端口的进程详细信息"""
try:
for conn in psutil.net_connections(kind="inet"):
if conn.laddr.port == port:
if cast(psutil_addr, conn.laddr).port == port:
try:
process = psutil.Process(conn.pid)
# 获取详细信息
+5
View File
@@ -139,6 +139,11 @@ class ProcessLLMRequest:
# group name identifier
if cfg.get("group_name_display") and event.message_obj.group_id:
if not event.message_obj.group:
logger.error(
f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
)
return
group_name = event.message_obj.group.group_name
if group_name:
req.system_prompt += f"\nGroup name: {group_name}\n"
+8 -1
View File
@@ -14,6 +14,7 @@ from astrbot.api import llm_tool, logger, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
from astrbot.api.message_components import File, Image
from astrbot.api.provider import ProviderRequest
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url
@@ -224,6 +225,8 @@ class Main(star.Star):
del self.user_waiting[uid]
elif isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
if image_url is None:
raise ValueError("Image URL is None")
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
elif image_url.startswith("file:///"):
@@ -240,6 +243,8 @@ class Main(star.Star):
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
if not request.prompt:
request.prompt = ""
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
@@ -477,7 +482,9 @@ class Main(star.Star):
# file_s3_url = await self.file_upload(file_path)
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_path)]
chain: list[BaseMessageComponent] = [
File(name=file_name, file=file_path)
]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
+11 -8
View File
@@ -5,6 +5,7 @@ import uuid
import zoneinfo
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from astrbot.api import llm_tool, logger, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
@@ -62,13 +63,13 @@ class Main(star.Star):
misfire_grace_time=60,
)
elif "cron" in reminder:
trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"]))
self.scheduler.add_job(
self._reminder_callback,
trigger="cron",
trigger=trigger,
id=id_,
args=[group, reminder],
misfire_grace_time=60,
**self._parse_cron_expr(reminder["cron"]),
)
def check_is_outdated(self, reminder: dict):
@@ -101,10 +102,10 @@ class Main(star.Star):
async def reminder_tool(
self,
event: AstrMessageEvent,
text: str = None,
datetime_str: str = None,
cron_expression: str = None,
human_readable_cron: str = None,
text: str | None = None,
datetime_str: str | None = None,
cron_expression: str | None = None,
human_readable_cron: str | None = None,
):
"""Call this function when user is asking for setting a reminder.
@@ -139,17 +140,19 @@ class Main(star.Star):
"id": str(uuid.uuid4()),
}
self.reminder_data[event.unified_msg_origin].append(d)
trigger = CronTrigger(**self._parse_cron_expr(cron_expression))
self.scheduler.add_job(
self._reminder_callback,
"cron",
trigger,
id=d["id"],
misfire_grace_time=60,
**self._parse_cron_expr(cron_expression),
args=[event.unified_msg_origin, d],
)
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
if datetime_str is None:
raise ValueError("datetime_str cannot be None.")
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(
+16 -9
View File
@@ -3,7 +3,7 @@ import urllib.parse
from dataclasses import dataclass
from aiohttp import ClientSession
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, Tag
HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
@@ -45,13 +45,13 @@ class SearchEngine:
self.page = 1
self.headers = HEADERS
def _set_selector(self, selector: str) -> None:
def _set_selector(self, selector: str) -> str:
raise NotImplementedError
def _get_next_page(self):
def _get_next_page(self, query: str):
raise NotImplementedError
async def _get_html(self, url: str, data: dict = None) -> str:
async def _get_html(self, url: str, data: dict | None = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
@@ -83,6 +83,9 @@ class SearchEngine:
"""清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def _get_url(self, tag: Tag) -> str:
return self.tidy_text(tag.get_text())
async def search(self, query: str, num_results: int) -> list[SearchResult]:
query = urllib.parse.quote(query)
@@ -92,12 +95,16 @@ class SearchEngine:
links = soup.select(self._set_selector("links"))
results = []
for link in links:
title = self.tidy_text(
link.select_one(self._set_selector("title")).text,
)
url = link.select_one(self._set_selector("url"))
# Safely get the title text (select_one may return None)
title_elem = link.select_one(self._set_selector("title"))
title = ""
if title_elem is not None:
title = self.tidy_text(title_elem.get_text())
url_tag = link.select_one(self._set_selector("url"))
snippet = ""
if title and url:
if title and url_tag:
url = self._get_url(url_tag)
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
+1 -9
View File
@@ -1,4 +1,4 @@
from . import USER_AGENT_BING, SearchEngine, SearchResult
from . import USER_AGENT_BING, SearchEngine
class Bing(SearchEngine):
@@ -28,11 +28,3 @@ class Bing(SearchEngine):
self.base_url = base_url
continue
raise Exception("Bing search failed")
async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if not isinstance(result.url, str):
result.url = result.url.text
return results
+10 -4
View File
@@ -1,7 +1,8 @@
import random
import re
from typing import cast
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, Tag
from . import USER_AGENTS, SearchEngine, SearchResult
@@ -26,10 +27,12 @@ class Sogo(SearchEngine):
url = f"{self.base_url}/web?query={query}"
return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
return cast(str, tag.get("href"))
async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results)
for result in results:
result.url = result.url.get("href")
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
@@ -40,7 +43,10 @@ class Sogo(SearchEngine):
soup = BeautifulSoup(html, "html.parser")
script = soup.find("script")
if script:
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(
1,
script_text = (
script.string if script.string is not None else script.get_text()
)
match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
if match:
url = match.group(1)
return url
+90
View File
@@ -0,0 +1,90 @@
"""Minimal type stubs for faiss used in this project.
This file only exposes a small subset of the faiss API that the
project uses, including the runtime-monkeypatched signatures such as
`Index.add_with_ids` so Pyright/Pylance stops reporting false positives.
"""
from typing import Any, overload
import numpy as np
class Index:
d: int
ntotal: int
code_size: int
nprobe: int
def add(self, x: np.ndarray) -> None: ...
def add_with_ids(self, x: np.ndarray, ids: np.ndarray) -> None: ...
def search(
self,
x: np.ndarray,
k: int,
*,
params: Any = ...,
D: np.ndarray | None = ...,
I: np.ndarray | None = ...,
) -> tuple[np.ndarray, np.ndarray]: ...
def remove_ids(self, x: np.ndarray) -> int: ...
@overload
def reconstruct(self, key: int) -> np.ndarray: ...
@overload
def reconstruct(self, key: int, x: np.ndarray) -> None: ...
def reconstruct(
self, key: int, x: np.ndarray | None = ...
) -> np.ndarray | None: ...
@overload
def reconstruct_n(self, n0: int, ni: int) -> np.ndarray: ...
@overload
def reconstruct_n(self, n0: int, ni: int, x: np.ndarray) -> None: ...
def reconstruct_n(
self, n0: int = ..., ni: int = ..., x: np.ndarray | None = ...
) -> np.ndarray | None: ...
def range_search(
self, x: np.ndarray, thresh: float, *, params: Any = ...
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
def add_sa_codes(self, codes: np.ndarray, ids: np.ndarray | None = ...) -> None: ...
def sa_encode(self, x: np.ndarray) -> np.ndarray: ...
def sa_decode(self, codes: np.ndarray) -> np.ndarray: ...
class IndexFlatL2(Index):
def __init__(self, d: int) -> None: ...
class IndexIDMap(Index):
index: Index
def __init__(self, index: Index) -> None: ...
def read_index(path: str) -> Index: ...
def write_index(index: Index, path: str | None = ...) -> None: ...
def normalize_L2(x: np.ndarray) -> None: ...
# Additional concrete-ish classes exposed by some faiss builds (SWIG helpers
# expose `downcast_*` helpers to convert generic objects to these concrete
# types). We keep these minimal — only the names are important for typing.
class IndexBinary(Index):
def __init__(self, d: int) -> None: ...
class InvertedLists:
def __len__(self) -> int: ...
class AdditiveQuantizer:
pass
class Quantizer:
pass
class VectorTransform:
pass
# SWIG-provided downcast helpers (present in some faiss Python builds).
def downcast_IndexBinary(obj: Any) -> IndexBinary: ...
def downcast_InvertedLists(obj: Any) -> InvertedLists: ...
def downcast_AdditiveQuantizer(obj: Any) -> AdditiveQuantizer: ...
def downcast_Quantizer(obj: Any) -> Quantizer: ...
def downcast_VectorTransform(obj: Any) -> VectorTransform: ...
def downcast_index(obj: Any) -> Index: ...
# version exposed by runtime
__version__: str