diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 3ae5f2bd9..d8c2b1400 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,5 +1,6 @@ import abc import asyncio +import os from collections.abc import AsyncGenerator from astrbot.core.agent.message import Message @@ -11,6 +12,7 @@ from astrbot.core.provider.entities import ( ToolCallsResult, ) from astrbot.core.provider.register import provider_cls_map +from astrbot.core.utils.astrbot_path import get_astrbot_path class AbstractProvider(abc.ABC): @@ -43,6 +45,14 @@ class AbstractProvider(abc.ABC): ) return meta + async def test(self) -> bool: + """test the provider is a + + Returns: + bool: the provider is available + """ + return True + class Provider(AbstractProvider): """Chat Provider""" @@ -165,6 +175,16 @@ class Provider(AbstractProvider): return dicts + async def test(self, timeout: float = 45.0) -> bool: + try: + response = await asyncio.wait_for( + self.text_chat(prompt="REPLY `PONG` ONLY"), + timeout=timeout, + ) + return response is not None + except Exception: + return False + class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: @@ -177,6 +197,20 @@ class STTProvider(AbstractProvider): """获取音频的文本""" raise NotImplementedError + async def test(self) -> bool: + try: + sample_audio_path = os.path.join( + get_astrbot_path(), + "samples", + "stt_health_check.wav", + ) + if not os.path.exists(sample_audio_path): + return False + text_result = await self.get_text(sample_audio_path) + return isinstance(text_result, str) and bool(text_result) + except Exception: + return False + class TTSProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: @@ -189,6 +223,13 @@ class TTSProvider(AbstractProvider): """获取文本的音频,返回音频文件路径""" raise NotImplementedError + async def test(self) -> bool: + try: + audio_result = await self.get_audio("hi") + return isinstance(audio_result, str) and bool(audio_result) + except Exception: + return False + class EmbeddingProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: @@ -211,6 +252,15 @@ class EmbeddingProvider(AbstractProvider): """获取向量的维度""" ... + async def test(self) -> bool: + try: + embedding_result = await self.get_embedding("health_check") + return isinstance(embedding_result, list) and ( + not embedding_result or isinstance(embedding_result[0], float) + ) + except Exception: + return False + async def get_embeddings_batch( self, texts: list[str], @@ -294,3 +344,10 @@ class RerankProvider(AbstractProvider): ) -> list[RerankResult]: """获取查询和文档的重排序分数""" ... + + async def test(self) -> bool: + try: + await self.rerank("Apple", documents=["apple", "banana"]) + return True + except Exception: + return False diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 6d68bf4c3..6514032d2 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -18,11 +18,8 @@ from astrbot.core.config.i18n_utils import ConfigMetadataI18n from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.provider import Provider -from astrbot.core.provider.entities import ProviderType -from astrbot.core.provider.provider import RerankProvider from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry -from astrbot.core.utils.astrbot_path import get_astrbot_path from .route import Response, Route, RouteContext @@ -356,169 +353,26 @@ class ConfigRoute(Route): f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", ) - if provider_capability_type == ProviderType.CHAT_COMPLETION: - try: - logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") - response = await asyncio.wait_for( - provider.text_chat(prompt="REPLY `PONG` ONLY"), - timeout=45.0, - ) - logger.debug( - f"Received response from {status_info['name']}: {response}", - ) - if response is not None: - status_info["status"] = "available" - response_text_snippet = "" - if ( - hasattr(response, "completion_text") - and response.completion_text - ): - response_text_snippet = ( - response.completion_text[:70] + "..." - if len(response.completion_text) > 70 - else response.completion_text - ) - elif hasattr(response, "result_chain") and response.result_chain: - try: - response_text_snippet = ( - response.result_chain.get_plain_text()[:70] + "..." - if len(response.result_chain.get_plain_text()) > 70 - else response.result_chain.get_plain_text() - ) - except Exception as _: - pass - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'", - ) - else: - status_info["error"] = ( - "Test call returned None, but expected an LLMResponse object." - ) - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None.", - ) - - except asyncio.TimeoutError: - status_info["error"] = ( - "Connection timed out after 45 seconds during test call." - ) - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) timed out.", - ) - except Exception as e: - error_message = str(e) - status_info["error"] = error_message - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", - ) - logger.debug( - f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", - ) - - elif provider_capability_type == ProviderType.EMBEDDING: - try: - # For embedding, we can call the get_embedding method with a short prompt. - embedding_result = await provider.get_embedding("health_check") - if isinstance(embedding_result, list) and ( - not embedding_result or isinstance(embedding_result[0], float) - ): - status_info["status"] = "available" - else: - status_info["status"] = "unavailable" - status_info["error"] = ( - f"Embedding test failed: unexpected result type {type(embedding_result)}" - ) - except Exception as e: - logger.error( - f"Error testing embedding provider {provider_name}: {e}", - exc_info=True, - ) - status_info["status"] = "unavailable" - status_info["error"] = f"Embedding test failed: {e!s}" - - elif provider_capability_type == ProviderType.TEXT_TO_SPEECH: - try: - # For TTS, we can call the get_audio method with a short prompt. - audio_result = await provider.get_audio("你好") - if isinstance(audio_result, str) and audio_result: - status_info["status"] = "available" - else: - status_info["status"] = "unavailable" - status_info["error"] = ( - f"TTS test failed: unexpected result type {type(audio_result)}" - ) - except Exception as e: - logger.error( - f"Error testing TTS provider {provider_name}: {e}", - exc_info=True, - ) - status_info["status"] = "unavailable" - status_info["error"] = f"TTS test failed: {e!s}" - elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: - try: - logger.debug( - f"Sending health check audio to provider: {status_info['name']}", - ) - sample_audio_path = os.path.join( - get_astrbot_path(), - "samples", - "stt_health_check.wav", - ) - if not os.path.exists(sample_audio_path): - status_info["status"] = "unavailable" - status_info["error"] = ( - "STT test failed: sample audio file not found." - ) - logger.warning( - f"STT test for {status_info['name']} failed: sample audio file not found at {sample_audio_path}", - ) - else: - text_result = await provider.get_text(sample_audio_path) - if isinstance(text_result, str) and text_result: - status_info["status"] = "available" - snippet = ( - text_result[:70] + "..." - if len(text_result) > 70 - else text_result - ) - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{snippet}'", - ) - else: - status_info["status"] = "unavailable" - status_info["error"] = ( - f"STT test failed: unexpected result type {type(text_result)}" - ) - logger.warning( - f"STT test for {status_info['name']} failed: unexpected result type {type(text_result)}", - ) - except Exception as e: - logger.error( - f"Error testing STT provider {provider_name}: {e}", - exc_info=True, - ) - status_info["status"] = "unavailable" - status_info["error"] = f"STT test failed: {e!s}" - elif provider_capability_type == ProviderType.RERANK: - try: - assert isinstance(provider, RerankProvider) - await provider.rerank("Apple", documents=["apple", "banana"]) + try: + result = await provider.test() + if result: status_info["status"] = "available" - except Exception as e: - logger.error( - f"Error testing rerank provider {provider_name}: {e}", - exc_info=True, + logger.info( + f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", ) - status_info["status"] = "unavailable" - status_info["error"] = f"Rerank test failed: {e!s}" - - else: - logger.debug( - f"Provider {provider_name} is not a Chat Completion or Embedding provider. Marking as available without test. Meta: {meta}", + else: + status_info["error"] = "Provider test returned False." + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) test returned False.", + ) + except Exception as e: + error_message = str(e) + status_info["error"] = error_message + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", ) - status_info["status"] = "available" - status_info["error"] = ( - "This provider type is not tested and is assumed to be available." + logger.debug( + f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", ) return status_info diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py index d306c41ae..d75049c0a 100644 --- a/packages/astrbot/commands/provider.py +++ b/packages/astrbot/commands/provider.py @@ -1,15 +1,10 @@ import asyncio -import os import re from astrbot import logger from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType -from astrbot.core.provider.provider import RerankProvider -from astrbot.core.utils.astrbot_path import get_astrbot_path - -REACHABILITY_CHECK_TIMEOUT = 30.0 class ProviderCommands: @@ -34,121 +29,20 @@ class ProviderCommands: ) async def _test_provider_capability(self, provider): - """测试单个 provider 的可用性 (复用 Dashboard 的检测逻辑)""" + """测试单个 provider 的可用性""" meta = provider.meta() provider_capability_type = meta.provider_type try: - if provider_capability_type == ProviderType.CHAT_COMPLETION: - # 发送 "Ping" 测试对话 - response = await asyncio.wait_for( - provider.text_chat(prompt="REPLY `PONG` ONLY"), - timeout=REACHABILITY_CHECK_TIMEOUT, - ) - if response is not None: - return True, None, None - err_code = "EMPTY_RESPONSE" - err_reason = "Provider returned empty response" - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason - - elif provider_capability_type == ProviderType.EMBEDDING: - # 测试 Embedding - embedding_result = await asyncio.wait_for( - provider.get_embedding("health_check"), - timeout=REACHABILITY_CHECK_TIMEOUT, - ) - if ( - isinstance(embedding_result, list) - and embedding_result - and all(isinstance(x, (int, float)) for x in embedding_result) - ): - return True, None, None - err_code = "INVALID_EMBEDDING" - err_reason = "Provider returned invalid embedding" - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason - - elif provider_capability_type == ProviderType.TEXT_TO_SPEECH: - # 测试 TTS - audio_result = await asyncio.wait_for( - provider.get_audio("你好"), - timeout=REACHABILITY_CHECK_TIMEOUT, - ) - if isinstance(audio_result, str) and audio_result: - # 清理检测生成的临时音频文件,避免频繁检测时堆积 - if os.path.isfile(audio_result): - try: - os.remove(audio_result) - except OSError as e: - logger.debug( - "Failed to cleanup TTS health check file %s: %s", - audio_result, - e, - ) - return True, None, None - err_code = "INVALID_AUDIO" - err_reason = "Provider returned invalid audio" - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason - - elif provider_capability_type == ProviderType.SPEECH_TO_TEXT: - # 测试 STT - sample_audio_path = os.path.join( - get_astrbot_path(), - "samples", - "stt_health_check.wav", - ) - if not os.path.exists(sample_audio_path): - # 如果样本文件不存在,降级为检查是否实现了方法 - return hasattr(provider, "get_text"), None, None - - text_result = await asyncio.wait_for( - provider.get_text(sample_audio_path), - timeout=REACHABILITY_CHECK_TIMEOUT, - ) - if isinstance(text_result, str) and text_result: - return True, None, None - err_code = "INVALID_TEXT" - err_reason = "Provider returned invalid text" - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason - - elif provider_capability_type == ProviderType.RERANK: - # 测试 Rerank - if isinstance(provider, RerankProvider): - await asyncio.wait_for( - provider.rerank("Apple", documents=["apple", "banana"]), - timeout=REACHABILITY_CHECK_TIMEOUT, - ) - return True, None, None - err_code = "NOT_RERANK_PROVIDER" - err_reason = "Provider is not RerankProvider" - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason - - else: - # 其他类型暂时视为通过,或者回退到 get_models - if hasattr(provider, "get_models"): - await asyncio.wait_for( - provider.get_models(), timeout=REACHABILITY_CHECK_TIMEOUT - ) - return True, None, None - return True, None, None # 未知类型默认通过 - - except asyncio.TimeoutError: - err_code = "TIMEOUT" - err_reason = "Reachability check timed out" + result = await provider.test() + if result: + return True, None, None + err_code = "TEST_FAILED" + err_reason = "Provider test returned False" + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason except Exception as exc: err_code = ( getattr(exc, "status_code", None) @@ -159,10 +53,10 @@ class ProviderCommands: if not err_code: err_code = exc.__class__.__name__ - self._log_reachability_failure( - provider, provider_capability_type, err_code, err_reason - ) - return False, err_code, err_reason + self._log_reachability_failure( + provider, provider_capability_type, err_code, err_reason + ) + return False, err_code, err_reason async def provider( self,