diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index a38643b34..76db898aa 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,7 +1,6 @@ from asyncio import Queue from typing import List, Union -from astrbot.core import sp from astrbot.core.provider.provider import ( Provider, TTSProvider, @@ -38,19 +37,6 @@ class Context: 暴露给插件的接口上下文。 """ - _event_queue: Queue = None - """事件队列。消息平台通过事件队列传递消息事件。""" - - _config: AstrBotConfig = None - """AstrBot 配置信息""" - - _db: BaseDatabase = None - """AstrBot 数据库""" - - provider_manager: ProviderManager = None - - platform_manager: PlatformManager = None - registered_web_apis: list = [] # back compatibility @@ -62,16 +48,19 @@ class Context: event_queue: Queue, config: AstrBotConfig, db: BaseDatabase, - provider_manager: ProviderManager = None, - platform_manager: PlatformManager = None, - conversation_manager: ConversationManager = None, - message_history_manager: PlatformMessageHistoryManager = None, - persona_manager: PersonaManager = None, - astrbot_config_mgr: AstrBotConfigManager = None, + provider_manager: ProviderManager, + platform_manager: PlatformManager, + conversation_manager: ConversationManager, + message_history_manager: PlatformMessageHistoryManager, + persona_manager: PersonaManager, + astrbot_config_mgr: AstrBotConfigManager, ): self._event_queue = event_queue + """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config + """AstrBot 默认配置""" self._db = db + """AstrBot 数据库""" self.provider_manager = provider_manager self.platform_manager = platform_manager self.conversation_manager = conversation_manager @@ -79,7 +68,7 @@ class Context: self.persona_manager = persona_manager self.astrbot_config_mgr = astrbot_config_mgr - def get_registered_star(self, star_name: str) -> StarMetadata: + def get_registered_star(self, star_name: str) -> StarMetadata | None: """根据插件名获取插件的 Metadata""" for star in star_registry: if star.name == star_name: @@ -114,7 +103,7 @@ class Context: """ self.provider_manager.provider_insts.append(provider) - def get_provider_by_id(self, provider_id: str) -> Provider: + def get_provider_by_id(self, provider_id: str) -> Provider | None: """通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.inst_map.get(provider_id) @@ -134,7 +123,7 @@ class Context: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts - def get_using_provider(self, umo: str = None) -> Provider: + def get_using_provider(self, umo: str | None = None) -> Provider | None: """ 获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 @@ -146,7 +135,7 @@ class Context: umo=umo, ) - def get_using_tts_provider(self, umo: str = None) -> TTSProvider: + def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider: """ 获取当前使用的用于 TTS 任务的 Provider。 @@ -158,7 +147,7 @@ class Context: umo=umo, ) - def get_using_stt_provider(self, umo: str = None) -> STTProvider: + def get_using_stt_provider(self, umo: str | None = None) -> STTProvider: """ 获取当前使用的用于 STT 任务的 Provider。 @@ -170,7 +159,7 @@ class Context: umo=umo, ) - def get_config(self, umo: str = None) -> AstrBotConfig: + def get_config(self, umo: str | None = None) -> AstrBotConfig: """获取 AstrBot 的配置。""" if not umo: # using default config