chore: fix bunches of type checking errors (#3213)

* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
Dt8333
2025-12-09 14:13:47 +08:00
committed by GitHub
parent aa6d07afcc
commit f624971613
96 changed files with 1218 additions and 524 deletions
@@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None llm_resp_result = None
async for llm_response in self._iter_llm_responses(): async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk: if llm_response.is_chunk:
if llm_response.result_chain: if llm_response.result_chain:
yield AgentResponse( yield AgentResponse(
+7 -2
View File
@@ -1,4 +1,4 @@
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic from typing import Any, Generic
import jsonschema import jsonschema
@@ -7,6 +7,8 @@ from deprecated import deprecated
from pydantic import Field, model_validator from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper, TContext from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any] ParametersType = dict[str, Any]
@@ -38,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]): class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling.""" """A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None handler: (
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
| None
) = None
"""a callable that implements the tool's functionality. It should be an async function.""" """a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None handler_module_path: str | None = None
+5 -1
View File
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool( async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext], context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]], handler: T.Callable[
...,
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str, method_name: str,
*args, *args,
**kwargs, **kwargs,
+4
View File
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
""" """
config_path: str
default_config: dict
schema: dict | None
def __init__( def __init__(
self, self,
config_path: str = ASTRBOT_CONFIG_PATH, config_path: str = ASTRBOT_CONFIG_PATH,
+1 -1
View File
@@ -197,7 +197,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行 # 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = [] extra_tasks = []
for task in self.star_context._register_tasks: for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks] tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_: for task in tasks_:
+2 -3
View File
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from deprecated import deprecated from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
echo=False, echo=False,
future=True, future=True,
) )
self.AsyncSessionLocal = sessionmaker( self.AsyncSessionLocal = async_sessionmaker(
self.engine, self.engine,
class_=AsyncSession, class_=AsyncSession,
expire_on_commit=False, expire_on_commit=False,
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
) )
continue
if ":" not in conv.user_id: if ":" not in conv.user_id:
continue continue
session = MessageSesion.from_str(session_str=conv.user_id) session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
) )
continue
if ":" in conv.user_id: if ":" in conv.user_id:
continue continue
platform_id = "webchat" platform_id = "webchat"
+6 -4
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str conn.text_factory = str
return conn return conn
def _exec_sql(self, sql: str, params: tuple = None): def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn conn = self.conn
try: try:
c = self.conn.cursor() c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close() c.close()
return Stats(platform, [], []) return Stats(platform)
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at), (user_id, cid, history, updated_at, created_at),
) )
def get_conversations(self, user_id: str) -> tuple: def get_conversations(self, user_id: str) -> list[Conversation]:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
+16 -15
View File
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here. Note: In astrbot v4, we moved `platform` table to here.
""" """
__tablename__ = "platform_stats" # type: ignore __tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False) timestamp: datetime = Field(nullable=False)
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True): class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore __tablename__: str = "conversations"
inner_conversation_id: int = Field( inner_conversation_id: int | None = Field(
default=None,
primary_key=True, primary_key=True,
sa_column_kwargs={"autoincrement": True}, sa_column_kwargs={"autoincrement": True},
) )
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs. It can be used to customize the behavior of LLMs.
""" """
__tablename__ = "personas" # type: ignore __tablename__: str = "personas"
id: int | None = Field( id: int | None = Field(
primary_key=True, primary_key=True,
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True): class Preference(SQLModel, table=True):
"""This class represents preferences for bots.""" """This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore __tablename__: str = "preferences"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages. or platform-specific messages.
""" """
__tablename__ = "platform_message_history" # type: ignore __tablename__: str = "platform_message_history"
id: int | None = Field( id: int | None = Field(
primary_key=True, primary_key=True,
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
Each session can have multiple conversations (对话) associated with it. Each session can have multiple conversations (对话) associated with it.
""" """
__tablename__ = "platform_sessions" # type: ignore __tablename__: str = "platform_sessions"
inner_id: int | None = Field( inner_id: int | None = Field(
primary_key=True, primary_key=True,
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
Attachments can be images, files, or other media types. Attachments can be images, files, or other media types.
""" """
__tablename__ = "attachments" # type: ignore __tablename__: str = "attachments"
inner_attachment_id: int | None = Field( inner_attachment_id: int | None = Field(
primary_key=True, primary_key=True,
@@ -261,17 +262,17 @@ class Personality(TypedDict):
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
""" """
prompt: str = "" prompt: str
name: str = "" name: str
begin_dialogs: list[str] = [] begin_dialogs: list[str]
mood_imitation_dialogs: list[str] = [] mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" """工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache # cache
_begin_dialogs_processed: list[dict] = [] _begin_dialogs_processed: list[dict]
_mood_imitation_dialogs_processed: str = "" _mood_imitation_dialogs_processed: str
# ==== # ====
+4 -3
View File
@@ -3,6 +3,7 @@ import threading
import typing as T import typing as T
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -489,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
query = select(Attachment).where( query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids) col(Attachment.attachment_id).in_(attachment_ids)
) )
result = await session.execute(query) result = await session.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
@@ -505,7 +506,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where( query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id col(Attachment.attachment_id) == attachment_id
) )
result = await session.execute(query) result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0 return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int: async def delete_attachments(self, attachment_ids: list[str]) -> int:
@@ -521,7 +522,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where( query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids) col(Attachment.attachment_id).in_(attachment_ids)
) )
result = await session.execute(query) result = T.cast(CursorResult, await session.execute(query))
return result.rowcount return result.rowcount
async def insert_persona( async def insert_persona(
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径 path (str): 保存索引的路径
""" """
if self.index is None:
return
faiss.write_index(self.index, self.path) faiss.write_index(self.index, self.path)
+6 -1
View File
@@ -27,7 +27,7 @@ class EventBus:
self, self,
event_queue: Queue, event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler], pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None, astrbot_config_mgr: AstrBotConfigManager,
): ):
self.event_queue = event_queue # 事件队列 self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler # abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"]) self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
if not scheduler:
logger.error(
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
)
continue
asyncio.create_task(scheduler.execute(event)) asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str): def _print_event(self, event: AstrMessageEvent, conf_name: str):
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank # 5. Rerank
first_rerank = None first_rerank = None
for kb_id in kb_ids: for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"] rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if ( if (
vec_db vec_db
+9 -3
View File
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel): class BaseMessageComponent(BaseModel):
type: ComponentType type: ComponentType
def __init__(self, **kwargs):
super().__init__(**kwargs)
def toDict(self): def toDict(self):
data = {} data = {}
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略 id: int | None = 0 # 忽略
name: str | None = "" # qq昵称 name: str | None = "" # qq昵称
uin: str | None = "0" # qq号 uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = [] content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略 seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略 time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d) ret["messages"].append(d)
return ret return ret
async def to_dict(self): async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []} ret = {"messages": []}
for node in self.nodes: for node in self.nodes:
@@ -714,12 +717,15 @@ class File(BaseMessageComponent):
if self.url: if self.url:
await self._download_file() await self._download_file()
return os.path.abspath(self.file_) if self.file_:
return os.path.abspath(self.file_)
return "" return ""
async def _download_file(self): async def _download_file(self):
"""下载文件""" """下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True) os.makedirs(download_dir, exist_ok=True)
if self.name: if self.name:
+2 -2
View File
@@ -98,8 +98,8 @@ class PersonaManager:
self, self,
persona_id: str, persona_id: str,
system_prompt: str, system_prompt: str,
begin_dialogs: list[str] = None, begin_dialogs: list[str] | None = None,
tools: list[str] = None, tools: list[str] | None = None,
) -> Persona: ) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id): if await self.db.get_persona_by_id(persona_id):
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
check_text: str | None = None, check_text: str | None = None,
) -> None | AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
"""检查内容安全""" """检查内容安全"""
text = check_text if check_text else event.get_message_str() text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text) ok, info = self.strategy_selector.check(text)
+2 -1
View File
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler( async def call_handler(
event: AstrMessageEvent, event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]], handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args, *args,
**kwargs, **kwargs,
) -> T.AsyncGenerator[T.Any, None]: ) -> T.AsyncGenerator[T.Any, None]:
@@ -91,6 +91,7 @@ async def call_event_hook(
) )
for handler in handlers: for handler in handlers:
try: try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug( logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
) )
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
async def process( async def process(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra( activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers", "activated_handlers",
) )
+1 -1
View File
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
): ):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if ( if (
event.get_result() and not event.get_result().is_stopped() event.get_result() and not event.is_stopped()
) or not event.get_result(): ) or not event.get_result():
async for _ in self.agent_sub_stage.process(event): async for _ in self.agent_sub_stage.process(event):
yield yield
+4 -2
View File
@@ -117,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg: if not self.enable_seg:
return False return False
if self.only_llm_result and not event.get_result().is_llm_result(): if (result := event.get_result()) is None:
return False
if self.only_llm_result and result.is_llm_result():
return False return False
if event.get_platform_name() in [ if event.get_platform_name() in [
@@ -185,7 +187,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file: if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。 # 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file) component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component result.chain[idx] = component
# 检查消息链是否为空 # 检查消息链是否为空
try: try:
+10 -6
View File
@@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -130,11 +131,13 @@ class ResultDecorateStage(Stage):
for comp in result.chain: for comp in result.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
text += comp.text text += comp.text
async for _ in self.content_safe_check_stage.process(
event, if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
check_text=text, async for _ in self.content_safe_check_stage.process(
): event,
yield check_text=text,
):
yield
# 发送消息前事件钩子 # 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type( handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -151,7 +154,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
) )
await handler.handler(event) await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
if (result := event.get_result()) is None or not result.chain:
logger.debug( logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
) )
+5 -1
View File
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from . import STAGES_ORDER from . import STAGES_ORDER
from .context import PipelineContext from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event) await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None) await event.send(None)
logger.debug("pipeline 执行完毕。") logger.debug("pipeline 执行完毕。")
+5 -3
View File
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str: def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)""" """获取消息发送者的名称。(可能会返回空字符串)"""
return self.message_obj.sender.nickname if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value): def set_extra(self, key, value):
"""设置额外的信息。""" """设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
""" """
self.call_llm = call_llm self.call_llm = call_llm
def get_result(self) -> MessageEventResult: def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。""" """获取消息事件的结果。"""
return self._result return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self, self,
prompt: str, prompt: str,
func_tool_manager=None, func_tool_manager=None,
session_id: str = None, session_id: str = "",
image_urls: list[str] | None = None, image_urls: list[str] | None = None,
contexts: list | None = None, contexts: list | None = None,
system_prompt: str = "", system_prompt: str = "",
+2 -2
View File
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。 session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id message_id: str # 消息id
group: Group # 群组 group: Group | None # 群组
sender: MessageMember # 发送者 sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串 message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return "" return ""
@group_id.setter @group_id.setter
def group_id(self, value: str): def group_id(self, value: str | None):
"""设置 group_id""" """设置 group_id"""
if value: if value:
if self.group: if self.group:
+3 -3
View File
@@ -1,7 +1,7 @@
import abc import abc
import uuid import uuid
from asyncio import Queue from asyncio import Queue
from collections.abc import Awaitable from collections.abc import Coroutine
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -100,7 +100,7 @@ class Platform(abc.ABC):
} }
@abc.abstractmethod @abc.abstractmethod
def run(self) -> Awaitable[Any]: def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。""" """得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError raise NotImplementedError
@@ -116,7 +116,7 @@ class Platform(abc.ABC):
self, self,
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
): ) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 """通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法。 异步方法。
+1 -1
View File
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" """平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str description: str
"""平台的描述""" """平台的描述"""
id: str | None = None id: str
"""平台的唯一标识符,用于配置中识别特定平台""" """平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None default_config_tmpl: dict | None = None
+1
View File
@@ -40,6 +40,7 @@ def register_platform_adapter(
pm = PlatformMetadata( pm = PlatformMetadata(
name=adapter_name, name=adapter_name,
description=desc, description=desc,
id=adapter_name,
default_config_tmpl=default_config_tmpl, default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name, adapter_display_name=adapter_display_name,
logo_path=logo_path, logo_path=logo_path,
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp, bot: CQHttp,
event: Event | None, event: Event | None,
is_group: bool, is_group: bool,
session_id: str, session_id: str | None,
messages: list[dict], messages: list[dict],
): ):
# session_id 必须是纯数字字符串 # session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
)
if is_group and isinstance(session_id, int): if is_group and isinstance(session_id_int, int):
await bot.send_group_msg(group_id=session_id, message=messages) await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id, int): elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id, message=messages) await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底 elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages) await bot.send(event=event, message=messages)
else: else:
@@ -4,7 +4,7 @@ import logging
import time import time
import uuid import uuid
from collections.abc import Awaitable from collections.abc import Awaitable
from typing import Any from typing import Any, cast
from aiocqhttp import CQHttp, Event from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed from aiocqhttp.exceptions import ActionFailed
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="aiocqhttp", name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件""" """OneBot V11 请求类事件"""
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"): if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象 @param event: 事件对象
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。 @param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
""" """
assert event.sender is not None
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group": if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id) abm.group_id = str(event.group_id)
abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A") abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private": elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err) await self.bot.send(event, err)
except BaseException as e: except BaseException as e:
logger.error(f"回复消息失败: {e}") logger.error(f"回复消息失败: {e}")
return None raise ValueError(err)
# 按消息段类型类型适配 # 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]): for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -2,6 +2,7 @@ import asyncio
import os import os
import threading import threading
import uuid import uuid
from typing import cast
import aiohttp import aiohttp
import dingtalk_stream import dingtalk_stream
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
self.client_id = platform_config["client_id"] self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"] self.client_secret = platform_config["client_secret"]
outer_self = self
class AstrCallbackClient(dingtalk_stream.ChatbotHandler): class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage): async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}") logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data) im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im) abm = await outer_self.convert_msg(im)
await self.handle_msg(abm) await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK" return AckMessage.STATUS_OK, "OK"
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
self.client, self.client,
) )
self.client_ = client # 用于 websockets 的 client self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
def _id_to_sid(self, dingtalk_id: str | None) -> str: def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id: if not dingtalk_id:
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="dingtalk", name="dingtalk",
description="钉钉机器人官方 API 适配器", description="钉钉机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage() abm = AstrBotMessage()
abm.message = [] abm.message = []
abm.message_str = "" abm.message_str = ""
abm.timestamp = int(message.create_at / 1000) abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = ( abm.type = (
MessageType.GROUP_MESSAGE MessageType.GROUP_MESSAGE
if message.conversation_type == "2" if message.conversation_type == "2"
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
nickname=message.sender_nick, nickname=message.sender_nick,
) )
abm.self_id = self._id_to_sid(message.chatbot_user_id) abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.message_id = message.message_id abm.message_id = cast(str, message.message_id)
abm.raw_message = message abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE: if abm.type == MessageType.GROUP_MESSAGE:
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else: else:
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
message_type: str = message.message_type message_type: str = cast(str, message.message_type)
match message_type: match message_type:
case "text": case "text":
abm.message_str = message.text.content.strip() abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str)) abm.message.append(Plain(abm.message_str))
case "richText": case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content rtc: dingtalk_stream.RichTextContent = cast(
contents: list[dict] = rtc.rich_text_list dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents: for content in contents:
plains = "" plains = ""
if "text" in content: if "text" in content:
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture": elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file( f_path = await self.download_ding_file(
content["downloadCode"], content["downloadCode"],
message.robot_code, cast(str, message.robot_code),
"jpg", "jpg",
) )
abm.message.append(Image.fromFileSystem(f_path)) abm.message.append(Image.fromFileSystem(f_path))
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error( logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}", f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
) )
return None return ""
resp_data = await resp.json() resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"] download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path) await download_file(download_url, f_path)
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error( logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}", f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
) )
return None return ""
return (await resp.json())["data"]["accessToken"] return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage): async def handle_msg(self, abm: AstrBotMessage):
@@ -250,9 +256,11 @@ class DingtalkPlatformAdapter(Platform):
def monkey_patch_close(): def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown") raise KeyboardInterrupt("Graceful shutdown")
self.client_.open_connection = monkey_patch_close if self.client_.websocket is not None:
await self.client_.websocket.close(code=1000, reason="Graceful shutdown") self.client_.open_connection = monkey_patch_close
self._shutdown_event.set() await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set()
def get_client(self): def get_client(self):
return self.client return self.client
@@ -1,4 +1,5 @@
import asyncio import asyncio
from typing import cast
import dingtalk_stream import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown, client.reply_markdown,
segment.text, segment.text,
segment.text, segment.text,
self.message_obj.raw_message, cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
) )
elif isinstance(segment, Comp.Image): elif isinstance(segment, Comp.Image):
markdown_str = "" markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown, client.reply_markdown,
"😄", "😄",
markdown_str, markdown_str,
self.message_obj.raw_message, cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
) )
logger.debug(f"send image: {ret}") logger.debug(f"send image: {ret}")
@@ -1,4 +1,5 @@
import sys import sys
from collections.abc import Awaitable, Callable
import discord import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy) super().__init__(intents=intents, proxy=proxy)
# 回调函数 # 回调函数
self.on_message_received = None self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
self.on_ready_once_callback = None self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False self._ready_once_fired = False
@override
async def on_ready(self): async def on_ready(self):
"""当机器人成功连接并准备就绪时触发""" """当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
return
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。") logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict: def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典""" """从 discord.Message 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
is_mentioned = self.user in message.mentions is_mentioned = self.user in message.mentions
return { return {
"message": message, "message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict: def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典""" """从 discord.Interaction 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
if interaction.user is None:
raise ValueError("Interaction received without a valid user")
return { return {
"interaction": interaction, "interaction": interaction,
"bot_id": str(self.user.id), "bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction", "type": "interaction",
} }
@override
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
"""当接收到消息时触发""" """当接收到消息时触发"""
if message.author.bot: if message.author.bot:
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__( def __init__(
self, self,
components: list[BaseMessageComponent] = None, components: list[BaseMessageComponent] | None = None,
timeout: float = None, timeout: float | None = None,
): ):
self.components = components or [] self.components = components or []
self.timeout = timeout self.timeout = timeout
@@ -1,10 +1,10 @@
import asyncio import asyncio
import re import re
import sys import sys
from typing import Any from typing import Any, cast
import discord import discord
from discord.abc import Messageable from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel from discord.channel import DMChannel
from astrbot import logger from astrbot import logger
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
) -> None: ) -> None:
super().__init__(platform_config, event_queue) super().__init__(platform_config, event_queue)
self.settings = platform_settings self.settings = platform_settings
self.client_self_id = None self.client_self_id: str | None = None
self.registered_handlers = [] self.registered_handlers = []
# 指令注册相关 # 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True) self.enable_command_register = self.config.get("discord_command_register", True)
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain, message_chain: MessageChain,
): ):
"""通过会话发送消息""" """通过会话发送消息"""
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
)
return
# 创建一个 message_obj 以便在 event 中使用 # 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage() message_obj = AstrBotMessage()
if "_" in session.session_id: if "_" in session.session_id:
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id), user_id=str(self.client_self_id),
nickname=self.client.user.display_name, nickname=self.client.user.display_name,
) )
message_obj.self_id = self.client_self_id message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id message_obj.session_id = session.session_id
message_obj.message = message_chain.chain message_obj.message = message_chain.chain
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
"discord", "discord",
"Discord 适配器", "Discord 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
default_config_tmpl=self.config, default_config_tmpl=self.config,
support_streaming_message=False, support_streaming_message=False,
) )
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type( def _get_message_type(
self, self,
channel: Messageable, channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None, guild_id: int | None = None,
) -> MessageType: ) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型""" """根据 channel 对象和 guild_id 判断消息类型"""
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str: def _get_channel_id(
self, channel: Messageable | GuildChannel | PrivateChannel
) -> str:
"""根据 channel 对象获取ID""" """根据 channel 对象获取ID"""
return str(getattr(channel, "id", None)) return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage""" """将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"] message = data["message"]
content = message.content content = message.content
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
) )
abm.message = message_chain abm.message = message_chain
abm.raw_message = message abm.raw_message = message
abm.self_id = self.client_self_id abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id) abm.session_id = str(message.channel.id)
abm.message_id = str(message.id) abm.message_id = str(message.id)
return abm return abm
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook, interaction_followup_webhook=followup_webhook,
) )
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
)
return
# 检查是否为斜杠指令 # 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None is_slash_command = message_event.interaction_followup_webhook is not None
# 1. 优先处理斜杠指令
if is_slash_command:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
return
# 2. 处理普通消息(提及检测)
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
raw_message = message.raw_message
if not isinstance(raw_message, discord.Message):
logger.warning(
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
)
return
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention # 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False is_mention = False
# User Mention # User Mention
if ( # 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
self.client if self.client.user in raw_message.mentions:
and self.client.user is_mention = True
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及) # Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"): if not is_mention and raw_message.role_mentions:
bot_member = None bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild: if raw_message.guild:
try: try:
bot_member = message.raw_message.guild.get_member( bot_member = raw_message.guild.get_member(
self.client.user.id, self.client.user.id,
) )
except Exception: except Exception:
bot_member = None bot_member = None
if bot_member and hasattr(bot_member, "roles"): if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles) bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions) mentioned_roles = set(raw_message.role_mentions)
if ( if (
bot_roles bot_roles
and mentioned_roles and mentioned_roles
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
): ):
is_mention = True is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态 # 如果是被@的消息,设置为唤醒状态
if is_slash_command or is_mention: if is_mention:
message_event.is_wake = True message_event.is_wake = True
message_event.is_at_or_wake_command = True message_event.is_at_or_wake_command = True
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
) )
abm.message = [Plain(text=message_str_for_filter)] abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id) abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id) abm.message_id = str(ctx.interaction.id)
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info( def _extract_command_info(
event_filter: Any, event_filter: Any,
handler_metadata: StarHandlerMetadata, handler_metadata: StarHandlerMetadata,
) -> tuple[str, str, CommandFilter] | None: ) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息""" """从事件过滤器中提取指令信息"""
cmd_name = None cmd_name = None
# is_group = False # is_group = False
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import cast
import discord import discord
from discord.types.interactions import ComponentInteractionData
from astrbot import logger from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel() channel = await self._get_channel()
if not channel: if not channel:
return return
if not isinstance(channel, discord.abc.Messageable):
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
return
await channel.send(**kwargs) await channel.send(**kwargs)
except Exception as e: except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer) await self.send(buffer)
return await super().send_streaming(generator, use_fallback) return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None: async def _get_channel(
self,
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象""" """获取当前事件对应的频道对象"""
try: try:
channel_id = int(self.session_id) channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord( async def _parse_to_discord(
self, self,
message: MessageChain, message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: ) -> tuple[
str,
list[discord.File],
discord.ui.View | None,
list[discord.Embed],
str | int | None,
]:
"""将 MessageChain 解析为 Discord 发送所需的内容""" """将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = [] content_parts = []
files = [] files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message, self.message_obj.raw_message,
"add_reaction", "add_reaction",
): ):
await self.message_obj.raw_message.add_reaction(emoji) await cast(discord.Message, self.message_obj.raw_message).add_reaction(
emoji
)
except Exception as e: except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}") logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return ( return (
hasattr(self.message_obj, "raw_message") hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type") and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command == discord.InteractionType.application_command
) )
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return ( return (
hasattr(self.message_obj, "raw_message") hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type") and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.component
) )
def get_interaction_custom_id(self) -> str: def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id""" """获取交互组件的custom_id"""
if self.is_button_interaction(): if self.is_button_interaction():
try: try:
return self.message_obj.raw_message.data.get("custom_id", "") return cast(
ComponentInteractionData,
cast(discord.Interaction, self.message_obj.raw_message).data,
).get("custom_id", "")
except Exception: except Exception:
pass pass
return "" return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
): ):
return any( return any(
mention.id == int(self.message_obj.self_id) mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions for mention in cast(
discord.Message, self.message_obj.raw_message
).mentions
) )
return False return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message, self.message_obj.raw_message,
"clean_content", "clean_content",
): ):
return self.message_obj.raw_message.clean_content return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str return self.message_str
@@ -3,9 +3,14 @@ import base64
import json import json
import re import re
import uuid import uuid
from typing import cast
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
from astrbot import logger from astrbot import logger
@@ -74,6 +79,10 @@ class LarkPlatformAdapter(Platform):
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
): ):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api) res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = { wrapped = {
"zh_cn": { "zh_cn": {
@@ -114,14 +123,21 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="lark", name="lark",
description="飞书机器人官方 API 适配器", description="飞书机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
message = event.event.message message = event.event.message
if message is None:
logger.debug("[Lark] 事件中没有消息体(message is None)")
return
abm = AstrBotMessage() abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000 abm.timestamp = cast(int, message.create_time) // 1000
abm.message = [] abm.message = []
abm.type = ( abm.type = (
MessageType.GROUP_MESSAGE MessageType.GROUP_MESSAGE
@@ -136,14 +152,28 @@ class LarkPlatformAdapter(Platform):
at_list = {} at_list = {}
if message.mentions: if message.mentions:
for m in message.mentions: for m in message.mentions:
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.id is None:
if m.name == self.bot_name: continue
abm.self_id = m.id.open_id # 飞书 open_id 可能是 None,这里做个防护
open_id = m.id.open_id if m.id.open_id else ""
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
content_json_b = json.loads(message.content) if m.name == self.bot_name:
if m.id.open_id is not None:
abm.self_id = m.id.open_id
if message.content is None:
logger.warning("[Lark] 消息内容为空")
return
try:
content_json_b = json.loads(message.content)
except json.JSONDecodeError:
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if message.message_type == "text": if message.message_type == "text":
message_str_raw = content_json_b["text"] # 带有 @ 的消息 message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw) # at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分 # 拆分文本,去掉AT符号部分
@@ -168,27 +198,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls content_json_b = _ls
elif message.message_type == "image": elif message.message_type == "image":
content_json_b = [ content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}, {
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
] ]
if message.message_type in ("post", "image"): if message.message_type in ("post", "image"):
for comp in content_json_b: for comp in content_json_b:
if comp["tag"] == "at": if comp.get("tag") == "at":
abm.message.append(at_list[comp["user_id"]]) user_id = comp.get("user_id")
elif comp["tag"] == "text" and comp["text"].strip(): if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip())) abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img": elif comp.get("tag") == "img":
image_key = comp["image_key"] image_key = comp.get("image_key")
if not image_key:
continue
request = ( request = (
GetMessageResourceRequest.builder() GetMessageResourceRequest.builder()
.message_id(message.message_id) .message_id(cast(str, message.message_id))
.file_key(image_key) .file_key(image_key)
.type("image") .type("image")
.build() .build()
) )
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request) response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success(): if not response.success():
logger.error(f"无法下载飞书图片: {image_key}") logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read() image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode() image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64)) abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -196,6 +246,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message: for comp in abm.message:
if isinstance(comp, Comp.Plain): if isinstance(comp, Comp.Plain):
abm.message_str += comp.text abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
return
if (
event.event.sender is None
or event.event.sender.sender_id is None
or event.event.sender.sender_id.open_id is None
):
logger.error("[Lark] 消息发送者信息不完整")
return
abm.message_id = message.message_id abm.message_id = message.message_id
abm.raw_message = message abm.raw_message = message
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -235,5 +298,5 @@ class LarkPlatformAdapter(Platform):
await self.client._disconnect() await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭") logger.info("飞书(Lark) 适配器已被优雅地关闭")
def get_client(self) -> lark.Client: def get_client(self) -> lark.ws.Client:
return self.client return self.client
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO from io import BytesIO
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import (
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
from astrbot import logger from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "") file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"): elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file) image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"): elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://") base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str) image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue()) f.write(BytesIO(image_data).getvalue())
else: else:
file_path = comp.file file_path = comp.file if comp.file else ""
if image_file is None: if image_file is None:
image_file = open(file_path, "rb") if not file_path:
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
request = ( request = (
CreateImageRequest.builder() CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
) )
.build() .build()
) )
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
continue
response = await lark_client.im.v1.image.acreate(request) response = await lark_client.im.v1.image.acreate(request)
if not response.success(): if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}") logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
continue
if response.data is None:
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
continue
image_key = response.data.image_key image_key = response.data.image_key
logger.debug(image_key) logger.debug(image_key)
ret.append(_stage) ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build() .build()
) )
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
return
response = await self.bot.im.v1.message.areply(request) response = await self.bot.im.v1.message.areply(request)
if not response.success(): if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message) await super().send(message)
async def react(self, emoji: str): async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
request = ( request = (
CreateMessageReactionRequest.builder() CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id) .message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
) )
.build() .build()
) )
response = await self.bot.im.v1.message_reaction.acreate(request) response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success(): if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
@@ -1,7 +1,6 @@
import asyncio import asyncio
import os import os
import random import random
from collections.abc import Awaitable
from typing import Any from typing import Any
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict): if not isinstance(message.raw_message, dict):
message.raw_message = {} message.raw_message = {}
message.raw_message["poll"] = poll message.raw_message["poll"] = poll
message.poll = poll message.__setattr__("poll", poll)
except Exception: except Exception:
pass pass
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self, self,
session: MessageSession, session: MessageSession,
message_chain: MessageChain, message_chain: MessageChain,
) -> Awaitable[Any]: ) -> None:
if not self.api: if not self.api:
logger.error("[Misskey] API 客户端未初始化") logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain) return await super().send_by_session(session, message_chain)
@@ -3,6 +3,7 @@ import base64
import os import os
import random import random
import uuid import uuid
from typing import cast
import aiofiles import aiofiles
import botpy import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval: if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload) ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1 stream_payload["index"] += 1
stream_payload["id"] = ret["id"] stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time() last_edit_time = asyncio.get_event_loop().time()
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None return None
source = self.message_obj.raw_message source = self.message_obj.raw_message
assert isinstance(
if not isinstance(
source, source,
( (
botpy.message.Message, botpy.message.Message,
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage, botpy.message.DirectMessage,
botpy.message.C2CMessage, botpy.message.C2CMessage,
), ),
) ):
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
return None
( (
plain_text, plain_text,
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
): ):
return None return None
payload = { payload: dict = {
"content": plain_text, "content": plain_text,
"msg_id": self.message_obj.message_id, "msg_id": self.message_obj.message_id,
} }
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None ret = None
match type(source): match source:
case botpy.message.GroupMessage: case botpy.message.GroupMessage():
if not source.group_openid:
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
return None
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid, group_openid=source.group_openid,
**payload, **payload,
) )
case botpy.message.C2CMessage:
case botpy.message.C2CMessage():
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload, **payload,
) )
logger.debug(f"Message sent to C2C: {ret}") logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
case botpy.message.Message():
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
ret = await self.bot.api.post_message( ret = await self.bot.api.post_message(
channel_id=source.channel_id, channel_id=source.channel_id,
**payload, **payload,
) )
case botpy.message.DirectMessage:
case botpy.message.DirectMessage():
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
case _:
pass
await super().send(self.send_buffer) await super().send(self.send_buffer)
self.send_buffer = None self.send_buffer = None
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type, "file_type": file_type,
"srv_send_msg": False, "srv_send_msg": False,
} }
result = None
if "openid" in kwargs: if "openid" in kwargs:
payload["openid"] = kwargs["openid"] payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if "group_openid" in kwargs: elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"] payload["group_openid"] = kwargs["group_openid"]
route = Route( route = Route(
"POST", "POST",
"/v2/groups/{group_openid}/files", "/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"], group_openid=kwargs["group_openid"],
) )
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
else:
raise ValueError("Invalid upload parameters")
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to upload image, response is not dict: {result}"
)
return Media(
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record( async def upload_group_and_c2c_record(
self, self,
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if result: if result:
if not isinstance(result, dict):
logger.error(f"上传文件响应格式错误: {result}")
return None
return Media( return Media(
file_uuid=result.get("file_uuid"), file_uuid=result["file_uuid"],
file_info=result.get("file_info"), file_info=result["file_info"],
ttl=result.get("ttl", 0), ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
) )
except Exception as e: except Exception as e:
logger.error(f"上传请求错误: {e}") logger.error(f"上传请求错误: {e}")
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None, message_reference: message.Reference | None = None,
media: message.Media | None = None, media: message.Media | None = None,
msg_id: str | None = None, msg_id: str | None = None,
msg_seq: str = 1, msg_seq: int | None = 1,
event_id: str | None = None, event_id: str | None = None,
markdown: message.MarkdownPayload | None = None, markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None, keyboard: message.Keyboard | None = None,
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals() payload = locals()
payload.pop("self", None) payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid) route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result)
@staticmethod @staticmethod
async def _parse_to_qqofficial(message: MessageChain): async def _parse_to_qqofficial(message: MessageChain):
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path) image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"): elif i.file and i.file.startswith("base64://"):
image_base64 = i.file image_base64 = i.file
else: elif i.file:
image_base64 = file_to_base64(i.file) image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://") image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record): elif isinstance(i, Record):
if i.file: if i.file:
@@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
import time import time
from typing import cast
import botpy import botpy
import botpy.message import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE, MessageType.GROUP_MESSAGE,
) )
abm.session_id = ( abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
) )
self._commit(abm) self._commit(abm)
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
self.appid = platform_config["appid"] self.appid = platform_config["appid"]
self.secret = platform_config["secret"] self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"] self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"] qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"] guild_dm = platform_config["enable_guild_direct_message"]
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="qq_official", name="qq_official",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
) )
@staticmethod @staticmethod
def _parse_from_qqofficial( def _parse_from_qqofficial(
message: botpy.message.Message | botpy.message.GroupMessage, message: botpy.message.Message
| botpy.message.GroupMessage
| botpy.message.DirectMessage
| botpy.message.C2CMessage,
message_type: MessageType, message_type: MessageType,
): ):
abm = AstrBotMessage() abm = AstrBotMessage()
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.raw_message = message abm.raw_message = message
abm.message_id = message.id abm.message_id = message.id
abm.tag = "qq_official" # abm.tag = "qq_official"
msg: list[BaseMessageComponent] = [] msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance( if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message, message,
botpy.message.DirectMessage, botpy.message.DirectMessage,
): ):
try: if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id) abm.self_id = str(message.mentions[0].id)
except BaseException as _: else:
abm.self_id = "" abm.self_id = ""
plain_content = message.content.replace( plain_content = message.content.replace(
@@ -1,6 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import Any from typing import Any, cast
import botpy import botpy
import botpy.message import botpy.message
@@ -36,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE, MessageType.GROUP_MESSAGE,
) )
abm.session_id = ( abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
) )
self._commit(abm) self._commit(abm)
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="qq_official_webhook", name="qq_official_webhook",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
) )
async def run(self): async def run(self):
@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import cast
import quart import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13: if opcode == 13:
# validation # validation
signed = await self.webhook_validation(data) signed = await self.webhook_validation(cast(dict, data))
print(signed) print(signed)
return signed return signed
@@ -4,9 +4,11 @@ import hmac
import json import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
@@ -66,7 +68,7 @@ class SlackWebhookClient:
""" """
try: try:
# 获取请求体和头部 # 获取请求体和头部
body = await req.get_data() body = cast(bytes, await req.get_data())
event_data = json.loads(body.decode("utf-8")) event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature # Verify Slack request signature
@@ -139,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler self.event_handler = event_handler
self.socket_client = None self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest): async def _handle_events(
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
):
"""处理 Socket Mode 事件""" """处理 Socket Mode 事件"""
try: try:
if self.socket_client is None:
raise RuntimeError("Socket client is not initialized")
# 确认收到事件 # 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id) response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response) await self.socket_client.send_socket_mode_response(response)
@@ -3,8 +3,7 @@ import base64
import re import re
import time import time
import uuid import uuid
from collections.abc import Awaitable from typing import Any, cast
from typing import Any
import aiohttp import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="slack", name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。", description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}") logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = self.bot_self_id abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息 # 获取用户信息
user_id = event.get("user", "") user_id = event.get("user", "")
try: try:
user_info = await self.web_client.users_info(user=user_id) user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"] user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id) user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception: except Exception:
user_name = user_id user_name = user_id
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "") channel_id = event.get("channel", "")
try: try:
channel_info = await self.web_client.conversations_info(channel=channel_id) channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"] is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im: if is_im:
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions: for mention in mentions:
try: try:
mentioned_user = await self.web_client.users_info(user=mention) mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"] user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get( user_name = user_data.get("real_name") or user_data.get(
"name", "name",
mention, mention,
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
) )
raise Exception(f"下载文件失败: {resp.status}") raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]: async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id() self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}") logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -1,6 +1,7 @@
import asyncio import asyncio
import re import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Iterable
from typing import cast
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
if isinstance(segment, Image): if isinstance(segment, Image):
# upload file # upload file
url = segment.url or segment.file url = segment.url or segment.file
if url.startswith("http"): if url and url.startswith("http"):
return { return {
"type": "image", "type": "image",
"image_url": url, "image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section", "type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"}, "text": {"type": "mrkdwn", "text": "图片上传失败"},
} }
image_url = response["files"][0]["url_private"] image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}") logger.debug(f"Slack file upload response: {response}")
return { return {
"type": "image", "type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section", "type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"}, "text": {"type": "mrkdwn", "text": "文件上传失败"},
} }
file_url = response["files"][0]["permalink"] file_url = cast(list, response["files"])[0]["permalink"]
return { return {
"type": "section", "type": "section",
"text": { "text": {
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
) )
members = [] members = []
for member_id in members_response["members"]: for member_id in cast(Iterable, members_response["members"]):
try: try:
user_info = await self.web_client.users_info(user=member_id) user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"] user_data = cast(dict, user_info["user"])
members.append( members.append(
MessageMember( MessageMember(
user_id=member_id, user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息 # 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id)) members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"] channel_data = cast(dict, channel_info["channel"])
return Group( return Group(
group_id=channel_id, group_id=channel_id,
group_name=channel_data.get("name", ""), group_name=channel_data.get("name", ""),
@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import re import re
from typing import Any, cast
import telegramify_markdown import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply, Reply,
) )
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent): class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name, "chat_id": user_name,
} }
if has_reply: if has_reply:
payload["reply_to_message_id"] = reply_message_id payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id: if message_thread_id:
payload["message_thread_id"] = message_thread_id payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
md_text = telegramify_markdown.markdownify( md_text = telegramify_markdown.markdownify(
chunk, chunk,
max_line_length=None,
normalize_whitespace=False, normalize_whitespace=False,
) )
await client.send_message( await client.send_message(
text=md_text, text=md_text,
parse_mode="MarkdownV2", parse_mode="MarkdownV2",
**payload, **cast(Any, payload),
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.", f"MarkdownV2 send failed: {e}. Using plain text instead.",
) )
await client.send_message(text=chunk, **payload) await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload) await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"): path = await i.get_file()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") name = i.name or os.path.basename(path)
path = os.path.join(temp_dir, i.name) await client.send_document(
await download_file(i.file, path) document=path, filename=name, **cast(Any, payload)
i.file = path )
await client.send_document(document=i.file, filename=i.name, **payload)
elif isinstance(i, Record): elif isinstance(i, Record):
path = await i.convert_to_file_path() path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload) await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE: if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text delta += i.text
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload) await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue continue
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"): path = await i.get_file()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") name = i.name or os.path.basename(path)
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await self.client.send_document( await self.client.send_document(
document=i.file, document=path,
filename=i.name, filename=name,
**payload, **cast(Any, payload),
) )
continue continue
elif isinstance(i, Record): elif isinstance(i, Record):
path = await i.convert_to_file_path() path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload) await self.client.send_voice(voice=path, **cast(Any, payload))
continue continue
else: else:
logger.warning(f"不支持的消息类型: {type(i)}") logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else: else:
# delta 长度一般不会大于 4096,因此这里直接发送 # delta 长度一般不会大于 4096,因此这里直接发送
try: try:
msg = await self.client.send_message(text=delta, **payload) msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta current_content = delta
except Exception as e: except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}") logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
markdown_text = telegramify_markdown.markdownify( markdown_text = telegramify_markdown.markdownify(
delta, delta,
max_line_length=None,
normalize_whitespace=False, normalize_whitespace=False,
) )
await self.client.edit_message_text( await self.client.edit_message_text(
@@ -2,7 +2,7 @@ import asyncio
import os import os
import time import time
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
from astrbot import logger from astrbot import logger
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
abm.raw_message = data abm.raw_message = data
return abm return abm
def run(self) -> Awaitable[Any]: def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple): async def callback(data: tuple):
abm = await self.convert_message(data) abm = await self.convert_message(data)
await self.handle_msg(abm) await self.handle_msg(abm)
@@ -101,9 +101,9 @@ class WebChatMessageEvent(AstrMessageEvent):
return data return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id) await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message) await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
final_data = "" final_data = ""
@@ -4,6 +4,7 @@ import json
import os import os
import time import time
import traceback import traceback
from typing import cast
import aiohttp import aiohttp
import anyio import anyio
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
) )
self.base_url = f"http://{self.host}:{self.port}" self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码 self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join( self.credentials_file = os.path.join(
get_astrbot_data_path(), get_astrbot_data_path(),
"wechatpadpro_credentials.json", "wechatpadpro_credentials.json",
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
) )
await asyncio.sleep(5) await asyncio.sleep(5)
async def handle_websocket_message(self, message: str): async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。""" """处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}") logger.debug(f"收到 WebSocket 消息: {message}")
try: try:
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。""" """将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = raw_message abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id")) abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time") abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180: if int(time.time()) - abm.timestamp > 180:
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "") content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "") push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type") msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = "" abm.message_str = ""
abm.message = [] abm.message = []
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str, from_user_name: str,
to_user_name: str, to_user_name: str,
msg_id: int, msg_id: int,
): ) -> dict | None:
"""下载原始图片。""" """下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg" url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key} params = {"key": self.auth_key}
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息 # 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "") from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id") msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image( image_resp = await self._download_raw_image(
from_user_name, from_user_name,
to_user_name, to_user_name,
msg_id, msg_id,
) )
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = ( image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None) image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
) )
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0 bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id") new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser( data_parser = GeweDataParser(
content=content, content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
) )
voicemsg = data_parser._format_to_xml().find("voicemsg") voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0" bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0) length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice( voice_resp = await self.download_voice(
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid, bufid=bufid,
length=length, length=length,
) )
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None) voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data: if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data) voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -827,7 +843,8 @@ class WeChatPadProAdapter(Platform):
try: try:
if self.ws_handle_task: if self.ws_handle_task:
self.ws_handle_task.cancel() self.ws_handle_task.cancel()
self._shutdown_event.set() if self._shutdown_event is not None:
self._shutdown_event.set()
except Exception: except Exception:
pass pass
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list( async def get_contact_details_list(
self, self,
room_wx_id_list: list[str] = None, room_wx_id_list: list[str] | None = None,
user_names: list[str] = None, user_names: list[str] | None = None,
) -> dict | None: ) -> dict | None:
"""获取联系人详情列表。""" """获取联系人详情列表。"""
if room_wx_id_list is None: if room_wx_id_list is None:
@@ -2,7 +2,8 @@ import asyncio
import os import os
import sys import sys
import uuid import uuid
from typing import Any from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart import quart
from requests import Response from requests import Response
@@ -40,7 +41,7 @@ else:
class WecomServer: class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule( self.server.add_url_rule(
"/callback/command", "/callback/command",
@@ -60,7 +61,7 @@ class WecomServer:
config["corpid"].strip(), config["corpid"].strip(),
) )
self.callback = None self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
@@ -114,7 +115,7 @@ class WecomServer:
logger.error("解密失败,签名异常,请检查配置。") logger.error("解密失败,签名异常,请检查配置。")
raise raise
else: else:
msg = parse_message(xml) msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}") logger.info(f"解析成功: {msg}")
if self.callback: if self.callback:
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject # inject
self.wechat_kf_api = WeChatKF(client=self.client) self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client) self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api self.client.__setattr__("kf", self.wechat_kf_api)
self.client.kf_message = self.wechat_kf_message_api self.client.__setattr__("kf_message", self.wechat_kf_message_api)
self.client.API_BASE_URL = self.api_base_url self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage): async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if msg.type == "text": if isinstance(msg, TextMessage):
assert isinstance(msg, TextMessage)
abm.message_str = msg.content abm.message_str = msg.content
abm.self_id = str(msg.agent) abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)] abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif msg.type == "image": elif isinstance(msg, ImageMessage):
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]" abm.message_str = "[图片]"
abm.self_id = str(msg.agent) abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)] abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif msg.type == "voice": elif isinstance(msg, VoiceMessage):
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)] abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
else: else:
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype") msgtype = msg.get("msgtype")
external_userid = msg.get("external_userid") external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = msg abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理 abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf: if is_wechat_kf:
# 微信客服 # 微信客服
kf_message_api = getattr(self.client, "kf_message", None) kf_message_api = getattr(self.client, "kf_message", None)
if not kf_message_api: if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。") logger.warning("未找到微信客服发送消息方法。")
return return
assert isinstance(kf_message_api, WeChatKFMessage)
user_id = self.get_sender_id() user_id = self.get_sender_id()
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
@staticmethod @staticmethod
async def _send( async def _send(
message_chain: MessageChain, message_chain: MessageChain | None,
stream_id: str, stream_id: str,
queue_mgr: WecomAIQueueMgr, queue_mgr: WecomAIQueueMgr,
streaming: bool = False, streaming: bool = False,
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain | None):
"""发送消息""" """发送消息"""
raw = self.message_obj.raw_message raw = self.message_obj.raw_message
assert isinstance(raw, dict), ( assert isinstance(raw, dict), (
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
) )
stream_id = raw.get("stream_id", self.session_id) stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(message) await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False): async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计""" """流式发送消息,参考webchat的send_streaming设计"""
@@ -1,7 +1,8 @@
import asyncio import asyncio
import sys import sys
import uuid import uuid
from typing import Any from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart import quart
from requests import Response from requests import Response
@@ -36,7 +37,7 @@ else:
class WeixinOfficialAccountServer: class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token") self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key") self.encoding_aes_key = config.get("encoding_aes_key")
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
self.event_queue = event_queue self.event_queue = event_queue
self.callback = None self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
raise raise
else: else:
msg = parse_message(xml) msg = parse_message(xml)
if not msg:
logger.error("解析失败。msg为None。")
raise
logger.info(f"解析成功: {msg}") logger.info(f"解析成功: {msg}")
if self.callback: if self.callback:
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.config["secret"].strip(), self.config["secret"].strip(),
) )
self.client.API_BASE_URL = self.api_base_url self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重 # 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future # msgid -> Future
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None) await self.convert_message(msg, None)
else: else:
if msg.id in self.wexin_event_workers: if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id] future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}") logger.debug(f"duplicate message id checked: {msg.id}")
else: else:
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future) await self.convert_message(msg, future)
# I love shield so much! # I love shield so much!
result = await asyncio.wait_for( result = await asyncio.wait_for(
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60, 60,
) # wait for 60s ) # wait for 60s
logger.debug(f"Got future result: {result}") logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None) self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
async def convert_message( async def convert_message(
self, self,
msg, msg,
future: asyncio.Future = None, future: asyncio.Future | None = None,
) -> AstrBotMessage | None: ) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if isinstance(msg, TextMessage): if isinstance(msg, TextMessage):
abm.message_str = msg.content abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target) abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)] abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
elif msg.type == "image": elif msg.type == "image":
assert isinstance(msg, ImageMessage) assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]" abm.message_str = "[图片]"
abm.self_id = str(msg.target) abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)] abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
elif msg.type == "voice": elif msg.type == "voice":
assert isinstance(msg, VoiceMessage) assert isinstance(msg, VoiceMessage)
@@ -306,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)] abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
else: else:
logger.warning(f"暂未实现的事件: {msg.type}") logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None) if future:
future.set_result(None)
return return
# 很不优雅 :( # 很不优雅 :(
abm.raw_message = { abm.raw_message = {
@@ -1,5 +1,6 @@
import asyncio import asyncio
import uuid import uuid
from typing import cast
from wechatpy import WeChatClient from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
message_obj = self.message_obj message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False) active_send_mode = cast(dict, message_obj.raw_message).get(
"active_send_mode", False
)
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
# Split long text messages if needed # Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = TextReply( reply = TextReply(
content=chunk, content=chunk,
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = ImageReply( reply = ImageReply(
media_id=response["media_id"], media_id=response["media_id"],
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = VoiceReply( reply = VoiceReply(
media_id=response["media_id"], media_id=response["media_id"],
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
+3 -3
View File
@@ -4,7 +4,7 @@ import asyncio
import copy import copy
import json import json
import os import os
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import Any
import aiohttp import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
name: str, name: str,
func_args: list[dict], func_args: list[dict],
desc: str, desc: str,
handler: Callable[..., Awaitable[Any]], handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool: ) -> FuncTool:
params = { params = {
"type": "object", # hard-coded here "type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
name: str, name: str,
func_args: list, func_args: list,
desc: str, desc: str,
handler: Callable[..., Awaitable[Any]], handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None: ) -> None:
"""添加函数调用工具 """添加函数调用工具
+126 -70
View File
@@ -1,5 +1,6 @@
import asyncio import asyncio
import traceback import traceback
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +11,7 @@ from .entities import ProviderType
from .provider import ( from .provider import (
EmbeddingProvider, EmbeddingProvider,
Provider, Provider,
Providers,
RerankProvider, RerankProvider,
STTProvider, STTProvider,
TTSProvider, TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map from .register import llm_tools, provider_cls_map
@runtime_checkable
class HasInitialize(Protocol):
async def initialize(self) -> None: ...
class ProviderManager: class ProviderManager:
def __init__( def __init__(
self, self,
@@ -48,7 +55,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例""" """加载的 Rerank Provider 的实例"""
self.inst_map: dict[ self.inst_map: dict[
str, str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider, Providers,
] = {} ] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例""" """Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools self.llm_tools = llm_tools
@@ -123,15 +130,13 @@ class ProviderManager:
self.curr_provider_inst = prov self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global") sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None: async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例""" """根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id) return self.inst_map.get(provider_id)
def get_using_provider( def get_using_provider(
self, self, provider_type: ProviderType, umo=None
provider_type: ProviderType, ) -> Providers | None:
umo=None,
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。 """获取正在使用的提供商实例。
Args: Args:
@@ -191,7 +196,6 @@ class ProviderManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(e) logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get( selected_provider_id = sp.get(
"curr_provider", "curr_provider",
self.provider_settings.get("default_provider_id"), self.provider_settings.get("default_provider_id"),
@@ -210,15 +214,37 @@ class ProviderManager:
scope="global", scope="global",
scope_id="global", scope_id="global",
) )
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
temp_provider = (
self.inst_map.get(selected_provider_id)
if isinstance(selected_provider_id, str)
else None
)
self.curr_provider_inst = (
temp_provider if isinstance(temp_provider, Provider) else None
)
if not self.curr_provider_inst and self.provider_insts: if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0] self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) temp_stt = (
self.inst_map.get(selected_stt_provider_id)
if isinstance(selected_stt_provider_id, str)
else None
)
self.curr_stt_provider_inst = (
temp_stt if isinstance(temp_stt, STTProvider) else None
)
if not self.curr_stt_provider_inst and self.stt_provider_insts: if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) temp_tts = (
self.inst_map.get(selected_tts_provider_id)
if isinstance(selected_tts_provider_id, str)
else None
)
self.curr_tts_provider_inst = (
temp_tts if isinstance(temp_tts, TTSProvider) else None
)
if not self.curr_tts_provider_inst and self.tts_provider_insts: if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst = self.tts_provider_insts[0]
@@ -358,73 +384,103 @@ class ProviderManager:
provider_metadata.id = provider_config["id"] provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: match provider_metadata.provider_type:
# STT 任务 case ProviderType.SPEECH_TO_TEXT:
inst = cls_type(provider_config, self.provider_settings) # STT 任务
if not issubclass(cls_type, STTProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of STTProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.stt_provider_insts.append(inst) self.stt_provider_insts.append(inst)
if ( if (
self.provider_stt_settings.get("provider_id") self.provider_stt_settings.get("provider_id")
== provider_config["id"] == provider_config["id"]
): ):
self.curr_stt_provider_inst = inst self.curr_stt_provider_inst = inst
logger.info( logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
case ProviderType.TEXT_TO_SPEECH:
# TTS 任务
if not issubclass(cls_type, TTSProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of TTSProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.provider_settings.get("provider_id")
== provider_config["id"]
):
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
case ProviderType.CHAT_COMPLETION:
# 文本生成任务
if not issubclass(cls_type, Provider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of Provider"
)
inst = cls_type(
provider_config,
self.provider_settings,
) )
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: if isinstance(inst, HasInitialize):
# TTS 任务 await inst.initialize()
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): self.provider_insts.append(inst)
await inst.initialize() if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
self.tts_provider_insts.append(inst) case ProviderType.EMBEDDING:
if self.provider_settings.get("provider_id") == provider_config["id"]: if not issubclass(cls_type, EmbeddingProvider):
self.curr_tts_provider_inst = inst raise TypeError(
logger.info( f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", )
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.embedding_provider_insts.append(inst)
case ProviderType.RERANK:
if not issubclass(cls_type, RerankProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of RerankProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.rerank_provider_insts.append(inst)
case _:
# 未知供应商抛出异常,确保inst初始化
# Should be unreachable
raise Exception(
f"未知的提供商类型:{provider_metadata.provider_type}"
) )
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = cls_type(
provider_config,
self.provider_settings,
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst self.inst_map[provider_config["id"]] = inst
except Exception as e: except Exception as e:
+12 -1
View File
@@ -2,6 +2,7 @@ import abc
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
"""Provider Abstract Class""" """Provider Abstract Class"""
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误 - 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
""" """
... if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list): async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录""" """弹出 context 第一条非系统提示词对话记录"""
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0 self.last_sync_time = 0
self.timeout = Timeout(10.0) self.timeout = Timeout(10.0)
self.retry_count = 3 self.retry_count = 3
self.client = None self._client: AsyncClient | None = None
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout) self._client = AsyncClient(timeout=self.timeout)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client: if self._client:
await self.client.aclose() await self._client.aclose()
self._client = None
async def _sync_time(self): async def _sync_time(self):
try: try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1: if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider): class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = ( self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
) )
self.client = None self._client: AsyncClient | None = None
self.token = None self.token = None
self.token_expire = 0 self.token_expire = 0
self.voice_params = { self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"), "volume": provider_config.get("azure_tts_volume", "100"),
} }
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient( self._client = AsyncClient(
headers={ headers={
"User-Agent": f"AstrBot/{VERSION}", "User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml", "Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client: if self._client:
await self.client.aclose() await self._client.aclose()
self._client = None
async def _refresh_token(self): async def _refresh_token(self):
token_url = ( token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "") key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config) self.provider = self._parse_provider(key_value, provider_config)
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider: def _parse_provider(
self, key_value: str, config: dict
) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["): if key_value.lower().startswith("other["):
json_str = ""
try: try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL) match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match: if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
Returns: Returns:
重排序结果列表 重排序结果列表
""" """
if not self.client:
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
return []
if not documents: if not documents:
logger.warning("文档列表为空,返回空结果") logger.warning("文档列表为空,返回空结果")
return [] return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "") self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key dashscope.api_key = self.chosen_api_key
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
kwargs = { kwargs = {
"model": model, "model": model,
"text": text, "messages": None,
"api_key": self.chosen_api_key, "api_key": self.chosen_api_key,
"voice": self.voice or "Cherry", "voice": self.voice or "Cherry",
"text": text,
} }
if not self.voice: if not self.voice:
logging.warning( logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
from pyffmpeg import FFmpeg from pyffmpeg import FFmpeg
ff = FFmpeg() ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path) ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e: except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line # use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
self.headers = { self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}", "Authorization": f"Bearer {self.chosen_api_key}",
} }
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
async def _get_reference_id_by_character(self, character: str) -> str: async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id """获取角色的reference_id
Args: Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
pattern = r"^[a-fA-F0-9]{32}$" pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip())) return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict: async def _generate_request(self, text: str) -> ServeTTSRequest:
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询 # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip(): if self.reference_id and self.reference_id.strip():
# 验证reference_id格式 # 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():
f.write(chunk) f.write(chunk)
return path return path
text = await response.aread() body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}") raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
from typing import cast
from google import genai from google import genai
from google.genai import types from google.genai import types
from google.genai.errors import APIError from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.provider_config = provider_config self.provider_config = provider_config
self.provider_settings = provider_settings self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key") api_key: str = provider_config["embedding_api_key"]
api_base: str = provider_config.get("embedding_api_base") api_base: str = provider_config["embedding_api_base"]
timeout: int = int(provider_config.get("timeout", 20)) timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000) http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
model=self.model, model=self.model,
contents=text, contents=text,
) )
assert result.embeddings is not None
assert result.embeddings[0].values is not None
return result.embeddings[0].values return result.embeddings[0].values
except APIError as e: except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}") raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
try: try:
result = await self.client.models.embed_content( result = await self.client.models.embed_content(
model=self.model, model=self.model,
contents=texts, contents=cast(types.ContentListUnion, text),
) )
return [embedding.values for embedding in result.embeddings] assert result.embeddings is not None
embeddings: list[list[float]] = []
for embedding in result.embeddings:
assert embedding.values is not None
embeddings.append(embedding.values)
return embeddings
except APIError as e: except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
@@ -4,6 +4,7 @@ import json
import logging import logging
import random import random
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import cast
from google import genai from google import genai
from google.genai import types from google.genai import types
@@ -136,7 +137,7 @@ class ProviderGoogleGenAI(Provider):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态") logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"] modalities = ["Text"]
tool_list = [] tool_list: list[types.Tool] | None = []
model_name = self.get_model() model_name = self.get_model()
native_coderunner = self.provider_config.get("gm_native_coderunner", False) native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False) native_search = self.provider_config.get("gm_native_search", False)
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
logprobs=payloads.get("logprobs"), logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"), seed=payloads.get("seed"),
response_modalities=modalities, response_modalities=modalities,
tools=tool_list, tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None, safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=( thinking_config=(
types.ThinkingConfig( types.ThinkingConfig(
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
content_cls: type[types.Content], content_cls: type[types.Content],
) -> None: ) -> None:
if contents and isinstance(contents[-1], content_cls): if contents and isinstance(contents[-1], content_cls):
assert contents[-1].parts is not None
contents[-1].parts.extend(part) contents[-1].parts.extend(part)
else: else:
contents.append(content_cls(parts=part)) contents.append(content_cls(parts=part))
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
) )
result = await self.client.models.generate_content( result = await self.client.models.generate_content(
model=self.get_model(), model=self.get_model(),
contents=conversation, contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
logger.debug(f"genai result: {result}") logger.debug(f"genai result: {result}")
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
) )
result = await self.client.models.generate_content_stream( result = await self.client.models.generate_content_stream(
model=self.get_model(), model=self.get_model(),
contents=conversation, contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
break break
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body) return json.dumps(dict_body)
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
"""进行流式请求""" """进行流式请求"""
try: try:
async with ( async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
data = json.loads(message[6:]) data = json.loads(message[6:])
if "extra_info" in data: if "extra_info" in data:
continue continue
audio = data.get("data", {}).get("audio") audio: str | None = data.get("data", {}).get(
"audio"
)
if audio is not None: if audio is not None:
yield audio yield audio
except json.JSONDecodeError: except json.JSONDecodeError:
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
embedding = await self.client.embeddings.create(input=text, model=self.model) embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model) embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data] return [item.embedding for item in embeddings.data]
def get_dim(self) -> int: def get_dim(self) -> int:
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(tool_call, str): if isinstance(tool_call, str):
# workaround for #1359 # workaround for #1359
tool_call = json.loads(tool_call) tool_call = json.loads(tool_call)
if tools is None:
# 工具集未提供
# Should be unreachable
raise Exception("工具集未提供")
for tool in tools.func_list: for tool in tools.func_list:
if ( if (
tool_call.type == "function" tool_call.type == "function"
@@ -7,6 +7,7 @@ import asyncio
import os import os
import re import re
from datetime import datetime from datetime import datetime
from typing import cast
from funasr_onnx import SenseVoiceSmall from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
provider_settings: dict, provider_settings: dict,
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("stt_model")) self.set_model(provider_config["stt_model"])
self.model = None self.model = None
self.is_emotion = provider_config.get("is_emotion", False) self.is_emotion = provider_config.get("is_emotion", False)
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
res = await loop.run_in_executor( res = await loop.run_in_executor(
None, # 使用默认的线程池 None, # 使用默认的线程池
lambda: self.model(audio_url, language="auto", use_itn=True), lambda: cast(SenseVoiceSmall, self.model)(
audio_url, language="auto", use_itn=True
),
) )
# res = self.model(audio_url, language="auto", use_itn=True) # res = self.model(audio_url, language="auto", use_itn=True)
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
} }
if top_n is not None: if top_n is not None:
payload["top_n"] = top_n payload["top_n"] = top_n
assert self.client is not None
async with self.client.post( async with self.client.post(
f"{self.base_url}/v1/rerank", f"{self.base_url}/v1/rerank",
json=payload, json=payload,
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
timeout=provider_config.get("timeout", NOT_GIVEN), timeout=provider_config.get("timeout", NOT_GIVEN),
) )
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
async def _get_audio_format(self, file_path): async def _get_audio_format(self, file_path):
# 定义要检测的头部字节 # 定义要检测的头部字节
@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import uuid import uuid
from typing import cast
import whisper import whisper
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
provider_settings: dict, provider_settings: dict,
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
self.model = None self.model = None
async def initialize(self): async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await tencent_silk_to_wav(audio_url, output_path) await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path audio_url = output_path
if not self.model:
raise RuntimeError("Whisper 模型未初始化")
result = await loop.run_in_executor(None, self.model.transcribe, audio_url) result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result["text"] return cast(str, result["text"])
@@ -1,6 +1,11 @@
from typing import cast
from xinference_client.client.restful.async_restful_client import ( from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client, AsyncClient as Client,
) )
from xinference_client.client.restful.async_restful_client import (
AsyncRESTfulRerankModelHandle,
)
from astrbot import logger from astrbot import logger
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
False, False,
) )
self.client = None self.client = None
self.model = None self.model: AsyncRESTfulRerankModelHandle | None = None
self.model_uid = None self.model_uid = None
async def initialize(self): async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
return return
if self.model_uid: if self.model_uid:
self.model = await self.client.get_model(self.model_uid) self.model = cast(
AsyncRESTfulRerankModelHandle,
await self.client.get_model(self.model_uid),
)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}") logger.error(f"Failed to initialize Xinference model: {e}")
+2 -2
View File
@@ -285,7 +285,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。""" """获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider | None: def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args: Args:
@@ -296,7 +296,7 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION, provider_type=ProviderType.CHAT_COMPLETION,
umo=umo, umo=umo,
) )
if prov and not isinstance(prov, Provider): if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型") raise ValueError("返回的 Provider 不是 Provider 类型")
return prov return prov
+22 -5
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import Any
import docstring_parser import docstring_parser
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools from astrbot.core.provider.register import llm_tools
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: def get_handler_full_name(
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> str:
"""获取 Handler 的全名""" """获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}" return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create( def get_handler_or_create(
handler: Callable[..., Awaitable[Any]], handler: Callable[
...,
Awaitable[MessageEventResult | str | None]
| AsyncGenerator[MessageEventResult | str | None],
],
event_type: EventType, event_type: EventType,
dont_add=False, dont_add=False,
**kwargs, **kwargs,
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
for ( for (
sub_handle sub_handle
) in parent_register_commandable.parent_group.sub_command_filters: ) in parent_register_commandable.parent_group.sub_command_filters:
if isinstance(sub_handle, CommandGroupFilter):
continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。 # 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md() sub_handle_md = sub_handle.get_handler_md()
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
else: else:
# 裸指令 # 裸指令
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create( handler_md = get_handler_or_create(
awaitable, awaitable,
EventType.AdapterMessageEvent, EventType.AdapterMessageEvent,
@@ -237,7 +248,7 @@ class RegisteringCommandable:
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command command: Callable[..., Callable[..., None]] = register_command
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter): def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group self.parent_group = parent_group
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"): if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"] registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Callable[..., Awaitable[Any]]): def decorator(
awaitable: Callable[
...,
AsyncGenerator[MessageEventResult | str | None]
| Awaitable[MessageEventResult | str | None],
],
):
llm_tool_name = name_ if name_ else awaitable.__name__ llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or "" func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc) docstring = docstring_parser.parse(func_doc)
+85 -4
View File
@@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter from .filter import HandlerFilter
from .star import star_map from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
for handler in self._handlers: for handler in self._handlers:
print(handler.handler_full_name) print(handler.handler_full_name)
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAstrBotLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPlatformLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.AdapterMessageEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMRequestEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMResponseEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnDecoratingResultEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnCallingFuncToolEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAfterMessageSentEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: EventType,
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
def get_handlers_by_event_type( def get_handlers_by_event_type(
self, self,
event_type: EventType, event_type: EventType,
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
OnAfterMessageSentEvent = enum.auto() # 发送消息后 OnAfterMessageSentEvent = enum.auto() # 发送消息后
H = TypeVar("H", bound=Callable[..., Any])
@dataclass @dataclass
class StarHandlerMetadata: class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。""" """描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType event_type: EventType
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
handler_module_path: str handler_module_path: str
"""Handler 所在的模块路径。""" """Handler 所在的模块路径。"""
handler: Callable[..., Awaitable[Any]] handler: H
"""Handler 的函数对象,应当是一个异步函数""" """Handler 的函数对象,应当是一个异步函数"""
event_filters: list[HandlerFilter] event_filters: list[HandlerFilter]
+3 -3
View File
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
async def check_update( async def check_update(
self, self,
url: str, url: str | None,
current_version: str, current_version: str | None,
consider_prerelease: bool = True, consider_prerelease: bool = True,
) -> ReleaseInfo: ) -> ReleaseInfo | None:
"""检查更新""" """检查更新"""
return await super().check_update( return await super().check_update(
self.ASTRBOT_RELEASE_API, self.ASTRBOT_RELEASE_API,
+1 -1
View File
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
return False return False
def save_temp_img(img: Image.Image | str) -> str: def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp") temp_dir = os.path.join(get_astrbot_data_path(), "temp")
# 获得文件创建时间,清除超过 12 小时的 # 获得文件创建时间,清除超过 12 小时的
try: try:
+17 -10
View File
@@ -20,16 +20,16 @@ class SessionController:
def __init__(self): def __init__(self):
self.future = asyncio.Future() self.future = asyncio.Future()
self.current_event: asyncio.Event = None self.current_event: asyncio.Event | None = None
"""当前正在等待的所用的异步事件""" """当前正在等待的所用的异步事件"""
self.ts: float = None self.ts: float | None = None
"""上次保持(keep)开始时的时间""" """上次保持(keep)开始时的时间"""
self.timeout: float | int = None self.timeout: float | int | None = None
"""上次保持(keep)开始时的超时时间""" """上次保持(keep)开始时的超时时间"""
self.history_chains: list[list[Comp.BaseMessageComponent]] = [] self.history_chains: list[list[Comp.BaseMessageComponent]] = []
def stop(self, error: Exception = None): def stop(self, error: Exception | None = None):
"""立即结束这个会话""" """立即结束这个会话"""
if not self.future.done(): if not self.future.done():
if error: if error:
@@ -53,6 +53,8 @@ class SessionController:
self.stop() self.stop()
return return
else: else:
assert self.timeout is not None
assert self.ts is not None
left_timeout = self.timeout - (new_ts - self.ts) left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout timeout = left_timeout + timeout
if timeout <= 0: if timeout <= 0:
@@ -69,7 +71,7 @@ class SessionController:
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout: int): async def _holding(self, event: asyncio.Event, timeout: float):
"""等待事件结束或超时""" """等待事件结束或超时"""
try: try:
await asyncio.wait_for(event.wait(), timeout) await asyncio.wait_for(event.wait(), timeout)
@@ -108,7 +110,9 @@ class SessionWaiter:
): ):
self.session_id = session_id self.session_id = session_id
self.session_filter = session_filter self.session_filter = session_filter
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数 self.handler: (
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
) = None # 处理函数
self.session_controller = SessionController() self.session_controller = SessionController()
self.record_history_chains = record_history_chains self.record_history_chains = record_history_chains
@@ -119,7 +123,7 @@ class SessionWaiter:
async def register_wait( async def register_wait(
self, self,
handler: Callable[[str], Awaitable[Any]], handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
timeout: int = 30, timeout: int = 30,
) -> Any: ) -> Any:
"""等待外部输入并处理""" """等待外部输入并处理"""
@@ -137,7 +141,7 @@ class SessionWaiter:
finally: finally:
self._cleanup() self._cleanup()
def _cleanup(self, error: Exception = None): def _cleanup(self, error: Exception | None = None):
"""清理会话""" """清理会话"""
USER_SESSIONS.pop(self.session_id, None) USER_SESSIONS.pop(self.session_id, None)
try: try:
@@ -161,6 +165,7 @@ class SessionWaiter:
) )
try: try:
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行 # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
assert session.handler is not None
await session.handler(session.session_controller, event) await session.handler(session.session_controller, event)
except Exception as e: except Exception as e:
session.session_controller.stop(e) session.session_controller.stop(e)
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
:param record_history_chain: 是否自动记录历史消息链可以通过 controller.get_history_chains() 获取深拷贝 :param record_history_chain: 是否自动记录历史消息链可以通过 controller.get_history_chains() 获取深拷贝
""" """
def decorator(func: Callable[[str], Awaitable[Any]]): def decorator(
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
):
@functools.wraps(func) @functools.wraps(func)
async def wrapper( async def wrapper(
event: AstrMessageEvent, event: AstrMessageEvent,
session_filter: SessionFilter = None, session_filter: SessionFilter | None = None,
*args, *args,
**kwargs, **kwargs,
): ):
+32
View File
@@ -53,6 +53,38 @@ class SharedPreferences:
ret = await self.db_helper.get_preferences(scope, scope_id, key) ret = await self.db_helper.get_preferences(scope, scope_id, key)
return ret return ret
@overload
async def session_get(
self,
umo: str,
key: str,
default: _VT = None,
) -> _VT: ...
@overload
async def session_get(
self,
umo: None,
key: str,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: str,
key: None,
default: Any = None,
) -> list[Preference]: ...
@overload
async def session_get(
self,
umo: None,
key: None,
default: Any = None,
) -> list[Preference]: ...
async def session_get( async def session_get(
self, self,
umo: str | None, umo: str | None,
+2 -2
View File
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
class RenderStrategy(ABC): class RenderStrategy(ABC):
@abstractmethod @abstractmethod
def render(self, text: str, return_url: bool) -> str: async def render(self, text: str, return_url: bool) -> str:
pass pass
@abstractmethod @abstractmethod
def render_custom_template( async def render_custom_template(
self, self,
tmpl_str: str, tmpl_str: str,
tmpl_data: dict, tmpl_data: dict,
+25 -31
View File
@@ -20,7 +20,7 @@ class FontManager:
_font_cache = {} _font_cache = {}
@classmethod @classmethod
def get_font(cls, size: int) -> ImageFont.FreeTypeFont: def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
"""获取指定大小的字体,优先从缓存获取""" """获取指定大小的字体,优先从缓存获取"""
if size in cls._font_cache: if size in cls._font_cache:
return cls._font_cache[size] return cls._font_cache[size]
@@ -66,23 +66,17 @@ class TextMeasurer:
"""测量文本尺寸的工具类""" """测量文本尺寸的工具类"""
@staticmethod @staticmethod
def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
"""获取文本的尺寸""" """获取文本的尺寸"""
try:
# PIL 9.0.0 以上版本 # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
return ( left, top, right, bottom = font.getbbox("Hello world")
font.getbbox(text)[2:] return int(right - left), int(bottom - top)
if hasattr(font, "getbbox")
else font.getsize(text)
)
except Exception:
# 兼容旧版本
return font.getsize(text)
@staticmethod @staticmethod
def split_text_to_fit_width( def split_text_to_fit_width(
text: str, font: ImageFont.FreeTypeFont, max_width: int text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
) -> List[str]: ) -> list[str]:
"""将文本拆分为多行,确保每行不超过指定宽度""" """将文本拆分为多行,确保每行不超过指定宽度"""
lines = [] lines = []
if not text: if not text:
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
# 倾斜变换,使用仿射变换实现斜体效果 # 倾斜变换,使用仿射变换实现斜体效果
# 变换矩阵: [1, 0.2, 0, 0, 1, 0] # 变换矩阵: [1, 0.2, 0, 0, 1, 0]
italic_img = text_img.transform( italic_img = text_img.transform(
text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
) )
# 粘贴到原图像 # 粘贴到原图像
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
class CodeBlockElement(MarkdownElement): class CodeBlockElement(MarkdownElement):
"""代码块元素""" """代码块元素"""
def __init__(self, content: List[str]): def __init__(self, content: list[str]):
super().__init__("\n".join(content)) super().__init__("\n".join(content))
def calculate_height(self, image_width: int, font_size: int) -> int: def calculate_height(self, image_width: int, font_size: int) -> int:
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
def render( def render(
self, self,
image: Image.Image, image: Image.Image,
draw: ImageDraw.Draw, draw: ImageDraw.ImageDraw,
x: int, x: int,
y: int, y: int,
image_width: int, image_width: int,
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
if pasted_image.width > max_width: if pasted_image.width > max_width:
ratio = max_width / pasted_image.width ratio = max_width / pasted_image.width
new_size = (int(max_width), int(pasted_image.height * ratio)) new_size = (int(max_width), int(pasted_image.height * ratio))
pasted_image = pasted_image.resize(new_size, Image.LANCZOS) pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
# 计算居中位置 # 计算居中位置
paste_x = x + (image_width - pasted_image.width) // 2 - 10 paste_x = x + (image_width - pasted_image.width) // 2 - 10
@@ -705,7 +699,7 @@ class MarkdownParser:
"""Markdown解析器,将文本解析为元素""" """Markdown解析器,将文本解析为元素"""
@staticmethod @staticmethod
async def parse(text: str) -> List[MarkdownElement]: async def parse(text: str) -> list[MarkdownElement]:
elements = [] elements = []
lines = text.split("\n") lines = text.split("\n")
@@ -847,7 +841,7 @@ class MarkdownRenderer:
self, self,
font_size: int = 26, font_size: int = 26,
width: int = 800, width: int = 800,
bg_color: Tuple[int, int, int] = (255, 255, 255), bg_color: tuple[int, int, int] = (255, 255, 255),
): ):
self.font_size = font_size self.font_size = font_size
self.width = width self.width = width
+1 -1
View File
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
from pyffmpeg import FFmpeg from pyffmpeg import FFmpeg
ff = FFmpeg() ff = FFmpeg()
ff.convert(input=input_path, output=output_path) ff.convert(input_file=input_path, output_file=output_path)
except Exception as e: except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
+6 -3
View File
@@ -60,9 +60,12 @@ class VersionComparator:
return -1 return -1
if isinstance(p1, str) and isinstance(p2, int): if isinstance(p1, str) and isinstance(p2, int):
return 1 return 1
if (isinstance(p1, int) and isinstance(p2, int)) or ( if isinstance(p1, int) and isinstance(p2, int):
isinstance(p1, str) and isinstance(p2, str) if p1 > p2:
): return 1
if p1 < p2:
return -1
if isinstance(p1, str) and isinstance(p2, str):
if p1 > p2: if p1 > p2:
return 1 return 1
if p1 < p2: if p1 < p2:
+14 -9
View File
@@ -4,7 +4,9 @@ import mimetypes
import os import os
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import cast
from quart import Response as QuartResponse
from quart import g, make_response, request, send_file from quart import g, make_response, request, send_file
from astrbot.core import logger from astrbot.core import logger
@@ -424,16 +426,19 @@ class ChatRoute(Route):
sender_name=username, sender_name=username,
) )
response = await make_response( response = cast(
stream(), QuartResponse,
{ await make_response(
"Content-Type": "text/event-stream", stream(),
"Cache-Control": "no-cache", {
"Transfer-Encoding": "chunked", "Content-Type": "text/event-stream",
"Connection": "keep-alive", "Cache-Control": "no-cache",
}, "Transfer-Encoding": "chunked",
"Connection": "keep-alive",
},
),
) )
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue] response.timeout = None # fix SSE auto disconnect issue
return response return response
async def delete_webchat_session(self): async def delete_webchat_session(self):
+6 -5
View File
@@ -3,6 +3,7 @@ import inspect
import os import os
import traceback import traceback
import uuid import uuid
from typing import Any
from quart import request from quart import request
@@ -26,7 +27,7 @@ from astrbot.core.star.star import star_registry
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
def try_cast(value: str, type_: str): def try_cast(value: Any, type_: str):
if type_ == "int": if type_ == "int":
try: try:
return int(value) return int(value)
@@ -505,9 +506,9 @@ class ConfigRoute(Route):
if not isinstance(inst, EmbeddingProvider): if not isinstance(inst, EmbeddingProvider):
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
# 初始化 init_fn = getattr(inst, "initialize", None)
if getattr(inst, "initialize", None): if inspect.iscoroutinefunction(init_fn):
await inst.initialize() await init_fn()
# 获取嵌入向量维度 # 获取嵌入向量维度
vec = await inst.get_embedding("echo") vec = await inst.get_embedding("echo")
@@ -777,7 +778,7 @@ class ConfigRoute(Route):
return {"metadata": CONFIG_METADATA_2, "config": config} return {"metadata": CONFIG_METADATA_2, "config": config}
async def _get_plugin_config(self, plugin_name: str): async def _get_plugin_config(self, plugin_name: str):
ret = {"metadata": None, "config": None} ret: dict = {"metadata": None, "config": None}
for plugin_md in star_registry: for plugin_md in star_registry:
if plugin_md.name == plugin_name: if plugin_md.name == plugin_name:
+13 -8
View File
@@ -1,6 +1,8 @@
import asyncio import asyncio
import json import json
from typing import cast
from quart import Response as QuartResponse
from quart import make_response from quart import make_response
from astrbot.core import LogBroker, logger from astrbot.core import LogBroker, logger
@@ -39,14 +41,17 @@ class LogRoute(Route):
if queue: if queue:
self.log_broker.unregister(queue) self.log_broker.unregister(queue)
response = await make_response( response = cast(
stream(), QuartResponse,
{ await make_response(
"Content-Type": "text/event-stream", stream(),
"Cache-Control": "no-cache", {
"Connection": "keep-alive", "Content-Type": "text/event-stream",
"Transfer-Encoding": "chunked", "Cache-Control": "no-cache",
}, "Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
),
) )
response.timeout = None response.timeout = None
return response return response
+4
View File
@@ -579,6 +579,10 @@ class PluginRoute(Route):
logger.warning(f"插件 {plugin_name} 不存在") logger.warning(f"插件 {plugin_name} 不存在")
return Response().error(f"插件 {plugin_name} 不存在").__dict__ return Response().error(f"插件 {plugin_name} 不存在").__dict__
if not plugin_obj.root_dir_name:
logger.warning(f"插件 {plugin_name} 目录不存在")
return Response().error(f"插件 {plugin_name} 目录不存在").__dict__
plugin_dir = os.path.join( plugin_dir = os.path.join(
self.plugin_manager.plugin_store_path, self.plugin_manager.plugin_store_path,
plugin_obj.root_dir_name or "", plugin_obj.root_dir_name or "",
+2
View File
@@ -12,6 +12,8 @@ class RouteContext:
class Route: class Route:
routes: list | dict
def __init__(self, context: RouteContext): def __init__(self, context: RouteContext):
self.app = context.app self.app = context.app
self.config = context.config self.config = context.config
+6 -3
View File
@@ -2,9 +2,12 @@ import asyncio
import logging import logging
import os import os
import socket import socket
from typing import cast
import jwt import jwt
import psutil import psutil
from flask.json.provider import DefaultJSONProvider
from psutil._common import addr as psutil_addr
from quart import Quart, g, jsonify, request from quart import Quart, g, jsonify, request
from quart.logging import default_handler from quart.logging import default_handler
@@ -21,7 +24,7 @@ from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute from .routes.session_management import SessionManagementRoute
from .routes.t2i import T2iRoute from .routes.t2i import T2iRoute
APP: Quart = None APP: Quart
class AstrBotDashboard: class AstrBotDashboard:
@@ -48,7 +51,7 @@ class AstrBotDashboard:
self.app.config["MAX_CONTENT_LENGTH"] = ( self.app.config["MAX_CONTENT_LENGTH"] = (
128 * 1024 * 1024 128 * 1024 * 1024
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
self.app.json.sort_keys = False cast(DefaultJSONProvider, self.app.json).sort_keys = False
self.app.before_request(self.auth_middleware) self.app.before_request(self.auth_middleware)
# token 用于验证请求 # token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler) logging.getLogger(self.app.name).removeHandler(default_handler)
@@ -147,7 +150,7 @@ class AstrBotDashboard:
"""获取占用端口的进程详细信息""" """获取占用端口的进程详细信息"""
try: try:
for conn in psutil.net_connections(kind="inet"): for conn in psutil.net_connections(kind="inet"):
if conn.laddr.port == port: if cast(psutil_addr, conn.laddr).port == port:
try: try:
process = psutil.Process(conn.pid) process = psutil.Process(conn.pid)
# 获取详细信息 # 获取详细信息
+5
View File
@@ -139,6 +139,11 @@ class ProcessLLMRequest:
# group name identifier # group name identifier
if cfg.get("group_name_display") and event.message_obj.group_id: if cfg.get("group_name_display") and event.message_obj.group_id:
if not event.message_obj.group:
logger.error(
f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
)
return
group_name = event.message_obj.group.group_name group_name = event.message_obj.group.group_name
if group_name: if group_name:
req.system_prompt += f"\nGroup name: {group_name}\n" req.system_prompt += f"\nGroup name: {group_name}\n"
+8 -1
View File
@@ -14,6 +14,7 @@ from astrbot.api import llm_tool, logger, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
from astrbot.api.message_components import File, Image from astrbot.api.message_components import File, Image
from astrbot.api.provider import ProviderRequest from astrbot.api.provider import ProviderRequest
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url from astrbot.core.utils.io import download_file, download_image_by_url
@@ -224,6 +225,8 @@ class Main(star.Star):
del self.user_waiting[uid] del self.user_waiting[uid]
elif isinstance(comp, Image): elif isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file image_url = comp.url if comp.url else comp.file
if image_url is None:
raise ValueError("Image URL is None")
if image_url.startswith("http"): if image_url.startswith("http"):
image_path = await download_image_by_url(image_url) image_path = await download_image_by_url(image_url)
elif image_url.startswith("file:///"): elif image_url.startswith("file:///"):
@@ -240,6 +243,8 @@ class Main(star.Star):
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer: if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()] files = self.user_file_msg_buffer[event.get_session_id()]
if not request.prompt:
request.prompt = ""
request.prompt += f"\nUser provided files: {files}" request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi") @filter.command_group("pi")
@@ -477,7 +482,9 @@ class Main(star.Star):
# file_s3_url = await self.file_upload(file_path) # file_s3_url = await self.file_upload(file_path)
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}") # logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_path)] chain: list[BaseMessageComponent] = [
File(name=file_name, file=file_path)
]
yield event.set_result(MessageEventResult(chain=chain)) yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log or "[Error]: " in log: elif "Traceback (most recent call last)" in log or "[Error]: " in log:
+11 -8
View File
@@ -5,6 +5,7 @@ import uuid
import zoneinfo import zoneinfo
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from astrbot.api import llm_tool, logger, star from astrbot.api import llm_tool, logger, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
@@ -62,13 +63,13 @@ class Main(star.Star):
misfire_grace_time=60, misfire_grace_time=60,
) )
elif "cron" in reminder: elif "cron" in reminder:
trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"]))
self.scheduler.add_job( self.scheduler.add_job(
self._reminder_callback, self._reminder_callback,
trigger="cron", trigger=trigger,
id=id_, id=id_,
args=[group, reminder], args=[group, reminder],
misfire_grace_time=60, misfire_grace_time=60,
**self._parse_cron_expr(reminder["cron"]),
) )
def check_is_outdated(self, reminder: dict): def check_is_outdated(self, reminder: dict):
@@ -101,10 +102,10 @@ class Main(star.Star):
async def reminder_tool( async def reminder_tool(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
text: str = None, text: str | None = None,
datetime_str: str = None, datetime_str: str | None = None,
cron_expression: str = None, cron_expression: str | None = None,
human_readable_cron: str = None, human_readable_cron: str | None = None,
): ):
"""Call this function when user is asking for setting a reminder. """Call this function when user is asking for setting a reminder.
@@ -139,17 +140,19 @@ class Main(star.Star):
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
} }
self.reminder_data[event.unified_msg_origin].append(d) self.reminder_data[event.unified_msg_origin].append(d)
trigger = CronTrigger(**self._parse_cron_expr(cron_expression))
self.scheduler.add_job( self.scheduler.add_job(
self._reminder_callback, self._reminder_callback,
"cron", trigger,
id=d["id"], id=d["id"],
misfire_grace_time=60, misfire_grace_time=60,
**self._parse_cron_expr(cron_expression),
args=[event.unified_msg_origin, d], args=[event.unified_msg_origin, d],
) )
if human_readable_cron: if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})" reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else: else:
if datetime_str is None:
raise ValueError("datetime_str cannot be None.")
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())} d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
self.reminder_data[event.unified_msg_origin].append(d) self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime( datetime_scheduled = datetime.datetime.strptime(
+16 -9
View File
@@ -3,7 +3,7 @@ import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
from aiohttp import ClientSession from aiohttp import ClientSession
from bs4 import BeautifulSoup from bs4 import BeautifulSoup, Tag
HEADERS = { HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0", "User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
@@ -45,13 +45,13 @@ class SearchEngine:
self.page = 1 self.page = 1
self.headers = HEADERS self.headers = HEADERS
def _set_selector(self, selector: str) -> None: def _set_selector(self, selector: str) -> str:
raise NotImplementedError raise NotImplementedError
def _get_next_page(self): def _get_next_page(self, query: str):
raise NotImplementedError raise NotImplementedError
async def _get_html(self, url: str, data: dict = None) -> str: async def _get_html(self, url: str, data: dict | None = None) -> str:
headers = self.headers headers = self.headers
headers["Referer"] = url headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS) headers["User-Agent"] = random.choice(USER_AGENTS)
@@ -83,6 +83,9 @@ class SearchEngine:
"""清理文本,去除空格、换行符等""" """清理文本,去除空格、换行符等"""
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def _get_url(self, tag: Tag) -> str:
return self.tidy_text(tag.get_text())
async def search(self, query: str, num_results: int) -> list[SearchResult]: async def search(self, query: str, num_results: int) -> list[SearchResult]:
query = urllib.parse.quote(query) query = urllib.parse.quote(query)
@@ -92,12 +95,16 @@ class SearchEngine:
links = soup.select(self._set_selector("links")) links = soup.select(self._set_selector("links"))
results = [] results = []
for link in links: for link in links:
title = self.tidy_text( # Safely get the title text (select_one may return None)
link.select_one(self._set_selector("title")).text, title_elem = link.select_one(self._set_selector("title"))
) title = ""
url = link.select_one(self._set_selector("url")) if title_elem is not None:
title = self.tidy_text(title_elem.get_text())
url_tag = link.select_one(self._set_selector("url"))
snippet = "" snippet = ""
if title and url: if title and url_tag:
url = self._get_url(url_tag)
results.append(SearchResult(title=title, url=url, snippet=snippet)) results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results return results[:num_results] if len(results) > num_results else results
except Exception as e: except Exception as e:
+1 -9
View File
@@ -1,4 +1,4 @@
from . import USER_AGENT_BING, SearchEngine, SearchResult from . import USER_AGENT_BING, SearchEngine
class Bing(SearchEngine): class Bing(SearchEngine):
@@ -28,11 +28,3 @@ class Bing(SearchEngine):
self.base_url = base_url self.base_url = base_url
continue continue
raise Exception("Bing search failed") raise Exception("Bing search failed")
async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if not isinstance(result.url, str):
result.url = result.url.text
return results
+10 -4
View File
@@ -1,7 +1,8 @@
import random import random
import re import re
from typing import cast
from bs4 import BeautifulSoup from bs4 import BeautifulSoup, Tag
from . import USER_AGENTS, SearchEngine, SearchResult from . import USER_AGENTS, SearchEngine, SearchResult
@@ -26,10 +27,12 @@ class Sogo(SearchEngine):
url = f"{self.base_url}/web?query={query}" url = f"{self.base_url}/web?query={query}"
return await self._get_html(url, None) return await self._get_html(url, None)
def _get_url(self, tag: Tag) -> str:
return cast(str, tag.get("href"))
async def search(self, query: str, num_results: int) -> list[SearchResult]: async def search(self, query: str, num_results: int) -> list[SearchResult]:
results = await super().search(query, num_results) results = await super().search(query, num_results)
for result in results: for result in results:
result.url = result.url.get("href")
if result.url.startswith("/link?"): if result.url.startswith("/link?"):
result.url = self.base_url + result.url result.url = self.base_url + result.url
result.url = await self._parse_url(result.url) result.url = await self._parse_url(result.url)
@@ -40,7 +43,10 @@ class Sogo(SearchEngine):
soup = BeautifulSoup(html, "html.parser") soup = BeautifulSoup(html, "html.parser")
script = soup.find("script") script = soup.find("script")
if script: if script:
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group( script_text = (
1, script.string if script.string is not None else script.get_text()
) )
match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
if match:
url = match.group(1)
return url return url
+90
View File
@@ -0,0 +1,90 @@
"""Minimal type stubs for faiss used in this project.
This file only exposes a small subset of the faiss API that the
project uses, including the runtime-monkeypatched signatures such as
`Index.add_with_ids` so Pyright/Pylance stops reporting false positives.
"""
from typing import Any, overload
import numpy as np
class Index:
d: int
ntotal: int
code_size: int
nprobe: int
def add(self, x: np.ndarray) -> None: ...
def add_with_ids(self, x: np.ndarray, ids: np.ndarray) -> None: ...
def search(
self,
x: np.ndarray,
k: int,
*,
params: Any = ...,
D: np.ndarray | None = ...,
I: np.ndarray | None = ...,
) -> tuple[np.ndarray, np.ndarray]: ...
def remove_ids(self, x: np.ndarray) -> int: ...
@overload
def reconstruct(self, key: int) -> np.ndarray: ...
@overload
def reconstruct(self, key: int, x: np.ndarray) -> None: ...
def reconstruct(
self, key: int, x: np.ndarray | None = ...
) -> np.ndarray | None: ...
@overload
def reconstruct_n(self, n0: int, ni: int) -> np.ndarray: ...
@overload
def reconstruct_n(self, n0: int, ni: int, x: np.ndarray) -> None: ...
def reconstruct_n(
self, n0: int = ..., ni: int = ..., x: np.ndarray | None = ...
) -> np.ndarray | None: ...
def range_search(
self, x: np.ndarray, thresh: float, *, params: Any = ...
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
def add_sa_codes(self, codes: np.ndarray, ids: np.ndarray | None = ...) -> None: ...
def sa_encode(self, x: np.ndarray) -> np.ndarray: ...
def sa_decode(self, codes: np.ndarray) -> np.ndarray: ...
class IndexFlatL2(Index):
def __init__(self, d: int) -> None: ...
class IndexIDMap(Index):
index: Index
def __init__(self, index: Index) -> None: ...
def read_index(path: str) -> Index: ...
def write_index(index: Index, path: str | None = ...) -> None: ...
def normalize_L2(x: np.ndarray) -> None: ...
# Additional concrete-ish classes exposed by some faiss builds (SWIG helpers
# expose `downcast_*` helpers to convert generic objects to these concrete
# types). We keep these minimal — only the names are important for typing.
class IndexBinary(Index):
def __init__(self, d: int) -> None: ...
class InvertedLists:
def __len__(self) -> int: ...
class AdditiveQuantizer:
pass
class Quantizer:
pass
class VectorTransform:
pass
# SWIG-provided downcast helpers (present in some faiss Python builds).
def downcast_IndexBinary(obj: Any) -> IndexBinary: ...
def downcast_InvertedLists(obj: Any) -> InvertedLists: ...
def downcast_AdditiveQuantizer(obj: Any) -> AdditiveQuantizer: ...
def downcast_Quantizer(obj: Any) -> Quantizer: ...
def downcast_VectorTransform(obj: Any) -> VectorTransform: ...
def downcast_index(obj: Any) -> Index: ...
# version exposed by runtime
__version__: str