Merge pull request #2097 from SheffeyG/fix-status-checking
Fix: add status checking for embedding model providers
This commit is contained in:
@@ -7,7 +7,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class ProviderManager:
|
||||
"""加载的 Speech To Text Provider 的实例"""
|
||||
self.tts_provider_insts: List[TTSProvider] = []
|
||||
"""加载的 Text To Speech Provider 的实例"""
|
||||
self.embedding_provider_insts: List[Provider] = []
|
||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||
"""加载的 Embedding Provider 的实例"""
|
||||
self.inst_map: dict[str, Provider] = {}
|
||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from asyncio import Queue
|
||||
from typing import List, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider
|
||||
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -141,6 +141,10 @@ class Context:
|
||||
"""获取所有用于 STT 任务的 Provider。"""
|
||||
return self.provider_manager.stt_provider_insts
|
||||
|
||||
def get_all_embedding_providers(self) -> List[EmbeddingProvider]:
|
||||
"""获取所有用于 Embedding 任务的 Provider。"""
|
||||
return self.provider_manager.embedding_provider_insts
|
||||
|
||||
def get_using_provider(self, umo: str = None) -> Provider:
|
||||
"""
|
||||
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import numbers
|
||||
import typing
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
@@ -187,6 +188,7 @@ class ConfigRoute(Route):
|
||||
"""辅助函数:测试单个 provider 的可用性"""
|
||||
meta = provider.meta()
|
||||
provider_name = provider.provider_config.get("id", "Unknown Provider")
|
||||
provider_type = provider.provider_config.get("provider_type", "Unknown Type")
|
||||
logger.debug(f"Got provider meta: {meta}")
|
||||
if not provider_name and meta:
|
||||
provider_name = meta.id
|
||||
@@ -197,6 +199,7 @@ class ConfigRoute(Route):
|
||||
"model": getattr(meta, "model", "Unknown Model"),
|
||||
"type": getattr(meta, "type", "Unknown Type"),
|
||||
"name": provider_name,
|
||||
"provider_type": provider_type,
|
||||
"status": "unavailable", # 默认为不可用
|
||||
"error": None,
|
||||
}
|
||||
@@ -204,40 +207,60 @@ class ConfigRoute(Route):
|
||||
f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})"
|
||||
)
|
||||
try:
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if hasattr(response, "completion_text") and response.completion_text:
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
if status_info["provider_type"] == "chat_completion":
|
||||
logger.debug(f"Sending 'Ping' to provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用
|
||||
if response is not None:
|
||||
status_info["status"] = "available"
|
||||
response_text_snippet = ""
|
||||
if hasattr(response, "completion_text") and response.completion_text:
|
||||
response_text_snippet = (
|
||||
response.completion_text[:70] + "..."
|
||||
if len(response.completion_text) > 70
|
||||
else response.completion_text
|
||||
)
|
||||
elif hasattr(response, "result_chain") and response.result_chain:
|
||||
try:
|
||||
response_text_snippet = (
|
||||
response.result_chain.get_plain_text()[:70] + "..."
|
||||
if len(response.result_chain.get_plain_text()) > 70
|
||||
else response.result_chain.get_plain_text()
|
||||
)
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'"
|
||||
)
|
||||
else:
|
||||
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
)
|
||||
elif status_info["provider_type"] == "embedding":
|
||||
logger.debug(f"Sending 'astrbot' to embedding provider: {status_info['name']}")
|
||||
response = await asyncio.wait_for(
|
||||
provider.get_embedding("astrbot"), timeout=45.0
|
||||
)
|
||||
logger.debug(f"Received response from {status_info['name']}: {response}")
|
||||
# 若返回向量则认为该嵌入模型可用
|
||||
if response and isinstance(response, typing.Iterable) and all(isinstance(x, numbers.Number) for x in response):
|
||||
status_info["status"] = "available"
|
||||
logger.info(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{str(response)[:10]}...'"
|
||||
)
|
||||
else:
|
||||
# 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None
|
||||
status_info["error"] = (
|
||||
"Test call returned None, but expected an LLMResponse object."
|
||||
f"Status checking for provider type '{status_info['type']}' not implemented."
|
||||
)
|
||||
logger.warning(
|
||||
f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None."
|
||||
f"Provider {status_info['name']}'s status checking not implemented yet"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@@ -263,7 +286,7 @@ class ConfigRoute(Route):
|
||||
# 记录更详细的traceback信息,但只在是严重错误时
|
||||
if status_code == 500:
|
||||
log_fn(traceback.format_exc())
|
||||
return Response().error(message, status_code=status_code).__dict__
|
||||
return Response().error(message).__dict__
|
||||
|
||||
async def check_one_provider_status(self):
|
||||
"""API: check a single LLM Provider's status by id"""
|
||||
@@ -274,11 +297,9 @@ class ConfigRoute(Route):
|
||||
logger.info(f"API call: /config/provider/check_one id={provider_id}")
|
||||
try:
|
||||
all_providers = self.core_lifecycle.star_context.get_all_providers()
|
||||
all_providers += self.core_lifecycle.star_context.get_all_embedding_providers()
|
||||
# replace manual loop with next(filter(...))
|
||||
target = next(
|
||||
(p for p in all_providers if p.provider_config.get("id") == provider_id),
|
||||
None
|
||||
)
|
||||
target = next(filter(lambda p: p.provider_config.get("id") == provider_id, all_providers), None)
|
||||
if not target:
|
||||
return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user