Merge pull request #2097 from SheffeyG/fix-status-checking

Fix: add status checking for embedding model providers
This commit is contained in:
Soulter
2025-07-13 16:19:37 +08:00
committed by GitHub
3 changed files with 62 additions and 37 deletions
+2 -2
View File
@@ -7,7 +7,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Personality, Provider, STTProvider, TTSProvider
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
@@ -93,7 +93,7 @@ class ProviderManager:
"""加载的 Speech To Text Provider 的实例"""
self.tts_provider_insts: List[TTSProvider] = []
"""加载的 Text To Speech Provider 的实例"""
self.embedding_provider_insts: List[Provider] = []
self.embedding_provider_insts: List[EmbeddingProvider] = []
"""加载的 Embedding Provider 的实例"""
self.inst_map: dict[str, Provider] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
+5 -1
View File
@@ -2,7 +2,7 @@ from asyncio import Queue
from typing import List, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -141,6 +141,10 @@ class Context:
"""获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str = None) -> Provider:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
+55 -34
View File
@@ -1,3 +1,4 @@
import numbers
import typing
import traceback
from .route import Route, Response, RouteContext
@@ -187,6 +188,7 @@ class ConfigRoute(Route):
"""辅助函数:测试单个 provider 的可用性"""
meta = provider.meta()
provider_name = provider.provider_config.get("id", "Unknown Provider")
provider_type = provider.provider_config.get("provider_type", "Unknown Type")
logger.debug(f"Got provider meta: {meta}")
if not provider_name and meta:
provider_name = meta.id
@@ -197,6 +199,7 @@ class ConfigRoute(Route):
"model": getattr(meta, "model", "Unknown Model"),
"type": getattr(meta, "type", "Unknown Type"),
"name": provider_name,
"provider_type": provider_type,
"status": "unavailable", # 默认为不可用
"error": None,
}
@@ -204,40 +207,60 @@ class ConfigRoute(Route):
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
)
try:
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
)
logger.debug(f"Received response from {status_info['name']}: {response}")
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if hasattr(response, "completion_text") and response.completion_text:
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
if status_info["provider_type"] == "chat_completion":
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
response = await asyncio.wait_for(
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
)
logger.debug(f"Received response from {status_info['name']}: {response}")
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
if response is not None:
status_info["status"] = "available"
response_text_snippet = ""
if hasattr(response, "completion_text") and response.completion_text:
response_text_snippet = (
response.completion_text[:70] + "..."
if len(response.completion_text) > 70
else response.completion_text
)
elif hasattr(response, "result_chain") and response.result_chain:
try:
response_text_snippet = (
response.result_chain.get_plain_text()[:70] + "..."
if len(response.result_chain.get_plain_text()) > 70
else response.result_chain.get_plain_text()
)
except Exception as _:
pass
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
)
else:
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
)
elif status_info["provider_type"] == "embedding":
logger.debug(f"Sending 'astrbot' to embedding provider: {status_info['name']}")
response = await asyncio.wait_for(
provider.get_embedding("astrbot"), timeout=45.0
)
logger.debug(f"Received response from {status_info['name']}: {response}")
# 若返回向量则认为该嵌入模型可用
if response and isinstance(response, typing.Iterable) and all(isinstance(x, numbers.Number) for x in response):
status_info["status"] = "available"
logger.info(
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{str(response)[:10]}...'"
)
else:
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
status_info["error"] = (
"Test call returned None, but expected an LLMResponse object."
f"Status checking for provider type '{status_info['type']}' not implemented."
)
logger.warning(
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
f"Provider {status_info['name']}'s status checking not implemented yet"
)
except asyncio.TimeoutError:
@@ -263,7 +286,7 @@ class ConfigRoute(Route):
# 记录更详细的traceback信息,但只在是严重错误时
if status_code == 500:
log_fn(traceback.format_exc())
return Response().error(message, status_code=status_code).__dict__
return Response().error(message).__dict__
async def check_one_provider_status(self):
"""API: check a single LLM Provider's status by id"""
@@ -274,11 +297,9 @@ class ConfigRoute(Route):
logger.info(f"API call: /config/provider/check_one id={provider_id}")
try:
all_providers = self.core_lifecycle.star_context.get_all_providers()
all_providers += self.core_lifecycle.star_context.get_all_embedding_providers()
# replace manual loop with next(filter(...))
target = next(
(p for p in all_providers if p.provider_config.get("id") == provider_id),
None
)
target = next(filter(lambda p: p.provider_config.get("id") == provider_id, all_providers), None)
if not target:
return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning)