feat: 大幅优化网页搜索的信息提取准确性

perf: 使用 tictoken 预先计算 token
This commit is contained in:
Soulter
2023-11-14 09:33:18 +08:00
parent d5e5b06e86
commit deebf61b5f
2 changed files with 43 additions and 30 deletions
+9 -3
View File
@@ -9,6 +9,7 @@ from model.provider.provider import Provider
import threading
from util import general_utils as gu
import traceback
import tiktoken
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
key_record_path = abs_path + 'chatgpt_key_record'
@@ -52,6 +53,8 @@ class ProviderOpenAIOfficial(Provider):
# 历史记录持久化间隔时间
self.history_dump_interval = 20
self.enc = tiktoken.get_encoding("cl100k_base")
# 读取历史记录
try:
db1 = dbConn()
@@ -129,6 +132,11 @@ class ProviderOpenAIOfficial(Provider):
f.write(json.dumps(fjson))
f.flush()
f.close()
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
prompt = self.enc.decode(_encoded_prompt[:self.openai_model_configs['max_tokens'] - 100])
gu.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, max_len=300)
cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
gu.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, max_len=99999)
@@ -198,7 +206,7 @@ class ProviderOpenAIOfficial(Provider):
gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999)
raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
print(response)
gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999)
# 结果分类
choice = response.choices[0]
@@ -209,8 +217,6 @@ class ProviderOpenAIOfficial(Provider):
# tools call (function calling)
return choice.message.tool_calls[0].function
gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999)
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
+34 -27
View File
@@ -12,13 +12,14 @@ import traceback
from googlesearch import search, SearchResult
from model.provider.provider import Provider
import json
from readability import Document
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", "").replace(" ", "").replace("\r", "")
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def special_fetch_zhihu(link: str) -> str:
'''
@@ -100,12 +101,12 @@ def web_keyword_search_via_bing(keyword) -> str:
# 爬取前两条的网页内容
if "zhihu.com" in link:
try:
_detail_store.append(special_fetch_zhihu(link)[100:800])
_detail_store.append(special_fetch_zhihu(link))
except BaseException as e:
print(f"zhihu parse err: {str(e)}")
else:
try:
_detail_store.append(fetch_website_content(link)[100:1000])
_detail_store.append(fetch_website_content(link))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
@@ -158,7 +159,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
if _detail_store >= 3:
break
try:
_detail_store.append(fetch_website_content(i["link"])[100:1000])
_detail_store.append(fetch_website_content(i["link"]))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
ret = f"{str(res)}"
@@ -174,26 +175,32 @@ def fetch_website_content(url):
}
response = requests.get(url, headers=headers, timeout=3)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
# 如果有container / content / main等的话,就只取这些部分
has = False
beleive_ls = ["container", "content", "main"]
res = ""
for cls in beleive_ls:
for i in soup.find_all(class_=cls):
has = True
res += i.text
if not has:
res = soup.text
res = res.replace("\n", "").replace(" ", " ").replace("\r", "").replace("\t", "")
if not has:
res = res[300:1100]
else:
res = res[100:800]
# with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f:
# f.write(res)
gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
return res
# soup = BeautifulSoup(response.text, "html.parser")
# # 如果有container / content / main等的话,就只取这些部分
# has = False
# beleive_ls = ["container", "content", "main"]
# res = ""
# for cls in beleive_ls:
# for i in soup.find_all(class_=cls):
# has = True
# res += i.text
# if not has:
# res = soup.text
# res = res.replace("\n", "").replace(" ", " ").replace("\r", "").replace("\t", "")
# if not has:
# res = res[300:1100]
# else:
# res = res[100:800]
# # with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f:
# # f.write(res)
# gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
# return res
doc = Document(response.content)
# print('title:', doc.title())
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
def web_search(question, provider: Provider, session_id, official_fc=False):
'''
@@ -253,9 +260,9 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
if has_func:
provider.forget(session_id)
question3 = f"""请你用可爱的语气回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。在总结的末尾加上1-2个相关的emoji。```\n{function_invoked_ret}\n```\n"""
print(question3)
gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999)
_c = 0
while _c < 5:
while _c < 3:
try:
print('text chat')
final_ret = provider.text_chat(question3)
@@ -263,7 +270,7 @@ def web_search(question, provider: Provider, session_id, official_fc=False):
except Exception as e:
print(e)
_c += 1
if _c == 5: raise e
if _c == 3: raise e
if "The message you submitted was too long" in str(e):
provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]