chore: clean code

This commit is contained in:
Soulter
2025-08-23 21:46:59 +08:00
parent e204b180a8
commit cfd05a8d17
+15 -26
View File
@@ -1,7 +1,6 @@
from asyncio import Queue
from typing import List, Union
from astrbot.core import sp
from astrbot.core.provider.provider import (
Provider,
TTSProvider,
@@ -38,19 +37,6 @@ class Context:
暴露给插件的接口上下文。
"""
_event_queue: Queue = None
"""事件队列。消息平台通过事件队列传递消息事件。"""
_config: AstrBotConfig = None
"""AstrBot 配置信息"""
_db: BaseDatabase = None
"""AstrBot 数据库"""
provider_manager: ProviderManager = None
platform_manager: PlatformManager = None
registered_web_apis: list = []
# back compatibility
@@ -62,16 +48,19 @@ class Context:
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
message_history_manager: PlatformMessageHistoryManager = None,
persona_manager: PersonaManager = None,
astrbot_config_mgr: AstrBotConfigManager = None,
provider_manager: ProviderManager,
platform_manager: PlatformManager,
conversation_manager: ConversationManager,
message_history_manager: PlatformMessageHistoryManager,
persona_manager: PersonaManager,
astrbot_config_mgr: AstrBotConfigManager,
):
self._event_queue = event_queue
"""事件队列。消息平台通过事件队列传递消息事件。"""
self._config = config
"""AstrBot 默认配置"""
self._db = db
"""AstrBot 数据库"""
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.conversation_manager = conversation_manager
@@ -79,7 +68,7 @@ class Context:
self.persona_manager = persona_manager
self.astrbot_config_mgr = astrbot_config_mgr
def get_registered_star(self, star_name: str) -> StarMetadata:
def get_registered_star(self, star_name: str) -> StarMetadata | None:
"""根据插件名获取插件的 Metadata"""
for star in star_registry:
if star.name == star_name:
@@ -114,7 +103,7 @@ class Context:
"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
def get_provider_by_id(self, provider_id: str) -> Provider | None:
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
return self.provider_manager.inst_map.get(provider_id)
@@ -134,7 +123,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str = None) -> Provider:
def get_using_provider(self, umo: str | None = None) -> Provider | None:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
@@ -146,7 +135,7 @@ class Context:
umo=umo,
)
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
"""
获取当前使用的用于 TTS 任务的 Provider。
@@ -158,7 +147,7 @@ class Context:
umo=umo,
)
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
"""
获取当前使用的用于 STT 任务的 Provider。
@@ -170,7 +159,7 @@ class Context:
umo=umo,
)
def get_config(self, umo: str = None) -> AstrBotConfig:
def get_config(self, umo: str | None = None) -> AstrBotConfig:
"""获取 AstrBot 的配置。"""
if not umo:
# using default config