From ea34c20198a9bd070626670fe353ba1813ba928c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 18 May 2024 10:34:35 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E4=BA=BA=E6=A0=BC?= =?UTF-8?q?=E5=92=8CLVM=E7=9A=84=E5=A4=84=E7=90=86=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cores/astrbot/core.py | 1 + model/command/openai_official.py | 70 ++++++++++++++++++------- model/platform/qq_official.py | 2 + model/provider/openai_official.py | 85 +++++++++++++++++++++++-------- util/general_utils.py | 13 +++-- 5 files changed, 128 insertions(+), 43 deletions(-) diff --git a/cores/astrbot/core.py b/cores/astrbot/core.py index 7576ca083..17aac1df0 100644 --- a/cores/astrbot/core.py +++ b/cores/astrbot/core.py @@ -133,6 +133,7 @@ def init(cfg): instance = llm_instance[OPENAI_OFFICIAL] assert isinstance(instance, ProviderOpenAIOfficial) + instance.DEFAULT_PERSONALITY = _global_object.default_personality instance.personality_set(_global_object.default_personality, session_id=None) # 检查provider设置偏好 diff --git a/model/command/openai_official.py b/model/command/openai_official.py index fe686dbc2..731c6941c 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,5 +1,5 @@ from model.command.command import Command -from model.provider.openai_official import ProviderOpenAIOfficial +from model.provider.openai_official import ProviderOpenAIOfficial, MODELS from util.personality import personalities from cores.astrbot.types import GlobalObject, CommandItem from SparkleLogging.utils.core import LogManager @@ -47,7 +47,7 @@ class CommandOpenAIOfficial(Command): elif self.command_start_with(message, "his", "历史"): return True, self.his(message, session_id) elif self.command_start_with(message, "status"): - return True, self.status() + return True, self.status(session_id) elif self.command_start_with(message, "help", "帮助"): return True, await self.help() elif self.command_start_with(message, "unset"): @@ -61,11 +61,12 @@ class CommandOpenAIOfficial(Command): elif self.command_start_with(message, "switch"): return True, await self.switch(message) elif self.command_start_with(message, "models"): - return True, await self.get_models() + return True, await self.print_models() + elif self.command_start_with(message, "model"): + return True, await self.set_model(message) return False, None async def get_models(self): - ret = "OpenAI GPT 类可用模型" try: models = await self.provider.client.models.list() except NotFoundError as e: @@ -73,23 +74,49 @@ class CommandOpenAIOfficial(Command): self.provider.client.base_url = bu + "/v1" models = await self.provider.client.models.list() finally: - print(models.data) - i = 1 - for model in models.data: - if str(model.id).startswith("gpt"): - ret += f"\n{i}. {model.id}" - i += 1 - logger.debug(ret) + return filter(lambda x: x.id.startswith("gpt"), models.data) + + async def print_models(self): + models = await self.get_models() + i = 1 + ret = "OpenAI GPT 类可用模型" + for model in models: + ret += f"\n{i}. {model.id}" + i += 1 + logger.debug(ret) return True, ret, "models" + + async def set_model(self, message: str): + l = message.split(" ") + if len(l) == 1: + return True, "请输入 /model 模型名/编号", "model" + model = str(l[1]) + models = await self.get_models() + models = list(models) + if model.isdigit() and int(model) <= len(models) and int(model) >= 1: + model = models[int(model)-1] + else: + f = False + for m in models: + if model == m.id: + f = True + break + if not f: + return True, "模型不存在或输入非法", "model" + + self.provider.set_model(model.id) + return True, f"模型已设置为 {model.id}", "model" + + async def help(self): commands = super().general_commands() commands['画'] = '调用 OpenAI DallE 模型生成图片' commands['set'] = '人格设置面板' commands['status'] = '查看 Api Key 状态和配置信息' commands['token'] = '查看本轮会话 token' - commands['reset'] = '重置当前与 LLM 的会话' - commands['reset p'] = '重置当前与 LLM 的会话,但保留人格(system prompt)' + commands['reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt)' + commands['reset p'] = '重置当前与 LLM 的会话,并清除人格。' return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" @@ -98,10 +125,10 @@ class CommandOpenAIOfficial(Command): return False, "未启用 OpenAI 官方 API", "reset" l = message.split(" ") if len(l) == 1: - await self.provider.forget(session_id) + await self.provider.forget(session_id, keep_system_prompt=True) return True, "重置成功", "reset" if len(l) == 2 and l[1] == "p": - await self.provider.forget(session_id, keep_system_prompt=True) + await self.provider.forget(session_id) def his(self, message: str, session_id: str): if self.provider is None: @@ -114,21 +141,28 @@ class CommandOpenAIOfficial(Command): page = int(l[1]) except BaseException as e: return True, "页码不合法", "his" - contexts, total_num = self.provider.dump_contexts_page(size_per_page, page=page) + contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page) t_pages = total_num // size_per_page + 1 return True, f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页", "his" - def status(self): + def status(self, session_id: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "status" keys_data = self.provider.get_keys_data() ret = "OpenAI Key" for k in keys_data: - status = "🟢" if keys_data[k]['status'] == 0 else "🔴" + status = "🟢" if keys_data[k] else "🔴" ret += "\n|- " + k[:8] + " " + status conf = self.provider.get_configs() ret += "\n当前模型:" + conf['model'] + if conf['model'] in MODELS: + ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens" + + if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]): + ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens" + + return True, ret, "status" async def switch(self, message: str): ''' diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 520d3d320..2995e93b3 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -5,6 +5,8 @@ import botpy.message import re import asyncio import aiohttp +import botpy.types +import botpy.types.message from util import general_utils as gu from botpy.types.message import Reference diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 6ff532185..a39a49141 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -78,11 +78,12 @@ class ProviderOpenAIOfficial(Provider): self.session_memory_lock = threading.Lock() self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 - self.curr_personality = { + self.DEFAULT_PERSONALITY = { "name": "default", "prompt": "你是一个很有帮助的 AI 助手。" } - + self.curr_personality = self.DEFAULT_PERSONALITY + self.session_personality = {} # 记录了某个session是否已设置人格。 # 从 SQLite DB 读取历史记录 try: db1 = dbConn() @@ -121,13 +122,13 @@ class ProviderOpenAIOfficial(Provider): def personality_set(self, default_personality: dict, session_id: str): if not default_personality: return - if session_id not in self.session_memory: self.session_memory[session_id] = [] self.curr_personality = default_personality + self.session_personality = {} # 重置 encoded_prompt = self.tokenizer.encode(default_personality['prompt']) tokens_num = len(encoded_prompt) model = self.model_configs['model'] - if model in MODELS and tokens_num > MODELS[model] - 800: - default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 800]) + if model in MODELS and tokens_num > MODELS[model] - 500: + default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500]) new_record = { "user": { @@ -160,13 +161,23 @@ class ProviderOpenAIOfficial(Provider): # 转换为 openai 要求的格式 context = [] + is_lvm = await self.is_lvm() for record in self.session_memory[session_id]: if "user" in record and record['user']: + if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list): + logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。") + continue context.append(record['user']) if "AI" in record and record['AI']: context.append(record['AI']) return context + + async def is_lvm(self): + ''' + 是否是 LVM + ''' + return self.model_configs['model'].startswith("gpt-4") async def get_models(self): ''' @@ -201,7 +212,9 @@ class ProviderOpenAIOfficial(Provider): }, { "type": "image_url", - "url": await self.encode_image_bs64(image_url) + "image_url": { + "url": await self.encode_image_bs64(image_url) + } } ] } @@ -249,7 +262,14 @@ class ProviderOpenAIOfficial(Provider): for i in range(len(self.session_memory[session_id])): # 检查是否是 system prompt if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": - continue + # 如果只有一个 system prompt,才不删掉 + f = False + for j in range(i+1, len(self.session_memory[session_id])): + if self.session_memory[session_id][j]['user']['role'] == "system": + f = True + break + if not f: + continue record = self.session_memory[session_id].pop(i) break @@ -274,7 +294,10 @@ class ProviderOpenAIOfficial(Provider): if session_id not in self.session_memory: self.session_memory[session_id] = [] + + if session_id not in self.session_personality or not self.session_personality[session_id]: self.personality_set(self.curr_personality, session_id) + self.session_personality[session_id] = True # 如果 prompt 超过了最大窗口,截断。 # 1. 可以保证之后 pop 的时候不会出现问题 @@ -321,9 +344,13 @@ class ProviderOpenAIOfficial(Provider): ok = await self.switch_to_next_key() if ok: continue else: raise Exception("所有 OpenAI API Key 目前都不可用。") - + except BadRequestError as e: + logger.warn(f"OpenAI 请求异常:{e}。") + if "image_url is only supported by certain models." in str(e): + raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。") + retry += 1 except RateLimitError as e: - if "You exceeded your current quota" in e: + if "You exceeded your current quota" in str(e): self.keys_data[self.chosen_api_key] = False ok = await self.switch_to_next_key() if ok: continue @@ -417,25 +444,41 @@ class ProviderOpenAIOfficial(Provider): self.session_memory[session_id] = [] if keep_system_prompt: self.personality_set(self.curr_personality, session_id) + else: + self.curr_personality = self.DEFAULT_PERSONALITY return True - def dump_contexts_page(self, size=5, page=1): + def dump_contexts_page(self, session_id: str, size=5, page=1,): ''' 获取缓存的会话 ''' + # contexts_str = "" + # for i, key in enumerate(self.session_memory): + # if i < (page-1)*size or i >= page*size: + # continue + # contexts_str += f"Session ID: {key}\n" + # for record in self.session_memory[key]: + # if "user" in record: + # contexts_str += f"User: {record['user']['content']}\n" + # if "AI" in record: + # contexts_str += f"AI: {record['AI']['content']}\n" + # contexts_str += "---\n" contexts_str = "" - for i, key in enumerate(self.session_memory): - if i < (page-1)*size or i >= page*size: - continue - contexts_str += f"Session ID: {key}\n" - for record in self.session_memory[key]: - if "user" in record: - contexts_str += f"User: {record['user']['content']}\n" - if "AI" in record: - contexts_str += f"AI: {record['AI']['content']}\n" - contexts_str += "---\n" + if session_id in self.session_memory: + for record in self.session_memory[session_id]: + if "user" in record and record['user']: + text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content'] + contexts_str += f"User: {text}\n" + if "AI" in record and record['AI']: + text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content'] + contexts_str += f"Assistant: {text}\n" + else: + contexts_str = "会话 ID 不存在。" - return contexts_str, len(self.session_memory) + return contexts_str, len(self.session_memory[session_id]) + + def set_model(self, model: str): + self.model_configs['model'] = model def get_configs(self): return self.model_configs diff --git a/util/general_utils.py b/util/general_utils.py index c9a3e7497..1c2a0543f 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -360,7 +360,13 @@ def save_temp_img(img: Image) -> str: # 获得时间戳 timestamp = int(time.time()) p = f"temp/{timestamp}.jpg" - img.save(p) + + if isinstance(img, Image.Image): + img.save(p) + else: + with open(p, "wb") as f: + f.write(img) + logger.info(f"保存临时图片: {p}") return p async def download_image_by_url(url: str) -> str: @@ -368,11 +374,10 @@ async def download_image_by_url(url: str) -> str: 下载图片 ''' try: + logger.info(f"下载图片: {url}") async with aiohttp.ClientSession() as session: async with session.get(url) as resp: - img = Image.open(await resp.read()) - p = save_temp_img(img) - return p + return save_temp_img(await resp.read()) except Exception as e: raise e