diff --git a/model/command/openai_official.py b/model/command/openai_official.py index bd7b9cc24..6fa3fb4ad 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -4,6 +4,7 @@ from util.personality import personalities from cores.astrbot.types import GlobalObject from SparkleLogging.utils.core import LogManager from logging import Logger +from openai._exceptions import NotFoundError, RateLimitError, APIError logger: Logger = LogManager.GetLogger(log_name='astrbot-core') @@ -60,8 +61,27 @@ class CommandOpenAIOfficial(Command): return True, self.key(message) elif self.command_start_with(message, "switch"): return True, await self.switch(message) - + elif self.command_start_with(message, "models"): + return True, await self.get_models() return False, None + + async def get_models(self): + ret = "OpenAI GPT 类可用模型" + try: + models = await self.provider.client.models.list() + except NotFoundError as e: + bu = str(self.provider.client.base_url) + self.provider.client.base_url = bu + "/v1" + models = await self.provider.client.models.list() + finally: + print(models.data) + i = 1 + for model in models.data: + if str(model.id).startswith("gpt"): + ret += f"\n{i}. {model.id}" + i += 1 + logger.debug(ret) + return True, ret, "models" async def help(self): commands = super().general_commands() diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index e83d37e64..72f1c2f89 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -78,6 +78,12 @@ class ProviderOpenAIOfficial(Provider): # 人格 self.curr_personality = {} + + def make_tmp_client(self, api_key: str, base_url: str): + return AsyncOpenAI( + api_key=api_key, + base_url=base_url + ) # 转储历史记录 def dump_history(self):