feat: 1. 适配新版openai sdk
2. 适配官方 function calling
This commit is contained in:
+3
-2
@@ -237,7 +237,7 @@ def initBot(cfg, prov):
|
||||
gu.log("加载Bing模型时发生错误, 请检查1. cookies文件是否正确放置 2. 是否设置了代理(梯子)。", gu.LEVEL_ERROR, max_len=60)
|
||||
if OPENAI_OFFICIAL in prov:
|
||||
gu.log("- OpenAI官方 -", gu.LEVEL_INFO)
|
||||
if cfg['openai']['key'] is not None:
|
||||
if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]:
|
||||
from model.provider.provider_openai_official import ProviderOpenAIOfficial
|
||||
from model.command.command_openai_official import CommandOpenAIOfficial
|
||||
llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai'])
|
||||
@@ -646,7 +646,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
|
||||
|
||||
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
|
||||
if _global_object.web_search or web_sch_flag:
|
||||
chatgpt_res = gplugin.web_search(qq_msg, llm_instance[chosen_provider], session_id)
|
||||
official_fc = chosen_provider == OPENAI_OFFICIAL
|
||||
chatgpt_res = gplugin.web_search(qq_msg, llm_instance[chosen_provider], session_id, official_fc)
|
||||
else:
|
||||
chatgpt_res = str(llm_instance[chosen_provider].text_chat(qq_msg, session_id, image_url))
|
||||
elif chosen_provider == REV_EDGEGPT:
|
||||
|
||||
@@ -5,7 +5,7 @@ class Provider:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def text_chat(self, prompt, session_id):
|
||||
def text_chat(self, prompt, session_id, image_url: None, function_call: None):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
@@ -7,6 +8,7 @@ from cores.database.conn import dbConn
|
||||
from model.provider.provider import Provider
|
||||
import threading
|
||||
from util import general_utils as gu
|
||||
import traceback
|
||||
|
||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||
key_record_path = abs_path + 'chatgpt_key_record'
|
||||
@@ -14,9 +16,6 @@ key_record_path = abs_path + 'chatgpt_key_record'
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, cfg):
|
||||
self.key_list = []
|
||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||
openai.api_base = cfg['api_base']
|
||||
print(f"设置 api_base 为: {openai.api_base}")
|
||||
# 如果 cfg['key']中有长度为1的字符串,那么是格式错误,直接报错
|
||||
for key in cfg['key']:
|
||||
if len(key) == 1:
|
||||
@@ -26,12 +25,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.key_list = cfg['key']
|
||||
else:
|
||||
input("[System] 请先去完善ChatGPT的Key。详情请前往https://beta.openai.com/account/api-keys")
|
||||
if len(self.key_list) == 0:
|
||||
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
|
||||
|
||||
# init key record
|
||||
self.init_key_record()
|
||||
self.key_stat = {}
|
||||
for k in self.key_list:
|
||||
self.key_stat[k] = {'exceed': False, 'used': 0}
|
||||
|
||||
self.chatGPT_configs = cfg['chatGPTConfigs']
|
||||
gu.log(f'加载ChatGPTConfigs: {self.chatGPT_configs}')
|
||||
self.api_base = None
|
||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||
self.api_base = cfg['api_base']
|
||||
print(f"设置 api_base 为: {self.api_base}")
|
||||
# openai client
|
||||
self.client = OpenAI(
|
||||
api_key=self.key_list[0],
|
||||
base_url=self.api_base
|
||||
)
|
||||
|
||||
self.openai_model_configs: dict = cfg['chatGPTConfigs']
|
||||
gu.log(f'加载 OpenAI Chat Configs: {self.openai_model_configs}')
|
||||
self.openai_configs = cfg
|
||||
# 会话缓存
|
||||
self.session_dict = {}
|
||||
@@ -45,9 +57,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
self.session_dict[session[0]] = json.loads(session[1])['data']
|
||||
gu.log("读取历史记录成功")
|
||||
gu.log("读取历史记录成功。")
|
||||
except BaseException as e:
|
||||
gu.log("读取历史记录失败,但不影响使用", level=gu.LEVEL_ERROR)
|
||||
gu.log("读取历史记录失败,但不影响使用。", level=gu.LEVEL_ERROR)
|
||||
|
||||
|
||||
# 读取统计信息
|
||||
@@ -72,7 +84,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.now_personality = {}
|
||||
|
||||
|
||||
# 转储历史记录的定时器~ Soulter
|
||||
# 转储历史记录
|
||||
def dump_history(self):
|
||||
time.sleep(10)
|
||||
db = dbConn()
|
||||
@@ -95,7 +107,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 每隔10分钟转储一次
|
||||
time.sleep(10*self.history_dump_interval)
|
||||
|
||||
def text_chat(self, prompt, session_id = None, image_url = None):
|
||||
def text_chat(self, prompt, session_id = None, image_url = None, function_call=None):
|
||||
if session_id is None:
|
||||
session_id = "unknown"
|
||||
if "unknown" in self.session_dict:
|
||||
@@ -128,25 +140,45 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 截断倍率
|
||||
truncate_rate = 0.75
|
||||
|
||||
while retry < 15:
|
||||
use_gpt4v = False
|
||||
for i in req:
|
||||
if isinstance(i['content'], list):
|
||||
use_gpt4v = True
|
||||
break
|
||||
if image_url is not None:
|
||||
use_gpt4v = True
|
||||
if use_gpt4v:
|
||||
conf = self.openai_model_configs.copy()
|
||||
conf['model'] = 'gpt-4-vision-preview'
|
||||
else:
|
||||
conf = self.openai_model_configs
|
||||
print(req)
|
||||
while retry < 10:
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
if function_call is None:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=req,
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=req,
|
||||
tools = function_call,
|
||||
**conf
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
|
||||
raise e
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
gu.log("当前Key已超额或异常, 正在切换", level=gu.LEVEL_WARNING)
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
retry -= 1
|
||||
elif 'maximum context length' in str(e):
|
||||
gu.log("token超限, 清空对应缓存,并进行消息截断")
|
||||
self.session_dict[session_id] = []
|
||||
@@ -159,20 +191,28 @@ class ProviderOpenAIOfficial(Provider):
|
||||
continue
|
||||
else:
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
time.sleep(3)
|
||||
time.sleep(2)
|
||||
err = str(e)
|
||||
retry+=1
|
||||
if retry >= 15:
|
||||
retry += 1
|
||||
if retry >= 10:
|
||||
gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999)
|
||||
raise BaseException("连接出错: "+str(err))
|
||||
|
||||
self.key_stat[openai.api_key]['used'] += response['usage']['total_tokens']
|
||||
self.save_key_record()
|
||||
# print("[ChatGPT] "+str(response["choices"][0]["message"]["content"]))
|
||||
chatgpt_res = str(response["choices"][0]["message"]["content"]).strip()
|
||||
current_usage_tokens = response['usage']['total_tokens']
|
||||
assert isinstance(response, ChatCompletion)
|
||||
print(response)
|
||||
|
||||
gu.log(f"OPENAI RESPONSE: {response['usage']}", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
# 结果分类
|
||||
choice = response.choices[0]
|
||||
if choice.message.content != None:
|
||||
# 文本形式
|
||||
chatgpt_res = str(choice.message.content).strip()
|
||||
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
|
||||
# tools call (function calling)
|
||||
return choice.message.tool_calls[0].function
|
||||
|
||||
gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
|
||||
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
|
||||
current_usage_tokens = response.usage.total_tokens
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
if current_usage_tokens > self.max_tokens:
|
||||
@@ -201,6 +241,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
|
||||
else:
|
||||
new_record['single_tokens'] = current_usage_tokens
|
||||
|
||||
cache_data_list.append(new_record)
|
||||
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
@@ -212,13 +253,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
image_url = ''
|
||||
while retry < 5:
|
||||
try:
|
||||
# print("test1")
|
||||
response = openai.Image.create(
|
||||
response = self.client.images.generate(
|
||||
prompt=prompt,
|
||||
n=img_num,
|
||||
size=img_size
|
||||
)
|
||||
# print("test2")
|
||||
image_url = []
|
||||
for i in range(img_num):
|
||||
image_url.append(response['data'][i]['url'])
|
||||
@@ -227,17 +266,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
||||
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
gu.log("当前Key已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
|
||||
response, is_switched = self.handle_switch_key(req)
|
||||
gu.log("当前 Key 已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING)
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
# 所有Key都超额或不正常
|
||||
raise e
|
||||
else:
|
||||
break
|
||||
retry += 1
|
||||
else:
|
||||
retry += 1
|
||||
if retry >= 5:
|
||||
raise BaseException("连接超时")
|
||||
|
||||
@@ -340,111 +376,53 @@ class ProviderOpenAIOfficial(Provider):
|
||||
req_list.append(new_record['user'])
|
||||
return context, new_record, req_list
|
||||
|
||||
def handle_switch_key(self, req):
|
||||
def handle_switch_key(self):
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
while True:
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
if key == None or self.key_stat[key]['exceed']:
|
||||
continue
|
||||
is_all_exceed = False
|
||||
openai.api_key = key
|
||||
gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO)
|
||||
if len(req) == 0:
|
||||
return None, False
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=req,
|
||||
**self.chatGPT_configs
|
||||
)
|
||||
return response, True
|
||||
except Exception as e:
|
||||
if 'You exceeded' in str(e):
|
||||
gu.log("当前Key已超额, 正在切换")
|
||||
self.key_stat[openai.api_key]['exceed'] = True
|
||||
self.save_key_record()
|
||||
time.sleep(1)
|
||||
continue
|
||||
else:
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
else:
|
||||
return True
|
||||
if is_all_exceed:
|
||||
gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
|
||||
return None, False
|
||||
else:
|
||||
gu.log("在切换key时程序异常。", level=gu.LEVEL_ERROR)
|
||||
return None, False
|
||||
|
||||
def getConfigs(self):
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
if key == None or self.key_stat[key]['exceed']:
|
||||
continue
|
||||
is_all_exceed = False
|
||||
self.client.api_key = key
|
||||
gu.log(f"切换到Key: {key}, 已使用token: {self.key_stat[key]['used']}", level=gu.LEVEL_INFO)
|
||||
break
|
||||
if is_all_exceed:
|
||||
gu.log("所有Key已超额", level=gu.LEVEL_CRITICAL)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_configs(self):
|
||||
return self.openai_configs
|
||||
|
||||
def save_key_record(self):
|
||||
with open(key_record_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.key_stat, f)
|
||||
|
||||
def get_key_stat(self):
|
||||
return self.key_stat
|
||||
|
||||
def get_key_list(self):
|
||||
return self.key_list
|
||||
|
||||
def get_curr_key(self):
|
||||
return openai.api_key
|
||||
return self.client.api_key
|
||||
|
||||
# 添加key
|
||||
def append_key(self, key, sponsor):
|
||||
self.key_list.append(key)
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
||||
self.save_key_record()
|
||||
self.init_key_record()
|
||||
|
||||
# 检查key是否可用
|
||||
def check_key(self, key):
|
||||
pre_key = openai.api_key
|
||||
openai.api_key = key
|
||||
messages = [{"role": "user", "content": "1"}]
|
||||
client_ = OpenAI(
|
||||
api_key=key,
|
||||
base_url=self.api_base
|
||||
)
|
||||
messages = [{"role": "user", "content": "please just echo `test`"}]
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
client_.chat.completions.create(
|
||||
messages=messages,
|
||||
**self.chatGPT_configs
|
||||
**self.openai_model_configs
|
||||
)
|
||||
openai.api_key = pre_key
|
||||
return True
|
||||
except Exception as e:
|
||||
pass
|
||||
openai.api_key = pre_key
|
||||
return False
|
||||
|
||||
#将key_list的key转储到key_record中,并记录相关数据
|
||||
def init_key_record(self):
|
||||
|
||||
# 不存在,创建
|
||||
if not os.path.exists(key_record_path):
|
||||
with open(key_record_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({}, f)
|
||||
|
||||
# 打开 chatgpt_key_record
|
||||
with open(key_record_path, 'r', encoding='utf-8') as keyfile:
|
||||
try:
|
||||
self.key_stat = json.load(keyfile)
|
||||
except Exception as e:
|
||||
gu.log(str(e), level=gu.LEVEL_ERROR)
|
||||
self.key_stat = {}
|
||||
finally:
|
||||
for key in self.key_list:
|
||||
if key not in self.key_stat:
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0}
|
||||
# if openai.api_key is None:
|
||||
# openai.api_key = key
|
||||
else:
|
||||
# if self.key_stat[key]['exceed']:
|
||||
# print(f"Key: {key} 已超额")
|
||||
# continue
|
||||
# else:
|
||||
# if openai.api_key is None:
|
||||
# openai.api_key = key
|
||||
# print(f"使用Key: {key}, 已使用token: {self.key_stat[key]['used']}")
|
||||
pass
|
||||
if openai.api_key == None:
|
||||
self.handle_switch_key("")
|
||||
self.save_key_record()
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class ProviderRevChatGPT(Provider):
|
||||
# print("[RevChatGPT] "+str(resp))
|
||||
return resp
|
||||
|
||||
def text_chat(self, prompt, session_id = None, image_url = None) -> str:
|
||||
def text_chat(self, prompt, session_id = None, image_url = None, function_call=None) -> str:
|
||||
|
||||
# 选择一个人少的账号。
|
||||
selected_revstat = None
|
||||
|
||||
@@ -35,7 +35,7 @@ class ProviderRevEdgeGPT(Provider):
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
async def text_chat(self, prompt, platform = 'none'):
|
||||
async def text_chat(self, prompt, platform = 'none', image_url=None, function_call=None):
|
||||
while self.busy:
|
||||
time.sleep(1)
|
||||
self.busy = True
|
||||
|
||||
+2
-2
@@ -1,7 +1,7 @@
|
||||
pydantic~=1.10.4
|
||||
requests~=2.28.1
|
||||
openai~=0.27.4
|
||||
qq-botpy~=1.1.2
|
||||
openai~=1.2.3
|
||||
qq-botpy
|
||||
revChatGPT~=6.8.6
|
||||
baidu-aip~=4.16.9
|
||||
EdgeGPT~=0.1.22.1
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
class PromptExceededError(Exception):
|
||||
|
||||
pass
|
||||
+26
-4
@@ -1,6 +1,7 @@
|
||||
|
||||
import json
|
||||
import util.general_utils as gu
|
||||
|
||||
import time
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
def __init__(self, msg):
|
||||
@@ -24,9 +25,18 @@ class FuncCall():
|
||||
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None:
|
||||
if name == None or func_args == None or desc == None or func_obj == None:
|
||||
raise FuncCallJsonFormatError("name, func_args, desc must be provided.")
|
||||
params = {
|
||||
"type": "object", # hardcore here
|
||||
"properties": {}
|
||||
}
|
||||
for param in func_args:
|
||||
params['properties'][param['name']] = {
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
}
|
||||
self._func = {
|
||||
"name": name,
|
||||
"args": func_args,
|
||||
"parameters": params,
|
||||
"description": desc,
|
||||
"func_obj": func_obj,
|
||||
}
|
||||
@@ -37,11 +47,23 @@ class FuncCall():
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"name": f["name"],
|
||||
"args": f["args"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
})
|
||||
|
||||
return json.dumps(_l, indent=intent, ensure_ascii=False)
|
||||
return json.dumps(_l, indent=intent, ensur_ascii=False)
|
||||
|
||||
def get_func(self) -> list:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f["name"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
}
|
||||
})
|
||||
return _l
|
||||
|
||||
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True, session_id = None):
|
||||
|
||||
|
||||
@@ -60,17 +60,21 @@ def log(
|
||||
tag: str = "System",
|
||||
fg: str = None,
|
||||
bg: str = None,
|
||||
max_len: int = 300):
|
||||
max_len: int = 500,
|
||||
err: Exception = None,):
|
||||
"""
|
||||
日志记录函数
|
||||
日志打印函数
|
||||
"""
|
||||
_set_level_code = level_codes[LEVEL_INFO]
|
||||
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] in level_codes:
|
||||
_set_level_code = level_codes[os.environ['LOG_LEVEL']]
|
||||
|
||||
|
||||
if level in level_codes and level_codes[level] < _set_level_code:
|
||||
return
|
||||
|
||||
if err is not None:
|
||||
msg += "\n异常原因: " + str(err)
|
||||
level = LEVEL_ERROR
|
||||
|
||||
if len(msg) > max_len:
|
||||
msg = msg[:max_len] + "..."
|
||||
|
||||
+67
-29
@@ -7,13 +7,23 @@ from util.func_call import (
|
||||
FuncCallJsonFormatError,
|
||||
FuncNotFoundError
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
import traceback
|
||||
from googlesearch import search, SearchResult
|
||||
from model.provider.provider import Provider
|
||||
import json
|
||||
|
||||
|
||||
def tidy_text(text: str) -> str:
|
||||
'''
|
||||
清理文本,去除空格、换行符等
|
||||
'''
|
||||
return text.strip().replace("\n", "").replace(" ", "").replace("\r", "")
|
||||
|
||||
def special_fetch_zhihu(link: str) -> str:
|
||||
'''
|
||||
function-calling 函数, 用于获取知乎文章的内容
|
||||
'''
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
@@ -31,9 +41,10 @@ def special_fetch_zhihu(link: str) -> str:
|
||||
raise Exception("zhihu none")
|
||||
return tidy_text(r.text)
|
||||
|
||||
|
||||
def google_web_search(keyword) -> str:
|
||||
# 获取goole搜索结果,得到title、desc、link
|
||||
'''
|
||||
获取 google 搜索结果, 得到 title、desc、link
|
||||
'''
|
||||
ret = ""
|
||||
index = 1
|
||||
try:
|
||||
@@ -53,6 +64,9 @@ def google_web_search(keyword) -> str:
|
||||
return ret
|
||||
|
||||
def web_keyword_search_via_bing(keyword) -> str:
|
||||
'''
|
||||
获取bing搜索结果, 得到 title、desc、link
|
||||
'''
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
@@ -105,12 +119,11 @@ def web_keyword_search_via_bing(keyword) -> str:
|
||||
ret = f"{str(res)}"
|
||||
return str(ret)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"bing fetch err: {str(e)}")
|
||||
gu.log(f"bing fetch err: {str(e)}")
|
||||
_cnt += 1
|
||||
time.sleep(1)
|
||||
|
||||
print("fail to fetch bing info, using sougou.")
|
||||
gu.log("fail to fetch bing info, using sougou.")
|
||||
return google_web_search(keyword)
|
||||
|
||||
def web_keyword_search_via_sougou(keyword) -> str:
|
||||
@@ -182,53 +195,78 @@ def fetch_website_content(url):
|
||||
gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
|
||||
return res
|
||||
|
||||
def web_search(question, provider, session_id):
|
||||
|
||||
def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
'''
|
||||
official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
new_func_call.add_func("google_web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"brief": "google search query (分词,尽量保留所有信息)"
|
||||
"description": "google search query (分词,尽量保留所有信息)"
|
||||
}],
|
||||
"网页搜索。如果问题需要使用搜索(如天气、新闻或任何新的东西),则调用。",
|
||||
"通过搜索引擎搜索。如果问题需要在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
google_web_search
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"brief": "网址"
|
||||
"description": "网址"
|
||||
}],
|
||||
"获取网址的内容",
|
||||
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下https://github.com的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
func_definition1 = new_func_call.func_dump()
|
||||
question1 = f"{question} \n(只能调用一个函数。)"
|
||||
try:
|
||||
res1, has_func = new_func_call.func_call(question1, func_definition1, is_task=False, is_summary=False)
|
||||
except BaseException as e:
|
||||
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
func = provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
|
||||
if isinstance(func, Function):
|
||||
# arguments='{\n "keyword": "北京今天的天气"\n}', name='google_web_search'
|
||||
# 执行对应的结果:
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == func.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
|
||||
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(func.arguments)
|
||||
function_invoked_ret = func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
# now func is a string
|
||||
return func
|
||||
else:
|
||||
try:
|
||||
function_invoked_ret, has_func = new_func_call.func_call(question1, new_func_call.func_dump(), is_task=False, is_summary=False)
|
||||
except BaseException as e:
|
||||
res = provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
has_func = True
|
||||
|
||||
has_func = True
|
||||
if has_func:
|
||||
provider.forget(session_id)
|
||||
question3 = f"""请你回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。```\n{res1}\n```\n"""
|
||||
question3 = f"""请你用可爱的语气回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。在总结的末尾加上1-2个相关的emoji。```\n{function_invoked_ret}\n```\n"""
|
||||
print(question3)
|
||||
_c = 0
|
||||
while _c < 5:
|
||||
try:
|
||||
print('text chat')
|
||||
res3 = provider.text_chat(question3)
|
||||
break
|
||||
final_ret = provider.text_chat(question3)
|
||||
return final_ret
|
||||
except Exception as e:
|
||||
print(e)
|
||||
_c += 1
|
||||
if _c == 5:
|
||||
raise e
|
||||
if _c == 5: raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
res2 = res2[:int(len(res2) / 2)]
|
||||
provider.forget(session_id)
|
||||
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
|
||||
time.sleep(3)
|
||||
question3 = f"""请回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,再给参考链接, 参考链接首末有空格。```\n{res1}\n{res2}\n```\n"""
|
||||
return res3
|
||||
else:
|
||||
return res1
|
||||
question3 = f"""请回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行回答,再给参考链接, 参考链接首末有空格。```\n{function_invoked_ret}\n```\n"""
|
||||
return function_invoked_ret
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
'''
|
||||
插件工具函数
|
||||
'''
|
||||
import os
|
||||
import inspect
|
||||
|
||||
|
||||
Reference in New Issue
Block a user