From deebf61b5f33b33ee5a63759b727a0ff572a4fc3 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 14 Nov 2023 09:33:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A4=A7=E5=B9=85=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=BD=91=E9=A1=B5=E6=90=9C=E7=B4=A2=E7=9A=84=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E6=8F=90=E5=8F=96=E5=87=86=E7=A1=AE=E6=80=A7=20perf:=20?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20tictoken=20=E9=A2=84=E5=85=88=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=20token?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/provider/provider_openai_official.py | 12 +++-- util/gplugin.py | 61 ++++++++++++---------- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/model/provider/provider_openai_official.py b/model/provider/provider_openai_official.py index 579edc547..ba32a7ed5 100644 --- a/model/provider/provider_openai_official.py +++ b/model/provider/provider_openai_official.py @@ -9,6 +9,7 @@ from model.provider.provider import Provider import threading from util import general_utils as gu import traceback +import tiktoken abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' key_record_path = abs_path + 'chatgpt_key_record' @@ -52,6 +53,8 @@ class ProviderOpenAIOfficial(Provider): # 历史记录持久化间隔时间 self.history_dump_interval = 20 + self.enc = tiktoken.get_encoding("cl100k_base") + # 读取历史记录 try: db1 = dbConn() @@ -129,6 +132,11 @@ class ProviderOpenAIOfficial(Provider): f.write(json.dumps(fjson)) f.flush() f.close() + + # 使用 tictoken 截断消息 + _encoded_prompt = self.enc.encode(prompt) + prompt = self.enc.decode(_encoded_prompt[:self.openai_model_configs['max_tokens'] - 100]) + gu.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, max_len=300) cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url) gu.log(f"CACHE_DATA_: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, max_len=99999) @@ -198,7 +206,7 @@ class ProviderOpenAIOfficial(Provider): gu.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见https://github.com/Soulter/QQChannelChatGPT/wiki/%E4%BA%8C%E3%80%81%E9%A1%B9%E7%9B%AE%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6%E9%85%8D%E7%BD%AE", max_len=999) raise BaseException("连接出错: "+str(err)) assert isinstance(response, ChatCompletion) - print(response) + gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999) # 结果分类 choice = response.choices[0] @@ -209,8 +217,6 @@ class ProviderOpenAIOfficial(Provider): # tools call (function calling) return choice.message.tool_calls[0].function - gu.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, max_len=9999) - self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens current_usage_tokens = response.usage.total_tokens diff --git a/util/gplugin.py b/util/gplugin.py index 972bd87b4..519ef9345 100644 --- a/util/gplugin.py +++ b/util/gplugin.py @@ -12,13 +12,14 @@ import traceback from googlesearch import search, SearchResult from model.provider.provider import Provider import json +from readability import Document def tidy_text(text: str) -> str: ''' 清理文本,去除空格、换行符等 ''' - return text.strip().replace("\n", "").replace(" ", "").replace("\r", "") + return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") def special_fetch_zhihu(link: str) -> str: ''' @@ -100,12 +101,12 @@ def web_keyword_search_via_bing(keyword) -> str: # 爬取前两条的网页内容 if "zhihu.com" in link: try: - _detail_store.append(special_fetch_zhihu(link)[100:800]) + _detail_store.append(special_fetch_zhihu(link)) except BaseException as e: print(f"zhihu parse err: {str(e)}") else: try: - _detail_store.append(fetch_website_content(link)[100:1000]) + _detail_store.append(fetch_website_content(link)) except BaseException as e: print(f"fetch_website_content err: {str(e)}") @@ -158,7 +159,7 @@ def web_keyword_search_via_sougou(keyword) -> str: if _detail_store >= 3: break try: - _detail_store.append(fetch_website_content(i["link"])[100:1000]) + _detail_store.append(fetch_website_content(i["link"])) except BaseException as e: print(f"fetch_website_content err: {str(e)}") ret = f"{str(res)}" @@ -174,26 +175,32 @@ def fetch_website_content(url): } response = requests.get(url, headers=headers, timeout=3) response.encoding = "utf-8" - soup = BeautifulSoup(response.text, "html.parser") - # 如果有container / content / main等的话,就只取这些部分 - has = False - beleive_ls = ["container", "content", "main"] - res = "" - for cls in beleive_ls: - for i in soup.find_all(class_=cls): - has = True - res += i.text - if not has: - res = soup.text - res = res.replace("\n", "").replace(" ", " ").replace("\r", "").replace("\t", "") - if not has: - res = res[300:1100] - else: - res = res[100:800] - # with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f: - # f.write(res) - gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG) - return res + # soup = BeautifulSoup(response.text, "html.parser") + # # 如果有container / content / main等的话,就只取这些部分 + # has = False + # beleive_ls = ["container", "content", "main"] + # res = "" + # for cls in beleive_ls: + # for i in soup.find_all(class_=cls): + # has = True + # res += i.text + # if not has: + # res = soup.text + # res = res.replace("\n", "").replace(" ", " ").replace("\r", "").replace("\t", "") + # if not has: + # res = res[300:1100] + # else: + # res = res[100:800] + # # with open(f"temp_{time.time()}.html", "w", encoding="utf-8") as f: + # # f.write(res) + # gu.log(f"fetch_website_content: end", tag="fetch_website_content", level=gu.LEVEL_DEBUG) + # return res + doc = Document(response.content) + # print('title:', doc.title()) + ret = doc.summary(html_partial=True) + soup = BeautifulSoup(ret, 'html.parser') + ret = tidy_text(soup.get_text()) + return ret def web_search(question, provider: Provider, session_id, official_fc=False): ''' @@ -253,9 +260,9 @@ def web_search(question, provider: Provider, session_id, official_fc=False): if has_func: provider.forget(session_id) question3 = f"""请你用可爱的语气回答`{question}`问题。\n以下是相关材料,请直接拿此材料针对问题进行总结回答,再给参考链接, 参考链接首末有空格。不要提到任何函数调用的信息。在总结的末尾加上1-2个相关的emoji。```\n{function_invoked_ret}\n```\n""" - print(question3) + gu.log(f"web_search: {question3}", tag="web_search", level=gu.LEVEL_DEBUG, max_len=99999) _c = 0 - while _c < 5: + while _c < 3: try: print('text chat') final_ret = provider.text_chat(question3) @@ -263,7 +270,7 @@ def web_search(question, provider: Provider, session_id, official_fc=False): except Exception as e: print(e) _c += 1 - if _c == 5: raise e + if _c == 3: raise e if "The message you submitted was too long" in str(e): provider.forget(session_id) function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]