diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 1a3d2d7e1..4579879a3 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -321,7 +321,7 @@ CONFIG_METADATA_2 = { "type": "list", "config_template": { "openai": { - "id": "default", + "id": "openai", "type": "openai_chat_completion", "enable": True, "key": [], @@ -330,6 +330,17 @@ CONFIG_METADATA_2 = { "model": "gpt-4o-mini", }, }, + "azure_openai": { + "id": "azure", + "type": "openai_chat_completion", + "enable": True, + "api_version": "2024-05-01-preview", + "key": [], + "api_base": "", + "model_config": { + "model": "gpt-4o-mini", + }, + }, "ollama": { "id": "ollama_default", "type": "openai_chat_completion", diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5f25abeb7..ef0b74a32 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -2,7 +2,7 @@ import base64 import json import os -from openai import AsyncOpenAI, NOT_GIVEN +from openai import AsyncOpenAI, AsyncAzureOpenAI, NOT_GIVEN from openai.types.chat.chat_completion import ChatCompletion from openai._exceptions import NotFoundError from astrbot.core.utils.io import download_image_by_url @@ -29,12 +29,24 @@ class ProviderOpenAIOfficial(Provider): self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - - self.client = AsyncOpenAI( - api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), - timeout=provider_config.get("timeout", NOT_GIVEN), - ) + + # 适配 azure openai #332 + if "api_version" in provider_config: + # 使用 azure api + self.client = AsyncAzureOpenAI( + api_key=self.chosen_api_key, + api_version=provider_config.get("api_version", None), + base_url=provider_config.get("api_base", None), + timeout=provider_config.get("timeout", NOT_GIVEN), + ) + else: + # 使用 openai api + self.client = AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=provider_config.get("api_base", None), + timeout=provider_config.get("timeout", NOT_GIVEN), + ) + self.set_model(provider_config['model_config']['model']) async def get_human_readable_context(self, session_id, page, page_size):