chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
This commit is contained in:
@@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
llm_resp_result = None
|
llm_resp_result = None
|
||||||
|
|
||||||
async for llm_response in self._iter_llm_responses():
|
async for llm_response in self._iter_llm_responses():
|
||||||
assert isinstance(llm_response, LLMResponse)
|
|
||||||
if llm_response.is_chunk:
|
if llm_response.is_chunk:
|
||||||
if llm_response.result_chain:
|
if llm_response.result_chain:
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from typing import Any, Generic
|
from typing import Any, Generic
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
@@ -7,6 +7,8 @@ from deprecated import deprecated
|
|||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult
|
||||||
|
|
||||||
from .run_context import ContextWrapper, TContext
|
from .run_context import ContextWrapper, TContext
|
||||||
|
|
||||||
ParametersType = dict[str, Any]
|
ParametersType = dict[str, Any]
|
||||||
@@ -38,7 +40,10 @@ class ToolSchema:
|
|||||||
class FunctionTool(ToolSchema, Generic[TContext]):
|
class FunctionTool(ToolSchema, Generic[TContext]):
|
||||||
"""A callable tool, for function calling."""
|
"""A callable tool, for function calling."""
|
||||||
|
|
||||||
handler: Callable[..., Awaitable[Any]] | None = None
|
handler: (
|
||||||
|
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
"""a callable that implements the tool's functionality. It should be an async function."""
|
"""a callable that implements the tool's functionality. It should be an async function."""
|
||||||
|
|
||||||
handler_module_path: str | None = None
|
handler_module_path: str | None = None
|
||||||
|
|||||||
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
|
|
||||||
async def call_local_llm_tool(
|
async def call_local_llm_tool(
|
||||||
context: ContextWrapper[AstrAgentContext],
|
context: ContextWrapper[AstrAgentContext],
|
||||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
handler: T.Callable[
|
||||||
|
...,
|
||||||
|
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
|
||||||
|
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
|
||||||
|
],
|
||||||
method_name: str,
|
method_name: str,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
|
|||||||
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
config_path: str
|
||||||
|
default_config: dict
|
||||||
|
schema: dict | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config_path: str = ASTRBOT_CONFIG_PATH,
|
config_path: str = ASTRBOT_CONFIG_PATH,
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class AstrBotCoreLifecycle:
|
|||||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||||
extra_tasks = []
|
extra_tasks = []
|
||||||
for task in self.star_context._register_tasks:
|
for task in self.star_context._register_tasks:
|
||||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
|
||||||
|
|
||||||
tasks_ = [event_bus_task, *extra_tasks]
|
tasks_ = [event_bus_task, *extra_tasks]
|
||||||
for task in tasks_:
|
for task in tasks_:
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
Attachment,
|
Attachment,
|
||||||
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
|
|||||||
echo=False,
|
echo=False,
|
||||||
future=True,
|
future=True,
|
||||||
)
|
)
|
||||||
self.AsyncSessionLocal = sessionmaker(
|
self.AsyncSessionLocal = async_sessionmaker(
|
||||||
self.engine,
|
self.engine,
|
||||||
class_=AsyncSession,
|
class_=AsyncSession,
|
||||||
expire_on_commit=False,
|
expire_on_commit=False,
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ async def migration_conversation_table(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
if ":" not in conv.user_id:
|
if ":" not in conv.user_id:
|
||||||
continue
|
continue
|
||||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||||
@@ -207,6 +208,7 @@ async def migration_webchat_data(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
if ":" in conv.user_id:
|
if ":" in conv.user_id:
|
||||||
continue
|
continue
|
||||||
platform_id = "webchat"
|
platform_id = "webchat"
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class SQLiteDatabase:
|
|||||||
conn.text_factory = str
|
conn.text_factory = str
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def _exec_sql(self, sql: str, params: tuple = None):
|
def _exec_sql(self, sql: str, params: tuple | None = None):
|
||||||
conn = self.conn
|
conn = self.conn
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
@@ -224,9 +224,11 @@ class SQLiteDatabase:
|
|||||||
|
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
return Stats(platform, [], [])
|
return Stats(platform)
|
||||||
|
|
||||||
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
def get_conversation_by_user_id(
|
||||||
|
self, user_id: str, cid: str
|
||||||
|
) -> Conversation | None:
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
@@ -258,7 +260,7 @@ class SQLiteDatabase:
|
|||||||
(user_id, cid, history, updated_at, created_at),
|
(user_id, cid, history, updated_at, created_at),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conversations(self, user_id: str) -> tuple:
|
def get_conversations(self, user_id: str) -> list[Conversation]:
|
||||||
try:
|
try:
|
||||||
c = self.conn.cursor()
|
c = self.conn.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except sqlite3.ProgrammingError:
|
||||||
|
|||||||
+16
-15
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
|
|||||||
Note: In astrbot v4, we moved `platform` table to here.
|
Note: In astrbot v4, we moved `platform` table to here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "platform_stats" # type: ignore
|
__tablename__: str = "platform_stats"
|
||||||
|
|
||||||
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||||
timestamp: datetime = Field(nullable=False)
|
timestamp: datetime = Field(nullable=False)
|
||||||
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
|
|||||||
|
|
||||||
|
|
||||||
class ConversationV2(SQLModel, table=True):
|
class ConversationV2(SQLModel, table=True):
|
||||||
__tablename__ = "conversations" # type: ignore
|
__tablename__: str = "conversations"
|
||||||
|
|
||||||
inner_conversation_id: int = Field(
|
inner_conversation_id: int | None = Field(
|
||||||
|
default=None,
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
sa_column_kwargs={"autoincrement": True},
|
sa_column_kwargs={"autoincrement": True},
|
||||||
)
|
)
|
||||||
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
|
|||||||
It can be used to customize the behavior of LLMs.
|
It can be used to customize the behavior of LLMs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "personas" # type: ignore
|
__tablename__: str = "personas"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
|
|||||||
class Preference(SQLModel, table=True):
|
class Preference(SQLModel, table=True):
|
||||||
"""This class represents preferences for bots."""
|
"""This class represents preferences for bots."""
|
||||||
|
|
||||||
__tablename__ = "preferences" # type: ignore
|
__tablename__: str = "preferences"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|||||||
or platform-specific messages.
|
or platform-specific messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "platform_message_history" # type: ignore
|
__tablename__: str = "platform_message_history"
|
||||||
|
|
||||||
id: int | None = Field(
|
id: int | None = Field(
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
|
|||||||
Each session can have multiple conversations (对话) associated with it.
|
Each session can have multiple conversations (对话) associated with it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "platform_sessions" # type: ignore
|
__tablename__: str = "platform_sessions"
|
||||||
|
|
||||||
inner_id: int | None = Field(
|
inner_id: int | None = Field(
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
|
|||||||
Attachments can be images, files, or other media types.
|
Attachments can be images, files, or other media types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "attachments" # type: ignore
|
__tablename__: str = "attachments"
|
||||||
|
|
||||||
inner_attachment_id: int | None = Field(
|
inner_attachment_id: int | None = Field(
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
@@ -261,17 +262,17 @@ class Personality(TypedDict):
|
|||||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt: str = ""
|
prompt: str
|
||||||
name: str = ""
|
name: str
|
||||||
begin_dialogs: list[str] = []
|
begin_dialogs: list[str]
|
||||||
mood_imitation_dialogs: list[str] = []
|
mood_imitation_dialogs: list[str]
|
||||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||||
tools: list[str] | None = None
|
tools: list[str] | None
|
||||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||||
|
|
||||||
# cache
|
# cache
|
||||||
_begin_dialogs_processed: list[dict] = []
|
_begin_dialogs_processed: list[dict]
|
||||||
_mood_imitation_dialogs_processed: str = ""
|
_mood_imitation_dialogs_processed: str
|
||||||
|
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import threading
|
|||||||
import typing as T
|
import typing as T
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import CursorResult
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||||
|
|
||||||
@@ -489,7 +490,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
async with self.get_db() as session:
|
async with self.get_db() as session:
|
||||||
session: AsyncSession
|
session: AsyncSession
|
||||||
query = select(Attachment).where(
|
query = select(Attachment).where(
|
||||||
Attachment.attachment_id.in_(attachment_ids)
|
col(Attachment.attachment_id).in_(attachment_ids)
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
@@ -505,7 +506,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
query = delete(Attachment).where(
|
query = delete(Attachment).where(
|
||||||
col(Attachment.attachment_id) == attachment_id
|
col(Attachment.attachment_id) == attachment_id
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = T.cast(CursorResult, await session.execute(query))
|
||||||
return result.rowcount > 0
|
return result.rowcount > 0
|
||||||
|
|
||||||
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
async def delete_attachments(self, attachment_ids: list[str]) -> int:
|
||||||
@@ -521,7 +522,7 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
query = delete(Attachment).where(
|
query = delete(Attachment).where(
|
||||||
col(Attachment.attachment_id).in_(attachment_ids)
|
col(Attachment.attachment_id).in_(attachment_ids)
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = T.cast(CursorResult, await session.execute(query))
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|
||||||
async def insert_persona(
|
async def insert_persona(
|
||||||
|
|||||||
@@ -90,4 +90,6 @@ class EmbeddingStorage:
|
|||||||
path (str): 保存索引的路径
|
path (str): 保存索引的路径
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
return
|
||||||
faiss.write_index(self.index, self.path)
|
faiss.write_index(self.index, self.path)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class EventBus:
|
|||||||
self,
|
self,
|
||||||
event_queue: Queue,
|
event_queue: Queue,
|
||||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
astrbot_config_mgr: AstrBotConfigManager,
|
||||||
):
|
):
|
||||||
self.event_queue = event_queue # 事件队列
|
self.event_queue = event_queue # 事件队列
|
||||||
# abconf uuid -> scheduler
|
# abconf uuid -> scheduler
|
||||||
@@ -40,6 +40,11 @@ class EventBus:
|
|||||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||||
self._print_event(event, conf_info["name"])
|
self._print_event(event, conf_info["name"])
|
||||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||||
|
if not scheduler:
|
||||||
|
logger.error(
|
||||||
|
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
|
||||||
|
)
|
||||||
|
continue
|
||||||
asyncio.create_task(scheduler.execute(event))
|
asyncio.create_task(scheduler.execute(event))
|
||||||
|
|
||||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||||
|
|||||||
@@ -166,7 +166,11 @@ class RetrievalManager:
|
|||||||
# 5. Rerank
|
# 5. Rerank
|
||||||
first_rerank = None
|
first_rerank = None
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
vec_db = kb_options[kb_id]["vec_db"]
|
||||||
|
if not isinstance(vec_db, FaissVecDB):
|
||||||
|
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
|
||||||
|
continue
|
||||||
|
|
||||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||||
if (
|
if (
|
||||||
vec_db
|
vec_db
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
|
|||||||
class BaseMessageComponent(BaseModel):
|
class BaseMessageComponent(BaseModel):
|
||||||
type: ComponentType
|
type: ComponentType
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def toDict(self):
|
def toDict(self):
|
||||||
data = {}
|
data = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
|
|||||||
id: int | None = 0 # 忽略
|
id: int | None = 0 # 忽略
|
||||||
name: str | None = "" # qq昵称
|
name: str | None = "" # qq昵称
|
||||||
uin: str | None = "0" # qq号
|
uin: str | None = "0" # qq号
|
||||||
content: list[BaseMessageComponent] | None = []
|
content: list[BaseMessageComponent] = []
|
||||||
seq: str | list | None = "" # 忽略
|
seq: str | list | None = "" # 忽略
|
||||||
time: int | None = 0 # 忽略
|
time: int | None = 0 # 忽略
|
||||||
|
|
||||||
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
|
|||||||
ret["messages"].append(d)
|
ret["messages"].append(d)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def to_dict(self):
|
async def to_dict(self) -> dict:
|
||||||
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
|
||||||
ret = {"messages": []}
|
ret = {"messages": []}
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
@@ -714,12 +717,15 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
if self.url:
|
if self.url:
|
||||||
await self._download_file()
|
await self._download_file()
|
||||||
return os.path.abspath(self.file_)
|
if self.file_:
|
||||||
|
return os.path.abspath(self.file_)
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _download_file(self):
|
async def _download_file(self):
|
||||||
"""下载文件"""
|
"""下载文件"""
|
||||||
|
if not self.url:
|
||||||
|
raise ValueError("Download failed: No URL provided in File component.")
|
||||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
os.makedirs(download_dir, exist_ok=True)
|
os.makedirs(download_dir, exist_ok=True)
|
||||||
if self.name:
|
if self.name:
|
||||||
|
|||||||
@@ -98,8 +98,8 @@ class PersonaManager:
|
|||||||
self,
|
self,
|
||||||
persona_id: str,
|
persona_id: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
begin_dialogs: list[str] = None,
|
begin_dialogs: list[str] | None = None,
|
||||||
tools: list[str] = None,
|
tools: list[str] | None = None,
|
||||||
) -> Persona:
|
) -> Persona:
|
||||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||||
if await self.db.get_persona_by_id(persona_id):
|
if await self.db.get_persona_by_id(persona_id):
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
|
|||||||
self,
|
self,
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
check_text: str | None = None,
|
check_text: str | None = None,
|
||||||
) -> None | AsyncGenerator[None, None]:
|
) -> AsyncGenerator[None, None]:
|
||||||
"""检查内容安全"""
|
"""检查内容安全"""
|
||||||
text = check_text if check_text else event.get_message_str()
|
text = check_text if check_text else event.get_message_str()
|
||||||
ok, info = self.strategy_selector.check(text)
|
ok, info = self.strategy_selector.check(text)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
|||||||
|
|
||||||
async def call_handler(
|
async def call_handler(
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
handler: T.Callable[..., T.Awaitable[T.Any]],
|
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T.AsyncGenerator[T.Any, None]:
|
) -> T.AsyncGenerator[T.Any, None]:
|
||||||
@@ -91,6 +91,7 @@ async def call_event_hook(
|
|||||||
)
|
)
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
|
assert inspect.iscoroutinefunction(handler.handler)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class StarRequestSubStage(Stage):
|
|||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
) -> AsyncGenerator[None, None]:
|
) -> AsyncGenerator[Any, None]:
|
||||||
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
||||||
"activated_handlers",
|
"activated_handlers",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
|
|||||||
):
|
):
|
||||||
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
|
||||||
if (
|
if (
|
||||||
event.get_result() and not event.get_result().is_stopped()
|
event.get_result() and not event.is_stopped()
|
||||||
) or not event.get_result():
|
) or not event.get_result():
|
||||||
async for _ in self.agent_sub_stage.process(event):
|
async for _ in self.agent_sub_stage.process(event):
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -117,7 +117,9 @@ class RespondStage(Stage):
|
|||||||
if not self.enable_seg:
|
if not self.enable_seg:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.only_llm_result and not event.get_result().is_llm_result():
|
if (result := event.get_result()) is None:
|
||||||
|
return False
|
||||||
|
if self.only_llm_result and result.is_llm_result():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if event.get_platform_name() in [
|
if event.get_platform_name() in [
|
||||||
@@ -185,7 +187,7 @@ class RespondStage(Stage):
|
|||||||
if isinstance(component, Comp.File) and component.file:
|
if isinstance(component, Comp.File) and component.file:
|
||||||
# 支持 File 消息段的路径映射。
|
# 支持 File 消息段的路径映射。
|
||||||
component.file = path_Mapping(mappings, component.file)
|
component.file = path_Mapping(mappings, component.file)
|
||||||
event.get_result().chain[idx] = component
|
result.chain[idx] = component
|
||||||
|
|
||||||
# 检查消息链是否为空
|
# 检查消息链是否为空
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from astrbot.core import file_token_service, html_renderer, logger
|
from astrbot.core import file_token_service, html_renderer, logger
|
||||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||||
from astrbot.core.message.message_event_result import ResultContentType
|
from astrbot.core.message.message_event_result import ResultContentType
|
||||||
|
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.platform.message_type import MessageType
|
from astrbot.core.platform.message_type import MessageType
|
||||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||||
@@ -130,11 +131,13 @@ class ResultDecorateStage(Stage):
|
|||||||
for comp in result.chain:
|
for comp in result.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
text += comp.text
|
text += comp.text
|
||||||
async for _ in self.content_safe_check_stage.process(
|
|
||||||
event,
|
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
|
||||||
check_text=text,
|
async for _ in self.content_safe_check_stage.process(
|
||||||
):
|
event,
|
||||||
yield
|
check_text=text,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
# 发送消息前事件钩子
|
# 发送消息前事件钩子
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
@@ -151,7 +154,8 @@ class ResultDecorateStage(Stage):
|
|||||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
|
||||||
)
|
)
|
||||||
await handler.handler(event)
|
await handler.handler(event)
|
||||||
if event.get_result() is None or not event.get_result().chain:
|
|
||||||
|
if (result := event.get_result()) is None or not result.chain:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
|
|||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.platform import AstrMessageEvent
|
from astrbot.core.platform import AstrMessageEvent
|
||||||
|
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
|
||||||
|
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
|
||||||
|
WecomAIBotMessageEvent,
|
||||||
|
)
|
||||||
|
|
||||||
from . import STAGES_ORDER
|
from . import STAGES_ORDER
|
||||||
from .context import PipelineContext
|
from .context import PipelineContext
|
||||||
@@ -78,7 +82,7 @@ class PipelineScheduler:
|
|||||||
await self._process_stages(event)
|
await self._process_stages(event)
|
||||||
|
|
||||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||||
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
|
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
|
||||||
await event.send(None)
|
await event.send(None)
|
||||||
|
|
||||||
logger.debug("pipeline 执行完毕。")
|
logger.debug("pipeline 执行完毕。")
|
||||||
|
|||||||
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
|
|
||||||
def get_sender_name(self) -> str:
|
def get_sender_name(self) -> str:
|
||||||
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
"""获取消息发送者的名称。(可能会返回空字符串)"""
|
||||||
return self.message_obj.sender.nickname
|
if isinstance(self.message_obj.sender.nickname, str):
|
||||||
|
return self.message_obj.sender.nickname
|
||||||
|
return ""
|
||||||
|
|
||||||
def set_extra(self, key, value):
|
def set_extra(self, key, value):
|
||||||
"""设置额外的信息。"""
|
"""设置额外的信息。"""
|
||||||
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
self.call_llm = call_llm
|
self.call_llm = call_llm
|
||||||
|
|
||||||
def get_result(self) -> MessageEventResult:
|
def get_result(self) -> MessageEventResult | None:
|
||||||
"""获取消息事件的结果。"""
|
"""获取消息事件的结果。"""
|
||||||
return self._result
|
return self._result
|
||||||
|
|
||||||
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
func_tool_manager=None,
|
func_tool_manager=None,
|
||||||
session_id: str = None,
|
session_id: str = "",
|
||||||
image_urls: list[str] | None = None,
|
image_urls: list[str] | None = None,
|
||||||
contexts: list | None = None,
|
contexts: list | None = None,
|
||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class AstrBotMessage:
|
|||||||
self_id: str # 机器人的识别id
|
self_id: str # 机器人的识别id
|
||||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||||
message_id: str # 消息id
|
message_id: str # 消息id
|
||||||
group: Group # 群组
|
group: Group | None # 群组
|
||||||
sender: MessageMember # 发送者
|
sender: MessageMember # 发送者
|
||||||
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||||
message_str: str # 最直观的纯文本消息字符串
|
message_str: str # 最直观的纯文本消息字符串
|
||||||
@@ -78,7 +78,7 @@ class AstrBotMessage:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
@group_id.setter
|
@group_id.setter
|
||||||
def group_id(self, value: str):
|
def group_id(self, value: str | None):
|
||||||
"""设置 group_id"""
|
"""设置 group_id"""
|
||||||
if value:
|
if value:
|
||||||
if self.group:
|
if self.group:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Coroutine
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -100,7 +100,7 @@ class Platform(abc.ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self) -> Awaitable[Any]:
|
def run(self) -> Coroutine[Any, Any, None]:
|
||||||
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
"""得到一个平台的运行实例,需要返回一个协程对象。"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class Platform(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
session: MessageSesion,
|
session: MessageSesion,
|
||||||
message_chain: MessageChain,
|
message_chain: MessageChain,
|
||||||
):
|
) -> None:
|
||||||
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
|
||||||
|
|
||||||
异步方法。
|
异步方法。
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ class PlatformMetadata:
|
|||||||
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
|
||||||
description: str
|
description: str
|
||||||
"""平台的描述"""
|
"""平台的描述"""
|
||||||
id: str | None = None
|
id: str
|
||||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||||
|
|
||||||
default_config_tmpl: dict | None = None
|
default_config_tmpl: dict | None = None
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def register_platform_adapter(
|
|||||||
pm = PlatformMetadata(
|
pm = PlatformMetadata(
|
||||||
name=adapter_name,
|
name=adapter_name,
|
||||||
description=desc,
|
description=desc,
|
||||||
|
id=adapter_name,
|
||||||
default_config_tmpl=default_config_tmpl,
|
default_config_tmpl=default_config_tmpl,
|
||||||
adapter_display_name=adapter_display_name,
|
adapter_display_name=adapter_display_name,
|
||||||
logo_path=logo_path,
|
logo_path=logo_path,
|
||||||
|
|||||||
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
|||||||
bot: CQHttp,
|
bot: CQHttp,
|
||||||
event: Event | None,
|
event: Event | None,
|
||||||
is_group: bool,
|
is_group: bool,
|
||||||
session_id: str,
|
session_id: str | None,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
):
|
):
|
||||||
# session_id 必须是纯数字字符串
|
# session_id 必须是纯数字字符串
|
||||||
session_id = int(session_id) if session_id.isdigit() else None
|
session_id_int = (
|
||||||
|
int(session_id) if session_id and session_id.isdigit() else None
|
||||||
|
)
|
||||||
|
|
||||||
if is_group and isinstance(session_id, int):
|
if is_group and isinstance(session_id_int, int):
|
||||||
await bot.send_group_msg(group_id=session_id, message=messages)
|
await bot.send_group_msg(group_id=session_id_int, message=messages)
|
||||||
elif not is_group and isinstance(session_id, int):
|
elif not is_group and isinstance(session_id_int, int):
|
||||||
await bot.send_private_msg(user_id=session_id, message=messages)
|
await bot.send_private_msg(user_id=session_id_int, message=messages)
|
||||||
elif isinstance(event, Event): # 最后兜底
|
elif isinstance(event, Event): # 最后兜底
|
||||||
await bot.send(event=event, message=messages)
|
await bot.send(event=event, message=messages)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from aiocqhttp import CQHttp, Event
|
from aiocqhttp import CQHttp, Event
|
||||||
from aiocqhttp.exceptions import ActionFailed
|
from aiocqhttp.exceptions import ActionFailed
|
||||||
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
name="aiocqhttp",
|
name="aiocqhttp",
|
||||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
support_streaming_message=False,
|
support_streaming_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
|
|||||||
"""OneBot V11 请求类事件"""
|
"""OneBot V11 请求类事件"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
abm.sender = MessageMember(
|
||||||
|
user_id=str(event.user_id), nickname=str(event.user_id)
|
||||||
|
)
|
||||||
abm.type = MessageType.OTHER_MESSAGE
|
abm.type = MessageType.OTHER_MESSAGE
|
||||||
if event.get("group_id"):
|
if event.get("group_id"):
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
@param event: 事件对象
|
@param event: 事件对象
|
||||||
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
|
||||||
"""
|
"""
|
||||||
|
assert event.sender is not None
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
if event["message_type"] == "group":
|
if event["message_type"] == "group":
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
|
abm.group = Group(str(event.group_id))
|
||||||
abm.group.group_name = event.get("group_name", "N/A")
|
abm.group.group_name = event.get("group_name", "N/A")
|
||||||
elif event["message_type"] == "private":
|
elif event["message_type"] == "private":
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
|
|||||||
await self.bot.send(event, err)
|
await self.bot.send(event, err)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"回复消息失败: {e}")
|
logger.error(f"回复消息失败: {e}")
|
||||||
return None
|
raise ValueError(err)
|
||||||
|
|
||||||
# 按消息段类型类型适配
|
# 按消息段类型类型适配
|
||||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import dingtalk_stream
|
import dingtalk_stream
|
||||||
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
self.client_id = platform_config["client_id"]
|
self.client_id = platform_config["client_id"]
|
||||||
self.client_secret = platform_config["client_secret"]
|
self.client_secret = platform_config["client_secret"]
|
||||||
|
|
||||||
|
outer_self = self
|
||||||
|
|
||||||
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
|
||||||
async def process(self_, message: dingtalk_stream.CallbackMessage):
|
async def process(self, message: dingtalk_stream.CallbackMessage):
|
||||||
logger.debug(f"dingtalk: {message.data}")
|
logger.debug(f"dingtalk: {message.data}")
|
||||||
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
|
||||||
abm = await self.convert_msg(im)
|
abm = await outer_self.convert_msg(im)
|
||||||
await self.handle_msg(abm)
|
await outer_self.handle_msg(abm)
|
||||||
|
|
||||||
return AckMessage.STATUS_OK, "OK"
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
self.client,
|
self.client,
|
||||||
)
|
)
|
||||||
self.client_ = client # 用于 websockets 的 client
|
self.client_ = client # 用于 websockets 的 client
|
||||||
|
self._shutdown_event: threading.Event | None = None
|
||||||
|
|
||||||
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
def _id_to_sid(self, dingtalk_id: str | None) -> str:
|
||||||
if not dingtalk_id:
|
if not dingtalk_id:
|
||||||
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
name="dingtalk",
|
name="dingtalk",
|
||||||
description="钉钉机器人官方 API 适配器",
|
description="钉钉机器人官方 API 适配器",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
support_streaming_message=False,
|
support_streaming_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.message = []
|
abm.message = []
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
abm.timestamp = int(message.create_at / 1000)
|
abm.timestamp = int(cast(int, message.create_at) / 1000)
|
||||||
abm.type = (
|
abm.type = (
|
||||||
MessageType.GROUP_MESSAGE
|
MessageType.GROUP_MESSAGE
|
||||||
if message.conversation_type == "2"
|
if message.conversation_type == "2"
|
||||||
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
nickname=message.sender_nick,
|
nickname=message.sender_nick,
|
||||||
)
|
)
|
||||||
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
abm.self_id = self._id_to_sid(message.chatbot_user_id)
|
||||||
abm.message_id = message.message_id
|
abm.message_id = cast(str, message.message_id)
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
|
|
||||||
if abm.type == MessageType.GROUP_MESSAGE:
|
if abm.type == MessageType.GROUP_MESSAGE:
|
||||||
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
else:
|
else:
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
|
|
||||||
message_type: str = message.message_type
|
message_type: str = cast(str, message.message_type)
|
||||||
match message_type:
|
match message_type:
|
||||||
case "text":
|
case "text":
|
||||||
abm.message_str = message.text.content.strip()
|
abm.message_str = message.text.content.strip()
|
||||||
abm.message.append(Plain(abm.message_str))
|
abm.message.append(Plain(abm.message_str))
|
||||||
case "richText":
|
case "richText":
|
||||||
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
|
rtc: dingtalk_stream.RichTextContent = cast(
|
||||||
contents: list[dict] = rtc.rich_text_list
|
dingtalk_stream.RichTextContent, message.rich_text_content
|
||||||
|
)
|
||||||
|
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
|
||||||
for content in contents:
|
for content in contents:
|
||||||
plains = ""
|
plains = ""
|
||||||
if "text" in content:
|
if "text" in content:
|
||||||
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
elif "type" in content and content["type"] == "picture":
|
elif "type" in content and content["type"] == "picture":
|
||||||
f_path = await self.download_ding_file(
|
f_path = await self.download_ding_file(
|
||||||
content["downloadCode"],
|
content["downloadCode"],
|
||||||
message.robot_code,
|
cast(str, message.robot_code),
|
||||||
"jpg",
|
"jpg",
|
||||||
)
|
)
|
||||||
abm.message.append(Image.fromFileSystem(f_path))
|
abm.message.append(Image.fromFileSystem(f_path))
|
||||||
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
|
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
|
||||||
)
|
)
|
||||||
return None
|
return ""
|
||||||
resp_data = await resp.json()
|
resp_data = await resp.json()
|
||||||
download_url = resp_data["data"]["downloadUrl"]
|
download_url = resp_data["data"]["downloadUrl"]
|
||||||
await download_file(download_url, f_path)
|
await download_file(download_url, f_path)
|
||||||
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
|
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
|
||||||
)
|
)
|
||||||
return None
|
return ""
|
||||||
return (await resp.json())["data"]["accessToken"]
|
return (await resp.json())["data"]["accessToken"]
|
||||||
|
|
||||||
async def handle_msg(self, abm: AstrBotMessage):
|
async def handle_msg(self, abm: AstrBotMessage):
|
||||||
@@ -250,9 +256,11 @@ class DingtalkPlatformAdapter(Platform):
|
|||||||
def monkey_patch_close():
|
def monkey_patch_close():
|
||||||
raise KeyboardInterrupt("Graceful shutdown")
|
raise KeyboardInterrupt("Graceful shutdown")
|
||||||
|
|
||||||
self.client_.open_connection = monkey_patch_close
|
if self.client_.websocket is not None:
|
||||||
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
self.client_.open_connection = monkey_patch_close
|
||||||
self._shutdown_event.set()
|
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
|
||||||
|
if self._shutdown_event is not None:
|
||||||
|
self._shutdown_event.set()
|
||||||
|
|
||||||
def get_client(self):
|
def get_client(self):
|
||||||
return self.client
|
return self.client
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import dingtalk_stream
|
import dingtalk_stream
|
||||||
|
|
||||||
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|||||||
client.reply_markdown,
|
client.reply_markdown,
|
||||||
segment.text,
|
segment.text,
|
||||||
segment.text,
|
segment.text,
|
||||||
self.message_obj.raw_message,
|
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||||
)
|
)
|
||||||
elif isinstance(segment, Comp.Image):
|
elif isinstance(segment, Comp.Image):
|
||||||
markdown_str = ""
|
markdown_str = ""
|
||||||
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|||||||
client.reply_markdown,
|
client.reply_markdown,
|
||||||
"😄",
|
"😄",
|
||||||
markdown_str,
|
markdown_str,
|
||||||
self.message_obj.raw_message,
|
cast(
|
||||||
|
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.debug(f"send image: {ret}")
|
logger.debug(f"send image: {ret}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
|
|||||||
super().__init__(intents=intents, proxy=proxy)
|
super().__init__(intents=intents, proxy=proxy)
|
||||||
|
|
||||||
# 回调函数
|
# 回调函数
|
||||||
self.on_message_received = None
|
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
|
||||||
self.on_ready_once_callback = None
|
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
|
||||||
self._ready_once_fired = False
|
self._ready_once_fired = False
|
||||||
|
|
||||||
@override
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""当机器人成功连接并准备就绪时触发"""
|
"""当机器人成功连接并准备就绪时触发"""
|
||||||
|
if self.user is None:
|
||||||
|
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
|
||||||
logger.info("[Discord] 客户端已准备就绪。")
|
logger.info("[Discord] 客户端已准备就绪。")
|
||||||
|
|
||||||
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
|
|||||||
|
|
||||||
def _create_message_data(self, message: discord.Message) -> dict:
|
def _create_message_data(self, message: discord.Message) -> dict:
|
||||||
"""从 discord.Message 创建数据字典"""
|
"""从 discord.Message 创建数据字典"""
|
||||||
|
if self.user is None:
|
||||||
|
raise RuntimeError("Bot is not ready: self.user is None")
|
||||||
|
|
||||||
is_mentioned = self.user in message.mentions
|
is_mentioned = self.user in message.mentions
|
||||||
return {
|
return {
|
||||||
"message": message,
|
"message": message,
|
||||||
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
|
|||||||
|
|
||||||
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
|
||||||
"""从 discord.Interaction 创建数据字典"""
|
"""从 discord.Interaction 创建数据字典"""
|
||||||
|
if self.user is None:
|
||||||
|
raise RuntimeError("Bot is not ready: self.user is None")
|
||||||
|
|
||||||
|
if interaction.user is None:
|
||||||
|
raise ValueError("Interaction received without a valid user")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"interaction": interaction,
|
"interaction": interaction,
|
||||||
"bot_id": str(self.user.id),
|
"bot_id": str(self.user.id),
|
||||||
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
|
|||||||
"type": "interaction",
|
"type": "interaction",
|
||||||
}
|
}
|
||||||
|
|
||||||
@override
|
|
||||||
async def on_message(self, message: discord.Message):
|
async def on_message(self, message: discord.Message):
|
||||||
"""当接收到消息时触发"""
|
"""当接收到消息时触发"""
|
||||||
if message.author.bot:
|
if message.author.bot:
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
components: list[BaseMessageComponent] = None,
|
components: list[BaseMessageComponent] | None = None,
|
||||||
timeout: float = None,
|
timeout: float | None = None,
|
||||||
):
|
):
|
||||||
self.components = components or []
|
self.components = components or []
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.abc import Messageable
|
from discord.abc import GuildChannel, Messageable, PrivateChannel
|
||||||
from discord.channel import DMChannel
|
from discord.channel import DMChannel
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(platform_config, event_queue)
|
super().__init__(platform_config, event_queue)
|
||||||
self.settings = platform_settings
|
self.settings = platform_settings
|
||||||
self.client_self_id = None
|
self.client_self_id: str | None = None
|
||||||
self.registered_handlers = []
|
self.registered_handlers = []
|
||||||
# 指令注册相关
|
# 指令注册相关
|
||||||
self.enable_command_register = self.config.get("discord_command_register", True)
|
self.enable_command_register = self.config.get("discord_command_register", True)
|
||||||
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
message_chain: MessageChain,
|
message_chain: MessageChain,
|
||||||
):
|
):
|
||||||
"""通过会话发送消息"""
|
"""通过会话发送消息"""
|
||||||
|
if self.client.user is None:
|
||||||
|
logger.error(
|
||||||
|
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# 创建一个 message_obj 以便在 event 中使用
|
# 创建一个 message_obj 以便在 event 中使用
|
||||||
message_obj = AstrBotMessage()
|
message_obj = AstrBotMessage()
|
||||||
if "_" in session.session_id:
|
if "_" in session.session_id:
|
||||||
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
user_id=str(self.client_self_id),
|
user_id=str(self.client_self_id),
|
||||||
nickname=self.client.user.display_name,
|
nickname=self.client.user.display_name,
|
||||||
)
|
)
|
||||||
message_obj.self_id = self.client_self_id
|
message_obj.self_id = cast(str, self.client_self_id)
|
||||||
message_obj.session_id = session.session_id
|
message_obj.session_id = session.session_id
|
||||||
message_obj.message = message_chain.chain
|
message_obj.message = message_chain.chain
|
||||||
|
|
||||||
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
"discord",
|
"discord",
|
||||||
"Discord 适配器",
|
"Discord 适配器",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
default_config_tmpl=self.config,
|
default_config_tmpl=self.config,
|
||||||
support_streaming_message=False,
|
support_streaming_message=False,
|
||||||
)
|
)
|
||||||
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
|
|
||||||
def _get_message_type(
|
def _get_message_type(
|
||||||
self,
|
self,
|
||||||
channel: Messageable,
|
channel: Messageable | GuildChannel | PrivateChannel,
|
||||||
guild_id: int | None = None,
|
guild_id: int | None = None,
|
||||||
) -> MessageType:
|
) -> MessageType:
|
||||||
"""根据 channel 对象和 guild_id 判断消息类型"""
|
"""根据 channel 对象和 guild_id 判断消息类型"""
|
||||||
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
return MessageType.FRIEND_MESSAGE
|
return MessageType.FRIEND_MESSAGE
|
||||||
return MessageType.GROUP_MESSAGE
|
return MessageType.GROUP_MESSAGE
|
||||||
|
|
||||||
def _get_channel_id(self, channel: Messageable) -> str:
|
def _get_channel_id(
|
||||||
|
self, channel: Messageable | GuildChannel | PrivateChannel
|
||||||
|
) -> str:
|
||||||
"""根据 channel 对象获取ID"""
|
"""根据 channel 对象获取ID"""
|
||||||
return str(getattr(channel, "id", None))
|
return str(getattr(channel, "id", None))
|
||||||
|
|
||||||
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
|
||||||
"""将普通消息转换为 AstrBotMessage"""
|
"""将普通消息转换为 AstrBotMessage"""
|
||||||
message: discord.Message = data["message"]
|
message = data["message"]
|
||||||
|
|
||||||
content = message.content
|
content = message.content
|
||||||
|
|
||||||
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
abm.message = message_chain
|
abm.message = message_chain
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
abm.self_id = self.client_self_id
|
abm.self_id = cast(str, self.client_self_id)
|
||||||
abm.session_id = str(message.channel.id)
|
abm.session_id = str(message.channel.id)
|
||||||
abm.message_id = str(message.id)
|
abm.message_id = str(message.id)
|
||||||
return abm
|
return abm
|
||||||
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
interaction_followup_webhook=followup_webhook,
|
interaction_followup_webhook=followup_webhook,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.client.user is None:
|
||||||
|
logger.error(
|
||||||
|
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# 检查是否为斜杠指令
|
# 检查是否为斜杠指令
|
||||||
is_slash_command = message_event.interaction_followup_webhook is not None
|
is_slash_command = message_event.interaction_followup_webhook is not None
|
||||||
|
|
||||||
|
# 1. 优先处理斜杠指令
|
||||||
|
if is_slash_command:
|
||||||
|
message_event.is_wake = True
|
||||||
|
message_event.is_at_or_wake_command = True
|
||||||
|
self.commit_event(message_event)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 处理普通消息(提及检测)
|
||||||
|
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
|
||||||
|
raw_message = message.raw_message
|
||||||
|
if not isinstance(raw_message, discord.Message):
|
||||||
|
logger.warning(
|
||||||
|
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
# 检查是否被@(User Mention 或 Bot 拥有的 Role Mention)
|
||||||
is_mention = False
|
is_mention = False
|
||||||
|
|
||||||
# User Mention
|
# User Mention
|
||||||
if (
|
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
|
||||||
self.client
|
if self.client.user in raw_message.mentions:
|
||||||
and self.client.user
|
is_mention = True
|
||||||
and hasattr(message.raw_message, "mentions")
|
|
||||||
):
|
|
||||||
if self.client.user in message.raw_message.mentions:
|
|
||||||
is_mention = True
|
|
||||||
# Role Mention(Bot 拥有的角色被提及)
|
# Role Mention(Bot 拥有的角色被提及)
|
||||||
if not is_mention and hasattr(message.raw_message, "role_mentions"):
|
if not is_mention and raw_message.role_mentions:
|
||||||
bot_member = None
|
bot_member = None
|
||||||
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
|
if raw_message.guild:
|
||||||
try:
|
try:
|
||||||
bot_member = message.raw_message.guild.get_member(
|
bot_member = raw_message.guild.get_member(
|
||||||
self.client.user.id,
|
self.client.user.id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
bot_member = None
|
bot_member = None
|
||||||
if bot_member and hasattr(bot_member, "roles"):
|
if bot_member and hasattr(bot_member, "roles"):
|
||||||
bot_roles = set(bot_member.roles)
|
bot_roles = set(bot_member.roles)
|
||||||
mentioned_roles = set(message.raw_message.role_mentions)
|
mentioned_roles = set(raw_message.role_mentions)
|
||||||
if (
|
if (
|
||||||
bot_roles
|
bot_roles
|
||||||
and mentioned_roles
|
and mentioned_roles
|
||||||
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
):
|
):
|
||||||
is_mention = True
|
is_mention = True
|
||||||
|
|
||||||
# 如果是斜杠指令或被@的消息,设置为唤醒状态
|
# 如果是被@的消息,设置为唤醒状态
|
||||||
if is_slash_command or is_mention:
|
if is_mention:
|
||||||
message_event.is_wake = True
|
message_event.is_wake = True
|
||||||
message_event.is_at_or_wake_command = True
|
message_event.is_at_or_wake_command = True
|
||||||
|
|
||||||
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
abm.message = [Plain(text=message_str_for_filter)]
|
abm.message = [Plain(text=message_str_for_filter)]
|
||||||
abm.raw_message = ctx.interaction
|
abm.raw_message = ctx.interaction
|
||||||
abm.self_id = self.client_self_id
|
abm.self_id = cast(str, self.client_self_id)
|
||||||
abm.session_id = str(ctx.channel_id)
|
abm.session_id = str(ctx.channel_id)
|
||||||
abm.message_id = str(ctx.interaction.id)
|
abm.message_id = str(ctx.interaction.id)
|
||||||
|
|
||||||
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
|
|||||||
def _extract_command_info(
|
def _extract_command_info(
|
||||||
event_filter: Any,
|
event_filter: Any,
|
||||||
handler_metadata: StarHandlerMetadata,
|
handler_metadata: StarHandlerMetadata,
|
||||||
) -> tuple[str, str, CommandFilter] | None:
|
) -> tuple[str, str, CommandFilter | None] | None:
|
||||||
"""从事件过滤器中提取指令信息"""
|
"""从事件过滤器中提取指令信息"""
|
||||||
cmd_name = None
|
cmd_name = None
|
||||||
# is_group = False
|
# is_group = False
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import binascii
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
from discord.types.interactions import ComponentInteractionData
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
channel = await self._get_channel()
|
channel = await self._get_channel()
|
||||||
if not channel:
|
if not channel:
|
||||||
return
|
return
|
||||||
|
if not isinstance(channel, discord.abc.Messageable):
|
||||||
|
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
|
||||||
|
return
|
||||||
await channel.send(**kwargs)
|
await channel.send(**kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
await self.send(buffer)
|
await self.send(buffer)
|
||||||
return await super().send_streaming(generator, use_fallback)
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
async def _get_channel(self) -> discord.abc.Messageable | None:
|
async def _get_channel(
|
||||||
|
self,
|
||||||
|
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
|
||||||
"""获取当前事件对应的频道对象"""
|
"""获取当前事件对应的频道对象"""
|
||||||
try:
|
try:
|
||||||
channel_id = int(self.session_id)
|
channel_id = int(self.session_id)
|
||||||
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
async def _parse_to_discord(
|
async def _parse_to_discord(
|
||||||
self,
|
self,
|
||||||
message: MessageChain,
|
message: MessageChain,
|
||||||
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
|
) -> tuple[
|
||||||
|
str,
|
||||||
|
list[discord.File],
|
||||||
|
discord.ui.View | None,
|
||||||
|
list[discord.Embed],
|
||||||
|
str | int | None,
|
||||||
|
]:
|
||||||
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
"""将 MessageChain 解析为 Discord 发送所需的内容"""
|
||||||
content_parts = []
|
content_parts = []
|
||||||
files = []
|
files = []
|
||||||
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
self.message_obj.raw_message,
|
self.message_obj.raw_message,
|
||||||
"add_reaction",
|
"add_reaction",
|
||||||
):
|
):
|
||||||
await self.message_obj.raw_message.add_reaction(emoji)
|
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
|
||||||
|
emoji
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Discord] 添加反应失败: {e}")
|
logger.error(f"[Discord] 添加反应失败: {e}")
|
||||||
|
|
||||||
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
return (
|
return (
|
||||||
hasattr(self.message_obj, "raw_message")
|
hasattr(self.message_obj, "raw_message")
|
||||||
and hasattr(self.message_obj.raw_message, "type")
|
and hasattr(self.message_obj.raw_message, "type")
|
||||||
and self.message_obj.raw_message.type
|
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||||
== discord.InteractionType.application_command
|
== discord.InteractionType.application_command
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
return (
|
return (
|
||||||
hasattr(self.message_obj, "raw_message")
|
hasattr(self.message_obj, "raw_message")
|
||||||
and hasattr(self.message_obj.raw_message, "type")
|
and hasattr(self.message_obj.raw_message, "type")
|
||||||
and self.message_obj.raw_message.type == discord.InteractionType.component
|
and cast(discord.Interaction, self.message_obj.raw_message).type
|
||||||
|
== discord.InteractionType.component
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_interaction_custom_id(self) -> str:
|
def get_interaction_custom_id(self) -> str:
|
||||||
"""获取交互组件的custom_id"""
|
"""获取交互组件的custom_id"""
|
||||||
if self.is_button_interaction():
|
if self.is_button_interaction():
|
||||||
try:
|
try:
|
||||||
return self.message_obj.raw_message.data.get("custom_id", "")
|
return cast(
|
||||||
|
ComponentInteractionData,
|
||||||
|
cast(discord.Interaction, self.message_obj.raw_message).data,
|
||||||
|
).get("custom_id", "")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return ""
|
return ""
|
||||||
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
):
|
):
|
||||||
return any(
|
return any(
|
||||||
mention.id == int(self.message_obj.self_id)
|
mention.id == int(self.message_obj.self_id)
|
||||||
for mention in self.message_obj.raw_message.mentions
|
for mention in cast(
|
||||||
|
discord.Message, self.message_obj.raw_message
|
||||||
|
).mentions
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
self.message_obj.raw_message,
|
self.message_obj.raw_message,
|
||||||
"clean_content",
|
"clean_content",
|
||||||
):
|
):
|
||||||
return self.message_obj.raw_message.clean_content
|
return cast(discord.Message, self.message_obj.raw_message).clean_content
|
||||||
return self.message_str
|
return self.message_str
|
||||||
|
|||||||
@@ -3,9 +3,14 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
from lark_oapi.api.im.v1 import *
|
from lark_oapi.api.im.v1 import (
|
||||||
|
CreateMessageRequest,
|
||||||
|
CreateMessageRequestBody,
|
||||||
|
GetMessageResourceRequest,
|
||||||
|
)
|
||||||
|
|
||||||
import astrbot.api.message_components as Comp
|
import astrbot.api.message_components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
@@ -74,6 +79,10 @@ class LarkPlatformAdapter(Platform):
|
|||||||
session: MessageSesion,
|
session: MessageSesion,
|
||||||
message_chain: MessageChain,
|
message_chain: MessageChain,
|
||||||
):
|
):
|
||||||
|
if self.lark_api.im is None:
|
||||||
|
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||||
|
return
|
||||||
|
|
||||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||||
wrapped = {
|
wrapped = {
|
||||||
"zh_cn": {
|
"zh_cn": {
|
||||||
@@ -114,14 +123,21 @@ class LarkPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
name="lark",
|
name="lark",
|
||||||
description="飞书机器人官方 API 适配器",
|
description="飞书机器人官方 API 适配器",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
support_streaming_message=False,
|
support_streaming_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||||
|
if event.event is None:
|
||||||
|
logger.debug("[Lark] 收到空事件(event.event is None)")
|
||||||
|
return
|
||||||
message = event.event.message
|
message = event.event.message
|
||||||
|
if message is None:
|
||||||
|
logger.debug("[Lark] 事件中没有消息体(message is None)")
|
||||||
|
return
|
||||||
|
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.timestamp = int(message.create_time) / 1000
|
abm.timestamp = cast(int, message.create_time) // 1000
|
||||||
abm.message = []
|
abm.message = []
|
||||||
abm.type = (
|
abm.type = (
|
||||||
MessageType.GROUP_MESSAGE
|
MessageType.GROUP_MESSAGE
|
||||||
@@ -136,14 +152,28 @@ class LarkPlatformAdapter(Platform):
|
|||||||
at_list = {}
|
at_list = {}
|
||||||
if message.mentions:
|
if message.mentions:
|
||||||
for m in message.mentions:
|
for m in message.mentions:
|
||||||
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
|
if m.id is None:
|
||||||
if m.name == self.bot_name:
|
continue
|
||||||
abm.self_id = m.id.open_id
|
# 飞书 open_id 可能是 None,这里做个防护
|
||||||
|
open_id = m.id.open_id if m.id.open_id else ""
|
||||||
|
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
|
||||||
|
|
||||||
content_json_b = json.loads(message.content)
|
if m.name == self.bot_name:
|
||||||
|
if m.id.open_id is not None:
|
||||||
|
abm.self_id = m.id.open_id
|
||||||
|
|
||||||
|
if message.content is None:
|
||||||
|
logger.warning("[Lark] 消息内容为空")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_json_b = json.loads(message.content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
|
||||||
|
return
|
||||||
|
|
||||||
if message.message_type == "text":
|
if message.message_type == "text":
|
||||||
message_str_raw = content_json_b["text"] # 带有 @ 的消息
|
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
|
||||||
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
|
||||||
# at_users = re.findall(at_pattern, message_str_raw)
|
# at_users = re.findall(at_pattern, message_str_raw)
|
||||||
# 拆分文本,去掉AT符号部分
|
# 拆分文本,去掉AT符号部分
|
||||||
@@ -168,27 +198,47 @@ class LarkPlatformAdapter(Platform):
|
|||||||
content_json_b = _ls
|
content_json_b = _ls
|
||||||
elif message.message_type == "image":
|
elif message.message_type == "image":
|
||||||
content_json_b = [
|
content_json_b = [
|
||||||
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
|
{
|
||||||
|
"tag": "img",
|
||||||
|
"image_key": content_json_b.get("image_key"),
|
||||||
|
"style": [],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
if message.message_type in ("post", "image"):
|
if message.message_type in ("post", "image"):
|
||||||
for comp in content_json_b:
|
for comp in content_json_b:
|
||||||
if comp["tag"] == "at":
|
if comp.get("tag") == "at":
|
||||||
abm.message.append(at_list[comp["user_id"]])
|
user_id = comp.get("user_id")
|
||||||
elif comp["tag"] == "text" and comp["text"].strip():
|
if user_id in at_list:
|
||||||
|
abm.message.append(at_list[user_id])
|
||||||
|
elif comp.get("tag") == "text" and comp.get("text", "").strip():
|
||||||
abm.message.append(Comp.Plain(comp["text"].strip()))
|
abm.message.append(Comp.Plain(comp["text"].strip()))
|
||||||
elif comp["tag"] == "img":
|
elif comp.get("tag") == "img":
|
||||||
image_key = comp["image_key"]
|
image_key = comp.get("image_key")
|
||||||
|
if not image_key:
|
||||||
|
continue
|
||||||
|
|
||||||
request = (
|
request = (
|
||||||
GetMessageResourceRequest.builder()
|
GetMessageResourceRequest.builder()
|
||||||
.message_id(message.message_id)
|
.message_id(cast(str, message.message_id))
|
||||||
.file_key(image_key)
|
.file_key(image_key)
|
||||||
.type("image")
|
.type("image")
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.lark_api.im is None:
|
||||||
|
logger.error("[Lark] API Client im 模块未初始化")
|
||||||
|
continue
|
||||||
|
|
||||||
response = await self.lark_api.im.v1.message_resource.aget(request)
|
response = await self.lark_api.im.v1.message_resource.aget(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(f"无法下载飞书图片: {image_key}")
|
logger.error(f"无法下载飞书图片: {image_key}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.file is None:
|
||||||
|
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
|
||||||
|
continue
|
||||||
|
|
||||||
image_bytes = response.file.read()
|
image_bytes = response.file.read()
|
||||||
image_base64 = base64.b64encode(image_bytes).decode()
|
image_base64 = base64.b64encode(image_bytes).decode()
|
||||||
abm.message.append(Comp.Image.fromBase64(image_base64))
|
abm.message.append(Comp.Image.fromBase64(image_base64))
|
||||||
@@ -196,6 +246,19 @@ class LarkPlatformAdapter(Platform):
|
|||||||
for comp in abm.message:
|
for comp in abm.message:
|
||||||
if isinstance(comp, Comp.Plain):
|
if isinstance(comp, Comp.Plain):
|
||||||
abm.message_str += comp.text
|
abm.message_str += comp.text
|
||||||
|
|
||||||
|
if message.message_id is None:
|
||||||
|
logger.error("[Lark] 消息缺少 message_id")
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
event.event.sender is None
|
||||||
|
or event.event.sender.sender_id is None
|
||||||
|
or event.event.sender.sender_id.open_id is None
|
||||||
|
):
|
||||||
|
logger.error("[Lark] 消息发送者信息不完整")
|
||||||
|
return
|
||||||
|
|
||||||
abm.message_id = message.message_id
|
abm.message_id = message.message_id
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
@@ -235,5 +298,5 @@ class LarkPlatformAdapter(Platform):
|
|||||||
await self.client._disconnect()
|
await self.client._disconnect()
|
||||||
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
logger.info("飞书(Lark) 适配器已被优雅地关闭")
|
||||||
|
|
||||||
def get_client(self) -> lark.Client:
|
def get_client(self) -> lark.ws.Client:
|
||||||
return self.client
|
return self.client
|
||||||
|
|||||||
@@ -5,7 +5,15 @@ import uuid
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import lark_oapi as lark
|
import lark_oapi as lark
|
||||||
from lark_oapi.api.im.v1 import *
|
from lark_oapi.api.im.v1 import (
|
||||||
|
CreateImageRequest,
|
||||||
|
CreateImageRequestBody,
|
||||||
|
CreateMessageReactionRequest,
|
||||||
|
CreateMessageReactionRequestBody,
|
||||||
|
Emoji,
|
||||||
|
ReplyMessageRequest,
|
||||||
|
ReplyMessageRequestBody,
|
||||||
|
)
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
file_path = comp.file.replace("file:///", "")
|
file_path = comp.file.replace("file:///", "")
|
||||||
elif comp.file and comp.file.startswith("http"):
|
elif comp.file and comp.file.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(comp.file)
|
image_file_path = await download_image_by_url(comp.file)
|
||||||
file_path = image_file_path
|
file_path = image_file_path if image_file_path else ""
|
||||||
elif comp.file and comp.file.startswith("base64://"):
|
elif comp.file and comp.file.startswith("base64://"):
|
||||||
base64_str = comp.file.removeprefix("base64://")
|
base64_str = comp.file.removeprefix("base64://")
|
||||||
image_data = base64.b64decode(base64_str)
|
image_data = base64.b64decode(base64_str)
|
||||||
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
with open(file_path, "wb") as f:
|
with open(file_path, "wb") as f:
|
||||||
f.write(BytesIO(image_data).getvalue())
|
f.write(BytesIO(image_data).getvalue())
|
||||||
else:
|
else:
|
||||||
file_path = comp.file
|
file_path = comp.file if comp.file else ""
|
||||||
|
|
||||||
if image_file is None:
|
if image_file is None:
|
||||||
image_file = open(file_path, "rb")
|
if not file_path:
|
||||||
|
logger.error("[Lark] 图片路径为空,无法上传")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
image_file = open(file_path, "rb")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Lark] 无法打开图片文件: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
request = (
|
request = (
|
||||||
CreateImageRequest.builder()
|
CreateImageRequest.builder()
|
||||||
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if lark_client.im is None:
|
||||||
|
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
|
||||||
|
continue
|
||||||
|
|
||||||
response = await lark_client.im.v1.image.acreate(request)
|
response = await lark_client.im.v1.image.acreate(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.data is None:
|
||||||
|
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
|
||||||
|
continue
|
||||||
|
|
||||||
image_key = response.data.image_key
|
image_key = response.data.image_key
|
||||||
logger.debug(image_key)
|
logger.debug(image_key)
|
||||||
ret.append(_stage)
|
ret.append(_stage)
|
||||||
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.bot.im is None:
|
||||||
|
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||||
|
return
|
||||||
|
|
||||||
response = await self.bot.im.v1.message.areply(request)
|
response = await self.bot.im.v1.message.areply(request)
|
||||||
|
|
||||||
if not response.success():
|
if not response.success():
|
||||||
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|
||||||
async def react(self, emoji: str):
|
async def react(self, emoji: str):
|
||||||
|
if self.bot.im is None:
|
||||||
|
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||||
|
return
|
||||||
|
|
||||||
request = (
|
request = (
|
||||||
CreateMessageReactionRequest.builder()
|
CreateMessageReactionRequest.builder()
|
||||||
.message_id(self.message_obj.message_id)
|
.message_id(self.message_obj.message_id)
|
||||||
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self.bot.im.v1.message_reaction.acreate(request)
|
response = await self.bot.im.v1.message_reaction.acreate(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from collections.abc import Awaitable
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import astrbot.api.message_components as Comp
|
import astrbot.api.message_components as Comp
|
||||||
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
|
|||||||
if not isinstance(message.raw_message, dict):
|
if not isinstance(message.raw_message, dict):
|
||||||
message.raw_message = {}
|
message.raw_message = {}
|
||||||
message.raw_message["poll"] = poll
|
message.raw_message["poll"] = poll
|
||||||
message.poll = poll
|
message.__setattr__("poll", poll)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
|
|||||||
self,
|
self,
|
||||||
session: MessageSession,
|
session: MessageSession,
|
||||||
message_chain: MessageChain,
|
message_chain: MessageChain,
|
||||||
) -> Awaitable[Any]:
|
) -> None:
|
||||||
if not self.api:
|
if not self.api:
|
||||||
logger.error("[Misskey] API 客户端未初始化")
|
logger.error("[Misskey] API 客户端未初始化")
|
||||||
return await super().send_by_session(session, message_chain)
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import base64
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import botpy
|
import botpy
|
||||||
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
time_since_last_edit = current_time - last_edit_time
|
time_since_last_edit = current_time - last_edit_time
|
||||||
|
|
||||||
if time_since_last_edit >= throttle_interval:
|
if time_since_last_edit >= throttle_interval:
|
||||||
ret = await self._post_send(stream=stream_payload)
|
ret = cast(
|
||||||
|
message.Message,
|
||||||
|
await self._post_send(stream=stream_payload),
|
||||||
|
)
|
||||||
stream_payload["index"] += 1
|
stream_payload["index"] += 1
|
||||||
stream_payload["id"] = ret["id"]
|
stream_payload["id"] = ret["id"]
|
||||||
last_edit_time = asyncio.get_event_loop().time()
|
last_edit_time = asyncio.get_event_loop().time()
|
||||||
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
source = self.message_obj.raw_message
|
source = self.message_obj.raw_message
|
||||||
assert isinstance(
|
|
||||||
|
if not isinstance(
|
||||||
source,
|
source,
|
||||||
(
|
(
|
||||||
botpy.message.Message,
|
botpy.message.Message,
|
||||||
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
botpy.message.DirectMessage,
|
botpy.message.DirectMessage,
|
||||||
botpy.message.C2CMessage,
|
botpy.message.C2CMessage,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
|
||||||
|
return None
|
||||||
|
|
||||||
(
|
(
|
||||||
plain_text,
|
plain_text,
|
||||||
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
payload = {
|
payload: dict = {
|
||||||
"content": plain_text,
|
"content": plain_text,
|
||||||
"msg_id": self.message_obj.message_id,
|
"msg_id": self.message_obj.message_id,
|
||||||
}
|
}
|
||||||
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
ret = None
|
ret = None
|
||||||
|
|
||||||
match type(source):
|
match source:
|
||||||
case botpy.message.GroupMessage:
|
case botpy.message.GroupMessage():
|
||||||
|
if not source.group_openid:
|
||||||
|
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
|
||||||
|
return None
|
||||||
|
|
||||||
if image_base64:
|
if image_base64:
|
||||||
media = await self.upload_group_and_c2c_image(
|
media = await self.upload_group_and_c2c_image(
|
||||||
image_base64,
|
image_base64,
|
||||||
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
group_openid=source.group_openid,
|
group_openid=source.group_openid,
|
||||||
**payload,
|
**payload,
|
||||||
)
|
)
|
||||||
case botpy.message.C2CMessage:
|
|
||||||
|
case botpy.message.C2CMessage():
|
||||||
if image_base64:
|
if image_base64:
|
||||||
media = await self.upload_group_and_c2c_image(
|
media = await self.upload_group_and_c2c_image(
|
||||||
image_base64,
|
image_base64,
|
||||||
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
**payload,
|
**payload,
|
||||||
)
|
)
|
||||||
logger.debug(f"Message sent to C2C: {ret}")
|
logger.debug(f"Message sent to C2C: {ret}")
|
||||||
case botpy.message.Message:
|
|
||||||
|
case botpy.message.Message():
|
||||||
if image_path:
|
if image_path:
|
||||||
payload["file_image"] = image_path
|
payload["file_image"] = image_path
|
||||||
ret = await self.bot.api.post_message(
|
ret = await self.bot.api.post_message(
|
||||||
channel_id=source.channel_id,
|
channel_id=source.channel_id,
|
||||||
**payload,
|
**payload,
|
||||||
)
|
)
|
||||||
case botpy.message.DirectMessage:
|
|
||||||
|
case botpy.message.DirectMessage():
|
||||||
if image_path:
|
if image_path:
|
||||||
payload["file_image"] = image_path
|
payload["file_image"] = image_path
|
||||||
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
pass
|
||||||
|
|
||||||
await super().send(self.send_buffer)
|
await super().send(self.send_buffer)
|
||||||
|
|
||||||
self.send_buffer = None
|
self.send_buffer = None
|
||||||
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
"file_type": file_type,
|
"file_type": file_type,
|
||||||
"srv_send_msg": False,
|
"srv_send_msg": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result = None
|
||||||
if "openid" in kwargs:
|
if "openid" in kwargs:
|
||||||
payload["openid"] = kwargs["openid"]
|
payload["openid"] = kwargs["openid"]
|
||||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
result = await self.bot.api._http.request(route, json=payload)
|
||||||
if "group_openid" in kwargs:
|
elif "group_openid" in kwargs:
|
||||||
payload["group_openid"] = kwargs["group_openid"]
|
payload["group_openid"] = kwargs["group_openid"]
|
||||||
route = Route(
|
route = Route(
|
||||||
"POST",
|
"POST",
|
||||||
"/v2/groups/{group_openid}/files",
|
"/v2/groups/{group_openid}/files",
|
||||||
group_openid=kwargs["group_openid"],
|
group_openid=kwargs["group_openid"],
|
||||||
)
|
)
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
result = await self.bot.api._http.request(route, json=payload)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid upload parameters")
|
||||||
|
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to upload image, response is not dict: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Media(
|
||||||
|
file_uuid=result["file_uuid"],
|
||||||
|
file_info=result["file_info"],
|
||||||
|
ttl=result.get("ttl", 0),
|
||||||
|
)
|
||||||
|
|
||||||
async def upload_group_and_c2c_record(
|
async def upload_group_and_c2c_record(
|
||||||
self,
|
self,
|
||||||
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
result = await self.bot.api._http.request(route, json=payload)
|
result = await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.error(f"上传文件响应格式错误: {result}")
|
||||||
|
return None
|
||||||
|
|
||||||
return Media(
|
return Media(
|
||||||
file_uuid=result.get("file_uuid"),
|
file_uuid=result["file_uuid"],
|
||||||
file_info=result.get("file_info"),
|
file_info=result["file_info"],
|
||||||
ttl=result.get("ttl", 0),
|
ttl=result.get("ttl", 0),
|
||||||
file_id=result.get("id", ""),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"上传请求错误: {e}")
|
logger.error(f"上传请求错误: {e}")
|
||||||
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
message_reference: message.Reference | None = None,
|
message_reference: message.Reference | None = None,
|
||||||
media: message.Media | None = None,
|
media: message.Media | None = None,
|
||||||
msg_id: str | None = None,
|
msg_id: str | None = None,
|
||||||
msg_seq: str = 1,
|
msg_seq: int | None = 1,
|
||||||
event_id: str | None = None,
|
event_id: str | None = None,
|
||||||
markdown: message.MarkdownPayload | None = None,
|
markdown: message.MarkdownPayload | None = None,
|
||||||
keyboard: message.Keyboard | None = None,
|
keyboard: message.Keyboard | None = None,
|
||||||
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
payload = locals()
|
payload = locals()
|
||||||
payload.pop("self", None)
|
payload.pop("self", None)
|
||||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||||
return await self.bot.api._http.request(route, json=payload)
|
result = await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to post c2c message, response is not dict: {result}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return message.Message(**result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _parse_to_qqofficial(message: MessageChain):
|
async def _parse_to_qqofficial(message: MessageChain):
|
||||||
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
image_base64 = file_to_base64(image_file_path)
|
image_base64 = file_to_base64(image_file_path)
|
||||||
elif i.file and i.file.startswith("base64://"):
|
elif i.file and i.file.startswith("base64://"):
|
||||||
image_base64 = i.file
|
image_base64 = i.file
|
||||||
else:
|
elif i.file:
|
||||||
image_base64 = file_to_base64(i.file)
|
image_base64 = file_to_base64(i.file)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported image file format")
|
||||||
image_base64 = image_base64.removeprefix("base64://")
|
image_base64 = image_base64.removeprefix("base64://")
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
if i.file:
|
if i.file:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import botpy
|
import botpy
|
||||||
import botpy.message
|
import botpy.message
|
||||||
@@ -44,7 +45,9 @@ class botClient(Client):
|
|||||||
MessageType.GROUP_MESSAGE,
|
MessageType.GROUP_MESSAGE,
|
||||||
)
|
)
|
||||||
abm.session_id = (
|
abm.session_id = (
|
||||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
abm.sender.user_id
|
||||||
|
if self.platform.unique_session
|
||||||
|
else cast(str, message.group_openid)
|
||||||
)
|
)
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
|
|
||||||
self.appid = platform_config["appid"]
|
self.appid = platform_config["appid"]
|
||||||
self.secret = platform_config["secret"]
|
self.secret = platform_config["secret"]
|
||||||
self.unique_session = platform_settings["unique_session"]
|
self.unique_session: bool = platform_settings["unique_session"]
|
||||||
qq_group = platform_config["enable_group_c2c"]
|
qq_group = platform_config["enable_group_c2c"]
|
||||||
guild_dm = platform_config["enable_guild_direct_message"]
|
guild_dm = platform_config["enable_guild_direct_message"]
|
||||||
|
|
||||||
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
name="qq_official",
|
name="qq_official",
|
||||||
description="QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_from_qqofficial(
|
def _parse_from_qqofficial(
|
||||||
message: botpy.message.Message | botpy.message.GroupMessage,
|
message: botpy.message.Message
|
||||||
|
| botpy.message.GroupMessage
|
||||||
|
| botpy.message.DirectMessage
|
||||||
|
| botpy.message.C2CMessage,
|
||||||
message_type: MessageType,
|
message_type: MessageType,
|
||||||
):
|
):
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
abm.timestamp = int(time.time())
|
abm.timestamp = int(time.time())
|
||||||
abm.raw_message = message
|
abm.raw_message = message
|
||||||
abm.message_id = message.id
|
abm.message_id = message.id
|
||||||
abm.tag = "qq_official"
|
# abm.tag = "qq_official"
|
||||||
msg: list[BaseMessageComponent] = []
|
msg: list[BaseMessageComponent] = []
|
||||||
|
|
||||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(
|
if isinstance(message, botpy.message.GroupMessage) or isinstance(
|
||||||
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
|
|||||||
message,
|
message,
|
||||||
botpy.message.DirectMessage,
|
botpy.message.DirectMessage,
|
||||||
):
|
):
|
||||||
try:
|
if isinstance(message, botpy.message.Message):
|
||||||
abm.self_id = str(message.mentions[0].id)
|
abm.self_id = str(message.mentions[0].id)
|
||||||
except BaseException as _:
|
else:
|
||||||
abm.self_id = ""
|
abm.self_id = ""
|
||||||
|
|
||||||
plain_content = message.content.replace(
|
plain_content = message.content.replace(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import botpy
|
import botpy
|
||||||
import botpy.message
|
import botpy.message
|
||||||
@@ -36,7 +36,9 @@ class botClient(Client):
|
|||||||
MessageType.GROUP_MESSAGE,
|
MessageType.GROUP_MESSAGE,
|
||||||
)
|
)
|
||||||
abm.session_id = (
|
abm.session_id = (
|
||||||
abm.sender.user_id if self.platform.unique_session else message.group_openid
|
abm.sender.user_id
|
||||||
|
if self.platform.unique_session
|
||||||
|
else cast(str, message.group_openid)
|
||||||
)
|
)
|
||||||
self._commit(abm)
|
self._commit(abm)
|
||||||
|
|
||||||
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
|||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
name="qq_official_webhook",
|
name="qq_official_webhook",
|
||||||
description="QQ 机器人官方 API 适配器",
|
description="QQ 机器人官方 API 适配器",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
|
||||||
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
|
|||||||
|
|
||||||
if opcode == 13:
|
if opcode == 13:
|
||||||
# validation
|
# validation
|
||||||
signed = await self.webhook_validation(data)
|
signed = await self.webhook_validation(cast(dict, data))
|
||||||
print(signed)
|
print(signed)
|
||||||
return signed
|
return signed
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import hmac
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from quart import Quart, Response, request
|
from quart import Quart, Response, request
|
||||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||||
|
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
|
||||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
from slack_sdk.web.async_client import AsyncWebClient
|
from slack_sdk.web.async_client import AsyncWebClient
|
||||||
@@ -66,7 +68,7 @@ class SlackWebhookClient:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取请求体和头部
|
# 获取请求体和头部
|
||||||
body = await req.get_data()
|
body = cast(bytes, await req.get_data())
|
||||||
event_data = json.loads(body.decode("utf-8"))
|
event_data = json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
# Verify Slack request signature
|
# Verify Slack request signature
|
||||||
@@ -139,9 +141,14 @@ class SlackSocketClient:
|
|||||||
self.event_handler = event_handler
|
self.event_handler = event_handler
|
||||||
self.socket_client = None
|
self.socket_client = None
|
||||||
|
|
||||||
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
|
async def _handle_events(
|
||||||
|
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
|
||||||
|
):
|
||||||
"""处理 Socket Mode 事件"""
|
"""处理 Socket Mode 事件"""
|
||||||
try:
|
try:
|
||||||
|
if self.socket_client is None:
|
||||||
|
raise RuntimeError("Socket client is not initialized")
|
||||||
|
|
||||||
# 确认收到事件
|
# 确认收到事件
|
||||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||||
await self.socket_client.send_socket_mode_response(response)
|
await self.socket_client.send_socket_mode_response(response)
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ import base64
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable
|
from typing import Any, cast
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
|
|||||||
self.metadata = PlatformMetadata(
|
self.metadata = PlatformMetadata(
|
||||||
name="slack",
|
name="slack",
|
||||||
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
|
||||||
id=self.config.get("id"),
|
id=cast(str, self.config.get("id")),
|
||||||
support_streaming_message=False,
|
support_streaming_message=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
|
|||||||
logger.debug(f"[slack] RawMessage {event}")
|
logger.debug(f"[slack] RawMessage {event}")
|
||||||
|
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = self.bot_self_id
|
abm.self_id = cast(str, self.bot_self_id)
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息
|
||||||
user_id = event.get("user", "")
|
user_id = event.get("user", "")
|
||||||
try:
|
try:
|
||||||
user_info = await self.web_client.users_info(user=user_id)
|
user_info = await self.web_client.users_info(user=user_id)
|
||||||
user_data = user_info["user"]
|
user_data = cast(dict, user_info["user"])
|
||||||
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
user_name = user_data.get("real_name") or user_data.get("name", user_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
user_name = user_id
|
user_name = user_id
|
||||||
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
|
|||||||
channel_id = event.get("channel", "")
|
channel_id = event.get("channel", "")
|
||||||
try:
|
try:
|
||||||
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
channel_info = await self.web_client.conversations_info(channel=channel_id)
|
||||||
is_im = channel_info["channel"]["is_im"]
|
is_im = cast(dict, channel_info["channel"])["is_im"]
|
||||||
|
|
||||||
if is_im:
|
if is_im:
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
|
|||||||
for mention in mentions:
|
for mention in mentions:
|
||||||
try:
|
try:
|
||||||
mentioned_user = await self.web_client.users_info(user=mention)
|
mentioned_user = await self.web_client.users_info(user=mention)
|
||||||
user_data = mentioned_user["user"]
|
user_data = cast(dict, mentioned_user["user"])
|
||||||
user_name = user_data.get("real_name") or user_data.get(
|
user_name = user_data.get("real_name") or user_data.get(
|
||||||
"name",
|
"name",
|
||||||
mention,
|
mention,
|
||||||
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
|
|||||||
)
|
)
|
||||||
raise Exception(f"下载文件失败: {resp.status}")
|
raise Exception(f"下载文件失败: {resp.status}")
|
||||||
|
|
||||||
async def run(self) -> Awaitable[Any]:
|
async def run(self) -> None:
|
||||||
self.bot_self_id = await self.get_bot_user_id()
|
self.bot_self_id = await self.get_bot_user_id()
|
||||||
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Iterable
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from slack_sdk.web.async_client import AsyncWebClient
|
from slack_sdk.web.async_client import AsyncWebClient
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
if isinstance(segment, Image):
|
if isinstance(segment, Image):
|
||||||
# upload file
|
# upload file
|
||||||
url = segment.url or segment.file
|
url = segment.url or segment.file
|
||||||
if url.startswith("http"):
|
if url and url.startswith("http"):
|
||||||
return {
|
return {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image_url": url,
|
"image_url": url,
|
||||||
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
"type": "section",
|
"type": "section",
|
||||||
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
"text": {"type": "mrkdwn", "text": "图片上传失败"},
|
||||||
}
|
}
|
||||||
image_url = response["files"][0]["url_private"]
|
image_url = cast(list, response["files"])[0]["url_private"]
|
||||||
logger.debug(f"Slack file upload response: {response}")
|
logger.debug(f"Slack file upload response: {response}")
|
||||||
return {
|
return {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
"type": "section",
|
"type": "section",
|
||||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||||
}
|
}
|
||||||
file_url = response["files"][0]["permalink"]
|
file_url = cast(list, response["files"])[0]["permalink"]
|
||||||
return {
|
return {
|
||||||
"type": "section",
|
"type": "section",
|
||||||
"text": {
|
"text": {
|
||||||
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
members = []
|
members = []
|
||||||
for member_id in members_response["members"]:
|
for member_id in cast(Iterable, members_response["members"]):
|
||||||
try:
|
try:
|
||||||
user_info = await self.web_client.users_info(user=member_id)
|
user_info = await self.web_client.users_info(user=member_id)
|
||||||
user_data = user_info["user"]
|
user_data = cast(dict, user_info["user"])
|
||||||
members.append(
|
members.append(
|
||||||
MessageMember(
|
MessageMember(
|
||||||
user_id=member_id,
|
user_id=member_id,
|
||||||
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
# 如果获取用户信息失败,使用默认信息
|
# 如果获取用户信息失败,使用默认信息
|
||||||
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
members.append(MessageMember(user_id=member_id, nickname=member_id))
|
||||||
|
|
||||||
channel_data = channel_info["channel"]
|
channel_data = cast(dict, channel_info["channel"])
|
||||||
return Group(
|
return Group(
|
||||||
group_id=channel_id,
|
group_id=channel_id,
|
||||||
group_name=channel_data.get("name", ""),
|
group_name=channel_data.get("name", ""),
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import telegramify_markdown
|
import telegramify_markdown
|
||||||
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
|
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
|
||||||
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
|
|||||||
Reply,
|
Reply,
|
||||||
)
|
)
|
||||||
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
|
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
||||||
from astrbot.core.utils.io import download_file
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramPlatformEvent(AstrMessageEvent):
|
class TelegramPlatformEvent(AstrMessageEvent):
|
||||||
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
"chat_id": user_name,
|
"chat_id": user_name,
|
||||||
}
|
}
|
||||||
if has_reply:
|
if has_reply:
|
||||||
payload["reply_to_message_id"] = reply_message_id
|
payload["reply_to_message_id"] = str(reply_message_id)
|
||||||
if message_thread_id:
|
if message_thread_id:
|
||||||
payload["message_thread_id"] = message_thread_id
|
payload["message_thread_id"] = message_thread_id
|
||||||
|
|
||||||
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
md_text = telegramify_markdown.markdownify(
|
md_text = telegramify_markdown.markdownify(
|
||||||
chunk,
|
chunk,
|
||||||
max_line_length=None,
|
|
||||||
normalize_whitespace=False,
|
normalize_whitespace=False,
|
||||||
)
|
)
|
||||||
await client.send_message(
|
await client.send_message(
|
||||||
text=md_text,
|
text=md_text,
|
||||||
parse_mode="MarkdownV2",
|
parse_mode="MarkdownV2",
|
||||||
**payload,
|
**cast(Any, payload),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"MarkdownV2 send failed: {e}. Using plain text instead.",
|
f"MarkdownV2 send failed: {e}. Using plain text instead.",
|
||||||
)
|
)
|
||||||
await client.send_message(text=chunk, **payload)
|
await client.send_message(text=chunk, **cast(Any, payload))
|
||||||
elif isinstance(i, Image):
|
elif isinstance(i, Image):
|
||||||
image_path = await i.convert_to_file_path()
|
image_path = await i.convert_to_file_path()
|
||||||
await client.send_photo(photo=image_path, **payload)
|
await client.send_photo(photo=image_path, **cast(Any, payload))
|
||||||
elif isinstance(i, File):
|
elif isinstance(i, File):
|
||||||
if i.file.startswith("https://"):
|
path = await i.get_file()
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
name = i.name or os.path.basename(path)
|
||||||
path = os.path.join(temp_dir, i.name)
|
await client.send_document(
|
||||||
await download_file(i.file, path)
|
document=path, filename=name, **cast(Any, payload)
|
||||||
i.file = path
|
)
|
||||||
|
|
||||||
await client.send_document(document=i.file, filename=i.name, **payload)
|
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
path = await i.convert_to_file_path()
|
path = await i.convert_to_file_path()
|
||||||
await client.send_voice(voice=path, **payload)
|
await client.send_voice(voice=path, **cast(Any, payload))
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||||
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
delta += i.text
|
delta += i.text
|
||||||
elif isinstance(i, Image):
|
elif isinstance(i, Image):
|
||||||
image_path = await i.convert_to_file_path()
|
image_path = await i.convert_to_file_path()
|
||||||
await self.client.send_photo(photo=image_path, **payload)
|
await self.client.send_photo(
|
||||||
|
photo=image_path, **cast(Any, payload)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
elif isinstance(i, File):
|
elif isinstance(i, File):
|
||||||
if i.file.startswith("https://"):
|
path = await i.get_file()
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
name = i.name or os.path.basename(path)
|
||||||
path = os.path.join(temp_dir, i.name)
|
|
||||||
await download_file(i.file, path)
|
|
||||||
i.file = path
|
|
||||||
|
|
||||||
await self.client.send_document(
|
await self.client.send_document(
|
||||||
document=i.file,
|
document=path,
|
||||||
filename=i.name,
|
filename=name,
|
||||||
**payload,
|
**cast(Any, payload),
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
path = await i.convert_to_file_path()
|
path = await i.convert_to_file_path()
|
||||||
await self.client.send_voice(voice=path, **payload)
|
await self.client.send_voice(voice=path, **cast(Any, payload))
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.warning(f"不支持的消息类型: {type(i)}")
|
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||||
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
# delta 长度一般不会大于 4096,因此这里直接发送
|
# delta 长度一般不会大于 4096,因此这里直接发送
|
||||||
try:
|
try:
|
||||||
msg = await self.client.send_message(text=delta, **payload)
|
msg = await self.client.send_message(
|
||||||
|
text=delta, **cast(Any, payload)
|
||||||
|
)
|
||||||
current_content = delta
|
current_content = delta
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
markdown_text = telegramify_markdown.markdownify(
|
markdown_text = telegramify_markdown.markdownify(
|
||||||
delta,
|
delta,
|
||||||
max_line_length=None,
|
|
||||||
normalize_whitespace=False,
|
normalize_whitespace=False,
|
||||||
)
|
)
|
||||||
await self.client.edit_message_text(
|
await self.client.edit_message_text(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Callable, Coroutine
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
|
|||||||
abm.raw_message = data
|
abm.raw_message = data
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
def run(self) -> Awaitable[Any]:
|
def run(self) -> Coroutine[Any, Any, None]:
|
||||||
async def callback(data: tuple):
|
async def callback(data: tuple):
|
||||||
abm = await self.convert_message(data)
|
abm = await self.convert_message(data)
|
||||||
await self.handle_msg(abm)
|
await self.handle_msg(abm)
|
||||||
|
|||||||
@@ -101,9 +101,9 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain | None):
|
||||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||||
await super().send(message)
|
await super().send(MessageChain([]))
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
final_data = ""
|
final_data = ""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import anyio
|
import anyio
|
||||||
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
|
|||||||
)
|
)
|
||||||
self.base_url = f"http://{self.host}:{self.port}"
|
self.base_url = f"http://{self.host}:{self.port}"
|
||||||
self.auth_key = None # 用于保存生成的授权码
|
self.auth_key = None # 用于保存生成的授权码
|
||||||
self.wxid = None # 用于保存登录成功后的 wxid
|
self.wxid: str | None = None # 用于保存登录成功后的 wxid
|
||||||
self.credentials_file = os.path.join(
|
self.credentials_file = os.path.join(
|
||||||
get_astrbot_data_path(),
|
get_astrbot_data_path(),
|
||||||
"wechatpadpro_credentials.json",
|
"wechatpadpro_credentials.json",
|
||||||
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def handle_websocket_message(self, message: str):
|
async def handle_websocket_message(self, message: str | bytes):
|
||||||
"""处理从 WebSocket 接收到的消息。"""
|
"""处理从 WebSocket 接收到的消息。"""
|
||||||
logger.debug(f"收到 WebSocket 消息: {message}")
|
logger.debug(f"收到 WebSocket 消息: {message}")
|
||||||
try:
|
try:
|
||||||
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
|
|||||||
|
|
||||||
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
|
||||||
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
|
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
|
||||||
|
if self.wxid is None:
|
||||||
|
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
|
||||||
|
return None
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.raw_message = raw_message
|
abm.raw_message = raw_message
|
||||||
abm.message_id = str(raw_message.get("msg_id"))
|
abm.message_id = str(raw_message.get("msg_id"))
|
||||||
abm.timestamp = raw_message.get("create_time")
|
abm.timestamp = cast(int, raw_message.get("create_time"))
|
||||||
abm.self_id = self.wxid
|
abm.self_id = self.wxid
|
||||||
|
|
||||||
if int(time.time()) - abm.timestamp > 180:
|
if int(time.time()) - abm.timestamp > 180:
|
||||||
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
|
|||||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||||
content = raw_message.get("content", {}).get("str", "")
|
content = raw_message.get("content", {}).get("str", "")
|
||||||
push_content = raw_message.get("push_content", "")
|
push_content = raw_message.get("push_content", "")
|
||||||
msg_type = raw_message.get("msg_type")
|
msg_type = cast(int, raw_message.get("msg_type"))
|
||||||
|
|
||||||
abm.message_str = ""
|
abm.message_str = ""
|
||||||
abm.message = []
|
abm.message = []
|
||||||
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
|
|||||||
from_user_name: str,
|
from_user_name: str,
|
||||||
to_user_name: str,
|
to_user_name: str,
|
||||||
msg_id: int,
|
msg_id: int,
|
||||||
):
|
) -> dict | None:
|
||||||
"""下载原始图片。"""
|
"""下载原始图片。"""
|
||||||
url = f"{self.base_url}/message/GetMsgBigImg"
|
url = f"{self.base_url}/message/GetMsgBigImg"
|
||||||
params = {"key": self.auth_key}
|
params = {"key": self.auth_key}
|
||||||
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
|
|||||||
# 图片消息
|
# 图片消息
|
||||||
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
|
||||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||||
msg_id = raw_message.get("msg_id")
|
msg_id = cast(int, raw_message.get("msg_id"))
|
||||||
image_resp = await self._download_raw_image(
|
image_resp = await self._download_raw_image(
|
||||||
from_user_name,
|
from_user_name,
|
||||||
to_user_name,
|
to_user_name,
|
||||||
msg_id,
|
msg_id,
|
||||||
)
|
)
|
||||||
|
if image_resp is None:
|
||||||
|
logger.error(f"下载图片失败: msg_id={msg_id}")
|
||||||
|
return
|
||||||
image_bs64_data = (
|
image_bs64_data = (
|
||||||
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
|
||||||
)
|
)
|
||||||
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
|
|||||||
bufid = 0
|
bufid = 0
|
||||||
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
|
||||||
new_msg_id = raw_message.get("new_msg_id")
|
new_msg_id = raw_message.get("new_msg_id")
|
||||||
|
if new_msg_id is None:
|
||||||
|
logger.error("语音消息缺少 new_msg_id")
|
||||||
|
return
|
||||||
data_parser = GeweDataParser(
|
data_parser = GeweDataParser(
|
||||||
content=content,
|
content=content,
|
||||||
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
|
||||||
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
voicemsg = data_parser._format_to_xml().find("voicemsg")
|
||||||
|
if voicemsg is None:
|
||||||
|
logger.error("无法从 XML 解析 voicemsg 节点")
|
||||||
|
return
|
||||||
bufid = voicemsg.get("bufid") or "0"
|
bufid = voicemsg.get("bufid") or "0"
|
||||||
length = int(voicemsg.get("length") or 0)
|
length = int(voicemsg.get("length") or 0)
|
||||||
voice_resp = await self.download_voice(
|
voice_resp = await self.download_voice(
|
||||||
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
|
|||||||
bufid=bufid,
|
bufid=bufid,
|
||||||
length=length,
|
length=length,
|
||||||
)
|
)
|
||||||
|
if voice_resp is None:
|
||||||
|
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
|
||||||
|
return
|
||||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||||
if voice_bs64_data:
|
if voice_bs64_data:
|
||||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||||
@@ -827,7 +843,8 @@ class WeChatPadProAdapter(Platform):
|
|||||||
try:
|
try:
|
||||||
if self.ws_handle_task:
|
if self.ws_handle_task:
|
||||||
self.ws_handle_task.cancel()
|
self.ws_handle_task.cancel()
|
||||||
self._shutdown_event.set()
|
if self._shutdown_event is not None:
|
||||||
|
self._shutdown_event.set()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
|
|||||||
|
|
||||||
async def get_contact_details_list(
|
async def get_contact_details_list(
|
||||||
self,
|
self,
|
||||||
room_wx_id_list: list[str] = None,
|
room_wx_id_list: list[str] | None = None,
|
||||||
user_names: list[str] = None,
|
user_names: list[str] | None = None,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
"""获取联系人详情列表。"""
|
"""获取联系人详情列表。"""
|
||||||
if room_wx_id_list is None:
|
if room_wx_id_list is None:
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
from requests import Response
|
from requests import Response
|
||||||
@@ -40,7 +41,7 @@ else:
|
|||||||
class WecomServer:
|
class WecomServer:
|
||||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.port = int(config.get("port"))
|
self.port = int(cast(str, config.get("port")))
|
||||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
self.server.add_url_rule(
|
self.server.add_url_rule(
|
||||||
"/callback/command",
|
"/callback/command",
|
||||||
@@ -60,7 +61,7 @@ class WecomServer:
|
|||||||
config["corpid"].strip(),
|
config["corpid"].strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.callback = None
|
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||||
self.shutdown_event = asyncio.Event()
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def verify(self):
|
async def verify(self):
|
||||||
@@ -114,7 +115,7 @@ class WecomServer:
|
|||||||
logger.error("解密失败,签名异常,请检查配置。")
|
logger.error("解密失败,签名异常,请检查配置。")
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
msg = parse_message(xml)
|
msg = cast(BaseMessage, parse_message(xml))
|
||||||
logger.info(f"解析成功: {msg}")
|
logger.info(f"解析成功: {msg}")
|
||||||
|
|
||||||
if self.callback:
|
if self.callback:
|
||||||
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
|
|||||||
# inject
|
# inject
|
||||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||||
self.client.kf = self.wechat_kf_api
|
self.client.__setattr__("kf", self.wechat_kf_api)
|
||||||
self.client.kf_message = self.wechat_kf_message_api
|
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
|
||||||
|
|
||||||
self.client.API_BASE_URL = self.api_base_url
|
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||||
|
|
||||||
async def callback(msg: BaseMessage):
|
async def callback(msg: BaseMessage):
|
||||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||||
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
if msg.type == "text":
|
if isinstance(msg, TextMessage):
|
||||||
assert isinstance(msg, TextMessage)
|
|
||||||
abm.message_str = msg.content
|
abm.message_str = msg.content
|
||||||
abm.self_id = str(msg.agent)
|
abm.self_id = str(msg.agent)
|
||||||
abm.message = [Plain(msg.content)]
|
abm.message = [Plain(msg.content)]
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
)
|
)
|
||||||
abm.message_id = msg.id
|
abm.message_id = str(msg.id)
|
||||||
abm.timestamp = msg.time
|
abm.timestamp = int(cast(int | str, msg.time))
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
elif msg.type == "image":
|
elif isinstance(msg, ImageMessage):
|
||||||
assert isinstance(msg, ImageMessage)
|
|
||||||
abm.message_str = "[图片]"
|
abm.message_str = "[图片]"
|
||||||
abm.self_id = str(msg.agent)
|
abm.self_id = str(msg.agent)
|
||||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
abm.message = [Image(file=msg.image, url=msg.image)]
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
)
|
)
|
||||||
abm.message_id = msg.id
|
abm.message_id = str(msg.id)
|
||||||
abm.timestamp = msg.time
|
abm.timestamp = int(cast(int | str, msg.time))
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
elif msg.type == "voice":
|
elif isinstance(msg, VoiceMessage):
|
||||||
assert isinstance(msg, VoiceMessage)
|
|
||||||
|
|
||||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.client.media.download,
|
self.client.media.download,
|
||||||
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
|
|||||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
)
|
)
|
||||||
abm.message_id = msg.id
|
abm.message_id = str(msg.id)
|
||||||
abm.timestamp = msg.time
|
abm.timestamp = int(cast(int | str, msg.time))
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
else:
|
else:
|
||||||
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
|
|||||||
|
|
||||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||||
msgtype = msg.get("msgtype")
|
msgtype = msg.get("msgtype")
|
||||||
external_userid = msg.get("external_userid")
|
external_userid = cast(str, msg.get("external_userid"))
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.raw_message = msg
|
abm.raw_message = msg
|
||||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||||
|
|||||||
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
|
|||||||
if is_wechat_kf:
|
if is_wechat_kf:
|
||||||
# 微信客服
|
# 微信客服
|
||||||
kf_message_api = getattr(self.client, "kf_message", None)
|
kf_message_api = getattr(self.client, "kf_message", None)
|
||||||
if not kf_message_api:
|
if not isinstance(kf_message_api, WeChatKFMessage):
|
||||||
logger.warning("未找到微信客服发送消息方法。")
|
logger.warning("未找到微信客服发送消息方法。")
|
||||||
return
|
return
|
||||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
|
||||||
user_id = self.get_sender_id()
|
user_id = self.get_sender_id()
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _send(
|
async def _send(
|
||||||
message_chain: MessageChain,
|
message_chain: MessageChain | None,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
queue_mgr: WecomAIQueueMgr,
|
queue_mgr: WecomAIQueueMgr,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain | None):
|
||||||
"""发送消息"""
|
"""发送消息"""
|
||||||
raw = self.message_obj.raw_message
|
raw = self.message_obj.raw_message
|
||||||
assert isinstance(raw, dict), (
|
assert isinstance(raw, dict), (
|
||||||
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
stream_id = raw.get("stream_id", self.session_id)
|
stream_id = raw.get("stream_id", self.session_id)
|
||||||
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
|
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
|
||||||
await super().send(message)
|
await super().send(MessageChain([]))
|
||||||
|
|
||||||
async def send_streaming(self, generator, use_fallback=False):
|
async def send_streaming(self, generator, use_fallback=False):
|
||||||
"""流式发送消息,参考webchat的send_streaming设计"""
|
"""流式发送消息,参考webchat的send_streaming设计"""
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
from requests import Response
|
from requests import Response
|
||||||
@@ -36,7 +37,7 @@ else:
|
|||||||
class WeixinOfficialAccountServer:
|
class WeixinOfficialAccountServer:
|
||||||
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
def __init__(self, event_queue: asyncio.Queue, config: dict):
|
||||||
self.server = quart.Quart(__name__)
|
self.server = quart.Quart(__name__)
|
||||||
self.port = int(config.get("port"))
|
self.port = int(cast(int | str, config.get("port")))
|
||||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||||
self.token = config.get("token")
|
self.token = config.get("token")
|
||||||
self.encoding_aes_key = config.get("encoding_aes_key")
|
self.encoding_aes_key = config.get("encoding_aes_key")
|
||||||
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
|
|||||||
|
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue
|
||||||
|
|
||||||
self.callback = None
|
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
|
||||||
self.shutdown_event = asyncio.Event()
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
async def verify(self):
|
async def verify(self):
|
||||||
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
msg = parse_message(xml)
|
msg = parse_message(xml)
|
||||||
|
if not msg:
|
||||||
|
logger.error("解析失败。msg为None。")
|
||||||
|
raise
|
||||||
logger.info(f"解析成功: {msg}")
|
logger.info(f"解析成功: {msg}")
|
||||||
|
|
||||||
if self.callback:
|
if self.callback:
|
||||||
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
self.config["secret"].strip(),
|
self.config["secret"].strip(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client.API_BASE_URL = self.api_base_url
|
self.client.__setattr__("API_BASE_URL", self.api_base_url)
|
||||||
|
|
||||||
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
|
||||||
# msgid -> Future
|
# msgid -> Future
|
||||||
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
await self.convert_message(msg, None)
|
await self.convert_message(msg, None)
|
||||||
else:
|
else:
|
||||||
if msg.id in self.wexin_event_workers:
|
if msg.id in self.wexin_event_workers:
|
||||||
future = self.wexin_event_workers[msg.id]
|
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
|
||||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||||
else:
|
else:
|
||||||
future = asyncio.get_event_loop().create_future()
|
future = asyncio.get_event_loop().create_future()
|
||||||
self.wexin_event_workers[msg.id] = future
|
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
60,
|
60,
|
||||||
) # wait for 60s
|
) # wait for 60s
|
||||||
logger.debug(f"Got future result: {result}")
|
logger.debug(f"Got future result: {result}")
|
||||||
self.wexin_event_workers.pop(msg.id, None)
|
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
|
||||||
return result # xml. see weixin_offacc_event.py
|
return result # xml. see weixin_offacc_event.py
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
async def convert_message(
|
async def convert_message(
|
||||||
self,
|
self,
|
||||||
msg,
|
msg,
|
||||||
future: asyncio.Future = None,
|
future: asyncio.Future | None = None,
|
||||||
) -> AstrBotMessage | None:
|
) -> AstrBotMessage | None:
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
if isinstance(msg, TextMessage):
|
if isinstance(msg, TextMessage):
|
||||||
abm.message_str = msg.content
|
abm.message_str = cast(str, msg.content)
|
||||||
abm.self_id = str(msg.target)
|
abm.self_id = str(msg.target)
|
||||||
abm.message = [Plain(msg.content)]
|
abm.message = [Plain(cast(str, msg.content))]
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
)
|
)
|
||||||
abm.message_id = msg.id
|
abm.message_id = str(cast(str | int, msg.id))
|
||||||
abm.timestamp = msg.time
|
abm.timestamp = cast(int, msg.time)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
elif msg.type == "image":
|
elif msg.type == "image":
|
||||||
assert isinstance(msg, ImageMessage)
|
assert isinstance(msg, ImageMessage)
|
||||||
abm.message_str = "[图片]"
|
abm.message_str = "[图片]"
|
||||||
abm.self_id = str(msg.target)
|
abm.self_id = str(msg.target)
|
||||||
abm.message = [Image(file=msg.image, url=msg.image)]
|
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
)
|
)
|
||||||
abm.message_id = msg.id
|
abm.message_id = str(cast(str | int, msg.id))
|
||||||
abm.timestamp = msg.time
|
abm.timestamp = cast(int, msg.time)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
elif msg.type == "voice":
|
elif msg.type == "voice":
|
||||||
assert isinstance(msg, VoiceMessage)
|
assert isinstance(msg, VoiceMessage)
|
||||||
@@ -306,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
abm.message = [Record(file=path_wav, url=path_wav)]
|
abm.message = [Record(file=path_wav, url=path_wav)]
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
msg.source,
|
cast(str, msg.source),
|
||||||
)
|
)
|
||||||
abm.message_id = msg.id
|
abm.message_id = str(cast(str | int, msg.id))
|
||||||
abm.timestamp = msg.time
|
abm.timestamp = cast(int, msg.time)
|
||||||
abm.session_id = abm.sender.user_id
|
abm.session_id = abm.sender.user_id
|
||||||
else:
|
else:
|
||||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||||
future.set_result(None)
|
if future:
|
||||||
|
future.set_result(None)
|
||||||
return
|
return
|
||||||
# 很不优雅 :(
|
# 很不优雅 :(
|
||||||
abm.raw_message = {
|
abm.raw_message = {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from wechatpy import WeChatClient
|
from wechatpy import WeChatClient
|
||||||
from wechatpy.replies import ImageReply, TextReply, VoiceReply
|
from wechatpy.replies import ImageReply, TextReply, VoiceReply
|
||||||
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
message_obj = self.message_obj
|
message_obj = self.message_obj
|
||||||
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
|
active_send_mode = cast(dict, message_obj.raw_message).get(
|
||||||
|
"active_send_mode", False
|
||||||
|
)
|
||||||
for comp in message.chain:
|
for comp in message.chain:
|
||||||
if isinstance(comp, Plain):
|
if isinstance(comp, Plain):
|
||||||
# Split long text messages if needed
|
# Split long text messages if needed
|
||||||
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
reply = TextReply(
|
reply = TextReply(
|
||||||
content=chunk,
|
content=chunk,
|
||||||
message=self.message_obj.raw_message["message"],
|
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||||
)
|
)
|
||||||
xml = reply.render()
|
xml = reply.render()
|
||||||
future = self.message_obj.raw_message["future"]
|
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||||
assert isinstance(future, asyncio.Future)
|
assert isinstance(future, asyncio.Future)
|
||||||
future.set_result(xml)
|
future.set_result(xml)
|
||||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||||
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
reply = ImageReply(
|
reply = ImageReply(
|
||||||
media_id=response["media_id"],
|
media_id=response["media_id"],
|
||||||
message=self.message_obj.raw_message["message"],
|
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||||
)
|
)
|
||||||
xml = reply.render()
|
xml = reply.render()
|
||||||
future = self.message_obj.raw_message["future"]
|
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||||
assert isinstance(future, asyncio.Future)
|
assert isinstance(future, asyncio.Future)
|
||||||
future.set_result(xml)
|
future.set_result(xml)
|
||||||
|
|
||||||
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
else:
|
else:
|
||||||
reply = VoiceReply(
|
reply = VoiceReply(
|
||||||
media_id=response["media_id"],
|
media_id=response["media_id"],
|
||||||
message=self.message_obj.raw_message["message"],
|
message=cast(dict, self.message_obj.raw_message)["message"],
|
||||||
)
|
)
|
||||||
xml = reply.render()
|
xml = reply.render()
|
||||||
future = self.message_obj.raw_message["future"]
|
future = cast(dict, self.message_obj.raw_message)["future"]
|
||||||
assert isinstance(future, asyncio.Future)
|
assert isinstance(future, asyncio.Future)
|
||||||
future.set_result(xml)
|
future.set_result(xml)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -118,7 +118,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list[dict],
|
func_args: list[dict],
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Callable[..., Awaitable[Any]],
|
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||||
) -> FuncTool:
|
) -> FuncTool:
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -140,7 +140,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Callable[..., Awaitable[Any]],
|
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""添加函数调用工具
|
"""添加函数调用工具
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
from astrbot.core import astrbot_config, logger, sp
|
from astrbot.core import astrbot_config, logger, sp
|
||||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||||
@@ -10,6 +11,7 @@ from .entities import ProviderType
|
|||||||
from .provider import (
|
from .provider import (
|
||||||
EmbeddingProvider,
|
EmbeddingProvider,
|
||||||
Provider,
|
Provider,
|
||||||
|
Providers,
|
||||||
RerankProvider,
|
RerankProvider,
|
||||||
STTProvider,
|
STTProvider,
|
||||||
TTSProvider,
|
TTSProvider,
|
||||||
@@ -17,6 +19,11 @@ from .provider import (
|
|||||||
from .register import llm_tools, provider_cls_map
|
from .register import llm_tools, provider_cls_map
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class HasInitialize(Protocol):
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class ProviderManager:
|
class ProviderManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -48,7 +55,7 @@ class ProviderManager:
|
|||||||
"""加载的 Rerank Provider 的实例"""
|
"""加载的 Rerank Provider 的实例"""
|
||||||
self.inst_map: dict[
|
self.inst_map: dict[
|
||||||
str,
|
str,
|
||||||
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
Providers,
|
||||||
] = {}
|
] = {}
|
||||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||||
self.llm_tools = llm_tools
|
self.llm_tools = llm_tools
|
||||||
@@ -123,15 +130,13 @@ class ProviderManager:
|
|||||||
self.curr_provider_inst = prov
|
self.curr_provider_inst = prov
|
||||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||||
|
|
||||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||||||
"""根据提供商 ID 获取提供商实例"""
|
"""根据提供商 ID 获取提供商实例"""
|
||||||
return self.inst_map.get(provider_id)
|
return self.inst_map.get(provider_id)
|
||||||
|
|
||||||
def get_using_provider(
|
def get_using_provider(
|
||||||
self,
|
self, provider_type: ProviderType, umo=None
|
||||||
provider_type: ProviderType,
|
) -> Providers | None:
|
||||||
umo=None,
|
|
||||||
) -> Provider | STTProvider | TTSProvider | None:
|
|
||||||
"""获取正在使用的提供商实例。
|
"""获取正在使用的提供商实例。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -191,7 +196,6 @@ class ProviderManager:
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|
||||||
# 设置默认提供商
|
|
||||||
selected_provider_id = sp.get(
|
selected_provider_id = sp.get(
|
||||||
"curr_provider",
|
"curr_provider",
|
||||||
self.provider_settings.get("default_provider_id"),
|
self.provider_settings.get("default_provider_id"),
|
||||||
@@ -210,15 +214,37 @@ class ProviderManager:
|
|||||||
scope="global",
|
scope="global",
|
||||||
scope_id="global",
|
scope_id="global",
|
||||||
)
|
)
|
||||||
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
|
|
||||||
|
temp_provider = (
|
||||||
|
self.inst_map.get(selected_provider_id)
|
||||||
|
if isinstance(selected_provider_id, str)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.curr_provider_inst = (
|
||||||
|
temp_provider if isinstance(temp_provider, Provider) else None
|
||||||
|
)
|
||||||
if not self.curr_provider_inst and self.provider_insts:
|
if not self.curr_provider_inst and self.provider_insts:
|
||||||
self.curr_provider_inst = self.provider_insts[0]
|
self.curr_provider_inst = self.provider_insts[0]
|
||||||
|
|
||||||
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
|
temp_stt = (
|
||||||
|
self.inst_map.get(selected_stt_provider_id)
|
||||||
|
if isinstance(selected_stt_provider_id, str)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.curr_stt_provider_inst = (
|
||||||
|
temp_stt if isinstance(temp_stt, STTProvider) else None
|
||||||
|
)
|
||||||
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
if not self.curr_stt_provider_inst and self.stt_provider_insts:
|
||||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||||
|
|
||||||
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
|
temp_tts = (
|
||||||
|
self.inst_map.get(selected_tts_provider_id)
|
||||||
|
if isinstance(selected_tts_provider_id, str)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.curr_tts_provider_inst = (
|
||||||
|
temp_tts if isinstance(temp_tts, TTSProvider) else None
|
||||||
|
)
|
||||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||||
|
|
||||||
@@ -358,73 +384,103 @@ class ProviderManager:
|
|||||||
|
|
||||||
provider_metadata.id = provider_config["id"]
|
provider_metadata.id = provider_config["id"]
|
||||||
|
|
||||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
match provider_metadata.provider_type:
|
||||||
# STT 任务
|
case ProviderType.SPEECH_TO_TEXT:
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
# STT 任务
|
||||||
|
if not issubclass(cls_type, STTProvider):
|
||||||
|
raise TypeError(
|
||||||
|
f"Provider class {cls_type} is not a subclass of STTProvider"
|
||||||
|
)
|
||||||
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if isinstance(inst, HasInitialize):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
|
|
||||||
self.stt_provider_insts.append(inst)
|
self.stt_provider_insts.append(inst)
|
||||||
if (
|
if (
|
||||||
self.provider_stt_settings.get("provider_id")
|
self.provider_stt_settings.get("provider_id")
|
||||||
== provider_config["id"]
|
== provider_config["id"]
|
||||||
):
|
):
|
||||||
self.curr_stt_provider_inst = inst
|
self.curr_stt_provider_inst = inst
|
||||||
logger.info(
|
logger.info(
|
||||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
||||||
|
)
|
||||||
|
if not self.curr_stt_provider_inst:
|
||||||
|
self.curr_stt_provider_inst = inst
|
||||||
|
|
||||||
|
case ProviderType.TEXT_TO_SPEECH:
|
||||||
|
# TTS 任务
|
||||||
|
if not issubclass(cls_type, TTSProvider):
|
||||||
|
raise TypeError(
|
||||||
|
f"Provider class {cls_type} is not a subclass of TTSProvider"
|
||||||
|
)
|
||||||
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
|
|
||||||
|
if isinstance(inst, HasInitialize):
|
||||||
|
await inst.initialize()
|
||||||
|
|
||||||
|
self.tts_provider_insts.append(inst)
|
||||||
|
if (
|
||||||
|
self.provider_settings.get("provider_id")
|
||||||
|
== provider_config["id"]
|
||||||
|
):
|
||||||
|
self.curr_tts_provider_inst = inst
|
||||||
|
logger.info(
|
||||||
|
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
||||||
|
)
|
||||||
|
if not self.curr_tts_provider_inst:
|
||||||
|
self.curr_tts_provider_inst = inst
|
||||||
|
|
||||||
|
case ProviderType.CHAT_COMPLETION:
|
||||||
|
# 文本生成任务
|
||||||
|
if not issubclass(cls_type, Provider):
|
||||||
|
raise TypeError(
|
||||||
|
f"Provider class {cls_type} is not a subclass of Provider"
|
||||||
|
)
|
||||||
|
inst = cls_type(
|
||||||
|
provider_config,
|
||||||
|
self.provider_settings,
|
||||||
)
|
)
|
||||||
if not self.curr_stt_provider_inst:
|
|
||||||
self.curr_stt_provider_inst = inst
|
|
||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
if isinstance(inst, HasInitialize):
|
||||||
# TTS 任务
|
await inst.initialize()
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
self.provider_insts.append(inst)
|
||||||
await inst.initialize()
|
if (
|
||||||
|
self.provider_settings.get("default_provider_id")
|
||||||
|
== provider_config["id"]
|
||||||
|
):
|
||||||
|
self.curr_provider_inst = inst
|
||||||
|
logger.info(
|
||||||
|
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
||||||
|
)
|
||||||
|
if not self.curr_provider_inst:
|
||||||
|
self.curr_provider_inst = inst
|
||||||
|
|
||||||
self.tts_provider_insts.append(inst)
|
case ProviderType.EMBEDDING:
|
||||||
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
if not issubclass(cls_type, EmbeddingProvider):
|
||||||
self.curr_tts_provider_inst = inst
|
raise TypeError(
|
||||||
logger.info(
|
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
|
||||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
)
|
||||||
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
|
if isinstance(inst, HasInitialize):
|
||||||
|
await inst.initialize()
|
||||||
|
self.embedding_provider_insts.append(inst)
|
||||||
|
case ProviderType.RERANK:
|
||||||
|
if not issubclass(cls_type, RerankProvider):
|
||||||
|
raise TypeError(
|
||||||
|
f"Provider class {cls_type} is not a subclass of RerankProvider"
|
||||||
|
)
|
||||||
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
|
if isinstance(inst, HasInitialize):
|
||||||
|
await inst.initialize()
|
||||||
|
self.rerank_provider_insts.append(inst)
|
||||||
|
case _:
|
||||||
|
# 未知供应商抛出异常,确保inst初始化
|
||||||
|
# Should be unreachable
|
||||||
|
raise Exception(
|
||||||
|
f"未知的提供商类型:{provider_metadata.provider_type}"
|
||||||
)
|
)
|
||||||
if not self.curr_tts_provider_inst:
|
|
||||||
self.curr_tts_provider_inst = inst
|
|
||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
|
||||||
# 文本生成任务
|
|
||||||
inst = cls_type(
|
|
||||||
provider_config,
|
|
||||||
self.provider_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
|
||||||
await inst.initialize()
|
|
||||||
|
|
||||||
self.provider_insts.append(inst)
|
|
||||||
if (
|
|
||||||
self.provider_settings.get("default_provider_id")
|
|
||||||
== provider_config["id"]
|
|
||||||
):
|
|
||||||
self.curr_provider_inst = inst
|
|
||||||
logger.info(
|
|
||||||
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
|
||||||
)
|
|
||||||
if not self.curr_provider_inst:
|
|
||||||
self.curr_provider_inst = inst
|
|
||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
|
||||||
if getattr(inst, "initialize", None):
|
|
||||||
await inst.initialize()
|
|
||||||
self.embedding_provider_insts.append(inst)
|
|
||||||
elif provider_metadata.provider_type == ProviderType.RERANK:
|
|
||||||
inst = cls_type(provider_config, self.provider_settings)
|
|
||||||
if getattr(inst, "initialize", None):
|
|
||||||
await inst.initialize()
|
|
||||||
self.rerank_provider_insts.append(inst)
|
|
||||||
|
|
||||||
self.inst_map[provider_config["id"]] = inst
|
self.inst_map[provider_config["id"]] = inst
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import abc
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import TypeAlias, Union
|
||||||
|
|
||||||
from astrbot.core.agent.message import Message
|
from astrbot.core.agent.message import Message
|
||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
|
|||||||
from astrbot.core.provider.register import provider_cls_map
|
from astrbot.core.provider.register import provider_cls_map
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||||
|
|
||||||
|
Providers: TypeAlias = Union[
|
||||||
|
"Provider",
|
||||||
|
"STTProvider",
|
||||||
|
"TTSProvider",
|
||||||
|
"EmbeddingProvider",
|
||||||
|
"RerankProvider",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AbstractProvider(abc.ABC):
|
class AbstractProvider(abc.ABC):
|
||||||
"""Provider Abstract Class"""
|
"""Provider Abstract Class"""
|
||||||
@@ -142,7 +151,9 @@ class Provider(AbstractProvider):
|
|||||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
...
|
if False: # pragma: no cover - make this an async generator for typing
|
||||||
|
yield None # type: ignore
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def pop_record(self, context: list):
|
async def pop_record(self, context: list):
|
||||||
"""弹出 context 第一条非系统提示词对话记录"""
|
"""弹出 context 第一条非系统提示词对话记录"""
|
||||||
|
|||||||
@@ -29,15 +29,24 @@ class OTTSProvider:
|
|||||||
self.last_sync_time = 0
|
self.last_sync_time = 0
|
||||||
self.timeout = Timeout(10.0)
|
self.timeout = Timeout(10.0)
|
||||||
self.retry_count = 3
|
self.retry_count = 3
|
||||||
self.client = None
|
self._client: AsyncClient | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Client not initialized. Please use 'async with' context."
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.client = AsyncClient(timeout=self.timeout)
|
self._client = AsyncClient(timeout=self.timeout)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
if self.client:
|
if self._client:
|
||||||
await self.client.aclose()
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
async def _sync_time(self):
|
async def _sync_time(self):
|
||||||
try:
|
try:
|
||||||
@@ -90,6 +99,7 @@ class OTTSProvider:
|
|||||||
if attempt == self.retry_count - 1:
|
if attempt == self.retry_count - 1:
|
||||||
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
|
||||||
await asyncio.sleep(0.5 * (attempt + 1))
|
await asyncio.sleep(0.5 * (attempt + 1))
|
||||||
|
raise RuntimeError("OTTS未返回音频文件")
|
||||||
|
|
||||||
|
|
||||||
class AzureNativeProvider(TTSProvider):
|
class AzureNativeProvider(TTSProvider):
|
||||||
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
self.endpoint = (
|
self.endpoint = (
|
||||||
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||||
)
|
)
|
||||||
self.client = None
|
self._client: AsyncClient | None = None
|
||||||
self.token = None
|
self.token = None
|
||||||
self.token_expire = 0
|
self.token_expire = 0
|
||||||
self.voice_params = {
|
self.voice_params = {
|
||||||
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Client not initialized. Please use 'async with' context."
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.client = AsyncClient(
|
self._client = AsyncClient(
|
||||||
headers={
|
headers={
|
||||||
"User-Agent": f"AstrBot/{VERSION}",
|
"User-Agent": f"AstrBot/{VERSION}",
|
||||||
"Content-Type": "application/ssml+xml",
|
"Content-Type": "application/ssml+xml",
|
||||||
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
if self.client:
|
if self._client:
|
||||||
await self.client.aclose()
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
async def _refresh_token(self):
|
async def _refresh_token(self):
|
||||||
token_url = (
|
token_url = (
|
||||||
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
|
|||||||
key_value = provider_config.get("azure_tts_subscription_key", "")
|
key_value = provider_config.get("azure_tts_subscription_key", "")
|
||||||
self.provider = self._parse_provider(key_value, provider_config)
|
self.provider = self._parse_provider(key_value, provider_config)
|
||||||
|
|
||||||
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
|
def _parse_provider(
|
||||||
|
self, key_value: str, config: dict
|
||||||
|
) -> OTTSProvider | AzureNativeProvider:
|
||||||
if key_value.lower().startswith("other["):
|
if key_value.lower().startswith("other["):
|
||||||
|
json_str = ""
|
||||||
try:
|
try:
|
||||||
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
|
||||||
if not match:
|
if not match:
|
||||||
|
|||||||
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
重排序结果列表
|
重排序结果列表
|
||||||
"""
|
"""
|
||||||
|
if not self.client:
|
||||||
|
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
|
||||||
|
return []
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.warning("文档列表为空,返回空结果")
|
logger.warning("文档列表为空,返回空结果")
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
super().__init__(provider_config, provider_settings)
|
super().__init__(provider_config, provider_settings)
|
||||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||||
self.set_model(provider_config.get("model"))
|
self.set_model(provider_config["model"])
|
||||||
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||||
dashscope.api_key = self.chosen_api_key
|
dashscope.api_key = self.chosen_api_key
|
||||||
|
|
||||||
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
|||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"text": text,
|
"messages": None,
|
||||||
"api_key": self.chosen_api_key,
|
"api_key": self.chosen_api_key,
|
||||||
"voice": self.voice or "Cherry",
|
"voice": self.voice or "Cherry",
|
||||||
|
"text": text,
|
||||||
}
|
}
|
||||||
if not self.voice:
|
if not self.voice:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
|
|||||||
from pyffmpeg import FFmpeg
|
from pyffmpeg import FFmpeg
|
||||||
|
|
||||||
ff = FFmpeg()
|
ff = FFmpeg()
|
||||||
ff.convert(input=mp3_path, output=wav_path)
|
ff.convert(input_file=mp3_path, output_file=wav_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||||
# use ffmpeg command line
|
# use ffmpeg command line
|
||||||
|
|||||||
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
self.headers = {
|
self.headers = {
|
||||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||||
}
|
}
|
||||||
self.set_model(provider_config.get("model"))
|
self.set_model(provider_config["model"])
|
||||||
|
|
||||||
async def _get_reference_id_by_character(self, character: str) -> str:
|
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||||
"""获取角色的reference_id
|
"""获取角色的reference_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
pattern = r"^[a-fA-F0-9]{32}$"
|
pattern = r"^[a-fA-F0-9]{32}$"
|
||||||
return bool(re.match(pattern, reference_id.strip()))
|
return bool(re.match(pattern, reference_id.strip()))
|
||||||
|
|
||||||
async def _generate_request(self, text: str) -> dict:
|
async def _generate_request(self, text: str) -> ServeTTSRequest:
|
||||||
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
|
||||||
if self.reference_id and self.reference_id.strip():
|
if self.reference_id and self.reference_id.strip():
|
||||||
# 验证reference_id格式
|
# 验证reference_id格式
|
||||||
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
async for chunk in response.aiter_bytes():
|
async for chunk in response.aiter_bytes():
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
return path
|
return path
|
||||||
text = await response.aread()
|
body = await response.aread()
|
||||||
|
text = body.decode("utf-8", errors="replace")
|
||||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
from google.genai.errors import APIError
|
from google.genai.errors import APIError
|
||||||
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|||||||
self.provider_config = provider_config
|
self.provider_config = provider_config
|
||||||
self.provider_settings = provider_settings
|
self.provider_settings = provider_settings
|
||||||
|
|
||||||
api_key: str = provider_config.get("embedding_api_key")
|
api_key: str = provider_config["embedding_api_key"]
|
||||||
api_base: str = provider_config.get("embedding_api_base")
|
api_base: str = provider_config["embedding_api_base"]
|
||||||
timeout: int = int(provider_config.get("timeout", 20))
|
timeout: int = int(provider_config.get("timeout", 20))
|
||||||
|
|
||||||
http_options = types.HttpOptions(timeout=timeout * 1000)
|
http_options = types.HttpOptions(timeout=timeout * 1000)
|
||||||
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
contents=text,
|
contents=text,
|
||||||
)
|
)
|
||||||
|
assert result.embeddings is not None
|
||||||
|
assert result.embeddings[0].values is not None
|
||||||
return result.embeddings[0].values
|
return result.embeddings[0].values
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
|
||||||
|
|
||||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||||
"""批量获取文本的嵌入"""
|
"""批量获取文本的嵌入"""
|
||||||
try:
|
try:
|
||||||
result = await self.client.models.embed_content(
|
result = await self.client.models.embed_content(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
contents=texts,
|
contents=cast(types.ContentListUnion, text),
|
||||||
)
|
)
|
||||||
return [embedding.values for embedding in result.embeddings]
|
assert result.embeddings is not None
|
||||||
|
|
||||||
|
embeddings: list[list[float]] = []
|
||||||
|
for embedding in result.embeddings:
|
||||||
|
assert embedding.values is not None
|
||||||
|
embeddings.append(embedding.values)
|
||||||
|
return embeddings
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
@@ -136,7 +137,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||||
modalities = ["Text"]
|
modalities = ["Text"]
|
||||||
|
|
||||||
tool_list = []
|
tool_list: list[types.Tool] | None = []
|
||||||
model_name = self.get_model()
|
model_name = self.get_model()
|
||||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||||
native_search = self.provider_config.get("gm_native_search", False)
|
native_search = self.provider_config.get("gm_native_search", False)
|
||||||
@@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
logprobs=payloads.get("logprobs"),
|
logprobs=payloads.get("logprobs"),
|
||||||
seed=payloads.get("seed"),
|
seed=payloads.get("seed"),
|
||||||
response_modalities=modalities,
|
response_modalities=modalities,
|
||||||
tools=tool_list,
|
tools=cast(types.ToolListUnion | None, tool_list),
|
||||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||||
thinking_config=(
|
thinking_config=(
|
||||||
types.ThinkingConfig(
|
types.ThinkingConfig(
|
||||||
@@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
content_cls: type[types.Content],
|
content_cls: type[types.Content],
|
||||||
) -> None:
|
) -> None:
|
||||||
if contents and isinstance(contents[-1], content_cls):
|
if contents and isinstance(contents[-1], content_cls):
|
||||||
|
assert contents[-1].parts is not None
|
||||||
contents[-1].parts.extend(part)
|
contents[-1].parts.extend(part)
|
||||||
else:
|
else:
|
||||||
contents.append(content_cls(parts=part))
|
contents.append(content_cls(parts=part))
|
||||||
@@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
)
|
)
|
||||||
result = await self.client.models.generate_content(
|
result = await self.client.models.generate_content(
|
||||||
model=self.get_model(),
|
model=self.get_model(),
|
||||||
contents=conversation,
|
contents=cast(types.ContentListUnion, conversation),
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
logger.debug(f"genai result: {result}")
|
logger.debug(f"genai result: {result}")
|
||||||
@@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
)
|
)
|
||||||
result = await self.client.models.generate_content_stream(
|
result = await self.client.models.generate_content_stream(
|
||||||
model=self.get_model(),
|
model=self.get_model(),
|
||||||
contents=conversation,
|
contents=cast(types.ContentListUnion, conversation),
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|||||||
|
|
||||||
return json.dumps(dict_body)
|
return json.dumps(dict_body)
|
||||||
|
|
||||||
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
|
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
|
||||||
"""进行流式请求"""
|
"""进行流式请求"""
|
||||||
try:
|
try:
|
||||||
async with (
|
async with (
|
||||||
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
|||||||
data = json.loads(message[6:])
|
data = json.loads(message[6:])
|
||||||
if "extra_info" in data:
|
if "extra_info" in data:
|
||||||
continue
|
continue
|
||||||
audio = data.get("data", {}).get("audio")
|
audio: str | None = data.get("data", {}).get(
|
||||||
|
"audio"
|
||||||
|
)
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
yield audio
|
yield audio
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|||||||
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
embedding = await self.client.embeddings.create(input=text, model=self.model)
|
||||||
return embedding.data[0].embedding
|
return embedding.data[0].embedding
|
||||||
|
|
||||||
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
|
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
||||||
"""批量获取文本的嵌入"""
|
"""批量获取文本的嵌入"""
|
||||||
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
|
embeddings = await self.client.embeddings.create(input=text, model=self.model)
|
||||||
return [item.embedding for item in embeddings.data]
|
return [item.embedding for item in embeddings.data]
|
||||||
|
|
||||||
def get_dim(self) -> int:
|
def get_dim(self) -> int:
|
||||||
|
|||||||
@@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
if isinstance(tool_call, str):
|
if isinstance(tool_call, str):
|
||||||
# workaround for #1359
|
# workaround for #1359
|
||||||
tool_call = json.loads(tool_call)
|
tool_call = json.loads(tool_call)
|
||||||
|
if tools is None:
|
||||||
|
# 工具集未提供
|
||||||
|
# Should be unreachable
|
||||||
|
raise Exception("工具集未提供")
|
||||||
for tool in tools.func_list:
|
for tool in tools.func_list:
|
||||||
if (
|
if (
|
||||||
tool_call.type == "function"
|
tool_call.type == "function"
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from funasr_onnx import SenseVoiceSmall
|
from funasr_onnx import SenseVoiceSmall
|
||||||
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
||||||
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
provider_settings: dict,
|
provider_settings: dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(provider_config, provider_settings)
|
super().__init__(provider_config, provider_settings)
|
||||||
self.set_model(provider_config.get("stt_model"))
|
self.set_model(provider_config["stt_model"])
|
||||||
self.model = None
|
self.model = None
|
||||||
self.is_emotion = provider_config.get("is_emotion", False)
|
self.is_emotion = provider_config.get("is_emotion", False)
|
||||||
|
|
||||||
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
res = await loop.run_in_executor(
|
res = await loop.run_in_executor(
|
||||||
None, # 使用默认的线程池
|
None, # 使用默认的线程池
|
||||||
lambda: self.model(audio_url, language="auto", use_itn=True),
|
lambda: cast(SenseVoiceSmall, self.model)(
|
||||||
|
audio_url, language="auto", use_itn=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# res = self.model(audio_url, language="auto", use_itn=True)
|
# res = self.model(audio_url, language="auto", use_itn=True)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
|
|||||||
}
|
}
|
||||||
if top_n is not None:
|
if top_n is not None:
|
||||||
payload["top_n"] = top_n
|
payload["top_n"] = top_n
|
||||||
|
assert self.client is not None
|
||||||
async with self.client.post(
|
async with self.client.post(
|
||||||
f"{self.base_url}/v1/rerank",
|
f"{self.base_url}/v1/rerank",
|
||||||
json=payload,
|
json=payload,
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
|||||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.set_model(provider_config.get("model"))
|
self.set_model(provider_config["model"])
|
||||||
|
|
||||||
async def _get_audio_format(self, file_path):
|
async def _get_audio_format(self, file_path):
|
||||||
# 定义要检测的头部字节
|
# 定义要检测的头部字节
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
provider_settings: dict,
|
provider_settings: dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(provider_config, provider_settings)
|
super().__init__(provider_config, provider_settings)
|
||||||
self.set_model(provider_config.get("model"))
|
self.set_model(provider_config["model"])
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
|||||||
await tencent_silk_to_wav(audio_url, output_path)
|
await tencent_silk_to_wav(audio_url, output_path)
|
||||||
audio_url = output_path
|
audio_url = output_path
|
||||||
|
|
||||||
|
if not self.model:
|
||||||
|
raise RuntimeError("Whisper 模型未初始化")
|
||||||
|
|
||||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||||
return result["text"]
|
return cast(str, result["text"])
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
from xinference_client.client.restful.async_restful_client import (
|
from xinference_client.client.restful.async_restful_client import (
|
||||||
AsyncClient as Client,
|
AsyncClient as Client,
|
||||||
)
|
)
|
||||||
|
from xinference_client.client.restful.async_restful_client import (
|
||||||
|
AsyncRESTfulRerankModelHandle,
|
||||||
|
)
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
|
||||||
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
self.client = None
|
self.client = None
|
||||||
self.model = None
|
self.model: AsyncRESTfulRerankModelHandle | None = None
|
||||||
self.model_uid = None
|
self.model_uid = None
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.model_uid:
|
if self.model_uid:
|
||||||
self.model = await self.client.get_model(self.model_uid)
|
self.model = cast(
|
||||||
|
AsyncRESTfulRerankModelHandle,
|
||||||
|
await self.client.get_model(self.model_uid),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize Xinference model: {e}")
|
logger.error(f"Failed to initialize Xinference model: {e}")
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ class Context:
|
|||||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||||
return self.provider_manager.embedding_provider_insts
|
return self.provider_manager.embedding_provider_insts
|
||||||
|
|
||||||
def get_using_provider(self, umo: str | None = None) -> Provider | None:
|
def get_using_provider(self, umo: str | None = None) -> Provider:
|
||||||
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -296,7 +296,7 @@ class Context:
|
|||||||
provider_type=ProviderType.CHAT_COMPLETION,
|
provider_type=ProviderType.CHAT_COMPLETION,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
if prov and not isinstance(prov, Provider):
|
if not isinstance(prov, Provider):
|
||||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||||
return prov
|
return prov
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import docstring_parser
|
import docstring_parser
|
||||||
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
|
|||||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||||
from astrbot.core.agent.tool import FunctionTool
|
from astrbot.core.agent.tool import FunctionTool
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
|
from astrbot.core.message.message_event_result import MessageEventResult
|
||||||
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
|
|
||||||
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
|
|||||||
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
|
||||||
|
|
||||||
|
|
||||||
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
def get_handler_full_name(
|
||||||
|
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
|
||||||
|
) -> str:
|
||||||
"""获取 Handler 的全名"""
|
"""获取 Handler 的全名"""
|
||||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||||
|
|
||||||
|
|
||||||
def get_handler_or_create(
|
def get_handler_or_create(
|
||||||
handler: Callable[..., Awaitable[Any]],
|
handler: Callable[
|
||||||
|
...,
|
||||||
|
Awaitable[MessageEventResult | str | None]
|
||||||
|
| AsyncGenerator[MessageEventResult | str | None],
|
||||||
|
],
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
dont_add=False,
|
dont_add=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|||||||
for (
|
for (
|
||||||
sub_handle
|
sub_handle
|
||||||
) in parent_register_commandable.parent_group.sub_command_filters:
|
) in parent_register_commandable.parent_group.sub_command_filters:
|
||||||
|
if isinstance(sub_handle, CommandGroupFilter):
|
||||||
|
continue
|
||||||
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
# 所有符合fullname一致的子指令handle添加自定义过滤器。
|
||||||
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
|
||||||
sub_handle_md = sub_handle.get_handler_md()
|
sub_handle_md = sub_handle.get_handler_md()
|
||||||
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# 裸指令
|
# 裸指令
|
||||||
|
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
|
||||||
|
assert isinstance(awaitable, Callable)
|
||||||
handler_md = get_handler_or_create(
|
handler_md = get_handler_or_create(
|
||||||
awaitable,
|
awaitable,
|
||||||
EventType.AdapterMessageEvent,
|
EventType.AdapterMessageEvent,
|
||||||
@@ -237,7 +248,7 @@ class RegisteringCommandable:
|
|||||||
|
|
||||||
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
|
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
|
||||||
command: Callable[..., Callable[..., None]] = register_command
|
command: Callable[..., Callable[..., None]] = register_command
|
||||||
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
|
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
|
||||||
|
|
||||||
def __init__(self, parent_group: CommandGroupFilter):
|
def __init__(self, parent_group: CommandGroupFilter):
|
||||||
self.parent_group = parent_group
|
self.parent_group = parent_group
|
||||||
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
|
|||||||
if kwargs.get("registering_agent"):
|
if kwargs.get("registering_agent"):
|
||||||
registering_agent = kwargs["registering_agent"]
|
registering_agent = kwargs["registering_agent"]
|
||||||
|
|
||||||
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
def decorator(
|
||||||
|
awaitable: Callable[
|
||||||
|
...,
|
||||||
|
AsyncGenerator[MessageEventResult | str | None]
|
||||||
|
| Awaitable[MessageEventResult | str | None],
|
||||||
|
],
|
||||||
|
):
|
||||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||||
func_doc = awaitable.__doc__ or ""
|
func_doc = awaitable.__doc__ or ""
|
||||||
docstring = docstring_parser.parse(func_doc)
|
docstring = docstring_parser.parse(func_doc)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, Literal, TypeVar, overload
|
||||||
|
|
||||||
from .filter import HandlerFilter
|
from .filter import HandlerFilter
|
||||||
from .star import star_map
|
from .star import star_map
|
||||||
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
for handler in self._handlers:
|
for handler in self._handlers:
|
||||||
print(handler.handler_full_name)
|
print(handler.handler_full_name)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnAstrBotLoadedEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnPlatformLoadedEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.AdapterMessageEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[
|
||||||
|
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||||
|
]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnLLMRequestEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnLLMResponseEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnDecoratingResultEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnCallingFuncToolEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[
|
||||||
|
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||||
|
]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: Literal[EventType.OnAfterMessageSentEvent],
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_handlers_by_event_type(
|
||||||
|
self,
|
||||||
|
event_type: EventType,
|
||||||
|
only_activated=True,
|
||||||
|
plugins_name: list[str] | None = None,
|
||||||
|
) -> list[
|
||||||
|
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
|
||||||
|
]: ...
|
||||||
|
|
||||||
def get_handlers_by_event_type(
|
def get_handlers_by_event_type(
|
||||||
self,
|
self,
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
@@ -111,8 +189,11 @@ class EventType(enum.Enum):
|
|||||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||||
|
|
||||||
|
|
||||||
|
H = TypeVar("H", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StarHandlerMetadata:
|
class StarHandlerMetadata(Generic[H]):
|
||||||
"""描述一个 Star 所注册的某一个 Handler。"""
|
"""描述一个 Star 所注册的某一个 Handler。"""
|
||||||
|
|
||||||
event_type: EventType
|
event_type: EventType
|
||||||
@@ -127,7 +208,7 @@ class StarHandlerMetadata:
|
|||||||
handler_module_path: str
|
handler_module_path: str
|
||||||
"""Handler 所在的模块路径。"""
|
"""Handler 所在的模块路径。"""
|
||||||
|
|
||||||
handler: Callable[..., Awaitable[Any]]
|
handler: H
|
||||||
"""Handler 的函数对象,应当是一个异步函数"""
|
"""Handler 的函数对象,应当是一个异步函数"""
|
||||||
|
|
||||||
event_filters: list[HandlerFilter]
|
event_filters: list[HandlerFilter]
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator):
|
|||||||
|
|
||||||
async def check_update(
|
async def check_update(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str | None,
|
||||||
current_version: str,
|
current_version: str | None,
|
||||||
consider_prerelease: bool = True,
|
consider_prerelease: bool = True,
|
||||||
) -> ReleaseInfo:
|
) -> ReleaseInfo | None:
|
||||||
"""检查更新"""
|
"""检查更新"""
|
||||||
return await super().check_update(
|
return await super().check_update(
|
||||||
self.ASTRBOT_RELEASE_API,
|
self.ASTRBOT_RELEASE_API,
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def save_temp_img(img: Image.Image | str) -> str:
|
def save_temp_img(img: Image.Image | bytes) -> str:
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
# 获得文件创建时间,清除超过 12 小时的
|
# 获得文件创建时间,清除超过 12 小时的
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -20,16 +20,16 @@ class SessionController:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.future = asyncio.Future()
|
self.future = asyncio.Future()
|
||||||
self.current_event: asyncio.Event = None
|
self.current_event: asyncio.Event | None = None
|
||||||
"""当前正在等待的所用的异步事件"""
|
"""当前正在等待的所用的异步事件"""
|
||||||
self.ts: float = None
|
self.ts: float | None = None
|
||||||
"""上次保持(keep)开始时的时间"""
|
"""上次保持(keep)开始时的时间"""
|
||||||
self.timeout: float | int = None
|
self.timeout: float | int | None = None
|
||||||
"""上次保持(keep)开始时的超时时间"""
|
"""上次保持(keep)开始时的超时时间"""
|
||||||
|
|
||||||
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
||||||
|
|
||||||
def stop(self, error: Exception = None):
|
def stop(self, error: Exception | None = None):
|
||||||
"""立即结束这个会话"""
|
"""立即结束这个会话"""
|
||||||
if not self.future.done():
|
if not self.future.done():
|
||||||
if error:
|
if error:
|
||||||
@@ -53,6 +53,8 @@ class SessionController:
|
|||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
assert self.timeout is not None
|
||||||
|
assert self.ts is not None
|
||||||
left_timeout = self.timeout - (new_ts - self.ts)
|
left_timeout = self.timeout - (new_ts - self.ts)
|
||||||
timeout = left_timeout + timeout
|
timeout = left_timeout + timeout
|
||||||
if timeout <= 0:
|
if timeout <= 0:
|
||||||
@@ -69,7 +71,7 @@ class SessionController:
|
|||||||
|
|
||||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||||
|
|
||||||
async def _holding(self, event: asyncio.Event, timeout: int):
|
async def _holding(self, event: asyncio.Event, timeout: float):
|
||||||
"""等待事件结束或超时"""
|
"""等待事件结束或超时"""
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(event.wait(), timeout)
|
await asyncio.wait_for(event.wait(), timeout)
|
||||||
@@ -108,7 +110,9 @@ class SessionWaiter:
|
|||||||
):
|
):
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.session_filter = session_filter
|
self.session_filter = session_filter
|
||||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
self.handler: (
|
||||||
|
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
||||||
|
) = None # 处理函数
|
||||||
|
|
||||||
self.session_controller = SessionController()
|
self.session_controller = SessionController()
|
||||||
self.record_history_chains = record_history_chains
|
self.record_history_chains = record_history_chains
|
||||||
@@ -119,7 +123,7 @@ class SessionWaiter:
|
|||||||
|
|
||||||
async def register_wait(
|
async def register_wait(
|
||||||
self,
|
self,
|
||||||
handler: Callable[[str], Awaitable[Any]],
|
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""等待外部输入并处理"""
|
"""等待外部输入并处理"""
|
||||||
@@ -137,7 +141,7 @@ class SessionWaiter:
|
|||||||
finally:
|
finally:
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _cleanup(self, error: Exception = None):
|
def _cleanup(self, error: Exception | None = None):
|
||||||
"""清理会话"""
|
"""清理会话"""
|
||||||
USER_SESSIONS.pop(self.session_id, None)
|
USER_SESSIONS.pop(self.session_id, None)
|
||||||
try:
|
try:
|
||||||
@@ -161,6 +165,7 @@ class SessionWaiter:
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
||||||
|
assert session.handler is not None
|
||||||
await session.handler(session.session_controller, event)
|
await session.handler(session.session_controller, event)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.session_controller.stop(e)
|
session.session_controller.stop(e)
|
||||||
@@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
|||||||
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
def decorator(
|
||||||
|
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||||
|
):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(
|
async def wrapper(
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
session_filter: SessionFilter = None,
|
session_filter: SessionFilter | None = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -53,6 +53,38 @@ class SharedPreferences:
|
|||||||
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
ret = await self.db_helper.get_preferences(scope, scope_id, key)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def session_get(
|
||||||
|
self,
|
||||||
|
umo: str,
|
||||||
|
key: str,
|
||||||
|
default: _VT = None,
|
||||||
|
) -> _VT: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def session_get(
|
||||||
|
self,
|
||||||
|
umo: None,
|
||||||
|
key: str,
|
||||||
|
default: Any = None,
|
||||||
|
) -> list[Preference]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def session_get(
|
||||||
|
self,
|
||||||
|
umo: str,
|
||||||
|
key: None,
|
||||||
|
default: Any = None,
|
||||||
|
) -> list[Preference]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def session_get(
|
||||||
|
self,
|
||||||
|
umo: None,
|
||||||
|
key: None,
|
||||||
|
default: Any = None,
|
||||||
|
) -> list[Preference]: ...
|
||||||
|
|
||||||
async def session_get(
|
async def session_get(
|
||||||
self,
|
self,
|
||||||
umo: str | None,
|
umo: str | None,
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
class RenderStrategy(ABC):
|
class RenderStrategy(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def render(self, text: str, return_url: bool) -> str:
|
async def render(self, text: str, return_url: bool) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def render_custom_template(
|
async def render_custom_template(
|
||||||
self,
|
self,
|
||||||
tmpl_str: str,
|
tmpl_str: str,
|
||||||
tmpl_data: dict,
|
tmpl_data: dict,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class FontManager:
|
|||||||
_font_cache = {}
|
_font_cache = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_font(cls, size: int) -> ImageFont.FreeTypeFont:
|
def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont:
|
||||||
"""获取指定大小的字体,优先从缓存获取"""
|
"""获取指定大小的字体,优先从缓存获取"""
|
||||||
if size in cls._font_cache:
|
if size in cls._font_cache:
|
||||||
return cls._font_cache[size]
|
return cls._font_cache[size]
|
||||||
@@ -66,23 +66,17 @@ class TextMeasurer:
|
|||||||
"""测量文本尺寸的工具类"""
|
"""测量文本尺寸的工具类"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
|
def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]:
|
||||||
"""获取文本的尺寸"""
|
"""获取文本的尺寸"""
|
||||||
try:
|
|
||||||
# PIL 9.0.0 以上版本
|
# 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0
|
||||||
return (
|
left, top, right, bottom = font.getbbox("Hello world")
|
||||||
font.getbbox(text)[2:]
|
return int(right - left), int(bottom - top)
|
||||||
if hasattr(font, "getbbox")
|
|
||||||
else font.getsize(text)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# 兼容旧版本
|
|
||||||
return font.getsize(text)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def split_text_to_fit_width(
|
def split_text_to_fit_width(
|
||||||
text: str, font: ImageFont.FreeTypeFont, max_width: int
|
text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""将文本拆分为多行,确保每行不超过指定宽度"""
|
"""将文本拆分为多行,确保每行不超过指定宽度"""
|
||||||
lines = []
|
lines = []
|
||||||
if not text:
|
if not text:
|
||||||
@@ -126,7 +120,7 @@ class MarkdownElement(ABC):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -152,7 +146,7 @@ class TextElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement):
|
|||||||
# 倾斜变换,使用仿射变换实现斜体效果
|
# 倾斜变换,使用仿射变换实现斜体效果
|
||||||
# 变换矩阵: [1, 0.2, 0, 0, 1, 0]
|
# 变换矩阵: [1, 0.2, 0, 0, 1, 0]
|
||||||
italic_img = text_img.transform(
|
italic_img = text_img.transform(
|
||||||
text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC
|
text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC
|
||||||
)
|
)
|
||||||
|
|
||||||
# 粘贴到原图像
|
# 粘贴到原图像
|
||||||
@@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement):
|
|||||||
class CodeBlockElement(MarkdownElement):
|
class CodeBlockElement(MarkdownElement):
|
||||||
"""代码块元素"""
|
"""代码块元素"""
|
||||||
|
|
||||||
def __init__(self, content: List[str]):
|
def __init__(self, content: list[str]):
|
||||||
super().__init__("\n".join(content))
|
super().__init__("\n".join(content))
|
||||||
|
|
||||||
def calculate_height(self, image_width: int, font_size: int) -> int:
|
def calculate_height(self, image_width: int, font_size: int) -> int:
|
||||||
@@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -667,7 +661,7 @@ class ImageElement(MarkdownElement):
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
draw: ImageDraw.Draw,
|
draw: ImageDraw.ImageDraw,
|
||||||
x: int,
|
x: int,
|
||||||
y: int,
|
y: int,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@@ -686,7 +680,7 @@ class ImageElement(MarkdownElement):
|
|||||||
if pasted_image.width > max_width:
|
if pasted_image.width > max_width:
|
||||||
ratio = max_width / pasted_image.width
|
ratio = max_width / pasted_image.width
|
||||||
new_size = (int(max_width), int(pasted_image.height * ratio))
|
new_size = (int(max_width), int(pasted_image.height * ratio))
|
||||||
pasted_image = pasted_image.resize(new_size, Image.LANCZOS)
|
pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
# 计算居中位置
|
# 计算居中位置
|
||||||
paste_x = x + (image_width - pasted_image.width) // 2 - 10
|
paste_x = x + (image_width - pasted_image.width) // 2 - 10
|
||||||
@@ -705,7 +699,7 @@ class MarkdownParser:
|
|||||||
"""Markdown解析器,将文本解析为元素"""
|
"""Markdown解析器,将文本解析为元素"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def parse(text: str) -> List[MarkdownElement]:
|
async def parse(text: str) -> list[MarkdownElement]:
|
||||||
elements = []
|
elements = []
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
|
|
||||||
@@ -847,7 +841,7 @@ class MarkdownRenderer:
|
|||||||
self,
|
self,
|
||||||
font_size: int = 26,
|
font_size: int = 26,
|
||||||
width: int = 800,
|
width: int = 800,
|
||||||
bg_color: Tuple[int, int, int] = (255, 255, 255),
|
bg_color: tuple[int, int, int] = (255, 255, 255),
|
||||||
):
|
):
|
||||||
self.font_size = font_size
|
self.font_size = font_size
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str:
|
|||||||
from pyffmpeg import FFmpeg
|
from pyffmpeg import FFmpeg
|
||||||
|
|
||||||
ff = FFmpeg()
|
ff = FFmpeg()
|
||||||
ff.convert(input=input_path, output=output_path)
|
ff.convert(input_file=input_path, output_file=output_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||||
|
|
||||||
|
|||||||
@@ -60,9 +60,12 @@ class VersionComparator:
|
|||||||
return -1
|
return -1
|
||||||
if isinstance(p1, str) and isinstance(p2, int):
|
if isinstance(p1, str) and isinstance(p2, int):
|
||||||
return 1
|
return 1
|
||||||
if (isinstance(p1, int) and isinstance(p2, int)) or (
|
if isinstance(p1, int) and isinstance(p2, int):
|
||||||
isinstance(p1, str) and isinstance(p2, str)
|
if p1 > p2:
|
||||||
):
|
return 1
|
||||||
|
if p1 < p2:
|
||||||
|
return -1
|
||||||
|
if isinstance(p1, str) and isinstance(p2, str):
|
||||||
if p1 > p2:
|
if p1 > p2:
|
||||||
return 1
|
return 1
|
||||||
if p1 < p2:
|
if p1 < p2:
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from quart import Response as QuartResponse
|
||||||
from quart import g, make_response, request, send_file
|
from quart import g, make_response, request, send_file
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
@@ -424,16 +426,19 @@ class ChatRoute(Route):
|
|||||||
sender_name=username,
|
sender_name=username,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await make_response(
|
response = cast(
|
||||||
stream(),
|
QuartResponse,
|
||||||
{
|
await make_response(
|
||||||
"Content-Type": "text/event-stream",
|
stream(),
|
||||||
"Cache-Control": "no-cache",
|
{
|
||||||
"Transfer-Encoding": "chunked",
|
"Content-Type": "text/event-stream",
|
||||||
"Connection": "keep-alive",
|
"Cache-Control": "no-cache",
|
||||||
},
|
"Transfer-Encoding": "chunked",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue]
|
response.timeout = None # fix SSE auto disconnect issue
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def delete_webchat_session(self):
|
async def delete_webchat_session(self):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ from astrbot.core.star.star import star_registry
|
|||||||
from .route import Response, Route, RouteContext
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
|
|
||||||
def try_cast(value: str, type_: str):
|
def try_cast(value: Any, type_: str):
|
||||||
if type_ == "int":
|
if type_ == "int":
|
||||||
try:
|
try:
|
||||||
return int(value)
|
return int(value)
|
||||||
@@ -505,9 +506,9 @@ class ConfigRoute(Route):
|
|||||||
if not isinstance(inst, EmbeddingProvider):
|
if not isinstance(inst, EmbeddingProvider):
|
||||||
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
|
return Response().error("提供商不是 EmbeddingProvider 类型").__dict__
|
||||||
|
|
||||||
# 初始化
|
init_fn = getattr(inst, "initialize", None)
|
||||||
if getattr(inst, "initialize", None):
|
if inspect.iscoroutinefunction(init_fn):
|
||||||
await inst.initialize()
|
await init_fn()
|
||||||
|
|
||||||
# 获取嵌入向量维度
|
# 获取嵌入向量维度
|
||||||
vec = await inst.get_embedding("echo")
|
vec = await inst.get_embedding("echo")
|
||||||
@@ -777,7 +778,7 @@ class ConfigRoute(Route):
|
|||||||
return {"metadata": CONFIG_METADATA_2, "config": config}
|
return {"metadata": CONFIG_METADATA_2, "config": config}
|
||||||
|
|
||||||
async def _get_plugin_config(self, plugin_name: str):
|
async def _get_plugin_config(self, plugin_name: str):
|
||||||
ret = {"metadata": None, "config": None}
|
ret: dict = {"metadata": None, "config": None}
|
||||||
|
|
||||||
for plugin_md in star_registry:
|
for plugin_md in star_registry:
|
||||||
if plugin_md.name == plugin_name:
|
if plugin_md.name == plugin_name:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from quart import Response as QuartResponse
|
||||||
from quart import make_response
|
from quart import make_response
|
||||||
|
|
||||||
from astrbot.core import LogBroker, logger
|
from astrbot.core import LogBroker, logger
|
||||||
@@ -39,14 +41,17 @@ class LogRoute(Route):
|
|||||||
if queue:
|
if queue:
|
||||||
self.log_broker.unregister(queue)
|
self.log_broker.unregister(queue)
|
||||||
|
|
||||||
response = await make_response(
|
response = cast(
|
||||||
stream(),
|
QuartResponse,
|
||||||
{
|
await make_response(
|
||||||
"Content-Type": "text/event-stream",
|
stream(),
|
||||||
"Cache-Control": "no-cache",
|
{
|
||||||
"Connection": "keep-alive",
|
"Content-Type": "text/event-stream",
|
||||||
"Transfer-Encoding": "chunked",
|
"Cache-Control": "no-cache",
|
||||||
},
|
"Connection": "keep-alive",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
response.timeout = None
|
response.timeout = None
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -579,6 +579,10 @@ class PluginRoute(Route):
|
|||||||
logger.warning(f"插件 {plugin_name} 不存在")
|
logger.warning(f"插件 {plugin_name} 不存在")
|
||||||
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
||||||
|
|
||||||
|
if not plugin_obj.root_dir_name:
|
||||||
|
logger.warning(f"插件 {plugin_name} 目录不存在")
|
||||||
|
return Response().error(f"插件 {plugin_name} 目录不存在").__dict__
|
||||||
|
|
||||||
plugin_dir = os.path.join(
|
plugin_dir = os.path.join(
|
||||||
self.plugin_manager.plugin_store_path,
|
self.plugin_manager.plugin_store_path,
|
||||||
plugin_obj.root_dir_name or "",
|
plugin_obj.root_dir_name or "",
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ class RouteContext:
|
|||||||
|
|
||||||
|
|
||||||
class Route:
|
class Route:
|
||||||
|
routes: list | dict
|
||||||
|
|
||||||
def __init__(self, context: RouteContext):
|
def __init__(self, context: RouteContext):
|
||||||
self.app = context.app
|
self.app = context.app
|
||||||
self.config = context.config
|
self.config = context.config
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import psutil
|
import psutil
|
||||||
|
from flask.json.provider import DefaultJSONProvider
|
||||||
|
from psutil._common import addr as psutil_addr
|
||||||
from quart import Quart, g, jsonify, request
|
from quart import Quart, g, jsonify, request
|
||||||
from quart.logging import default_handler
|
from quart.logging import default_handler
|
||||||
|
|
||||||
@@ -21,7 +24,7 @@ from .routes.route import Response, RouteContext
|
|||||||
from .routes.session_management import SessionManagementRoute
|
from .routes.session_management import SessionManagementRoute
|
||||||
from .routes.t2i import T2iRoute
|
from .routes.t2i import T2iRoute
|
||||||
|
|
||||||
APP: Quart = None
|
APP: Quart
|
||||||
|
|
||||||
|
|
||||||
class AstrBotDashboard:
|
class AstrBotDashboard:
|
||||||
@@ -48,7 +51,7 @@ class AstrBotDashboard:
|
|||||||
self.app.config["MAX_CONTENT_LENGTH"] = (
|
self.app.config["MAX_CONTENT_LENGTH"] = (
|
||||||
128 * 1024 * 1024
|
128 * 1024 * 1024
|
||||||
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
|
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
|
||||||
self.app.json.sort_keys = False
|
cast(DefaultJSONProvider, self.app.json).sort_keys = False
|
||||||
self.app.before_request(self.auth_middleware)
|
self.app.before_request(self.auth_middleware)
|
||||||
# token 用于验证请求
|
# token 用于验证请求
|
||||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||||
@@ -147,7 +150,7 @@ class AstrBotDashboard:
|
|||||||
"""获取占用端口的进程详细信息"""
|
"""获取占用端口的进程详细信息"""
|
||||||
try:
|
try:
|
||||||
for conn in psutil.net_connections(kind="inet"):
|
for conn in psutil.net_connections(kind="inet"):
|
||||||
if conn.laddr.port == port:
|
if cast(psutil_addr, conn.laddr).port == port:
|
||||||
try:
|
try:
|
||||||
process = psutil.Process(conn.pid)
|
process = psutil.Process(conn.pid)
|
||||||
# 获取详细信息
|
# 获取详细信息
|
||||||
|
|||||||
@@ -139,6 +139,11 @@ class ProcessLLMRequest:
|
|||||||
|
|
||||||
# group name identifier
|
# group name identifier
|
||||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||||
|
if not event.message_obj.group:
|
||||||
|
logger.error(
|
||||||
|
f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}"
|
||||||
|
)
|
||||||
|
return
|
||||||
group_name = event.message_obj.group.group_name
|
group_name = event.message_obj.group.group_name
|
||||||
if group_name:
|
if group_name:
|
||||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from astrbot.api import llm_tool, logger, star
|
|||||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||||
from astrbot.api.message_components import File, Image
|
from astrbot.api.message_components import File, Image
|
||||||
from astrbot.api.provider import ProviderRequest
|
from astrbot.api.provider import ProviderRequest
|
||||||
|
from astrbot.core.message.components import BaseMessageComponent
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||||
|
|
||||||
@@ -224,6 +225,8 @@ class Main(star.Star):
|
|||||||
del self.user_waiting[uid]
|
del self.user_waiting[uid]
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
image_url = comp.url if comp.url else comp.file
|
image_url = comp.url if comp.url else comp.file
|
||||||
|
if image_url is None:
|
||||||
|
raise ValueError("Image URL is None")
|
||||||
if image_url.startswith("http"):
|
if image_url.startswith("http"):
|
||||||
image_path = await download_image_by_url(image_url)
|
image_path = await download_image_by_url(image_url)
|
||||||
elif image_url.startswith("file:///"):
|
elif image_url.startswith("file:///"):
|
||||||
@@ -240,6 +243,8 @@ class Main(star.Star):
|
|||||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
||||||
if event.get_session_id() in self.user_file_msg_buffer:
|
if event.get_session_id() in self.user_file_msg_buffer:
|
||||||
files = self.user_file_msg_buffer[event.get_session_id()]
|
files = self.user_file_msg_buffer[event.get_session_id()]
|
||||||
|
if not request.prompt:
|
||||||
|
request.prompt = ""
|
||||||
request.prompt += f"\nUser provided files: {files}"
|
request.prompt += f"\nUser provided files: {files}"
|
||||||
|
|
||||||
@filter.command_group("pi")
|
@filter.command_group("pi")
|
||||||
@@ -477,7 +482,9 @@ class Main(star.Star):
|
|||||||
# file_s3_url = await self.file_upload(file_path)
|
# file_s3_url = await self.file_upload(file_path)
|
||||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
chain = [File(name=file_name, file=file_path)]
|
chain: list[BaseMessageComponent] = [
|
||||||
|
File(name=file_name, file=file_path)
|
||||||
|
]
|
||||||
yield event.set_result(MessageEventResult(chain=chain))
|
yield event.set_result(MessageEventResult(chain=chain))
|
||||||
|
|
||||||
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import uuid
|
|||||||
import zoneinfo
|
import zoneinfo
|
||||||
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
|
||||||
from astrbot.api import llm_tool, logger, star
|
from astrbot.api import llm_tool, logger, star
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||||
@@ -62,13 +63,13 @@ class Main(star.Star):
|
|||||||
misfire_grace_time=60,
|
misfire_grace_time=60,
|
||||||
)
|
)
|
||||||
elif "cron" in reminder:
|
elif "cron" in reminder:
|
||||||
|
trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"]))
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._reminder_callback,
|
self._reminder_callback,
|
||||||
trigger="cron",
|
trigger=trigger,
|
||||||
id=id_,
|
id=id_,
|
||||||
args=[group, reminder],
|
args=[group, reminder],
|
||||||
misfire_grace_time=60,
|
misfire_grace_time=60,
|
||||||
**self._parse_cron_expr(reminder["cron"]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_is_outdated(self, reminder: dict):
|
def check_is_outdated(self, reminder: dict):
|
||||||
@@ -101,10 +102,10 @@ class Main(star.Star):
|
|||||||
async def reminder_tool(
|
async def reminder_tool(
|
||||||
self,
|
self,
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
text: str = None,
|
text: str | None = None,
|
||||||
datetime_str: str = None,
|
datetime_str: str | None = None,
|
||||||
cron_expression: str = None,
|
cron_expression: str | None = None,
|
||||||
human_readable_cron: str = None,
|
human_readable_cron: str | None = None,
|
||||||
):
|
):
|
||||||
"""Call this function when user is asking for setting a reminder.
|
"""Call this function when user is asking for setting a reminder.
|
||||||
|
|
||||||
@@ -139,17 +140,19 @@ class Main(star.Star):
|
|||||||
"id": str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
}
|
}
|
||||||
self.reminder_data[event.unified_msg_origin].append(d)
|
self.reminder_data[event.unified_msg_origin].append(d)
|
||||||
|
trigger = CronTrigger(**self._parse_cron_expr(cron_expression))
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._reminder_callback,
|
self._reminder_callback,
|
||||||
"cron",
|
trigger,
|
||||||
id=d["id"],
|
id=d["id"],
|
||||||
misfire_grace_time=60,
|
misfire_grace_time=60,
|
||||||
**self._parse_cron_expr(cron_expression),
|
|
||||||
args=[event.unified_msg_origin, d],
|
args=[event.unified_msg_origin, d],
|
||||||
)
|
)
|
||||||
if human_readable_cron:
|
if human_readable_cron:
|
||||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||||
else:
|
else:
|
||||||
|
if datetime_str is None:
|
||||||
|
raise ValueError("datetime_str cannot be None.")
|
||||||
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
|
d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())}
|
||||||
self.reminder_data[event.unified_msg_origin].append(d)
|
self.reminder_data[event.unified_msg_origin].append(d)
|
||||||
datetime_scheduled = datetime.datetime.strptime(
|
datetime_scheduled = datetime.datetime.strptime(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import urllib.parse
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup, Tag
|
||||||
|
|
||||||
HEADERS = {
|
HEADERS = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
|
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0",
|
||||||
@@ -45,13 +45,13 @@ class SearchEngine:
|
|||||||
self.page = 1
|
self.page = 1
|
||||||
self.headers = HEADERS
|
self.headers = HEADERS
|
||||||
|
|
||||||
def _set_selector(self, selector: str) -> None:
|
def _set_selector(self, selector: str) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_next_page(self):
|
def _get_next_page(self, query: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _get_html(self, url: str, data: dict = None) -> str:
|
async def _get_html(self, url: str, data: dict | None = None) -> str:
|
||||||
headers = self.headers
|
headers = self.headers
|
||||||
headers["Referer"] = url
|
headers["Referer"] = url
|
||||||
headers["User-Agent"] = random.choice(USER_AGENTS)
|
headers["User-Agent"] = random.choice(USER_AGENTS)
|
||||||
@@ -83,6 +83,9 @@ class SearchEngine:
|
|||||||
"""清理文本,去除空格、换行符等"""
|
"""清理文本,去除空格、换行符等"""
|
||||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||||
|
|
||||||
|
def _get_url(self, tag: Tag) -> str:
|
||||||
|
return self.tidy_text(tag.get_text())
|
||||||
|
|
||||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||||
query = urllib.parse.quote(query)
|
query = urllib.parse.quote(query)
|
||||||
|
|
||||||
@@ -92,12 +95,16 @@ class SearchEngine:
|
|||||||
links = soup.select(self._set_selector("links"))
|
links = soup.select(self._set_selector("links"))
|
||||||
results = []
|
results = []
|
||||||
for link in links:
|
for link in links:
|
||||||
title = self.tidy_text(
|
# Safely get the title text (select_one may return None)
|
||||||
link.select_one(self._set_selector("title")).text,
|
title_elem = link.select_one(self._set_selector("title"))
|
||||||
)
|
title = ""
|
||||||
url = link.select_one(self._set_selector("url"))
|
if title_elem is not None:
|
||||||
|
title = self.tidy_text(title_elem.get_text())
|
||||||
|
|
||||||
|
url_tag = link.select_one(self._set_selector("url"))
|
||||||
snippet = ""
|
snippet = ""
|
||||||
if title and url:
|
if title and url_tag:
|
||||||
|
url = self._get_url(url_tag)
|
||||||
results.append(SearchResult(title=title, url=url, snippet=snippet))
|
results.append(SearchResult(title=title, url=url, snippet=snippet))
|
||||||
return results[:num_results] if len(results) > num_results else results
|
return results[:num_results] if len(results) > num_results else results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from . import USER_AGENT_BING, SearchEngine, SearchResult
|
from . import USER_AGENT_BING, SearchEngine
|
||||||
|
|
||||||
|
|
||||||
class Bing(SearchEngine):
|
class Bing(SearchEngine):
|
||||||
@@ -28,11 +28,3 @@ class Bing(SearchEngine):
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
continue
|
continue
|
||||||
raise Exception("Bing search failed")
|
raise Exception("Bing search failed")
|
||||||
|
|
||||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
|
||||||
results = await super().search(query, num_results)
|
|
||||||
for result in results:
|
|
||||||
if not isinstance(result.url, str):
|
|
||||||
result.url = result.url.text
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup, Tag
|
||||||
|
|
||||||
from . import USER_AGENTS, SearchEngine, SearchResult
|
from . import USER_AGENTS, SearchEngine, SearchResult
|
||||||
|
|
||||||
@@ -26,10 +27,12 @@ class Sogo(SearchEngine):
|
|||||||
url = f"{self.base_url}/web?query={query}"
|
url = f"{self.base_url}/web?query={query}"
|
||||||
return await self._get_html(url, None)
|
return await self._get_html(url, None)
|
||||||
|
|
||||||
|
def _get_url(self, tag: Tag) -> str:
|
||||||
|
return cast(str, tag.get("href"))
|
||||||
|
|
||||||
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
async def search(self, query: str, num_results: int) -> list[SearchResult]:
|
||||||
results = await super().search(query, num_results)
|
results = await super().search(query, num_results)
|
||||||
for result in results:
|
for result in results:
|
||||||
result.url = result.url.get("href")
|
|
||||||
if result.url.startswith("/link?"):
|
if result.url.startswith("/link?"):
|
||||||
result.url = self.base_url + result.url
|
result.url = self.base_url + result.url
|
||||||
result.url = await self._parse_url(result.url)
|
result.url = await self._parse_url(result.url)
|
||||||
@@ -40,7 +43,10 @@ class Sogo(SearchEngine):
|
|||||||
soup = BeautifulSoup(html, "html.parser")
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
script = soup.find("script")
|
script = soup.find("script")
|
||||||
if script:
|
if script:
|
||||||
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(
|
script_text = (
|
||||||
1,
|
script.string if script.string is not None else script.get_text()
|
||||||
)
|
)
|
||||||
|
match = re.search(r'window.location.replace\("(.+?)"\)', script_text)
|
||||||
|
if match:
|
||||||
|
url = match.group(1)
|
||||||
return url
|
return url
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
"""Minimal type stubs for faiss used in this project.
|
||||||
|
|
||||||
|
This file only exposes a small subset of the faiss API that the
|
||||||
|
project uses, including the runtime-monkeypatched signatures such as
|
||||||
|
`Index.add_with_ids` so Pyright/Pylance stops reporting false positives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, overload
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Index:
|
||||||
|
d: int
|
||||||
|
ntotal: int
|
||||||
|
code_size: int
|
||||||
|
nprobe: int
|
||||||
|
|
||||||
|
def add(self, x: np.ndarray) -> None: ...
|
||||||
|
def add_with_ids(self, x: np.ndarray, ids: np.ndarray) -> None: ...
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
x: np.ndarray,
|
||||||
|
k: int,
|
||||||
|
*,
|
||||||
|
params: Any = ...,
|
||||||
|
D: np.ndarray | None = ...,
|
||||||
|
I: np.ndarray | None = ...,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]: ...
|
||||||
|
def remove_ids(self, x: np.ndarray) -> int: ...
|
||||||
|
@overload
|
||||||
|
def reconstruct(self, key: int) -> np.ndarray: ...
|
||||||
|
@overload
|
||||||
|
def reconstruct(self, key: int, x: np.ndarray) -> None: ...
|
||||||
|
def reconstruct(
|
||||||
|
self, key: int, x: np.ndarray | None = ...
|
||||||
|
) -> np.ndarray | None: ...
|
||||||
|
@overload
|
||||||
|
def reconstruct_n(self, n0: int, ni: int) -> np.ndarray: ...
|
||||||
|
@overload
|
||||||
|
def reconstruct_n(self, n0: int, ni: int, x: np.ndarray) -> None: ...
|
||||||
|
def reconstruct_n(
|
||||||
|
self, n0: int = ..., ni: int = ..., x: np.ndarray | None = ...
|
||||||
|
) -> np.ndarray | None: ...
|
||||||
|
def range_search(
|
||||||
|
self, x: np.ndarray, thresh: float, *, params: Any = ...
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ...
|
||||||
|
def add_sa_codes(self, codes: np.ndarray, ids: np.ndarray | None = ...) -> None: ...
|
||||||
|
def sa_encode(self, x: np.ndarray) -> np.ndarray: ...
|
||||||
|
def sa_decode(self, codes: np.ndarray) -> np.ndarray: ...
|
||||||
|
|
||||||
|
class IndexFlatL2(Index):
|
||||||
|
def __init__(self, d: int) -> None: ...
|
||||||
|
|
||||||
|
class IndexIDMap(Index):
|
||||||
|
index: Index
|
||||||
|
|
||||||
|
def __init__(self, index: Index) -> None: ...
|
||||||
|
|
||||||
|
def read_index(path: str) -> Index: ...
|
||||||
|
def write_index(index: Index, path: str | None = ...) -> None: ...
|
||||||
|
def normalize_L2(x: np.ndarray) -> None: ...
|
||||||
|
|
||||||
|
# Additional concrete-ish classes exposed by some faiss builds (SWIG helpers
|
||||||
|
# expose `downcast_*` helpers to convert generic objects to these concrete
|
||||||
|
# types). We keep these minimal — only the names are important for typing.
|
||||||
|
class IndexBinary(Index):
|
||||||
|
def __init__(self, d: int) -> None: ...
|
||||||
|
|
||||||
|
class InvertedLists:
|
||||||
|
def __len__(self) -> int: ...
|
||||||
|
|
||||||
|
class AdditiveQuantizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Quantizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class VectorTransform:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# SWIG-provided downcast helpers (present in some faiss Python builds).
|
||||||
|
def downcast_IndexBinary(obj: Any) -> IndexBinary: ...
|
||||||
|
def downcast_InvertedLists(obj: Any) -> InvertedLists: ...
|
||||||
|
def downcast_AdditiveQuantizer(obj: Any) -> AdditiveQuantizer: ...
|
||||||
|
def downcast_Quantizer(obj: Any) -> Quantizer: ...
|
||||||
|
def downcast_VectorTransform(obj: Any) -> VectorTransform: ...
|
||||||
|
def downcast_index(obj: Any) -> Index: ...
|
||||||
|
|
||||||
|
# version exposed by runtime
|
||||||
|
__version__: str
|
||||||
Reference in New Issue
Block a user