From 707fcad8b4f3c84dee45ce672dcc75c0e98f6c82 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 17 May 2024 00:06:49 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20gpt=20=E6=A8=A1=E5=9E=8B=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E6=9F=A5=E7=9C=8B=E6=8C=87=E4=BB=A4=20models?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/command/openai_official.py | 22 +++++++++++++++++++++- model/provider/openai_official.py | 6 ++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/model/command/openai_official.py b/model/command/openai_official.py index bd7b9cc24..6fa3fb4ad 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -4,6 +4,7 @@ from util.personality import personalities from cores.astrbot.types import GlobalObject from SparkleLogging.utils.core import LogManager from logging import Logger +from openai._exceptions import NotFoundError, RateLimitError, APIError logger: Logger = LogManager.GetLogger(log_name='astrbot-core') @@ -60,8 +61,27 @@ class CommandOpenAIOfficial(Command): return True, self.key(message) elif self.command_start_with(message, "switch"): return True, await self.switch(message) - + elif self.command_start_with(message, "models"): + return True, await self.get_models() return False, None + + async def get_models(self): + ret = "OpenAI GPT 类可用模型" + try: + models = await self.provider.client.models.list() + except NotFoundError as e: + bu = str(self.provider.client.base_url) + self.provider.client.base_url = bu + "/v1" + models = await self.provider.client.models.list() + finally: + print(models.data) + i = 1 + for model in models.data: + if str(model.id).startswith("gpt"): + ret += f"\n{i}. {model.id}" + i += 1 + logger.debug(ret) + return True, ret, "models" async def help(self): commands = super().general_commands() diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index e83d37e64..72f1c2f89 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -78,6 +78,12 @@ class ProviderOpenAIOfficial(Provider): # 人格 self.curr_personality = {} + + def make_tmp_client(self, api_key: str, base_url: str): + return AsyncOpenAI( + api_key=api_key, + base_url=base_url + ) # 转储历史记录 def dump_history(self):