✨ feat: openai_source 支持传入任何自定义参数以适配 Ollama 和 FastGPT 等
This commit is contained in:
@@ -27,7 +27,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
os.environ['https_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['http_proxy'] = self.astrbot_config['http_proxy']
|
||||
os.environ['no_proxy'] = 'localhost,127.0.0.1'
|
||||
os.environ['no_proxy'] = 'localhost'
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("AstrBot v"+ VERSION)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import inspect
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -49,6 +50,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
self.default_params = inspect.signature(self.client.chat.completions.create).parameters.keys()
|
||||
|
||||
model_config = provider_config.get("model_config", {})
|
||||
model = model_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
@@ -69,10 +72,21 @@ class ProviderOpenAIOfficial(Provider):
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
if tool_list:
|
||||
payloads['tools'] = tool_list
|
||||
|
||||
|
||||
# 不在默认参数中的参数放在 extra_body 中
|
||||
extra_body = {}
|
||||
to_del = []
|
||||
for key in payloads.keys():
|
||||
if key not in self.default_params:
|
||||
extra_body[key] = payloads[key]
|
||||
to_del.append(key)
|
||||
for key in to_del:
|
||||
del payloads[key]
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False
|
||||
stream=False,
|
||||
extra_body=extra_body
|
||||
)
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
|
||||
@@ -181,8 +181,8 @@ class ChatRoute(Route):
|
||||
self.db.update_conversation(username, cid, history=json.dumps(history))
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except BaseException as e:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接: {str(e)}。")
|
||||
except BaseException as _:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||
self.curr_chat_sse.pop(username)
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user