chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_resp_result = None
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
assert isinstance(llm_response, LLMResponse)
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any, Generic
|
||||
|
||||
import jsonschema
|
||||
@@ -7,6 +7,8 @@ from deprecated import deprecated
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
|
||||
from .run_context import ContextWrapper, TContext
|
||||
|
||||
ParametersType = dict[str, Any]
|
||||
@@ -38,7 +40,10 @@ class ToolSchema:
|
||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||
"""A callable tool, for function calling."""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]] | None = None
|
||||
handler: (
|
||||
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
|
||||
| None
|
||||
) = None
|
||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||
|
||||
handler_module_path: str | None = None
|
||||
|
||||
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
async def call_local_llm_tool(
|
||||
context: ContextWrapper[AstrAgentContext],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[
|
||||
...,
|
||||
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
|
||||
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
|
||||
],
|
||||
method_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
|
||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||
"""
|
||||
|
||||
config_path: str
|
||||
default_config: dict
|
||||
schema: dict | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||
|
||||
@@ -197,7 +197,7 @@ class AstrBotCoreLifecycle:
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
|
||||
|
||||
tasks_ = [event_bus_task, *extra_tasks]
|
||||
for task in tasks_:
|
||||
|
||||
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.AsyncSessionLocal = sessionmaker(
|
||||
self.AsyncSessionLocal = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
|
||||
@@ -70,6 +70,7 @@ async def migration_conversation_table(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" not in conv.user_id:
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
@@ -207,6 +208,7 @@ async def migration_webchat_data(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||
)
|
||||
continue
|
||||
if ":" in conv.user_id:
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
|
||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
||||
conn.text_factory = str
|
||||
return conn
|
||||
|
||||
def _exec_sql(self, sql: str, params: tuple = None):
|
||||
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||
conn = self.conn
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
@@ -224,9 +224,11 @@ class SQLiteDatabase:
|
||||
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
return Stats(platform)
|
||||
|
||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
||||
def get_conversation_by_user_id(
|
||||
self, user_id: str, cid: str
|
||||
) -> Conversation | None:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
@@ -258,7 +260,7 @@ class SQLiteDatabase:
|
||||
(user_id, cid, history, updated_at, created_at),
|
||||
)
|
||||
|
||||
def get_conversations(self, user_id: str) -> tuple:
|
||||
def get_conversations(self, user_id: str) -> list[Conversation]:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
|
||||
+16
-15
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
||||
Note: In astrbot v4, we moved `platform` table to here.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_stats" # type: ignore
|
||||
__tablename__: str = "platform_stats"
|
||||
|
||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
timestamp: datetime = Field(nullable=False)
|
||||
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
|
||||
|
||||
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations" # type: ignore
|
||||
__tablename__: str = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(
|
||||
inner_conversation_id: int | None = Field(
|
||||
default=None,
|
||||
primary_key=True,
|
||||
sa_column_kwargs={"autoincrement": True},
|
||||
)
|
||||
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
|
||||
It can be used to customize the behavior of LLMs.
|
||||
"""
|
||||
|
||||
__tablename__ = "personas" # type: ignore
|
||||
__tablename__: str = "personas"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
|
||||
class Preference(SQLModel, table=True):
|
||||
"""This class represents preferences for bots."""
|
||||
|
||||
__tablename__ = "preferences" # type: ignore
|
||||
__tablename__: str = "preferences"
|
||||
|
||||
id: int | None = Field(
|
||||
default=None,
|
||||
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
or platform-specific messages.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_message_history" # type: ignore
|
||||
__tablename__: str = "platform_message_history"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
|
||||
Each session can have multiple conversations (对话) associated with it.
|
||||
"""
|
||||
|
||||
__tablename__ = "platform_sessions" # type: ignore
|
||||
__tablename__: str = "platform_sessions"
|
||||
|
||||
inner_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
|
||||
Attachments can be images, files, or other media types.
|
||||
"""
|
||||
|
||||
__tablename__ = "attachments" # type: ignore
|
||||
__tablename__: str = "attachments"
|
||||
|
||||
inner_attachment_id: int | None = Field(
|
||||
primary_key=True,
|
||||
@@ -261,17 +262,17 @@ class Personality(TypedDict):
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
prompt: str
|
||||
name: str
|
||||
begin_dialogs: list[str]
|
||||
mood_imitation_dialogs: list[str]
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
tools: list[str] | None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
_begin_dialogs_processed: list[dict]
|
||||
_mood_imitation_dialogs_processed: str
|
||||
|
||||
|
||||
# ====
|
||||
|
||||
@@ -3,6 +3,7 @@ import threading
|
||||
import typing as T
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
|
||||
@@ -489,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(Attachment).where(
|
||||
Attachment.attachment_id.in_(attachment_ids)
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
@@ -505,7 +506,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id) == attachment_id
|
||||
)
|
||||
result = await session.execute(query)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount > 0
|
||||
|
||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||
@@ -521,7 +522,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
query = delete(Attachment).where(
|
||||
col(Attachment.attachment_id).in_(attachment_ids)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
result = T.cast(CursorResult, await session.execute(query))
|
||||
return result.rowcount
|
||||
|
||||
async def insert_persona(
|
||||
|
||||
@@ -90,4 +90,6 @@ class EmbeddingStorage:
|
||||
path (str): 保存索引的路径
|
||||
|
||||
"""
|
||||
if self.index is None:
|
||||
return
|
||||
faiss.write_index(self.index, self.path)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EventBus:
|
||||
self,
|
||||
event_queue: Queue,
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
astrbot_config_mgr: AstrBotConfigManager,
|
||||
):
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
@@ -40,6 +40,11 @@ class EventBus:
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
if not scheduler:
|
||||
logger.error(
|
||||
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
|
||||
)
|
||||
continue
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
|
||||
@@ -166,7 +166,11 @@ class RetrievalManager:
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
vec_db = kb_options[kb_id]["vec_db"]
|
||||
if not isinstance(vec_db, FaissVecDB):
|
||||
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
||||
continue
|
||||
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
|
||||
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
|
||||
class BaseMessageComponent(BaseModel):
|
||||
type: ComponentType
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def toDict(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
|
||||
id: int | None = 0 # 忽略
|
||||
name: str | None = "" # qq昵称
|
||||
uin: str | None = "0" # qq号
|
||||
content: list[BaseMessageComponent] | None = []
|
||||
content: list[BaseMessageComponent] = []
|
||||
seq: str | list | None = "" # 忽略
|
||||
time: int | None = 0 # 忽略
|
||||
|
||||
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
|
||||
ret["messages"].append(d)
|
||||
return ret
|
||||
|
||||
async def to_dict(self):
|
||||
async def to_dict(self) -> dict:
|
||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||
ret = {"messages": []}
|
||||
for node in self.nodes:
|
||||
@@ -714,12 +717,15 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("Download failed: No URL provided in File component.")
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
if self.name:
|
||||
|
||||
@@ -98,8 +98,8 @@ class PersonaManager:
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
begin_dialogs: list[str] | None = None,
|
||||
tools: list[str] | None = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
|
||||
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
check_text: str | None = None,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""检查内容安全"""
|
||||
text = check_text if check_text else event.get_message_str()
|
||||
ok, info = self.strategy_selector.check(text)
|
||||
|
||||
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
async def call_handler(
|
||||
event: AstrMessageEvent,
|
||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
@@ -91,6 +91,7 @@ async def call_event_hook(
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
assert inspect.iscoroutinefunction(handler.handler)
|
||||
logger.debug(
|
||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||
"activated_handlers",
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
|
||||
):
|
||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||
if (
|
||||
event.get_result() and not event.get_result().is_stopped()
|
||||
event.get_result() and not event.is_stopped()
|
||||
) or not event.get_result():
|
||||
async for _ in self.agent_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
@@ -117,7 +117,9 @@ class RespondStage(Stage):
|
||||
if not self.enable_seg:
|
||||
return False
|
||||
|
||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and result.is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
@@ -185,7 +187,7 @@ class RespondStage(Stage):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
result.chain[idx] = component
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
@@ -130,6 +131,8 @@ class ResultDecorateStage(Stage):
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain):
|
||||
text += comp.text
|
||||
|
||||
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
|
||||
async for _ in self.content_safe_check_stage.process(
|
||||
event,
|
||||
check_text=text,
|
||||
@@ -151,7 +154,8 @@ class ResultDecorateStage(Stage):
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
||||
)
|
||||
await handler.handler(event)
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
|
||||
if (result := event.get_result()) is None or not result.chain:
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||
WecomAIBotMessageEvent,
|
||||
)
|
||||
|
||||
from . import STAGES_ORDER
|
||||
from .context import PipelineContext
|
||||
@@ -78,7 +82,7 @@ class PipelineScheduler:
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
||||
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
def get_sender_name(self) -> str:
|
||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||
if isinstance(self.message_obj.sender.nickname, str):
|
||||
return self.message_obj.sender.nickname
|
||||
return ""
|
||||
|
||||
def set_extra(self, key, value):
|
||||
"""设置额外的信息。"""
|
||||
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self.call_llm = call_llm
|
||||
|
||||
def get_result(self) -> MessageEventResult:
|
||||
def get_result(self) -> MessageEventResult | None:
|
||||
"""获取消息事件的结果。"""
|
||||
return self._result
|
||||
|
||||
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
session_id: str = None,
|
||||
session_id: str = "",
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
system_prompt: str = "",
|
||||
|
||||
@@ -54,7 +54,7 @@ class AstrBotMessage:
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||
message_id: str # 消息id
|
||||
group: Group # 群组
|
||||
group: Group | None # 群组
|
||||
sender: MessageMember # 发送者
|
||||
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
||||
return ""
|
||||
|
||||
@group_id.setter
|
||||
def group_id(self, value: str):
|
||||
def group_id(self, value: str | None):
|
||||
"""设置 group_id"""
|
||||
if value:
|
||||
if self.group:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -100,7 +100,7 @@ class Platform(abc.ABC):
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -116,7 +116,7 @@ class Platform(abc.ABC):
|
||||
self,
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
) -> None:
|
||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||
|
||||
异步方法。
|
||||
|
||||
@@ -7,7 +7,7 @@ class PlatformMetadata:
|
||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str | None = None
|
||||
id: str
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict | None = None
|
||||
|
||||
@@ -40,6 +40,7 @@ def register_platform_adapter(
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
id=adapter_name,
|
||||
default_config_tmpl=default_config_tmpl,
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
|
||||
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
bot: CQHttp,
|
||||
event: Event | None,
|
||||
is_group: bool,
|
||||
session_id: str,
|
||||
session_id: str | None,
|
||||
messages: list[dict],
|
||||
):
|
||||
# session_id 必须是纯数字字符串
|
||||
session_id = int(session_id) if session_id.isdigit() else None
|
||||
session_id_int = (
|
||||
int(session_id) if session_id and session_id.isdigit() else None
|
||||
)
|
||||
|
||||
if is_group and isinstance(session_id, int):
|
||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
||||
elif not is_group and isinstance(session_id, int):
|
||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
||||
if is_group and isinstance(session_id_int, int):
|
||||
await bot.send_group_msg(group_id=session_id_int, message=messages)
|
||||
elif not is_group and isinstance(session_id_int, int):
|
||||
await bot.send_private_msg(user_id=session_id_int, message=messages)
|
||||
elif isinstance(event, Event): # 最后兜底
|
||||
await bot.send(event=event, message=messages)
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
|
||||
self.metadata = PlatformMetadata(
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 请求类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||
)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if event.get("group_id"):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
|
||||
@param event: 事件对象
|
||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||
"""
|
||||
assert event.sender is not None
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(
|
||||
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
|
||||
if event["message_type"] == "group":
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
abm.group_id = str(event.group_id)
|
||||
abm.group = Group(str(event.group_id))
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send(event, err)
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return None
|
||||
raise ValueError(err)
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import dingtalk_stream
|
||||
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self.client_id = platform_config["client_id"]
|
||||
self.client_secret = platform_config["client_secret"]
|
||||
|
||||
outer_self = self
|
||||
|
||||
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
||||
async def process(self, message: dingtalk_stream.CallbackMessage):
|
||||
logger.debug(f"dingtalk: {message.data}")
|
||||
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||
abm = await self.convert_msg(im)
|
||||
await self.handle_msg(abm)
|
||||
abm = await outer_self.convert_msg(im)
|
||||
await outer_self.handle_msg(abm)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
self.client,
|
||||
)
|
||||
self.client_ = client # 用于 websockets 的 client
|
||||
self._shutdown_event: threading.Event | None = None
|
||||
|
||||
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
||||
if not dingtalk_id:
|
||||
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
abm = AstrBotMessage()
|
||||
abm.message = []
|
||||
abm.message_str = ""
|
||||
abm.timestamp = int(message.create_at / 1000)
|
||||
abm.timestamp = int(cast(int, message.create_at) / 1000)
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if message.conversation_type == "2"
|
||||
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
nickname=message.sender_nick,
|
||||
)
|
||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||
abm.message_id = message.message_id
|
||||
abm.message_id = cast(str, message.message_id)
|
||||
abm.raw_message = message
|
||||
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
message_type: str = message.message_type
|
||||
message_type: str = cast(str, message.message_type)
|
||||
match message_type:
|
||||
case "text":
|
||||
abm.message_str = message.text.content.strip()
|
||||
abm.message.append(Plain(abm.message_str))
|
||||
case "richText":
|
||||
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
||||
contents: list[dict] = rtc.rich_text_list
|
||||
rtc: dingtalk_stream.RichTextContent = cast(
|
||||
dingtalk_stream.RichTextContent, message.rich_text_content
|
||||
)
|
||||
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
|
||||
for content in contents:
|
||||
plains = ""
|
||||
if "text" in content:
|
||||
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
elif "type" in content and content["type"] == "picture":
|
||||
f_path = await self.download_ding_file(
|
||||
content["downloadCode"],
|
||||
message.robot_code,
|
||||
cast(str, message.robot_code),
|
||||
"jpg",
|
||||
)
|
||||
abm.message.append(Image.fromFileSystem(f_path))
|
||||
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return None
|
||||
return ""
|
||||
resp_data = await resp.json()
|
||||
download_url = resp_data["data"]["downloadUrl"]
|
||||
await download_file(download_url, f_path)
|
||||
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
logger.error(
|
||||
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
|
||||
)
|
||||
return None
|
||||
return ""
|
||||
return (await resp.json())["data"]["accessToken"]
|
||||
|
||||
async def handle_msg(self, abm: AstrBotMessage):
|
||||
@@ -250,8 +256,10 @@ class DingtalkPlatformAdapter(Platform):
|
||||
def monkey_patch_close():
|
||||
raise KeyboardInterrupt("Graceful shutdown")
|
||||
|
||||
if self.client_.websocket is not None:
|
||||
self.client_.open_connection = monkey_patch_close
|
||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
|
||||
def get_client(self):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import dingtalk_stream
|
||||
|
||||
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
segment.text,
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client.reply_markdown,
|
||||
"😄",
|
||||
markdown_str,
|
||||
self.message_obj.raw_message,
|
||||
cast(
|
||||
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
|
||||
),
|
||||
)
|
||||
logger.debug(f"send image: {ret}")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import discord
|
||||
|
||||
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
|
||||
# 回调函数
|
||||
self.on_message_received = None
|
||||
self.on_ready_once_callback = None
|
||||
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
|
||||
self._ready_once_fired = False
|
||||
|
||||
@override
|
||||
async def on_ready(self):
|
||||
"""当机器人成功连接并准备就绪时触发"""
|
||||
if self.user is None:
|
||||
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
|
||||
return
|
||||
|
||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||
logger.info("[Discord] 客户端已准备就绪。")
|
||||
|
||||
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_message_data(self, message: discord.Message) -> dict:
|
||||
"""从 discord.Message 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
is_mentioned = self.user in message.mentions
|
||||
return {
|
||||
"message": message,
|
||||
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
|
||||
|
||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||
"""从 discord.Interaction 创建数据字典"""
|
||||
if self.user is None:
|
||||
raise RuntimeError("Bot is not ready: self.user is None")
|
||||
|
||||
if interaction.user is None:
|
||||
raise ValueError("Interaction received without a valid user")
|
||||
|
||||
return {
|
||||
"interaction": interaction,
|
||||
"bot_id": str(self.user.id),
|
||||
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
|
||||
"type": "interaction",
|
||||
}
|
||||
|
||||
@override
|
||||
async def on_message(self, message: discord.Message):
|
||||
"""当接收到消息时触发"""
|
||||
if message.author.bot:
|
||||
|
||||
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
components: list[BaseMessageComponent] = None,
|
||||
timeout: float = None,
|
||||
components: list[BaseMessageComponent] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
self.components = components or []
|
||||
self.timeout = timeout
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import discord
|
||||
from discord.abc import Messageable
|
||||
from discord.abc import GuildChannel, Messageable, PrivateChannel
|
||||
from discord.channel import DMChannel
|
||||
|
||||
from astrbot import logger
|
||||
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.client_self_id = None
|
||||
self.client_self_id: str | None = None
|
||||
self.registered_handlers = []
|
||||
# 指令注册相关
|
||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
"""通过会话发送消息"""
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 创建一个 message_obj 以便在 event 中使用
|
||||
message_obj = AstrBotMessage()
|
||||
if "_" in session.session_id:
|
||||
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
user_id=str(self.client_self_id),
|
||||
nickname=self.client.user.display_name,
|
||||
)
|
||||
message_obj.self_id = self.client_self_id
|
||||
message_obj.self_id = cast(str, self.client_self_id)
|
||||
message_obj.session_id = session.session_id
|
||||
message_obj.message = message_chain.chain
|
||||
|
||||
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
"discord",
|
||||
"Discord 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
default_config_tmpl=self.config,
|
||||
support_streaming_message=False,
|
||||
)
|
||||
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
|
||||
def _get_message_type(
|
||||
self,
|
||||
channel: Messageable,
|
||||
channel: Messageable | GuildChannel | PrivateChannel,
|
||||
guild_id: int | None = None,
|
||||
) -> MessageType:
|
||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
|
||||
return MessageType.FRIEND_MESSAGE
|
||||
return MessageType.GROUP_MESSAGE
|
||||
|
||||
def _get_channel_id(self, channel: Messageable) -> str:
|
||||
def _get_channel_id(
|
||||
self, channel: Messageable | GuildChannel | PrivateChannel
|
||||
) -> str:
|
||||
"""根据 channel 对象获取ID"""
|
||||
return str(getattr(channel, "id", None))
|
||||
|
||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||
"""将普通消息转换为 AstrBotMessage"""
|
||||
message: discord.Message = data["message"]
|
||||
message = data["message"]
|
||||
|
||||
content = message.content
|
||||
|
||||
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = message_chain
|
||||
abm.raw_message = message
|
||||
abm.self_id = self.client_self_id
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.session_id = str(message.channel.id)
|
||||
abm.message_id = str(message.id)
|
||||
return abm
|
||||
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
|
||||
interaction_followup_webhook=followup_webhook,
|
||||
)
|
||||
|
||||
if self.client.user is None:
|
||||
logger.error(
|
||||
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否为斜杠指令
|
||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||
|
||||
# 1. 优先处理斜杠指令
|
||||
if is_slash_command:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
self.commit_event(message_event)
|
||||
return
|
||||
|
||||
# 2. 处理普通消息(提及检测)
|
||||
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
|
||||
raw_message = message.raw_message
|
||||
if not isinstance(raw_message, discord.Message):
|
||||
logger.warning(
|
||||
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
|
||||
)
|
||||
return
|
||||
|
||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||
is_mention = False
|
||||
|
||||
# User Mention
|
||||
if (
|
||||
self.client
|
||||
and self.client.user
|
||||
and hasattr(message.raw_message, "mentions")
|
||||
):
|
||||
if self.client.user in message.raw_message.mentions:
|
||||
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
|
||||
if self.client.user in raw_message.mentions:
|
||||
is_mention = True
|
||||
|
||||
# Role Mention(Bot 拥有的角色被提及)
|
||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
||||
if not is_mention and raw_message.role_mentions:
|
||||
bot_member = None
|
||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
||||
if raw_message.guild:
|
||||
try:
|
||||
bot_member = message.raw_message.guild.get_member(
|
||||
bot_member = raw_message.guild.get_member(
|
||||
self.client.user.id,
|
||||
)
|
||||
except Exception:
|
||||
bot_member = None
|
||||
if bot_member and hasattr(bot_member, "roles"):
|
||||
bot_roles = set(bot_member.roles)
|
||||
mentioned_roles = set(message.raw_message.role_mentions)
|
||||
mentioned_roles = set(raw_message.role_mentions)
|
||||
if (
|
||||
bot_roles
|
||||
and mentioned_roles
|
||||
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
|
||||
):
|
||||
is_mention = True
|
||||
|
||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
||||
if is_slash_command or is_mention:
|
||||
# 如果是被@的消息,设置为唤醒状态
|
||||
if is_mention:
|
||||
message_event.is_wake = True
|
||||
message_event.is_at_or_wake_command = True
|
||||
|
||||
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message = [Plain(text=message_str_for_filter)]
|
||||
abm.raw_message = ctx.interaction
|
||||
abm.self_id = self.client_self_id
|
||||
abm.self_id = cast(str, self.client_self_id)
|
||||
abm.session_id = str(ctx.channel_id)
|
||||
abm.message_id = str(ctx.interaction.id)
|
||||
|
||||
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
|
||||
def _extract_command_info(
|
||||
event_filter: Any,
|
||||
handler_metadata: StarHandlerMetadata,
|
||||
) -> tuple[str, str, CommandFilter] | None:
|
||||
) -> tuple[str, str, CommandFilter | None] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
# is_group = False
|
||||
|
||||
@@ -4,8 +4,10 @@ import binascii
|
||||
from collections.abc import AsyncGenerator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import discord
|
||||
from discord.types.interactions import ComponentInteractionData
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
channel = await self._get_channel()
|
||||
if not channel:
|
||||
return
|
||||
if not isinstance(channel, discord.abc.Messageable):
|
||||
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
|
||||
return
|
||||
await channel.send(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
||||
async def _get_channel(
|
||||
self,
|
||||
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
|
||||
"""获取当前事件对应的频道对象"""
|
||||
try:
|
||||
channel_id = int(self.session_id)
|
||||
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
async def _parse_to_discord(
|
||||
self,
|
||||
message: MessageChain,
|
||||
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
|
||||
) -> tuple[
|
||||
str,
|
||||
list[discord.File],
|
||||
discord.ui.View | None,
|
||||
list[discord.Embed],
|
||||
str | int | None,
|
||||
]:
|
||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||
content_parts = []
|
||||
files = []
|
||||
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"add_reaction",
|
||||
):
|
||||
await self.message_obj.raw_message.add_reaction(emoji)
|
||||
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
|
||||
emoji
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||
|
||||
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.application_command
|
||||
)
|
||||
|
||||
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
return (
|
||||
hasattr(self.message_obj, "raw_message")
|
||||
and hasattr(self.message_obj.raw_message, "type")
|
||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
||||
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||
== discord.InteractionType.component
|
||||
)
|
||||
|
||||
def get_interaction_custom_id(self) -> str:
|
||||
"""获取交互组件的custom_id"""
|
||||
if self.is_button_interaction():
|
||||
try:
|
||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
||||
return cast(
|
||||
ComponentInteractionData,
|
||||
cast(discord.Interaction, self.message_obj.raw_message).data,
|
||||
).get("custom_id", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
):
|
||||
return any(
|
||||
mention.id == int(self.message_obj.self_id)
|
||||
for mention in self.message_obj.raw_message.mentions
|
||||
for mention in cast(
|
||||
discord.Message, self.message_obj.raw_message
|
||||
).mentions
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
||||
self.message_obj.raw_message,
|
||||
"clean_content",
|
||||
):
|
||||
return self.message_obj.raw_message.clean_content
|
||||
return cast(discord.Message, self.message_obj.raw_message).clean_content
|
||||
return self.message_str
|
||||
|
||||
@@ -3,9 +3,14 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot import logger
|
||||
@@ -74,6 +79,10 @@ class LarkPlatformAdapter(Platform):
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
@@ -114,14 +123,21 @@ class LarkPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
if event.event is None:
|
||||
logger.debug("[Lark] 收到空事件(event.event is None)")
|
||||
return
|
||||
message = event.event.message
|
||||
if message is None:
|
||||
logger.debug("[Lark] 事件中没有消息体(message is None)")
|
||||
return
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.timestamp = int(message.create_time) / 1000
|
||||
abm.timestamp = cast(int, message.create_time) // 1000
|
||||
abm.message = []
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
@@ -136,14 +152,28 @@ class LarkPlatformAdapter(Platform):
|
||||
at_list = {}
|
||||
if message.mentions:
|
||||
for m in message.mentions:
|
||||
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
||||
if m.id is None:
|
||||
continue
|
||||
# 飞书 open_id 可能是 None,这里做个防护
|
||||
open_id = m.id.open_id if m.id.open_id else ""
|
||||
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
|
||||
|
||||
if m.name == self.bot_name:
|
||||
if m.id.open_id is not None:
|
||||
abm.self_id = m.id.open_id
|
||||
|
||||
if message.content is None:
|
||||
logger.warning("[Lark] 消息内容为空")
|
||||
return
|
||||
|
||||
try:
|
||||
content_json_b = json.loads(message.content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
|
||||
return
|
||||
|
||||
if message.message_type == "text":
|
||||
message_str_raw = content_json_b["text"] # 带有 @ 的消息
|
||||
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
|
||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||
# at_users = re.findall(at_pattern, message_str_raw)
|
||||
# 拆分文本,去掉AT符号部分
|
||||
@@ -168,27 +198,47 @@ class LarkPlatformAdapter(Platform):
|
||||
content_json_b = _ls
|
||||
elif message.message_type == "image":
|
||||
content_json_b = [
|
||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
|
||||
{
|
||||
"tag": "img",
|
||||
"image_key": content_json_b.get("image_key"),
|
||||
"style": [],
|
||||
},
|
||||
]
|
||||
|
||||
if message.message_type in ("post", "image"):
|
||||
for comp in content_json_b:
|
||||
if comp["tag"] == "at":
|
||||
abm.message.append(at_list[comp["user_id"]])
|
||||
elif comp["tag"] == "text" and comp["text"].strip():
|
||||
if comp.get("tag") == "at":
|
||||
user_id = comp.get("user_id")
|
||||
if user_id in at_list:
|
||||
abm.message.append(at_list[user_id])
|
||||
elif comp.get("tag") == "text" and comp.get("text", "").strip():
|
||||
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||
elif comp["tag"] == "img":
|
||||
image_key = comp["image_key"]
|
||||
elif comp.get("tag") == "img":
|
||||
image_key = comp.get("image_key")
|
||||
if not image_key:
|
||||
continue
|
||||
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message.message_id)
|
||||
.message_id(cast(str, message.message_id))
|
||||
.file_key(image_key)
|
||||
.type("image")
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
continue
|
||||
|
||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法下载飞书图片: {image_key}")
|
||||
continue
|
||||
|
||||
if response.file is None:
|
||||
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
|
||||
continue
|
||||
|
||||
image_bytes = response.file.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode()
|
||||
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||
@@ -196,6 +246,19 @@ class LarkPlatformAdapter(Platform):
|
||||
for comp in abm.message:
|
||||
if isinstance(comp, Comp.Plain):
|
||||
abm.message_str += comp.text
|
||||
|
||||
if message.message_id is None:
|
||||
logger.error("[Lark] 消息缺少 message_id")
|
||||
return
|
||||
|
||||
if (
|
||||
event.event.sender is None
|
||||
or event.event.sender.sender_id is None
|
||||
or event.event.sender.sender_id.open_id is None
|
||||
):
|
||||
logger.error("[Lark] 消息发送者信息不完整")
|
||||
return
|
||||
|
||||
abm.message_id = message.message_id
|
||||
abm.raw_message = message
|
||||
abm.sender = MessageMember(
|
||||
@@ -235,5 +298,5 @@ class LarkPlatformAdapter(Platform):
|
||||
await self.client._disconnect()
|
||||
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||
|
||||
def get_client(self) -> lark.Client:
|
||||
def get_client(self) -> lark.ws.Client:
|
||||
return self.client
|
||||
|
||||
@@ -5,7 +5,15 @@ import uuid
|
||||
from io import BytesIO
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
ReplyMessageRequest,
|
||||
ReplyMessageRequestBody,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
file_path = image_file_path if image_file_path else ""
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
file_path = comp.file
|
||||
file_path = comp.file if comp.file else ""
|
||||
|
||||
if image_file is None:
|
||||
if not file_path:
|
||||
logger.error("[Lark] 图片路径为空,无法上传")
|
||||
continue
|
||||
try:
|
||||
image_file = open(file_path, "rb")
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开图片文件: {e}")
|
||||
continue
|
||||
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
|
||||
continue
|
||||
|
||||
response = await lark_client.im.v1.image.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
continue
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
|
||||
continue
|
||||
|
||||
image_key = response.data.image_key
|
||||
logger.debug(image_key)
|
||||
ret.append(_stage)
|
||||
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||
return
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
|
||||
if not response.success():
|
||||
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
await super().send(message)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
return
|
||||
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
if not isinstance(message.raw_message, dict):
|
||||
message.raw_message = {}
|
||||
message.raw_message["poll"] = poll
|
||||
message.poll = poll
|
||||
message.__setattr__("poll", poll)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
self,
|
||||
session: MessageSession,
|
||||
message_chain: MessageChain,
|
||||
) -> Awaitable[Any]:
|
||||
) -> None:
|
||||
if not self.api:
|
||||
logger.error("[Misskey] API 客户端未初始化")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import aiofiles
|
||||
import botpy
|
||||
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
ret = cast(
|
||||
message.Message,
|
||||
await self._post_send(stream=stream_payload),
|
||||
)
|
||||
stream_payload["index"] += 1
|
||||
stream_payload["id"] = ret["id"]
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
return None
|
||||
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(
|
||||
|
||||
if not isinstance(
|
||||
source,
|
||||
(
|
||||
botpy.message.Message,
|
||||
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
botpy.message.DirectMessage,
|
||||
botpy.message.C2CMessage,
|
||||
),
|
||||
)
|
||||
):
|
||||
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
|
||||
return None
|
||||
|
||||
(
|
||||
plain_text,
|
||||
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
):
|
||||
return None
|
||||
|
||||
payload = {
|
||||
payload: dict = {
|
||||
"content": plain_text,
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
|
||||
ret = None
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
match source:
|
||||
case botpy.message.GroupMessage():
|
||||
if not source.group_openid:
|
||||
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
|
||||
return None
|
||||
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(
|
||||
image_base64,
|
||||
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
group_openid=source.group_openid,
|
||||
**payload,
|
||||
)
|
||||
case botpy.message.C2CMessage:
|
||||
|
||||
case botpy.message.C2CMessage():
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(
|
||||
image_base64,
|
||||
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
**payload,
|
||||
)
|
||||
logger.debug(f"Message sent to C2C: {ret}")
|
||||
case botpy.message.Message:
|
||||
|
||||
case botpy.message.Message():
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
ret = await self.bot.api.post_message(
|
||||
channel_id=source.channel_id,
|
||||
**payload,
|
||||
)
|
||||
case botpy.message.DirectMessage:
|
||||
|
||||
case botpy.message.DirectMessage():
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
case _:
|
||||
pass
|
||||
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"file_type": file_type,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
result = None
|
||||
if "openid" in kwargs:
|
||||
payload["openid"] = kwargs["openid"]
|
||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
if "group_openid" in kwargs:
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
elif "group_openid" in kwargs:
|
||||
payload["group_openid"] = kwargs["group_openid"]
|
||||
route = Route(
|
||||
"POST",
|
||||
"/v2/groups/{group_openid}/files",
|
||||
group_openid=kwargs["group_openid"],
|
||||
)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
else:
|
||||
raise ValueError("Invalid upload parameters")
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise RuntimeError(
|
||||
f"Failed to upload image, response is not dict: {result}"
|
||||
)
|
||||
|
||||
return Media(
|
||||
file_uuid=result["file_uuid"],
|
||||
file_info=result["file_info"],
|
||||
ttl=result.get("ttl", 0),
|
||||
)
|
||||
|
||||
async def upload_group_and_c2c_record(
|
||||
self,
|
||||
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if result:
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"上传文件响应格式错误: {result}")
|
||||
return None
|
||||
|
||||
return Media(
|
||||
file_uuid=result.get("file_uuid"),
|
||||
file_info=result.get("file_info"),
|
||||
file_uuid=result["file_uuid"],
|
||||
file_info=result["file_info"],
|
||||
ttl=result.get("ttl", 0),
|
||||
file_id=result.get("id", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传请求错误: {e}")
|
||||
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
message_reference: message.Reference | None = None,
|
||||
media: message.Media | None = None,
|
||||
msg_id: str | None = None,
|
||||
msg_seq: str = 1,
|
||||
msg_seq: int | None = 1,
|
||||
event_id: str | None = None,
|
||||
markdown: message.MarkdownPayload | None = None,
|
||||
keyboard: message.Keyboard | None = None,
|
||||
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
payload = locals()
|
||||
payload.pop("self", None)
|
||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
result = await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise RuntimeError(
|
||||
f"Failed to post c2c message, response is not dict: {result}"
|
||||
)
|
||||
|
||||
return message.Message(**result)
|
||||
|
||||
@staticmethod
|
||||
async def _parse_to_qqofficial(message: MessageChain):
|
||||
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
elif i.file and i.file.startswith("base64://"):
|
||||
image_base64 = i.file
|
||||
else:
|
||||
elif i.file:
|
||||
image_base64 = file_to_base64(i.file)
|
||||
else:
|
||||
raise ValueError("Unsupported image file format")
|
||||
image_base64 = image_base64.removeprefix("base64://")
|
||||
elif isinstance(i, Record):
|
||||
if i.file:
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -44,7 +45,9 @@ class botClient(Client):
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
self._commit(abm)
|
||||
|
||||
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unique_session: bool = platform_settings["unique_session"]
|
||||
qq_group = platform_config["enable_group_c2c"]
|
||||
guild_dm = platform_config["enable_guild_direct_message"]
|
||||
|
||||
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="qq_official",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_from_qqofficial(
|
||||
message: botpy.message.Message | botpy.message.GroupMessage,
|
||||
message: botpy.message.Message
|
||||
| botpy.message.GroupMessage
|
||||
| botpy.message.DirectMessage
|
||||
| botpy.message.C2CMessage,
|
||||
message_type: MessageType,
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
abm.timestamp = int(time.time())
|
||||
abm.raw_message = message
|
||||
abm.message_id = message.id
|
||||
abm.tag = "qq_official"
|
||||
# abm.tag = "qq_official"
|
||||
msg: list[BaseMessageComponent] = []
|
||||
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(
|
||||
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
message,
|
||||
botpy.message.DirectMessage,
|
||||
):
|
||||
try:
|
||||
if isinstance(message, botpy.message.Message):
|
||||
abm.self_id = str(message.mentions[0].id)
|
||||
except BaseException as _:
|
||||
else:
|
||||
abm.self_id = ""
|
||||
|
||||
plain_content = message.content.replace(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import botpy
|
||||
import botpy.message
|
||||
@@ -36,7 +36,9 @@ class botClient(Client):
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
self._commit(abm)
|
||||
|
||||
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
return PlatformMetadata(
|
||||
name="qq_official_webhook",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import quart
|
||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
|
||||
|
||||
if opcode == 13:
|
||||
# validation
|
||||
signed = await self.webhook_validation(data)
|
||||
signed = await self.webhook_validation(cast(dict, data))
|
||||
print(signed)
|
||||
return signed
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ import hmac
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from quart import Quart, Response, request
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
@@ -66,7 +68,7 @@ class SlackWebhookClient:
|
||||
"""
|
||||
try:
|
||||
# 获取请求体和头部
|
||||
body = await req.get_data()
|
||||
body = cast(bytes, await req.get_data())
|
||||
event_data = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Verify Slack request signature
|
||||
@@ -139,9 +141,14 @@ class SlackSocketClient:
|
||||
self.event_handler = event_handler
|
||||
self.socket_client = None
|
||||
|
||||
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
|
||||
async def _handle_events(
|
||||
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
|
||||
):
|
||||
"""处理 Socket Mode 事件"""
|
||||
try:
|
||||
if self.socket_client is None:
|
||||
raise RuntimeError("Socket client is not initialized")
|
||||
|
||||
# 确认收到事件
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
await self.socket_client.send_socket_mode_response(response)
|
||||
|
||||
@@ -3,8 +3,7 @@ import base64
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import aiohttp
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
|
||||
self.metadata = PlatformMetadata(
|
||||
name="slack",
|
||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||
id=self.config.get("id"),
|
||||
id=cast(str, self.config.get("id")),
|
||||
support_streaming_message=False,
|
||||
)
|
||||
|
||||
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
|
||||
logger.debug(f"[slack] RawMessage {event}")
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = self.bot_self_id
|
||||
abm.self_id = cast(str, self.bot_self_id)
|
||||
|
||||
# 获取用户信息
|
||||
user_id = event.get("user", "")
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=user_id)
|
||||
user_data = user_info["user"]
|
||||
user_data = cast(dict, user_info["user"])
|
||||
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
||||
except Exception:
|
||||
user_name = user_id
|
||||
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
|
||||
channel_id = event.get("channel", "")
|
||||
try:
|
||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||
is_im = channel_info["channel"]["is_im"]
|
||||
is_im = cast(dict, channel_info["channel"])["is_im"]
|
||||
|
||||
if is_im:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
|
||||
for mention in mentions:
|
||||
try:
|
||||
mentioned_user = await self.web_client.users_info(user=mention)
|
||||
user_data = mentioned_user["user"]
|
||||
user_data = cast(dict, mentioned_user["user"])
|
||||
user_name = user_data.get("real_name") or user_data.get(
|
||||
"name",
|
||||
mention,
|
||||
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
|
||||
)
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
|
||||
async def run(self) -> Awaitable[Any]:
|
||||
async def run(self) -> None:
|
||||
self.bot_self_id = await self.get_bot_user_id()
|
||||
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, Image):
|
||||
# upload file
|
||||
url = segment.url or segment.file
|
||||
if url.startswith("http"):
|
||||
if url and url.startswith("http"):
|
||||
return {
|
||||
"type": "image",
|
||||
"image_url": url,
|
||||
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
||||
}
|
||||
image_url = response["files"][0]["url_private"]
|
||||
image_url = cast(list, response["files"])[0]["url_private"]
|
||||
logger.debug(f"Slack file upload response: {response}")
|
||||
return {
|
||||
"type": "image",
|
||||
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||
}
|
||||
file_url = response["files"][0]["permalink"]
|
||||
file_url = cast(list, response["files"])[0]["permalink"]
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
|
||||
members = []
|
||||
for member_id in members_response["members"]:
|
||||
for member_id in cast(Iterable, members_response["members"]):
|
||||
try:
|
||||
user_info = await self.web_client.users_info(user=member_id)
|
||||
user_data = user_info["user"]
|
||||
user_data = cast(dict, user_info["user"])
|
||||
members.append(
|
||||
MessageMember(
|
||||
user_id=member_id,
|
||||
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
||||
# 如果获取用户信息失败,使用默认信息
|
||||
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
||||
|
||||
channel_data = channel_info["channel"]
|
||||
channel_data = cast(dict, channel_info["channel"])
|
||||
return Group(
|
||||
group_id=channel_id,
|
||||
group_name=channel_data.get("name", ""),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
import telegramify_markdown
|
||||
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
|
||||
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
|
||||
Reply,
|
||||
)
|
||||
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
"chat_id": user_name,
|
||||
}
|
||||
if has_reply:
|
||||
payload["reply_to_message_id"] = reply_message_id
|
||||
payload["reply_to_message_id"] = str(reply_message_id)
|
||||
if message_thread_id:
|
||||
payload["message_thread_id"] = message_thread_id
|
||||
|
||||
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
md_text = telegramify_markdown.markdownify(
|
||||
chunk,
|
||||
max_line_length=None,
|
||||
normalize_whitespace=False,
|
||||
)
|
||||
await client.send_message(
|
||||
text=md_text,
|
||||
parse_mode="MarkdownV2",
|
||||
**payload,
|
||||
**cast(Any, payload),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 send failed: {e}. Using plain text instead.",
|
||||
)
|
||||
await client.send_message(text=chunk, **payload)
|
||||
await client.send_message(text=chunk, **cast(Any, payload))
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
await client.send_photo(photo=image_path, **cast(Any, payload))
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
||||
path = await i.get_file()
|
||||
name = i.name or os.path.basename(path)
|
||||
await client.send_document(
|
||||
document=path, filename=name, **cast(Any, payload)
|
||||
)
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await client.send_voice(voice=path, **payload)
|
||||
await client.send_voice(voice=path, **cast(Any, payload))
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
delta += i.text
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await self.client.send_photo(photo=image_path, **payload)
|
||||
await self.client.send_photo(
|
||||
photo=image_path, **cast(Any, payload)
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, i.name)
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
path = await i.get_file()
|
||||
name = i.name or os.path.basename(path)
|
||||
|
||||
await self.client.send_document(
|
||||
document=i.file,
|
||||
filename=i.name,
|
||||
**payload,
|
||||
document=path,
|
||||
filename=name,
|
||||
**cast(Any, payload),
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await self.client.send_voice(voice=path, **payload)
|
||||
await self.client.send_voice(voice=path, **cast(Any, payload))
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
msg = await self.client.send_message(
|
||||
text=delta, **cast(Any, payload)
|
||||
)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
try:
|
||||
markdown_text = telegramify_markdown.markdownify(
|
||||
delta,
|
||||
max_line_length=None,
|
||||
normalize_whitespace=False,
|
||||
)
|
||||
await self.client.edit_message_text(
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
def run(self) -> Coroutine[Any, Any, None]:
|
||||
async def callback(data: tuple):
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -101,9 +101,9 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain | None):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
await super().send(message)
|
||||
await super().send(MessageChain([]))
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.wxid: str | None = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(),
|
||||
"wechatpadpro_credentials.json",
|
||||
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def handle_websocket_message(self, message: str):
|
||||
async def handle_websocket_message(self, message: str | bytes):
|
||||
"""处理从 WebSocket 接收到的消息。"""
|
||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||
try:
|
||||
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
|
||||
|
||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
|
||||
if self.wxid is None:
|
||||
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
|
||||
return None
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = raw_message
|
||||
abm.message_id = str(raw_message.get("msg_id"))
|
||||
abm.timestamp = raw_message.get("create_time")
|
||||
abm.timestamp = cast(int, raw_message.get("create_time"))
|
||||
abm.self_id = self.wxid
|
||||
|
||||
if int(time.time()) - abm.timestamp > 180:
|
||||
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
content = raw_message.get("content", {}).get("str", "")
|
||||
push_content = raw_message.get("push_content", "")
|
||||
msg_type = raw_message.get("msg_type")
|
||||
msg_type = cast(int, raw_message.get("msg_type"))
|
||||
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
|
||||
from_user_name: str,
|
||||
to_user_name: str,
|
||||
msg_id: int,
|
||||
):
|
||||
) -> dict | None:
|
||||
"""下载原始图片。"""
|
||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||
params = {"key": self.auth_key}
|
||||
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
|
||||
# 图片消息
|
||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
msg_id = raw_message.get("msg_id")
|
||||
msg_id = cast(int, raw_message.get("msg_id"))
|
||||
image_resp = await self._download_raw_image(
|
||||
from_user_name,
|
||||
to_user_name,
|
||||
msg_id,
|
||||
)
|
||||
if image_resp is None:
|
||||
logger.error(f"下载图片失败: msg_id={msg_id}")
|
||||
return
|
||||
image_bs64_data = (
|
||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||
)
|
||||
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
|
||||
bufid = 0
|
||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||
new_msg_id = raw_message.get("new_msg_id")
|
||||
if new_msg_id is None:
|
||||
logger.error("语音消息缺少 new_msg_id")
|
||||
return
|
||||
data_parser = GeweDataParser(
|
||||
content=content,
|
||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
|
||||
)
|
||||
|
||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||
if voicemsg is None:
|
||||
logger.error("无法从 XML 解析 voicemsg 节点")
|
||||
return
|
||||
bufid = voicemsg.get("bufid") or "0"
|
||||
length = int(voicemsg.get("length") or 0)
|
||||
voice_resp = await self.download_voice(
|
||||
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
|
||||
bufid=bufid,
|
||||
length=length,
|
||||
)
|
||||
if voice_resp is None:
|
||||
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
|
||||
return
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
@@ -827,6 +843,7 @@ class WeChatPadProAdapter(Platform):
|
||||
try:
|
||||
if self.ws_handle_task:
|
||||
self.ws_handle_task.cancel()
|
||||
if self._shutdown_event is not None:
|
||||
self._shutdown_event.set()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
|
||||
|
||||
async def get_contact_details_list(
|
||||
self,
|
||||
room_wx_id_list: list[str] = None,
|
||||
user_names: list[str] = None,
|
||||
room_wx_id_list: list[str] | None = None,
|
||||
user_names: list[str] | None = None,
|
||||
) -> dict | None:
|
||||
"""获取联系人详情列表。"""
|
||||
if room_wx_id_list is None:
|
||||
|
||||
@@ -2,7 +2,8 @@ import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -40,7 +41,7 @@ else:
|
||||
class WecomServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.port = int(cast(str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.server.add_url_rule(
|
||||
"/callback/command",
|
||||
@@ -60,7 +61,7 @@ class WecomServer:
|
||||
config["corpid"].strip(),
|
||||
)
|
||||
|
||||
self.callback = None
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
@@ -114,7 +115,7 @@ class WecomServer:
|
||||
logger.error("解密失败,签名异常,请检查配置。")
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
msg = cast(BaseMessage, parse_message(xml))
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
|
||||
# inject
|
||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
self.client.__setattr__("kf", self.wechat_kf_api)
|
||||
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if msg.type == "text":
|
||||
assert isinstance(msg, TextMessage)
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.self_id = str(msg.agent)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
elif isinstance(msg, ImageMessage):
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.agent)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
|
||||
elif isinstance(msg, VoiceMessage):
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self.client.media.download,
|
||||
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(msg.id)
|
||||
abm.timestamp = int(cast(int | str, msg.time))
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||
msgtype = msg.get("msgtype")
|
||||
external_userid = msg.get("external_userid")
|
||||
external_userid = cast(str, msg.get("external_userid"))
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
|
||||
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
if is_wechat_kf:
|
||||
# 微信客服
|
||||
kf_message_api = getattr(self.client, "kf_message", None)
|
||||
if not kf_message_api:
|
||||
if not isinstance(kf_message_api, WeChatKFMessage):
|
||||
logger.warning("未找到微信客服发送消息方法。")
|
||||
return
|
||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||
|
||||
user_id = self.get_sender_id()
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
|
||||
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
|
||||
@staticmethod
|
||||
async def _send(
|
||||
message_chain: MessageChain,
|
||||
message_chain: MessageChain | None,
|
||||
stream_id: str,
|
||||
queue_mgr: WecomAIQueueMgr,
|
||||
streaming: bool = False,
|
||||
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
async def send(self, message: MessageChain | None):
|
||||
"""发送消息"""
|
||||
raw = self.message_obj.raw_message
|
||||
assert isinstance(raw, dict), (
|
||||
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
stream_id = raw.get("stream_id", self.session_id)
|
||||
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
|
||||
await super().send(message)
|
||||
await super().send(MessageChain([]))
|
||||
|
||||
async def send_streaming(self, generator, use_fallback=False):
|
||||
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import quart
|
||||
from requests import Response
|
||||
@@ -36,7 +37,7 @@ else:
|
||||
class WeixinOfficialAccountServer:
|
||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||
self.server = quart.Quart(__name__)
|
||||
self.port = int(config.get("port"))
|
||||
self.port = int(cast(int | str, config.get("port")))
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
self.token = config.get("token")
|
||||
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
self.callback = None
|
||||
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
async def verify(self):
|
||||
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
|
||||
raise
|
||||
else:
|
||||
msg = parse_message(xml)
|
||||
if not msg:
|
||||
logger.error("解析失败。msg为None。")
|
||||
raise
|
||||
logger.info(f"解析成功: {msg}")
|
||||
|
||||
if self.callback:
|
||||
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||
|
||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||
# msgid -> Future
|
||||
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[msg.id]
|
||||
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.wexin_event_workers[msg.id] = future
|
||||
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
|
||||
await self.convert_message(msg, future)
|
||||
# I love shield so much!
|
||||
result = await asyncio.wait_for(
|
||||
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
60,
|
||||
) # wait for 60s
|
||||
logger.debug(f"Got future result: {result}")
|
||||
self.wexin_event_workers.pop(msg.id, None)
|
||||
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
|
||||
return result # xml. see weixin_offacc_event.py
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
async def convert_message(
|
||||
self,
|
||||
msg,
|
||||
future: asyncio.Future = None,
|
||||
future: asyncio.Future | None = None,
|
||||
) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if isinstance(msg, TextMessage):
|
||||
abm.message_str = msg.content
|
||||
abm.message_str = cast(str, msg.content)
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Plain(msg.content)]
|
||||
abm.message = [Plain(cast(str, msg.content))]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "image":
|
||||
assert isinstance(msg, ImageMessage)
|
||||
abm.message_str = "[图片]"
|
||||
abm.self_id = str(msg.target)
|
||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif msg.type == "voice":
|
||||
assert isinstance(msg, VoiceMessage)
|
||||
@@ -306,14 +310,15 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
abm.sender = MessageMember(
|
||||
msg.source,
|
||||
msg.source,
|
||||
cast(str, msg.source),
|
||||
cast(str, msg.source),
|
||||
)
|
||||
abm.message_id = msg.id
|
||||
abm.timestamp = msg.time
|
||||
abm.message_id = str(cast(str | int, msg.id))
|
||||
abm.timestamp = cast(int, msg.time)
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
if future:
|
||||
future.set_result(None)
|
||||
return
|
||||
# 很不优雅 :(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from wechatpy import WeChatClient
|
||||
from wechatpy.replies import ImageReply, TextReply, VoiceReply
|
||||
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
||||
active_send_mode = cast(dict, message_obj.raw_message).get(
|
||||
"active_send_mode", False
|
||||
)
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = TextReply(
|
||||
content=chunk,
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = ImageReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
reply = VoiceReply(
|
||||
media_id=response["media_id"],
|
||||
message=self.message_obj.raw_message["message"],
|
||||
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||
)
|
||||
xml = reply.render()
|
||||
future = self.message_obj.raw_message["future"]
|
||||
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||
assert isinstance(future, asyncio.Future)
|
||||
future.set_result(xml)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -118,7 +118,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list[dict],
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> FuncTool:
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
@@ -140,7 +140,7 @@ class FunctionToolManager:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> None:
|
||||
"""添加函数调用工具
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from astrbot.core import astrbot_config, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -10,6 +11,7 @@ from .entities import ProviderType
|
||||
from .provider import (
|
||||
EmbeddingProvider,
|
||||
Provider,
|
||||
Providers,
|
||||
RerankProvider,
|
||||
STTProvider,
|
||||
TTSProvider,
|
||||
@@ -17,6 +19,11 @@ from .provider import (
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HasInitialize(Protocol):
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -48,7 +55,7 @@ class ProviderManager:
|
||||
"""加载的 Rerank Provider 的实例"""
|
||||
self.inst_map: dict[
|
||||
str,
|
||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||
Providers,
|
||||
] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
self.llm_tools = llm_tools
|
||||
@@ -123,15 +130,13 @@ class ProviderManager:
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
return self.inst_map.get(provider_id)
|
||||
|
||||
def get_using_provider(
|
||||
self,
|
||||
provider_type: ProviderType,
|
||||
umo=None,
|
||||
) -> Provider | STTProvider | TTSProvider | None:
|
||||
self, provider_type: ProviderType, umo=None
|
||||
) -> Providers | None:
|
||||
"""获取正在使用的提供商实例。
|
||||
|
||||
Args:
|
||||
@@ -191,7 +196,6 @@ class ProviderManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
# 设置默认提供商
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
@@ -210,15 +214,37 @@ class ProviderManager:
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
||||
|
||||
temp_provider = (
|
||||
self.inst_map.get(selected_provider_id)
|
||||
if isinstance(selected_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_provider_inst = (
|
||||
temp_provider if isinstance(temp_provider, Provider) else None
|
||||
)
|
||||
if not self.curr_provider_inst and self.provider_insts:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
|
||||
temp_stt = (
|
||||
self.inst_map.get(selected_stt_provider_id)
|
||||
if isinstance(selected_stt_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_stt_provider_inst = (
|
||||
temp_stt if isinstance(temp_stt, STTProvider) else None
|
||||
)
|
||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
|
||||
temp_tts = (
|
||||
self.inst_map.get(selected_tts_provider_id)
|
||||
if isinstance(selected_tts_provider_id, str)
|
||||
else None
|
||||
)
|
||||
self.curr_tts_provider_inst = (
|
||||
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
||||
)
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
@@ -358,11 +384,16 @@ class ProviderManager:
|
||||
|
||||
provider_metadata.id = provider_config["id"]
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
match provider_metadata.provider_type:
|
||||
case ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
if not issubclass(cls_type, STTProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of STTProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
@@ -377,15 +408,22 @@ class ProviderManager:
|
||||
if not self.curr_stt_provider_inst:
|
||||
self.curr_stt_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||
case ProviderType.TEXT_TO_SPEECH:
|
||||
# TTS 任务
|
||||
if not issubclass(cls_type, TTSProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.tts_provider_insts.append(inst)
|
||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
||||
if (
|
||||
self.provider_settings.get("provider_id")
|
||||
== provider_config["id"]
|
||||
):
|
||||
self.curr_tts_provider_inst = inst
|
||||
logger.info(
|
||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||||
@@ -393,14 +431,18 @@ class ProviderManager:
|
||||
if not self.curr_tts_provider_inst:
|
||||
self.curr_tts_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
case ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
if not issubclass(cls_type, Provider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of Provider"
|
||||
)
|
||||
inst = cls_type(
|
||||
provider_config,
|
||||
self.provider_settings,
|
||||
)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
@@ -415,16 +457,30 @@ class ProviderManager:
|
||||
if not self.curr_provider_inst:
|
||||
self.curr_provider_inst = inst
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||
case ProviderType.EMBEDDING:
|
||||
if not issubclass(cls_type, EmbeddingProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
self.embedding_provider_insts.append(inst)
|
||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||
case ProviderType.RERANK:
|
||||
if not issubclass(cls_type, RerankProvider):
|
||||
raise TypeError(
|
||||
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
||||
)
|
||||
inst = cls_type(provider_config, self.provider_settings)
|
||||
if getattr(inst, "initialize", None):
|
||||
if isinstance(inst, HasInitialize):
|
||||
await inst.initialize()
|
||||
self.rerank_provider_insts.append(inst)
|
||||
case _:
|
||||
# 未知供应商抛出异常,确保inst初始化
|
||||
# Should be unreachable
|
||||
raise Exception(
|
||||
f"未知的提供商类型:{provider_metadata.provider_type}"
|
||||
)
|
||||
|
||||
self.inst_map[provider_config["id"]] = inst
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,6 +2,7 @@ import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
Providers: TypeAlias = Union[
|
||||
"Provider",
|
||||
"STTProvider",
|
||||
"TTSProvider",
|
||||
"EmbeddingProvider",
|
||||
"RerankProvider",
|
||||
]
|
||||
|
||||
|
||||
class AbstractProvider(abc.ABC):
|
||||
"""Provider Abstract Class"""
|
||||
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
|
||||
"""
|
||||
...
|
||||
if False: # pragma: no cover - make this an async generator for typing
|
||||
yield None # type: ignore
|
||||
raise NotImplementedError()
|
||||
|
||||
async def pop_record(self, context: list):
|
||||
"""弹出 context 第一条非系统提示词对话记录"""
|
||||
|
||||
@@ -29,15 +29,24 @@ class OTTSProvider:
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.client = None
|
||||
self._client: AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError(
|
||||
"Client not initialized. Please use 'async with' context."
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(timeout=self.timeout)
|
||||
self._client = AsyncClient(timeout=self.timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _sync_time(self):
|
||||
try:
|
||||
@@ -90,6 +99,7 @@ class OTTSProvider:
|
||||
if attempt == self.retry_count - 1:
|
||||
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
||||
await asyncio.sleep(0.5 * (attempt + 1))
|
||||
raise RuntimeError("OTTS未返回音频文件")
|
||||
|
||||
|
||||
class AzureNativeProvider(TTSProvider):
|
||||
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
self.endpoint = (
|
||||
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
self.client = None
|
||||
self._client: AsyncClient | None = None
|
||||
self.token = None
|
||||
self.token_expire = 0
|
||||
self.voice_params = {
|
||||
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
|
||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||
}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError(
|
||||
"Client not initialized. Please use 'async with' context."
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self.client = AsyncClient(
|
||||
self._client = AsyncClient(
|
||||
headers={
|
||||
"User-Agent": f"AstrBot/{VERSION}",
|
||||
"Content-Type": "application/ssml+xml",
|
||||
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _refresh_token(self):
|
||||
token_url = (
|
||||
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
|
||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||
self.provider = self._parse_provider(key_value, provider_config)
|
||||
|
||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
||||
def _parse_provider(
|
||||
self, key_value: str, config: dict
|
||||
) -> OTTSProvider | AzureNativeProvider:
|
||||
if key_value.lower().startswith("other["):
|
||||
json_str = ""
|
||||
try:
|
||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
|
||||
Returns:
|
||||
重排序结果列表
|
||||
"""
|
||||
if not self.client:
|
||||
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
logger.warning("文档列表为空,返回空结果")
|
||||
return []
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"text": text,
|
||||
"messages": None,
|
||||
"api_key": self.chosen_api_key,
|
||||
"voice": self.voice or "Cherry",
|
||||
"text": text,
|
||||
}
|
||||
if not self.voice:
|
||||
logging.warning(
|
||||
|
||||
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=mp3_path, output=wav_path)
|
||||
ff.convert(input_file=mp3_path, output_file=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str:
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
|
||||
Args:
|
||||
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
pattern = r"^[a-fA-F0-9]{32}$"
|
||||
return bool(re.match(pattern, reference_id.strip()))
|
||||
|
||||
async def _generate_request(self, text: str) -> dict:
|
||||
async def _generate_request(self, text: str) -> ServeTTSRequest:
|
||||
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
||||
if self.reference_id and self.reference_id.strip():
|
||||
# 验证reference_id格式
|
||||
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
text = await response.aread()
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
api_key: str = provider_config.get("embedding_api_key")
|
||||
api_base: str = provider_config.get("embedding_api_base")
|
||||
api_key: str = provider_config["embedding_api_key"]
|
||||
api_base: str = provider_config["embedding_api_base"]
|
||||
timeout: int = int(provider_config.get("timeout", 20))
|
||||
|
||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
model=self.model,
|
||||
contents=text,
|
||||
)
|
||||
assert result.embeddings is not None
|
||||
assert result.embeddings[0].values is not None
|
||||
return result.embeddings[0].values
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
try:
|
||||
result = await self.client.models.embed_content(
|
||||
model=self.model,
|
||||
contents=texts,
|
||||
contents=cast(types.ContentListUnion, text),
|
||||
)
|
||||
return [embedding.values for embedding in result.embeddings]
|
||||
assert result.embeddings is not None
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
for embedding in result.embeddings:
|
||||
assert embedding.values is not None
|
||||
embeddings.append(embedding.values)
|
||||
return embeddings
|
||||
except APIError as e:
|
||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -136,7 +137,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
|
||||
tool_list = []
|
||||
tool_list: list[types.Tool] | None = []
|
||||
model_name = self.get_model()
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
tools=cast(types.ToolListUnion | None, tool_list),
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=(
|
||||
types.ThinkingConfig(
|
||||
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
content_cls: type[types.Content],
|
||||
) -> None:
|
||||
if contents and isinstance(contents[-1], content_cls):
|
||||
assert contents[-1].parts is not None
|
||||
contents[-1].parts.extend(part)
|
||||
else:
|
||||
contents.append(content_cls(parts=part))
|
||||
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
logger.debug(f"genai result: {result}")
|
||||
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
break
|
||||
|
||||
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
|
||||
return json.dumps(dict_body)
|
||||
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
|
||||
"""进行流式请求"""
|
||||
try:
|
||||
async with (
|
||||
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
data = json.loads(message[6:])
|
||||
if "extra_info" in data:
|
||||
continue
|
||||
audio = data.get("data", {}).get("audio")
|
||||
audio: str | None = data.get("data", {}).get(
|
||||
"audio"
|
||||
)
|
||||
if audio is not None:
|
||||
yield audio
|
||||
except json.JSONDecodeError:
|
||||
|
||||
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return embedding.data[0].embedding
|
||||
|
||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||
"""批量获取文本的嵌入"""
|
||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
||||
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
||||
return [item.embedding for item in embeddings.data]
|
||||
|
||||
def get_dim(self) -> int:
|
||||
|
||||
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if isinstance(tool_call, str):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
if tools is None:
|
||||
# 工具集未提供
|
||||
# Should be unreachable
|
||||
raise Exception("工具集未提供")
|
||||
for tool in tools.func_list:
|
||||
if (
|
||||
tool_call.type == "function"
|
||||
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from funasr_onnx import SenseVoiceSmall
|
||||
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
||||
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("stt_model"))
|
||||
self.set_model(provider_config["stt_model"])
|
||||
self.model = None
|
||||
self.is_emotion = provider_config.get("is_emotion", False)
|
||||
|
||||
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
loop = asyncio.get_event_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None, # 使用默认的线程池
|
||||
lambda: self.model(audio_url, language="auto", use_itn=True),
|
||||
lambda: cast(SenseVoiceSmall, self.model)(
|
||||
audio_url, language="auto", use_itn=True
|
||||
),
|
||||
)
|
||||
|
||||
# res = self.model(audio_url, language="auto", use_itn=True)
|
||||
|
||||
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
|
||||
}
|
||||
if top_n is not None:
|
||||
payload["top_n"] = top_n
|
||||
assert self.client is not None
|
||||
async with self.client.post(
|
||||
f"{self.base_url}/v1/rerank",
|
||||
json=payload,
|
||||
|
||||
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
|
||||
async def _get_audio_format(self, file_path):
|
||||
# 定义要检测的头部字节
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import whisper
|
||||
|
||||
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("model"))
|
||||
self.set_model(provider_config["model"])
|
||||
self.model = None
|
||||
|
||||
async def initialize(self):
|
||||
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
if not self.model:
|
||||
raise RuntimeError("Whisper 模型未初始化")
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result["text"]
|
||||
return cast(str, result["text"])
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from typing import cast
|
||||
|
||||
from xinference_client.client.restful.async_restful_client import (
|
||||
AsyncClient as Client,
|
||||
)
|
||||
from xinference_client.client.restful.async_restful_client import (
|
||||
AsyncRESTfulRerankModelHandle,
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
|
||||
False,
|
||||
)
|
||||
self.client = None
|
||||
self.model = None
|
||||
self.model: AsyncRESTfulRerankModelHandle | None = None
|
||||
self.model_uid = None
|
||||
|
||||
async def initialize(self):
|
||||
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
|
||||
return
|
||||
|
||||
if self.model_uid:
|
||||
self.model = await self.client.get_model(self.model_uid)
|
||||
self.model = cast(
|
||||
AsyncRESTfulRerankModelHandle,
|
||||
await self.client.get_model(self.model_uid),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Xinference model: {e}")
|
||||
|
||||
@@ -285,7 +285,7 @@ class Context:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str | None = None) -> Provider | None:
|
||||
def get_using_provider(self, umo: str | None = None) -> Provider:
|
||||
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
Args:
|
||||
@@ -296,7 +296,7 @@ class Context:
|
||||
provider_type=ProviderType.CHAT_COMPLETION,
|
||||
umo=umo,
|
||||
)
|
||||
if prov and not isinstance(prov, Provider):
|
||||
if not isinstance(prov, Provider):
|
||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||
return prov
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import docstring_parser
|
||||
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||
from astrbot.core.agent.tool import FunctionTool
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
|
||||
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
|
||||
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
||||
def get_handler_full_name(
|
||||
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||
) -> str:
|
||||
"""获取 Handler 的全名"""
|
||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||
|
||||
|
||||
def get_handler_or_create(
|
||||
handler: Callable[..., Awaitable[Any]],
|
||||
handler: Callable[
|
||||
...,
|
||||
Awaitable[MessageEventResult | str | None]
|
||||
| AsyncGenerator[MessageEventResult | str | None],
|
||||
],
|
||||
event_type: EventType,
|
||||
dont_add=False,
|
||||
**kwargs,
|
||||
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
for (
|
||||
sub_handle
|
||||
) in parent_register_commandable.parent_group.sub_command_filters:
|
||||
if isinstance(sub_handle, CommandGroupFilter):
|
||||
continue
|
||||
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
||||
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
||||
sub_handle_md = sub_handle.get_handler_md()
|
||||
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
|
||||
else:
|
||||
# 裸指令
|
||||
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
|
||||
assert isinstance(awaitable, Callable)
|
||||
handler_md = get_handler_or_create(
|
||||
awaitable,
|
||||
EventType.AdapterMessageEvent,
|
||||
@@ -237,7 +248,7 @@ class RegisteringCommandable:
|
||||
|
||||
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
|
||||
command: Callable[..., Callable[..., None]] = register_command
|
||||
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
|
||||
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
|
||||
|
||||
def __init__(self, parent_group: CommandGroupFilter):
|
||||
self.parent_group = parent_group
|
||||
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
||||
if kwargs.get("registering_agent"):
|
||||
registering_agent = kwargs["registering_agent"]
|
||||
|
||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||
def decorator(
|
||||
awaitable: Callable[
|
||||
...,
|
||||
AsyncGenerator[MessageEventResult | str | None]
|
||||
| Awaitable[MessageEventResult | str | None],
|
||||
],
|
||||
):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
func_doc = awaitable.__doc__ or ""
|
||||
docstring = docstring_parser.parse(func_doc)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
|
||||
for handler in self._handlers:
|
||||
print(handler.handler_full_name)
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnAstrBotLoadedEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnPlatformLoadedEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.AdapterMessageEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[
|
||||
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnLLMRequestEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnLLMResponseEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnDecoratingResultEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnCallingFuncToolEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[
|
||||
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||
]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: Literal[EventType.OnAfterMessageSentEvent],
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||
|
||||
@overload
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: EventType,
|
||||
only_activated=True,
|
||||
plugins_name: list[str] | None = None,
|
||||
) -> list[
|
||||
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||
]: ...
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self,
|
||||
event_type: EventType,
|
||||
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
|
||||
H = TypeVar("H", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata:
|
||||
class StarHandlerMetadata(Generic[H]):
|
||||
"""描述一个 Star 所注册的某一个 Handler。"""
|
||||
|
||||
event_type: EventType
|
||||
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
|
||||
handler_module_path: str
|
||||
"""Handler 所在的模块路径。"""
|
||||
|
||||
handler: Callable[..., Awaitable[Any]]
|
||||
handler: H
|
||||
"""Handler 的函数对象,应当是一个异步函数"""
|
||||
|
||||
event_filters: list[HandlerFilter]
|
||||
|
||||
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
async def check_update(
|
||||
self,
|
||||
url: str,
|
||||
current_version: str,
|
||||
url: str | None,
|
||||
current_version: str | None,
|
||||
consider_prerelease: bool = True,
|
||||
) -> ReleaseInfo:
|
||||
) -> ReleaseInfo | None:
|
||||
"""检查更新"""
|
||||
return await super().check_update(
|
||||
self.ASTRBOT_RELEASE_API,
|
||||
|
||||
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
return False
|
||||
|
||||
|
||||
def save_temp_img(img: Image.Image | str) -> str:
|
||||
def save_temp_img(img: Image.Image | bytes) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
|
||||
@@ -20,16 +20,16 @@ class SessionController:
|
||||
|
||||
def __init__(self):
|
||||
self.future = asyncio.Future()
|
||||
self.current_event: asyncio.Event = None
|
||||
self.current_event: asyncio.Event | None = None
|
||||
"""当前正在等待的所用的异步事件"""
|
||||
self.ts: float = None
|
||||
self.ts: float | None = None
|
||||
"""上次保持(keep)开始时的时间"""
|
||||
self.timeout: float | int = None
|
||||
self.timeout: float | int | None = None
|
||||
"""上次保持(keep)开始时的超时时间"""
|
||||
|
||||
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
||||
|
||||
def stop(self, error: Exception = None):
|
||||
def stop(self, error: Exception | None = None):
|
||||
"""立即结束这个会话"""
|
||||
if not self.future.done():
|
||||
if error:
|
||||
@@ -53,6 +53,8 @@ class SessionController:
|
||||
self.stop()
|
||||
return
|
||||
else:
|
||||
assert self.timeout is not None
|
||||
assert self.ts is not None
|
||||
left_timeout = self.timeout - (new_ts - self.ts)
|
||||
timeout = left_timeout + timeout
|
||||
if timeout <= 0:
|
||||
@@ -69,7 +71,7 @@ class SessionController:
|
||||
|
||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
|
||||
async def _holding(self, event: asyncio.Event, timeout: int):
|
||||
async def _holding(self, event: asyncio.Event, timeout: float):
|
||||
"""等待事件结束或超时"""
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout)
|
||||
@@ -108,7 +110,9 @@ class SessionWaiter:
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.session_filter = session_filter
|
||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
||||
self.handler: (
|
||||
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
||||
) = None # 处理函数
|
||||
|
||||
self.session_controller = SessionController()
|
||||
self.record_history_chains = record_history_chains
|
||||
@@ -119,7 +123,7 @@ class SessionWaiter:
|
||||
|
||||
async def register_wait(
|
||||
self,
|
||||
handler: Callable[[str], Awaitable[Any]],
|
||||
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
timeout: int = 30,
|
||||
) -> Any:
|
||||
"""等待外部输入并处理"""
|
||||
@@ -137,7 +141,7 @@ class SessionWaiter:
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self, error: Exception = None):
|
||||
def _cleanup(self, error: Exception | None = None):
|
||||
"""清理会话"""
|
||||
USER_SESSIONS.pop(self.session_id, None)
|
||||
try:
|
||||
@@ -161,6 +165,7 @@ class SessionWaiter:
|
||||
)
|
||||
try:
|
||||
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
||||
assert session.handler is not None
|
||||
await session.handler(session.session_controller, event)
|
||||
except Exception as e:
|
||||
session.session_controller.stop(e)
|
||||
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
||||
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
||||
def decorator(
|
||||
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(
|
||||
event: AstrMessageEvent,
|
||||
session_filter: SessionFilter = None,
|
||||
session_filter: SessionFilter | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -53,6 +53,38 @@ class SharedPreferences:
|
||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||
return ret
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str,
|
||||
key: str,
|
||||
default: _VT = None,
|
||||
) -> _VT: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
@overload
|
||||
async def session_get(
|
||||
self,
|
||||
umo: None,
|
||||
key: None,
|
||||
default: Any = None,
|
||||
) -> list[Preference]: ...
|
||||
|
||||
async def session_get(
|
||||
self,
|
||||
umo: str | None,
|
||||
|
||||
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class RenderStrategy(ABC):
|
||||
@abstractmethod
|
||||
def render(self, text: str, return_url: bool) -> str:
|
||||
async def render(self, text: str, return_url: bool) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render_custom_template(
|
||||
async def render_custom_template(
|
||||
self,
|
||||
tmpl_str: str,
|
||||
tmpl_data: dict,
|
||||
|
||||
@@ -20,7 +20,7 @@ class FontManager:
|
||||
_font_cache = {}
|
||||
|
||||
@classmethod
|
||||
def get_font(cls, size: int) -> ImageFont.FreeTypeFont:
|
||||
def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
|
||||
"""获取指定大小的字体,优先从缓存获取"""
|
||||
if size in cls._font_cache:
|
||||
return cls._font_cache[size]
|
||||
@@ -66,23 +66,17 @@ class TextMeasurer:
|
||||
"""测量文本尺寸的工具类"""
|
||||
|
||||
@staticmethod
|
||||
def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
|
||||
def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
|
||||
"""获取文本的尺寸"""
|
||||
try:
|
||||
# PIL 9.0.0 以上版本
|
||||
return (
|
||||
font.getbbox(text)[2:]
|
||||
if hasattr(font, "getbbox")
|
||||
else font.getsize(text)
|
||||
)
|
||||
except Exception:
|
||||
# 兼容旧版本
|
||||
return font.getsize(text)
|
||||
|
||||
# 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
|
||||
left, top, right, bottom = font.getbbox("Hello world")
|
||||
return int(right - left), int(bottom - top)
|
||||
|
||||
@staticmethod
|
||||
def split_text_to_fit_width(
|
||||
text: str, font: ImageFont.FreeTypeFont, max_width: int
|
||||
) -> List[str]:
|
||||
text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
|
||||
) -> list[str]:
|
||||
"""将文本拆分为多行,确保每行不超过指定宽度"""
|
||||
lines = []
|
||||
if not text:
|
||||
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
|
||||
# 倾斜变换,使用仿射变换实现斜体效果
|
||||
# 变换矩阵: [1, 0.2, 0, 0, 1, 0]
|
||||
italic_img = text_img.transform(
|
||||
text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC
|
||||
text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
|
||||
)
|
||||
|
||||
# 粘贴到原图像
|
||||
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
|
||||
class CodeBlockElement(MarkdownElement):
|
||||
"""代码块元素"""
|
||||
|
||||
def __init__(self, content: List[str]):
|
||||
def __init__(self, content: list[str]):
|
||||
super().__init__("\n".join(content))
|
||||
|
||||
def calculate_height(self, image_width: int, font_size: int) -> int:
|
||||
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
|
||||
def render(
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw: ImageDraw.Draw,
|
||||
draw: ImageDraw.ImageDraw,
|
||||
x: int,
|
||||
y: int,
|
||||
image_width: int,
|
||||
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
|
||||
if pasted_image.width > max_width:
|
||||
ratio = max_width / pasted_image.width
|
||||
new_size = (int(max_width), int(pasted_image.height * ratio))
|
||||
pasted_image = pasted_image.resize(new_size, Image.LANCZOS)
|
||||
pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 计算居中位置
|
||||
paste_x = x + (image_width - pasted_image.width) // 2 - 10
|
||||
@@ -705,7 +699,7 @@ class MarkdownParser:
|
||||
"""Markdown解析器,将文本解析为元素"""
|
||||
|
||||
@staticmethod
|
||||
async def parse(text: str) -> List[MarkdownElement]:
|
||||
async def parse(text: str) -> list[MarkdownElement]:
|
||||
elements = []
|
||||
lines = text.split("\n")
|
||||
|
||||
@@ -847,7 +841,7 @@ class MarkdownRenderer:
|
||||
self,
|
||||
font_size: int = 26,
|
||||
width: int = 800,
|
||||
bg_color: Tuple[int, int, int] = (255, 255, 255),
|
||||
bg_color: tuple[int, int, int] = (255, 255, 255),
|
||||
):
|
||||
self.font_size = font_size
|
||||
self.width = width
|
||||
|
||||
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=input_path, output=output_path)
|
||||
ff.convert(input_file=input_path, output_file=output_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
|
||||
|
||||
@@ -60,9 +60,12 @@ class VersionComparator:
|
||||
return -1
|
||||
if isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
if (isinstance(p1, int) and isinstance(p2, int)) or (
|
||||
isinstance(p1, str) and isinstance(p2, str)
|
||||
):
|
||||
if isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
if p1 < p2:
|
||||
return -1
|
||||
if isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
if p1 < p2:
|
||||
|
||||
@@ -4,7 +4,9 @@ import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import g, make_response, request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
@@ -424,7 +426,9 @@ class ChatRoute(Route):
|
||||
sender_name=username,
|
||||
)
|
||||
|
||||
response = await make_response(
|
||||
response = cast(
|
||||
QuartResponse,
|
||||
await make_response(
|
||||
stream(),
|
||||
{
|
||||
"Content-Type": "text/event-stream",
|
||||
@@ -432,8 +436,9 @@ class ChatRoute(Route):
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
),
|
||||
)
|
||||
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue]
|
||||
response.timeout = None # fix SSE auto disconnect issue
|
||||
return response
|
||||
|
||||
async def delete_webchat_session(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ import inspect
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from quart import request
|
||||
|
||||
@@ -26,7 +27,7 @@ from astrbot.core.star.star import star_registry
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
def try_cast(value: Any, type_: str):
|
||||
if type_ == "int":
|
||||
try:
|
||||
return int(value)
|
||||
@@ -505,9 +506,9 @@ class ConfigRoute(Route):
|
||||
if not isinstance(inst, EmbeddingProvider):
|
||||
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
|
||||
|
||||
# 初始化
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
init_fn = getattr(inst, "initialize", None)
|
||||
if inspect.iscoroutinefunction(init_fn):
|
||||
await init_fn()
|
||||
|
||||
# 获取嵌入向量维度
|
||||
vec = await inst.get_embedding("echo")
|
||||
@@ -777,7 +778,7 @@ class ConfigRoute(Route):
|
||||
return {"metadata": CONFIG_METADATA_2, "config": config}
|
||||
|
||||
async def _get_plugin_config(self, plugin_name: str):
|
||||
ret = {"metadata": None, "config": None}
|
||||
ret: dict = {"metadata": None, "config": None}
|
||||
|
||||
for plugin_md in star_registry:
|
||||
if plugin_md.name == plugin_name:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import make_response
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
@@ -39,7 +41,9 @@ class LogRoute(Route):
|
||||
if queue:
|
||||
self.log_broker.unregister(queue)
|
||||
|
||||
response = await make_response(
|
||||
response = cast(
|
||||
QuartResponse,
|
||||
await make_response(
|
||||
stream(),
|
||||
{
|
||||
"Content-Type": "text/event-stream",
|
||||
@@ -47,6 +51,7 @@ class LogRoute(Route):
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
),
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
@@ -579,6 +579,10 @@ class PluginRoute(Route):
|
||||
logger.warning(f"插件 {plugin_name} 不存在")
|
||||
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
||||
|
||||
if not plugin_obj.root_dir_name:
|
||||
logger.warning(f"插件 {plugin_name} 目录不存在")
|
||||
return Response().error(f"插件 {plugin_name} 目录不存在").__dict__
|
||||
|
||||
plugin_dir = os.path.join(
|
||||
self.plugin_manager.plugin_store_path,
|
||||
plugin_obj.root_dir_name or "",
|
||||
|
||||
@@ -12,6 +12,8 @@ class RouteContext:
|
||||
|
||||
|
||||
class Route:
|
||||
routes: list | dict
|
||||
|
||||
def __init__(self, context: RouteContext):
|
||||
self.app = context.app
|
||||
self.config = context.config
|
||||
|
||||
@@ -2,9 +2,12 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import cast
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from flask.json.provider import DefaultJSONProvider
|
||||
from psutil._common import addr as psutil_addr
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
|
||||
@@ -21,7 +24,7 @@ from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
from .routes.t2i import T2iRoute
|
||||
|
||||
APP: Quart = None
|
||||
APP: Quart
|
||||
|
||||
|
||||
class AstrBotDashboard:
|
||||
@@ -48,7 +51,7 @@ class AstrBotDashboard:
|
||||
self.app.config["MAX_CONTENT_LENGTH"] = (
|
||||
128 * 1024 * 1024
|
||||
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
|
||||
self.app.json.sort_keys = False
|
||||
cast(DefaultJSONProvider, self.app.json).sort_keys = False
|
||||
self.app.before_request(self.auth_middleware)
|
||||
# token 用于验证请求
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
@@ -147,7 +150,7 @@ class AstrBotDashboard:
|
||||
"""获取占用端口的进程详细信息"""
|
||||
try:
|
||||
for conn in psutil.net_connections(kind="inet"):
|
||||
if conn.laddr.port == port:
|
||||
if cast(psutil_addr, conn.laddr).port == port:
|
||||
try:
|
||||
process = psutil.Process(conn.pid)
|
||||
# 获取详细信息
|
||||
|
||||
@@ -139,6 +139,11 @@ class ProcessLLMRequest:
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
if not event.message_obj.group:
|
||||
logger.error(
|
||||
f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
|
||||
)
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
|
||||
@@ -14,6 +14,7 @@ from astrbot.api import llm_tool, logger, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
from astrbot.api.message_components import File, Image
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
|
||||
@@ -224,6 +225,8 @@ class Main(star.Star):
|
||||
del self.user_waiting[uid]
|
||||
elif isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
if image_url is None:
|
||||
raise ValueError("Image URL is None")
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
elif image_url.startswith("file:///"):
|
||||
@@ -240,6 +243,8 @@ class Main(star.Star):
|
||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
||||
if event.get_session_id() in self.user_file_msg_buffer:
|
||||
files = self.user_file_msg_buffer[event.get_session_id()]
|
||||
if not request.prompt:
|
||||
request.prompt = ""
|
||||
request.prompt += f"\nUser provided files: {files}"
|
||||
|
||||
@filter.command_group("pi")
|
||||
@@ -477,7 +482,9 @@ class Main(star.Star):
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain = [File(name=file_name, file=file_path)]
|
||||
chain: list[BaseMessageComponent] = [
|
||||
File(name=file_name, file=file_path)
|
||||
]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
import zoneinfo
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from astrbot.api import llm_tool, logger, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
@@ -62,13 +63,13 @@ class Main(star.Star):
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
elif "cron" in reminder:
|
||||
trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"]))
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger="cron",
|
||||
trigger=trigger,
|
||||
id=id_,
|
||||
args=[group, reminder],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(reminder["cron"]),
|
||||
)
|
||||
|
||||
def check_is_outdated(self, reminder: dict):
|
||||
@@ -101,10 +102,10 @@ class Main(star.Star):
|
||||
async def reminder_tool(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
text: str = None,
|
||||
datetime_str: str = None,
|
||||
cron_expression: str = None,
|
||||
human_readable_cron: str = None,
|
||||
text: str | None = None,
|
||||
datetime_str: str | None = None,
|
||||
cron_expression: str | None = None,
|
||||
human_readable_cron: str | None = None,
|
||||
):
|
||||
"""Call this function when user is asking for setting a reminder.
|
||||
|
||||
@@ -139,17 +140,19 @@ class Main(star.Star):
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
trigger = CronTrigger(**self._parse_cron_expr(cron_expression))
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
"cron",
|
||||
trigger,
|
||||
id=d["id"],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(cron_expression),
|
||||
args=[event.unified_msg_origin, d],
|
||||
)
|
||||
if human_readable_cron:
|
||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||
else:
|
||||
if datetime_str is None:
|
||||
raise ValueError("datetime_str cannot be None.")
|
||||
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
datetime_scheduled = datetime.datetime.strptime(
|
||||
|
||||
@@ -3,7 +3,7 @@ import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
|
||||
@@ -45,13 +45,13 @@ class SearchEngine:
|
||||
self.page = 1
|
||||
self.headers = HEADERS
|
||||
|
||||
def _set_selector(self, selector: str) -> None:
|
||||
def _set_selector(self, selector: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_next_page(self):
|
||||
def _get_next_page(self, query: str):
|
||||
raise NotImplementedError
|
||||
|
||||
async def _get_html(self, url: str, data: dict = None) -> str:
|
||||
async def _get_html(self, url: str, data: dict | None = None) -> str:
|
||||
headers = self.headers
|
||||
headers["Referer"] = url
|
||||
headers["User-Agent"] = random.choice(USER_AGENTS)
|
||||
@@ -83,6 +83,9 @@ class SearchEngine:
|
||||
"""清理文本,去除空格、换行符等"""
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
def _get_url(self, tag: Tag) -> str:
|
||||
return self.tidy_text(tag.get_text())
|
||||
|
||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||
query = urllib.parse.quote(query)
|
||||
|
||||
@@ -92,12 +95,16 @@ class SearchEngine:
|
||||
links = soup.select(self._set_selector("links"))
|
||||
results = []
|
||||
for link in links:
|
||||
title = self.tidy_text(
|
||||
link.select_one(self._set_selector("title")).text,
|
||||
)
|
||||
url = link.select_one(self._set_selector("url"))
|
||||
# Safely get the title text (select_one may return None)
|
||||
title_elem = link.select_one(self._set_selector("title"))
|
||||
title = ""
|
||||
if title_elem is not None:
|
||||
title = self.tidy_text(title_elem.get_text())
|
||||
|
||||
url_tag = link.select_one(self._set_selector("url"))
|
||||
snippet = ""
|
||||
if title and url:
|
||||
if title and url_tag:
|
||||
url = self._get_url(url_tag)
|
||||
results.append(SearchResult(title=title, url=url, snippet=snippet))
|
||||
return results[:num_results] if len(results) > num_results else results
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from . import USER_AGENT_BING, SearchEngine, SearchResult
|
||||
from . import USER_AGENT_BING, SearchEngine
|
||||
|
||||
|
||||
class Bing(SearchEngine):
|
||||
@@ -28,11 +28,3 @@ class Bing(SearchEngine):
|
||||
self.base_url = base_url
|
||||
continue
|
||||
raise Exception("Bing search failed")
|
||||
|
||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
if not isinstance(result.url, str):
|
||||
result.url = result.url.text
|
||||
|
||||
return results
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import random
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from . import USER_AGENTS, SearchEngine, SearchResult
|
||||
|
||||
@@ -26,10 +27,12 @@ class Sogo(SearchEngine):
|
||||
url = f"{self.base_url}/web?query={query}"
|
||||
return await self._get_html(url, None)
|
||||
|
||||
def _get_url(self, tag: Tag) -> str:
|
||||
return cast(str, tag.get("href"))
|
||||
|
||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
result.url = result.url.get("href")
|
||||
if result.url.startswith("/link?"):
|
||||
result.url = self.base_url + result.url
|
||||
result.url = await self._parse_url(result.url)
|
||||
@@ -40,7 +43,10 @@ class Sogo(SearchEngine):
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
script = soup.find("script")
|
||||
if script:
|
||||
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(
|
||||
1,
|
||||
script_text = (
|
||||
script.string if script.string is not None else script.get_text()
|
||||
)
|
||||
match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
|
||||
if match:
|
||||
url = match.group(1)
|
||||
return url
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Minimal type stubs for faiss used in this project.
|
||||
|
||||
This file only exposes a small subset of the faiss API that the
|
||||
project uses, including the runtime-monkeypatched signatures such as
|
||||
`Index.add_with_ids` so Pyright/Pylance stops reporting false positives.
|
||||
"""
|
||||
|
||||
from typing import Any, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
class Index:
|
||||
d: int
|
||||
ntotal: int
|
||||
code_size: int
|
||||
nprobe: int
|
||||
|
||||
def add(self, x: np.ndarray) -> None: ...
|
||||
def add_with_ids(self, x: np.ndarray, ids: np.ndarray) -> None: ...
|
||||
def search(
|
||||
self,
|
||||
x: np.ndarray,
|
||||
k: int,
|
||||
*,
|
||||
params: Any = ...,
|
||||
D: np.ndarray | None = ...,
|
||||
I: np.ndarray | None = ...,
|
||||
) -> tuple[np.ndarray, np.ndarray]: ...
|
||||
def remove_ids(self, x: np.ndarray) -> int: ...
|
||||
@overload
|
||||
def reconstruct(self, key: int) -> np.ndarray: ...
|
||||
@overload
|
||||
def reconstruct(self, key: int, x: np.ndarray) -> None: ...
|
||||
def reconstruct(
|
||||
self, key: int, x: np.ndarray | None = ...
|
||||
) -> np.ndarray | None: ...
|
||||
@overload
|
||||
def reconstruct_n(self, n0: int, ni: int) -> np.ndarray: ...
|
||||
@overload
|
||||
def reconstruct_n(self, n0: int, ni: int, x: np.ndarray) -> None: ...
|
||||
def reconstruct_n(
|
||||
self, n0: int = ..., ni: int = ..., x: np.ndarray | None = ...
|
||||
) -> np.ndarray | None: ...
|
||||
def range_search(
|
||||
self, x: np.ndarray, thresh: float, *, params: Any = ...
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
|
||||
def add_sa_codes(self, codes: np.ndarray, ids: np.ndarray | None = ...) -> None: ...
|
||||
def sa_encode(self, x: np.ndarray) -> np.ndarray: ...
|
||||
def sa_decode(self, codes: np.ndarray) -> np.ndarray: ...
|
||||
|
||||
class IndexFlatL2(Index):
|
||||
def __init__(self, d: int) -> None: ...
|
||||
|
||||
class IndexIDMap(Index):
|
||||
index: Index
|
||||
|
||||
def __init__(self, index: Index) -> None: ...
|
||||
|
||||
def read_index(path: str) -> Index: ...
|
||||
def write_index(index: Index, path: str | None = ...) -> None: ...
|
||||
def normalize_L2(x: np.ndarray) -> None: ...
|
||||
|
||||
# Additional concrete-ish classes exposed by some faiss builds (SWIG helpers
|
||||
# expose `downcast_*` helpers to convert generic objects to these concrete
|
||||
# types). We keep these minimal — only the names are important for typing.
|
||||
class IndexBinary(Index):
|
||||
def __init__(self, d: int) -> None: ...
|
||||
|
||||
class InvertedLists:
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
class AdditiveQuantizer:
|
||||
pass
|
||||
|
||||
class Quantizer:
|
||||
pass
|
||||
|
||||
class VectorTransform:
|
||||
pass
|
||||
|
||||
# SWIG-provided downcast helpers (present in some faiss Python builds).
|
||||
def downcast_IndexBinary(obj: Any) -> IndexBinary: ...
|
||||
def downcast_InvertedLists(obj: Any) -> InvertedLists: ...
|
||||
def downcast_AdditiveQuantizer(obj: Any) -> AdditiveQuantizer: ...
|
||||
def downcast_Quantizer(obj: Any) -> Quantizer: ...
|
||||
def downcast_VectorTransform(obj: Any) -> VectorTransform: ...
|
||||
def downcast_index(obj: Any) -> Index: ...
|
||||
|
||||
# version exposed by runtime
|
||||
__version__: str
|
||||
Reference in New Issue
Block a user