From 7cbaed8c6cb95e0aefc59c689f194a498c0b1459 Mon Sep 17 00:00:00 2001 From: sheffey <57262511+SheffeyG@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:08:22 +0800 Subject: [PATCH 1/2] fix: add status checking for embedding model providers --- astrbot/core/provider/manager.py | 4 +- astrbot/core/star/context.py | 6 +- astrbot/dashboard/routes/config.py | 88 ++++++++++++++++++------------ 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 05747c3ff..df21e6a12 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -7,7 +7,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.db import BaseDatabase from .entities import ProviderType -from .provider import Personality, Provider, STTProvider, TTSProvider +from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider from .register import llm_tools, provider_cls_map @@ -93,7 +93,7 @@ class ProviderManager: """加载的 Speech To Text Provider 的实例""" self.tts_provider_insts: List[TTSProvider] = [] """加载的 Text To Speech Provider 的实例""" - self.embedding_provider_insts: List[Provider] = [] + self.embedding_provider_insts: List[EmbeddingProvider] = [] """加载的 Embedding Provider 的实例""" self.inst_map: dict[str, Provider] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index dda664b1b..0b14525d3 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -2,7 +2,7 @@ from asyncio import Queue from typing import List, Union from astrbot.core import sp -from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider +from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig @@ -141,6 +141,10 @@ class Context: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts + def get_all_embedding_providers(self) -> List[EmbeddingProvider]: + """获取所有用于 Embedding 任务的 Provider。""" + return self.provider_manager.embedding_provider_insts + def get_using_provider(self, umo: str = None) -> Provider: """ 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 1dbe4de4a..dde4caa7b 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -187,6 +187,7 @@ class ConfigRoute(Route): """辅助函数:测试单个 provider 的可用性""" meta = provider.meta() provider_name = provider.provider_config.get("id", "Unknown Provider") + provider_type = provider.provider_config.get("provider_type", "Unknown Type") logger.debug(f"Got provider meta: {meta}") if not provider_name and meta: provider_name = meta.id @@ -197,6 +198,7 @@ class ConfigRoute(Route): "model": getattr(meta, "model", "Unknown Model"), "type": getattr(meta, "type", "Unknown Type"), "name": provider_name, + "provider_type": provider_type, "status": "unavailable", # 默认为不可用 "error": None, } @@ -204,40 +206,60 @@ class ConfigRoute(Route): f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})" ) try: - logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") - response = await asyncio.wait_for( - provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 - ) - logger.debug(f"Received response from {status_info['name']}: {response}") - # 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用 - if response is not None: - status_info["status"] = "available" - response_text_snippet = "" - if hasattr(response, "completion_text") and response.completion_text: - response_text_snippet = ( - response.completion_text[:70] + "..." - if len(response.completion_text) > 70 - else response.completion_text - ) - elif hasattr(response, "result_chain") and response.result_chain: - try: - response_text_snippet = ( - response.result_chain.get_plain_text()[:70] + "..." - if len(response.result_chain.get_plain_text()) > 70 - else response.result_chain.get_plain_text() - ) - except Exception as _: - pass - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" + if status_info["provider_type"] == "chat_completion": + logger.debug(f"Sending 'Ping' to provider: {status_info['name']}") + response = await asyncio.wait_for( + provider.text_chat(prompt="REPLY `PONG` ONLY"), timeout=45.0 ) + logger.debug(f"Received response from {status_info['name']}: {response}") + # 只要 text_chat 调用成功返回一个 LLMResponse 对象 (即 response 不为 None),就认为可用 + if response is not None: + status_info["status"] = "available" + response_text_snippet = "" + if hasattr(response, "completion_text") and response.completion_text: + response_text_snippet = ( + response.completion_text[:70] + "..." + if len(response.completion_text) > 70 + else response.completion_text + ) + elif hasattr(response, "result_chain") and response.result_chain: + try: + response_text_snippet = ( + response.result_chain.get_plain_text()[:70] + "..." + if len(response.result_chain.get_plain_text()) > 70 + else response.result_chain.get_plain_text() + ) + except Exception as _: + pass + logger.info( + f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{response_text_snippet}'" + ) + else: + # 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None + status_info["error"] = ( + "Test call returned None, but expected an LLMResponse object." + ) + logger.warning( + f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None." + ) + elif status_info["provider_type"] == "embedding": + logger.debug(f"Sending 'astrbot' to embedding provider: {status_info['name']}") + response = await asyncio.wait_for( + provider.get_embedding("astrbot"), timeout=45.0 + ) + logger.debug(f"Received response from {status_info['name']}: {response}") + # 若返回向量则认为该嵌入模型可用 + if response and isinstance(response, list) and all(isinstance(x, float) for x in response): + status_info["status"] = "available" + logger.info( + f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{str(response)[:10]}...'" + ) else: - # 这个分支理论上不应该被走到,除非 text_chat 实现可能返回 None status_info["error"] = ( - "Test call returned None, but expected an LLMResponse object." + f"Status checking for provider type '{status_info['type']}' not implemented." ) logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) test call returned None." + f"Provider {status_info['name']}'s status checking not implemented yet" ) except asyncio.TimeoutError: @@ -263,7 +285,7 @@ class ConfigRoute(Route): # 记录更详细的traceback信息,但只在是严重错误时 if status_code == 500: log_fn(traceback.format_exc()) - return Response().error(message, status_code=status_code).__dict__ + return Response().error(message).__dict__ async def check_one_provider_status(self): """API: check a single LLM Provider's status by id""" @@ -274,11 +296,9 @@ class ConfigRoute(Route): logger.info(f"API call: /config/provider/check_one id={provider_id}") try: all_providers = self.core_lifecycle.star_context.get_all_providers() + all_providers += self.core_lifecycle.star_context.get_all_embedding_providers() # replace manual loop with next(filter(...)) - target = next( - (p for p in all_providers if p.provider_config.get("id") == provider_id), - None - ) + target = next(filter(lambda p: p.provider_config.get("id") == provider_id, all_providers), None) if not target: return self._error_response(f"Provider with id '{provider_id}' not found", 404, logger.warning) From 4d214bb5c1b81c441df4819b582d82cf9bcde4c6 Mon Sep 17 00:00:00 2001 From: sheffey <57262511+SheffeyG@users.noreply.github.com> Date: Fri, 11 Jul 2025 18:10:46 +0800 Subject: [PATCH 2/2] check general numbers type instead --- astrbot/dashboard/routes/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index dde4caa7b..90c61ca68 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,3 +1,4 @@ +import numbers import typing import traceback from .route import Route, Response, RouteContext @@ -249,7 +250,7 @@ class ConfigRoute(Route): ) logger.debug(f"Received response from {status_info['name']}: {response}") # 若返回向量则认为该嵌入模型可用 - if response and isinstance(response, list) and all(isinstance(x, float) for x in response): + if response and isinstance(response, typing.Iterable) and all(isinstance(x, numbers.Number) for x in response): status_info["status"] = "available" logger.info( f"Provider {status_info['name']} (ID: {status_info['id']}) is available. Response snippet: '{str(response)[:10]}...'"