feat: 适配 Azure OpenAI #332
This commit is contained in:
@@ -321,7 +321,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"openai": {
|
||||
"id": "default",
|
||||
"id": "openai",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
@@ -330,6 +330,17 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"azure_openai": {
|
||||
"id": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"model_config": {
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
},
|
||||
"ollama": {
|
||||
"id": "ollama_default",
|
||||
"type": "openai_chat_completion",
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI, NOT_GIVEN
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import NotFoundError
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
@@ -29,12 +29,24 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
# 适配 azure openai #332
|
||||
if "api_version" in provider_config:
|
||||
# 使用 azure api
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
else:
|
||||
# 使用 openai api
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
|
||||
Reference in New Issue
Block a user