Files
AstrBot/astrbot/core/provider/provider.py
T
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00

343 lines
12 KiB
Python

import abc
import asyncio
import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMeta,
RerankResult,
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
"""Set the current model name"""
self.model_name = model_name
def get_model(self) -> str:
"""Get the current model name"""
return self.model_name
def meta(self) -> ProviderMeta:
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
if not meta_data:
raise ValueError(f"Provider type {provider_type_name} not registered")
meta = ProviderMeta(
id=self.provider_config.get("id", "default"),
model=self.get_model(),
type=provider_type_name,
provider_type=meta_data.provider_type,
)
return meta
async def test(self):
"""test the provider is a
raises:
Exception: if the provider is not available
"""
...
class Provider(AbstractProvider):
"""Chat Provider"""
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError
def get_keys(self) -> list[str]:
"""获得提供商 Key"""
keys = self.provider_config.get("key", [""])
return keys or [""]
@abc.abstractmethod
def set_key(self, key: str):
raise NotImplementedError
@abc.abstractmethod
async def get_models(self) -> list[str]:
"""获得支持的模型列表"""
raise NotImplementedError
@abc.abstractmethod
async def text_chat(
self,
prompt: str | None = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list[Message] | list[dict] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
...
async def text_chat_stream(
self,
prompt: str | None = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list[Message] | list[dict] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
def _ensure_message_to_dicts(
self,
messages: list[dict] | list[Message] | None,
) -> list[dict]:
"""Convert a list of Message objects to a list of dictionaries."""
if not messages:
return []
dicts: list[dict] = []
for message in messages:
if isinstance(message, Message):
dicts.append(message.model_dump())
else:
dicts.append(message)
return dicts
async def test(self, timeout: float = 45.0):
await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_text(self, audio_url: str) -> str:
"""获取音频的文本"""
raise NotImplementedError
async def test(self):
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
await self.get_text(sample_audio_path)
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError
async def test(self):
await self.get_audio("hi")
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的向量"""
...
@abc.abstractmethod
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的向量"""
...
@abc.abstractmethod
def get_dim(self) -> int:
"""获取向量的维度"""
...
async def test(self):
await self.get_embedding("astrbot")
async def get_embeddings_batch(
self,
texts: list[str],
batch_size: int = 16,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[list[float]]:
"""批量获取文本的向量,分批处理以节省内存
Args:
texts: 文本列表
batch_size: 每批处理的文本数量
tasks_limit: 并发任务数量限制
max_retries: 失败时的最大重试次数
progress_callback: 进度回调函数,接收参数 (current, total)
Returns:
向量列表
"""
semaphore = asyncio.Semaphore(tasks_limit)
all_embeddings: list[list[float]] = []
failed_batches: list[tuple[int, list[str]]] = []
completed_count = 0
total_count = len(texts)
async def process_batch(batch_idx: int, batch_texts: list[str]):
nonlocal completed_count
async with semaphore:
for attempt in range(max_retries):
try:
batch_embeddings = await self.get_embeddings(batch_texts)
all_embeddings.extend(batch_embeddings)
completed_count += len(batch_texts)
if progress_callback:
await progress_callback(completed_count, total_count)
return
except Exception as e:
if attempt == max_retries - 1:
# 最后一次重试失败,记录失败的批次
failed_batches.append((batch_idx, batch_texts))
raise Exception(
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
)
# 等待一段时间后重试,使用指数退避
await asyncio.sleep(2**attempt)
tasks = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_idx = i // batch_size
tasks.append(process_batch(batch_idx, batch_texts))
# 收集所有任务的结果,包括失败的任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 检查是否有失败的任务
errors = [r for r in results if isinstance(r, Exception)]
if errors:
error_msg = (
f"{len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
)
raise Exception(error_msg)
return all_embeddings
class RerankProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...
async def test(self):
result = await self.rerank("Apple", documents=["apple", "banana"])
if not result:
raise Exception("Rerank provider test failed, no results returned")