Files
AstrBot/astrbot/core/provider/provider.py
T
Soulter 5a11d8f0ee refactor: LLM response handling with reasoning content (#3632)
* refactor: LLM response handling with reasoning content

- Added a `show_reasoning` parameter to `run_agent` to control the display of reasoning content.
- Updated `LLMResponse` to include a `reasoning_content` field for storing reasoning text.
- Modified `WebChatMessageEvent` to handle and send reasoning content in streaming responses.
- Implemented reasoning extraction in various provider sources (e.g., OpenAI, Gemini).
- Updated the chat interface to display reasoning content in a collapsible format.
- Removed the deprecated `thinking_filter` package and its associated logic.
- Updated localization files to include new reasoning-related strings.

* feat: add Groq chat completion provider and associated configurations

* Update astrbot/core/provider/sources/gemini_source.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2025-11-15 21:31:03 +08:00

297 lines
10 KiB
Python

import abc
import asyncio
from collections.abc import AsyncGenerator
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
ProviderMeta,
RerankResult,
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
"""Set the current model name"""
self.model_name = model_name
def get_model(self) -> str:
"""Get the current model name"""
return self.model_name
def meta(self) -> ProviderMeta:
"""Get the provider metadata"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
if not meta_data:
raise ValueError(f"Provider type {provider_type_name} not registered")
meta = ProviderMeta(
id=self.provider_config.get("id", "default"),
model=self.get_model(),
type=provider_type_name,
provider_type=meta_data.provider_type,
)
return meta
class Provider(AbstractProvider):
"""Chat Provider"""
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError
def get_keys(self) -> list[str]:
"""获得提供商 Key"""
keys = self.provider_config.get("key", [""])
return keys or [""]
@abc.abstractmethod
def set_key(self, key: str):
raise NotImplementedError
@abc.abstractmethod
async def get_models(self) -> list[str]:
"""获得支持的模型列表"""
raise NotImplementedError
@abc.abstractmethod
async def text_chat(
self,
prompt: str | None = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list[Message] | list[dict] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
...
async def text_chat_stream(
self,
prompt: str | None = None,
session_id: str | None = None,
image_urls: list[str] | None = None,
func_tool: ToolSet | None = None,
contexts: list[Message] | list[dict] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: tool set
contexts: 上下文,和 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
"""
...
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
def _ensure_message_to_dicts(
self,
messages: list[dict] | list[Message] | None,
) -> list[dict]:
"""Convert a list of Message objects to a list of dictionaries."""
if not messages:
return []
dicts: list[dict] = []
for message in messages:
if isinstance(message, Message):
dicts.append(message.model_dump())
else:
dicts.append(message)
return dicts
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_text(self, audio_url: str) -> str:
"""获取音频的文本"""
raise NotImplementedError
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的向量"""
...
@abc.abstractmethod
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的向量"""
...
@abc.abstractmethod
def get_dim(self) -> int:
"""获取向量的维度"""
...
async def get_embeddings_batch(
self,
texts: list[str],
batch_size: int = 16,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[list[float]]:
"""批量获取文本的向量,分批处理以节省内存
Args:
texts: 文本列表
batch_size: 每批处理的文本数量
tasks_limit: 并发任务数量限制
max_retries: 失败时的最大重试次数
progress_callback: 进度回调函数,接收参数 (current, total)
Returns:
向量列表
"""
semaphore = asyncio.Semaphore(tasks_limit)
all_embeddings: list[list[float]] = []
failed_batches: list[tuple[int, list[str]]] = []
completed_count = 0
total_count = len(texts)
async def process_batch(batch_idx: int, batch_texts: list[str]):
nonlocal completed_count
async with semaphore:
for attempt in range(max_retries):
try:
batch_embeddings = await self.get_embeddings(batch_texts)
all_embeddings.extend(batch_embeddings)
completed_count += len(batch_texts)
if progress_callback:
await progress_callback(completed_count, total_count)
return
except Exception as e:
if attempt == max_retries - 1:
# 最后一次重试失败,记录失败的批次
failed_batches.append((batch_idx, batch_texts))
raise Exception(
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}",
)
# 等待一段时间后重试,使用指数退避
await asyncio.sleep(2**attempt)
tasks = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_idx = i // batch_size
tasks.append(process_batch(batch_idx, batch_texts))
# 收集所有任务的结果,包括失败的任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 检查是否有失败的任务
errors = [r for r in results if isinstance(r, Exception)]
if errors:
error_msg = (
f"{len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
)
raise Exception(error_msg)
return all_embeddings
class RerankProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...