f624971613
* chore(core.utils): 🚨 修正错误Lint
* chore(core.provider): 🚨 修复基类错误Lint
* chore(core.utils): 补全session_get()的重载
* chore(core.provider): 🚨 修正实现错误Lint
* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint
* chore(core.platform): 修正错误实现Lint
* fix(core.provider): 修复循环调用和错误assert
* chore(core.platform): 修复部分实现Lint
* chore(core.provider): 补充Dify.text_chat_stream的参数类型
* chore(core.pipeline): 🚨 修复错误Lint
* fix(core.slack): 补充遗漏导入
* chore(core.utils): 修复错误的session_get声明
* chore(core.platform): 移除Lark adapter import中的wildcard
* chore(core.db): 修复声明和部分逻辑
* chore(core.db): 添加typings,使faiss参数能被正确识别。
* chore(core): 修复声明
* chore(core): 修改声明
* chore: 补充faiss声明
* chore(dashboard): 修改实现,减少报错
* chore(package): 修改部分声明与实现,减少报错
* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查
* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查
* chore(core.config): 添加类型标注,通过类型检查
* chore(core.message): 为File._download_file添加检查,通过类型检查
* fix: 将断言改为条件判断以实现优雅关闭的容错性
* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 增强 Lark 相关空值/异常检查并完善日志输出
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将断言替换为条件检查并加入日志与错误处理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* chore: 移除LLM生成的无用注释
* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断
* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志
* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 去除cast,直接使用字段与字典访问,修正端口解析
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: group_name_display 时若 group 对象为空则记录错误并返回
* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置
* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_
* refactor: 移除 typing.cast,简化内容安全检查调用
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast
* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值
* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全
* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错
* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"
This reverts commit 1cbfdf9d1b.
* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接
* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型
* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转
* chore: ruff format
---------
Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
343 lines
12 KiB
Python
343 lines
12 KiB
Python
import abc
|
|
import asyncio
|
|
import os
|
|
from collections.abc import AsyncGenerator
|
|
from typing import TypeAlias, Union
|
|
|
|
from astrbot.core.agent.message import Message
|
|
from astrbot.core.agent.tool import ToolSet
|
|
from astrbot.core.provider.entities import (
|
|
LLMResponse,
|
|
ProviderMeta,
|
|
RerankResult,
|
|
ToolCallsResult,
|
|
)
|
|
from astrbot.core.provider.register import provider_cls_map
|
|
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
|
|
|
Providers: TypeAlias = Union[
|
|
"Provider",
|
|
"STTProvider",
|
|
"TTSProvider",
|
|
"EmbeddingProvider",
|
|
"RerankProvider",
|
|
]
|
|
|
|
|
|
class AbstractProvider(abc.ABC):
|
|
"""Provider Abstract Class"""
|
|
|
|
def __init__(self, provider_config: dict) -> None:
|
|
super().__init__()
|
|
self.model_name = ""
|
|
self.provider_config = provider_config
|
|
|
|
def set_model(self, model_name: str):
|
|
"""Set the current model name"""
|
|
self.model_name = model_name
|
|
|
|
def get_model(self) -> str:
|
|
"""Get the current model name"""
|
|
return self.model_name
|
|
|
|
def meta(self) -> ProviderMeta:
|
|
"""Get the provider metadata"""
|
|
provider_type_name = self.provider_config["type"]
|
|
meta_data = provider_cls_map.get(provider_type_name)
|
|
if not meta_data:
|
|
raise ValueError(f"Provider type {provider_type_name} not registered")
|
|
meta = ProviderMeta(
|
|
id=self.provider_config.get("id", "default"),
|
|
model=self.get_model(),
|
|
type=provider_type_name,
|
|
provider_type=meta_data.provider_type,
|
|
)
|
|
return meta
|
|
|
|
async def test(self):
|
|
"""test the provider is a
|
|
|
|
raises:
|
|
Exception: if the provider is not available
|
|
"""
|
|
...
|
|
|
|
|
|
class Provider(AbstractProvider):
|
|
"""Chat Provider"""
|
|
|
|
def __init__(
|
|
self,
|
|
provider_config: dict,
|
|
provider_settings: dict,
|
|
) -> None:
|
|
super().__init__(provider_config)
|
|
self.provider_settings = provider_settings
|
|
|
|
@abc.abstractmethod
|
|
def get_current_key(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
def get_keys(self) -> list[str]:
|
|
"""获得提供商 Key"""
|
|
keys = self.provider_config.get("key", [""])
|
|
return keys or [""]
|
|
|
|
@abc.abstractmethod
|
|
def set_key(self, key: str):
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def get_models(self) -> list[str]:
|
|
"""获得支持的模型列表"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def text_chat(
|
|
self,
|
|
prompt: str | None = None,
|
|
session_id: str | None = None,
|
|
image_urls: list[str] | None = None,
|
|
func_tool: ToolSet | None = None,
|
|
contexts: list[Message] | list[dict] | None = None,
|
|
system_prompt: str | None = None,
|
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
model: str | None = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
|
|
|
Args:
|
|
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
|
|
session_id: 会话 ID(此属性已经被废弃)
|
|
image_urls: 图片 URL 列表
|
|
tools: tool set
|
|
contexts: 上下文,和 prompt 二选一使用
|
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
|
kwargs: 其他参数
|
|
|
|
Notes:
|
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
|
|
"""
|
|
...
|
|
|
|
async def text_chat_stream(
|
|
self,
|
|
prompt: str | None = None,
|
|
session_id: str | None = None,
|
|
image_urls: list[str] | None = None,
|
|
func_tool: ToolSet | None = None,
|
|
contexts: list[Message] | list[dict] | None = None,
|
|
system_prompt: str | None = None,
|
|
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
|
model: str | None = None,
|
|
**kwargs,
|
|
) -> AsyncGenerator[LLMResponse, None]:
|
|
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
|
|
|
Args:
|
|
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
|
|
session_id: 会话 ID(此属性已经被废弃)
|
|
image_urls: 图片 URL 列表
|
|
tools: tool set
|
|
contexts: 上下文,和 prompt 二选一使用
|
|
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
|
kwargs: 其他参数
|
|
|
|
Notes:
|
|
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
|
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
|
|
|
"""
|
|
if False: # pragma: no cover - make this an async generator for typing
|
|
yield None # type: ignore
|
|
raise NotImplementedError()
|
|
|
|
async def pop_record(self, context: list):
|
|
"""弹出 context 第一条非系统提示词对话记录"""
|
|
poped = 0
|
|
indexs_to_pop = []
|
|
for idx, record in enumerate(context):
|
|
if record["role"] == "system":
|
|
continue
|
|
indexs_to_pop.append(idx)
|
|
poped += 1
|
|
if poped == 2:
|
|
break
|
|
|
|
for idx in reversed(indexs_to_pop):
|
|
context.pop(idx)
|
|
|
|
def _ensure_message_to_dicts(
|
|
self,
|
|
messages: list[dict] | list[Message] | None,
|
|
) -> list[dict]:
|
|
"""Convert a list of Message objects to a list of dictionaries."""
|
|
if not messages:
|
|
return []
|
|
dicts: list[dict] = []
|
|
for message in messages:
|
|
if isinstance(message, Message):
|
|
dicts.append(message.model_dump())
|
|
else:
|
|
dicts.append(message)
|
|
|
|
return dicts
|
|
|
|
async def test(self, timeout: float = 45.0):
|
|
await asyncio.wait_for(
|
|
self.text_chat(prompt="REPLY `PONG` ONLY"),
|
|
timeout=timeout,
|
|
)
|
|
|
|
|
|
class STTProvider(AbstractProvider):
|
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
super().__init__(provider_config)
|
|
self.provider_config = provider_config
|
|
self.provider_settings = provider_settings
|
|
|
|
@abc.abstractmethod
|
|
async def get_text(self, audio_url: str) -> str:
|
|
"""获取音频的文本"""
|
|
raise NotImplementedError
|
|
|
|
async def test(self):
|
|
sample_audio_path = os.path.join(
|
|
get_astrbot_path(),
|
|
"samples",
|
|
"stt_health_check.wav",
|
|
)
|
|
await self.get_text(sample_audio_path)
|
|
|
|
|
|
class TTSProvider(AbstractProvider):
|
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
super().__init__(provider_config)
|
|
self.provider_config = provider_config
|
|
self.provider_settings = provider_settings
|
|
|
|
@abc.abstractmethod
|
|
async def get_audio(self, text: str) -> str:
|
|
"""获取文本的音频,返回音频文件路径"""
|
|
raise NotImplementedError
|
|
|
|
async def test(self):
|
|
await self.get_audio("hi")
|
|
|
|
|
|
class EmbeddingProvider(AbstractProvider):
|
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
super().__init__(provider_config)
|
|
self.provider_config = provider_config
|
|
self.provider_settings = provider_settings
|
|
|
|
@abc.abstractmethod
|
|
async def get_embedding(self, text: str) -> list[float]:
|
|
"""获取文本的向量"""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
|
|
"""批量获取文本的向量"""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def get_dim(self) -> int:
|
|
"""获取向量的维度"""
|
|
...
|
|
|
|
async def test(self):
|
|
await self.get_embedding("astrbot")
|
|
|
|
async def get_embeddings_batch(
|
|
self,
|
|
texts: list[str],
|
|
batch_size: int = 16,
|
|
tasks_limit: int = 3,
|
|
max_retries: int = 3,
|
|
progress_callback=None,
|
|
) -> list[list[float]]:
|
|
"""批量获取文本的向量,分批处理以节省内存
|
|
|
|
Args:
|
|
texts: 文本列表
|
|
batch_size: 每批处理的文本数量
|
|
tasks_limit: 并发任务数量限制
|
|
max_retries: 失败时的最大重试次数
|
|
progress_callback: 进度回调函数,接收参数 (current, total)
|
|
|
|
Returns:
|
|
向量列表
|
|
|
|
"""
|
|
semaphore = asyncio.Semaphore(tasks_limit)
|
|
all_embeddings: list[list[float]] = []
|
|
failed_batches: list[tuple[int, list[str]]] = []
|
|
completed_count = 0
|
|
total_count = len(texts)
|
|
|
|
async def process_batch(batch_idx: int, batch_texts: list[str]):
|
|
nonlocal completed_count
|
|
async with semaphore:
|
|
for attempt in range(max_retries):
|
|
try:
|
|
batch_embeddings = await self.get_embeddings(batch_texts)
|
|
all_embeddings.extend(batch_embeddings)
|
|
completed_count += len(batch_texts)
|
|
if progress_callback:
|
|
await progress_callback(completed_count, total_count)
|
|
return
|
|
except Exception as e:
|
|
if attempt == max_retries - 1:
|
|
# 最后一次重试失败,记录失败的批次
|
|
failed_batches.append((batch_idx, batch_texts))
|
|
raise Exception(
|
|
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
|
|
)
|
|
# 等待一段时间后重试,使用指数退避
|
|
await asyncio.sleep(2**attempt)
|
|
|
|
tasks = []
|
|
for i in range(0, len(texts), batch_size):
|
|
batch_texts = texts[i : i + batch_size]
|
|
batch_idx = i // batch_size
|
|
tasks.append(process_batch(batch_idx, batch_texts))
|
|
|
|
# 收集所有任务的结果,包括失败的任务
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# 检查是否有失败的任务
|
|
errors = [r for r in results if isinstance(r, Exception)]
|
|
if errors:
|
|
error_msg = (
|
|
f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
|
|
)
|
|
raise Exception(error_msg)
|
|
|
|
return all_embeddings
|
|
|
|
|
|
class RerankProvider(AbstractProvider):
|
|
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
|
super().__init__(provider_config)
|
|
self.provider_config = provider_config
|
|
self.provider_settings = provider_settings
|
|
|
|
@abc.abstractmethod
|
|
async def rerank(
|
|
self,
|
|
query: str,
|
|
documents: list[str],
|
|
top_n: int | None = None,
|
|
) -> list[RerankResult]:
|
|
"""获取查询和文档的重排序分数"""
|
|
...
|
|
|
|
async def test(self):
|
|
result = await self.rerank("Apple", documents=["apple", "banana"])
|
|
if not result:
|
|
raise Exception("Rerank provider test failed, no results returned")
|