feat: 1. 适配新版openai sdk

2. 适配官方 function calling
This commit is contained in:
Soulter
2023-11-13 21:54:23 +08:00
parent 5b1aee1b4d
commit cb5975c102
11 changed files with 217 additions and 174 deletions
+3 -2
View File
@@ -237,7 +237,7 @@ def initBot(cfg, prov):
gu.log("加载Bing模型时发生错误, 请检查1. cookies文件是否正确放置 2. 是否设置了代理(梯子)。", gu.LEVEL_ERROR, max_len=60)
if OPENAI_OFFICIAL in prov:
gu.log("- OpenAI官方 -", gu.LEVEL_INFO)
if cfg['openai']['key'] is not None:
if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]:
from model.provider.provider_openai_official import ProviderOpenAIOfficial
from model.command.command_openai_official import CommandOpenAIOfficial
llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai'])
@@ -646,7 +646,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
chatgpt_res = gplugin.web_search(qq_msg, llm_instance[chosen_provider], session_id)
official_fc = chosen_provider == OPENAI_OFFICIAL
chatgpt_res = gplugin.web_search(qq_msg, llm_instance[chosen_provider], session_id, official_fc)
else:
chatgpt_res = str(llm_instance[chosen_provider].text_chat(qq_msg, session_id, image_url))
elif chosen_provider == REV_EDGEGPT:
+1 -1
View File
@@ -5,7 +5,7 @@ class Provider:
pass
@abc.abstractmethod
def text_chat(self, prompt, session_id):
def text_chat(self, prompt, session_id, image_url: None, function_call: None):
pass
@abc.abstractmethod
+106 -128
View File
@@ -1,4 +1,5 @@
import openai
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletion
import json
import time
import os
@@ -7,6 +8,7 @@ from cores.database.conn import dbConn
from model.provider.provider import Provider
import threading
from util import general_utils as gu
import traceback
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
key_record_path = abs_path + 'chatgpt_key_record'
@@ -14,9 +16,6 @@ key_record_path = abs_path + 'chatgpt_key_record'
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg):
self.key_list = []
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
openai.api_base = cfg['api_base']
print(f"设置 api_base 为: {openai.api_base}")
# 如果 cfg['key']中有长度为1的字符串,那么是格式错误,直接报错
for key in cfg['key']:
if len(key) == 1:
@@ -26,12 +25,25 @@ class ProviderOpenAIOfficial(Provider):
self.key_list = cfg['key']
else:
input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
# init key record
self.init_key_record()
self.key_stat = {}
for k in self.key_list:
self.key_stat[k] = {'exceed': False, 'used': 0}
self.chatGPT_configs = cfg['chatGPTConfigs']
gu.log(f'加载ChatGPTConfigs: {self.chatGPT_configs}')
self.api_base = None
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
self.api_base = cfg['api_base']
print(f"设置 api_base 为: {self.api_base}")
# openai client
self.client = OpenAI(
api_key=self.key_list[0],
base_url=self.api_base
)
self.openai_model_configs: dict = cfg['chatGPTConfigs']
gu.log(f'加载 OpenAI Chat Configs: {self.openai_model_configs}')
self.openai_configs = cfg
# 会话缓存
self.session_dict = {}
@@ -45,9 +57,9 @@ class ProviderOpenAIOfficial(Provider):
db1 = dbConn()
for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data']
gu.log("读取历史记录成功")
gu.log("读取历史记录成功")
except BaseException as e:
gu.log("读取历史记录失败,但不影响使用", level=gu.LEVEL_ERROR)
gu.log("读取历史记录失败,但不影响使用", level=gu.LEVEL_ERROR)
# 读取统计信息
@@ -72,7 +84,7 @@ class ProviderOpenAIOfficial(Provider):
self.now_personality = {}
# 转储历史记录的定时器~ Soulter
# 转储历史记录
def dump_history(self):
time.sleep(10)
db = dbConn()
@@ -95,7 +107,7 @@ class ProviderOpenAIOfficial(Provider):
# 每隔10分钟转储一次
time.sleep(10*self.history_dump_interval)
def text_chat(self, prompt, session_id = None, image_url = None):
def text_chat(self, prompt, session_id = None, image_url = None, function_call=None):
if session_id is None:
session_id = "unknown"
if "unknown" in self.session_dict:
@@ -128,25 +140,45 @@ class ProviderOpenAIOfficial(Provider):
# 截断倍率
truncate_rate = 0.75
while retry < 15:
use_gpt4v = False
for i in req:
if isinstance(i['content'], list):
use_gpt4v = True
break
if image_url is not None:
use_gpt4v = True
if use_gpt4v:
conf = self.openai_model_configs.copy()
conf['model'] = 'gpt-4-vision-preview'
else:
conf = self.openai_model_configs
print(req)
while retry < 10:
try:
response = openai.ChatCompletion.create(
messages=req,
**self.chatGPT_configs
)
if function_call is None:
response = self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = self.client.chat.completions.create(
messages=req,
tools = function_call,
**conf
)
break
except Exception as e:
print(traceback.format_exc())
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
gu.log("当前Key已超额或异常, 正在切换", level=gu.LEVEL_WARNING)
self.key_stat[openai.api_key]['exceed'] = True
self.save_key_record()
response, is_switched = self.handle_switch_key(req)
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
# 所有Key都超额或不正常
raise e
else:
break
retry -= 1
elif 'maximum context length' in str(e):
gu.log("token超限, 清空对应缓存,并进行消息截断")
self.session_dict[session_id] = []
@@ -159,20 +191,28 @@ class ProviderOpenAIOfficial(Provider):
continue
else:
gu.log(str(e), level=gu.LEVEL_ERROR)
time.sleep(3)
time.sleep(2)
err = str(e)
retry+=1
if retry >= 15:
retry += 1
if retry >= 10:
gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999)
raise BaseException("连接出错: "+str(err))
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
self.save_key_record()
# print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
chatgpt_res = str(response["choices"][0]["message"]["content"]).strip()
current_usage_tokens = response['usage']['total_tokens']
assert isinstance(response, ChatCompletion)
print(response)
gu.log(f"OPENAI RESPONSE: {response['usage']}", level=gu.LEVEL_DEBUG, max_len=9999)
# 结果分类
choice = response.choices[0]
if choice.message.content != None:
# 文本形式
chatgpt_res = str(choice.message.content).strip()
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
# tools call (function calling)
return choice.message.tool_calls[0].function
gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999)
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
if current_usage_tokens > self.max_tokens:
@@ -201,6 +241,7 @@ class ProviderOpenAIOfficial(Provider):
new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
else:
new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record)
self.session_dict[session_id] = cache_data_list
@@ -212,13 +253,11 @@ class ProviderOpenAIOfficial(Provider):
image_url = ''
while retry < 5:
try:
# print("test1")
response = openai.Image.create(
response = self.client.images.generate(
prompt=prompt,
n=img_num,
size=img_size
)
# print("test2")
image_url = []
for i in range(img_num):
image_url.append(response['data'][i]['url'])
@@ -227,17 +266,14 @@ class ProviderOpenAIOfficial(Provider):
gu.log(str(e), level=gu.LEVEL_ERROR)
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
gu.log("当前Key已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
self.key_stat[openai.api_key]['exceed'] = True
self.save_key_record()
response, is_switched = self.handle_switch_key(req)
gu.log("当前 Key 已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
# 所有Key都超额或不正常
raise e
else:
break
retry += 1
else:
retry += 1
if retry >= 5:
raise BaseException("连接超时")
@@ -340,111 +376,53 @@ class ProviderOpenAIOfficial(Provider):
req_list.append(new_record['user'])
return context, new_record, req_list
def handle_switch_key(self, req):
def handle_switch_key(self):
# messages = [{"role": "user", "content": prompt}]
while True:
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
continue
is_all_exceed = False
openai.api_key = key
gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO)
if len(req) == 0:
return None, False
try:
response = openai.ChatCompletion.create(
messages=req,
**self.chatGPT_configs
)
return response, True
except Exception as e:
if 'You exceeded' in str(e):
gu.log("当前Key已超额, 正在切换")
self.key_stat[openai.api_key]['exceed'] = True
self.save_key_record()
time.sleep(1)
continue
else:
gu.log(str(e), level=gu.LEVEL_ERROR)
else:
return True
if is_all_exceed:
gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
return None, False
else:
gu.log("在切换key时程序异常。", level=gu.LEVEL_ERROR)
return None, False
def getConfigs(self):
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
continue
is_all_exceed = False
self.client.api_key = key
gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO)
break
if is_all_exceed:
gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
return False
return True
def get_configs(self):
return self.openai_configs
def save_key_record(self):
with open(key_record_path, 'w', encoding='utf-8') as f:
json.dump(self.key_stat, f)
def get_key_stat(self):
return self.key_stat
def get_key_list(self):
return self.key_list
def get_curr_key(self):
return openai.api_key
return self.client.api_key
# 添加key
def append_key(self, key, sponsor):
self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
self.save_key_record()
self.init_key_record()
# 检查key是否可用
def check_key(self, key):
pre_key = openai.api_key
openai.api_key = key
messages = [{"role": "user", "content": "1"}]
client_ = OpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
try:
response = openai.ChatCompletion.create(
client_.chat.completions.create(
messages=messages,
**self.chatGPT_configs
**self.openai_model_configs
)
openai.api_key = pre_key
return True
except Exception as e:
pass
openai.api_key = pre_key
return False
#将key_list的key转储到key_record中,并记录相关数据
def init_key_record(self):
# 不存在,创建
if not os.path.exists(key_record_path):
with open(key_record_path, 'w', encoding='utf-8') as f:
json.dump({}, f)
# 打开 chatgpt_key_record
with open(key_record_path, 'r', encoding='utf-8') as keyfile:
try:
self.key_stat = json.load(keyfile)
except Exception as e:
gu.log(str(e), level=gu.LEVEL_ERROR)
self.key_stat = {}
finally:
for key in self.key_list:
if key not in self.key_stat:
self.key_stat[key] = {'exceed': False, 'used': 0}
# if openai.api_key is None:
# openai.api_key = key
else:
# if self.key_stat[key]['exceed']:
# print(f"Key: {key} 已超额")
# continue
# else:
# if openai.api_key is None:
# openai.api_key = key
# print(f"使用Key: {key}, 已使用token: {self.key_stat[key]['used']}")
pass
if openai.api_key == None:
self.handle_switch_key("")
self.save_key_record()
+1 -1
View File
@@ -101,7 +101,7 @@ class ProviderRevChatGPT(Provider):
# print("[RevChatGPT] "+str(resp))
return resp
def text_chat(self, prompt, session_id = None, image_url = None) -> str:
def text_chat(self, prompt, session_id = None, image_url = None, function_call=None) -> str:
# 选择一个人少的账号。
selected_revstat = None
+1 -1
View File
@@ -35,7 +35,7 @@ class ProviderRevEdgeGPT(Provider):
except BaseException:
return False
async def text_chat(self, prompt, platform = 'none'):
async def text_chat(self, prompt, platform = 'none', image_url=None, function_call=None):
while self.busy:
time.sleep(1)
self.busy = True
+2 -2
View File
@@ -1,7 +1,7 @@
pydantic~=1.10.4
requests~=2.28.1
openai~=0.27.4
qq-botpy~=1.1.2
openai~=1.2.3
qq-botpy
revChatGPT~=6.8.6
baidu-aip~=4.16.9
EdgeGPT~=0.1.22.1
-3
View File
@@ -1,3 +0,0 @@
class PromptExceededError(Exception):
pass
+26 -4
View File
@@ -1,6 +1,7 @@
import json
import util.general_utils as gu
import time
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
@@ -24,9 +25,18 @@ class FuncCall():
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None:
if name == None or func_args == None or desc == None or func_obj == None:
raise FuncCallJsonFormatError("name, func_args, desc must be provided.")
params = {
"type": "object", # hardcore here
"properties": {}
}
for param in func_args:
params['properties'][param['name']] = {
"type": param['type'],
"description": param['description']
}
self._func = {
"name": name,
"args": func_args,
"parameters": params,
"description": desc,
"func_obj": func_obj,
}
@@ -37,11 +47,23 @@ class FuncCall():
for f in self.func_list:
_l.append({
"name": f["name"],
"args": f["args"],
"parameters": f["parameters"],
"description": f["description"],
})
return json.dumps(_l, indent=intent, ensure_ascii=False)
return json.dumps(_l, indent=intent, ensur_ascii=False)
def get_func(self) -> list:
_l = []
for f in self.func_list:
_l.append({
"type": "function",
"function": {
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
}
})
return _l
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True, session_id = None):
+7 -3
View File
@@ -60,17 +60,21 @@ def log(
tag: str = "System",
fg: str = None,
bg: str = None,
max_len: int = 300):
max_len: int = 500,
err: Exception = None,):
"""
日志记录函数
日志打印函数
"""
_set_level_code = level_codes[LEVEL_INFO]
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] in level_codes:
_set_level_code = level_codes[os.environ['LOG_LEVEL']]
if level in level_codes and level_codes[level] < _set_level_code:
return
if err is not None:
msg += "\n异常原因: " + str(err)
level = LEVEL_ERROR
if len(msg) > max_len:
msg = msg[:max_len] + "..."
+67 -29
View File
@@ -7,13 +7,23 @@ from util.func_call import (
FuncCallJsonFormatError,
FuncNotFoundError
)
from openai.types.chat.chat_completion_message_tool_call import Function
import traceback
from googlesearch import search, SearchResult
from model.provider.provider import Provider
import json
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", "").replace(" ", "").replace("\r", "")
def special_fetch_zhihu(link: str) -> str:
'''
function-calling 函数, 用于获取知乎文章的内容
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
@@ -31,9 +41,10 @@ def special_fetch_zhihu(link: str) -> str:
raise Exception("zhihu none")
return tidy_text(r.text)
def google_web_search(keyword) -> str:
# 获取goole搜索结果,得到title、desc、link
'''
获取 google 搜索结果, 得到 title、desc、link
'''
ret = ""
index = 1
try:
@@ -53,6 +64,9 @@ def google_web_search(keyword) -> str:
return ret
def web_keyword_search_via_bing(keyword) -> str:
'''
获取bing搜索结果, 得到 title、desc、link
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
@@ -105,12 +119,11 @@ def web_keyword_search_via_bing(keyword) -> str:
ret = f"{str(res)}"
return str(ret)
except Exception as e:
print(traceback.format_exc())
print(f"bing fetch err: {str(e)}")
gu.log(f"bing fetch err: {str(e)}")
_cnt += 1
time.sleep(1)
print("fail to fetch bing info, using sougou.")
gu.log("fail to fetch bing info, using sougou.")
return google_web_search(keyword)
def web_keyword_search_via_sougou(keyword) -> str:
@@ -182,53 +195,78 @@ def fetch_website_content(url):
gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
return res
def web_search(question, provider, session_id):
def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("google_web_search", [{
"type": "string",
"name": "keyword",
"brief": "google search query (分词,尽量保留所有信息)"
"description": "google search query (分词,尽量保留所有信息)"
}],
"网页搜索。如果问题需要使用搜索(如天气、新闻或任何新的东西),则调用",
"通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数",
google_web_search
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"brief": "网址"
"description": "网址"
}],
"获取网址的内容",
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
func_definition1 = new_func_call.func_dump()
question1 = f"{question} \n(只能调用一个函数。)"
try:
res1, has_func = new_func_call.func_call(question1, func_definition1, is_task=False, is_summary=False)
except BaseException as e:
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
if i["name"] == func.name:
func_obj = i["func_obj"]
break
if not func_obj:
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
function_invoked_ret = func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
try:
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
except BaseException as e:
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
has_func = True
if has_func:
provider.forget(session_id)
question3 = f"""请你回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。```\n{res1}\n```\n"""
question3 = f"""请你用可爱的语气回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。在总结的末尾加上1-2个相关的emoji。```\n{function_invoked_ret}\n```\n"""
print(question3)
_c = 0
while _c < 5:
try:
print('text chat')
res3 = provider.text_chat(question3)
break
final_ret = provider.text_chat(question3)
return final_ret
except Exception as e:
print(e)
_c += 1
if _c == 5:
raise e
if _c == 5: raise e
if "The message you submitted was too long" in str(e):
res2 = res2[:int(len(res2) / 2)]
provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
time.sleep(3)
question3 = f"""请回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,再给参考链接, 参考链接首末有空格。```\n{res1}\n{res2}\n```\n"""
return res3
else:
return res1
question3 = f"""请回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,再给参考链接, 参考链接首末有空格。```\n{function_invoked_ret}\n```\n"""
return function_invoked_ret
+3
View File
@@ -1,3 +1,6 @@
'''
插件工具函数
'''
import os
import inspect