perf: 优化人格和LVM的处理过程

This commit is contained in:
Soulter
2024-05-18 10:34:35 +08:00
parent 934ca94e62
commit ea34c20198
5 changed files with 128 additions and 43 deletions
+1
View File
@@ -133,6 +133,7 @@ def init(cfg):
instance = llm_instance[OPENAI_OFFICIAL]
assert isinstance(instance, ProviderOpenAIOfficial)
instance.DEFAULT_PERSONALITY = _global_object.default_personality
instance.personality_set(_global_object.default_personality, session_id=None)
# 检查provider设置偏好
+52 -18
View File
@@ -1,5 +1,5 @@
from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
from util.personality import personalities
from cores.astrbot.types import GlobalObject, CommandItem
from SparkleLogging.utils.core import LogManager
@@ -47,7 +47,7 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id)
elif self.command_start_with(message, "status"):
return True, self.status()
return True, self.status(session_id)
elif self.command_start_with(message, "help", "帮助"):
return True, await self.help()
elif self.command_start_with(message, "unset"):
@@ -61,11 +61,12 @@ class CommandOpenAIOfficial(Command):
elif self.command_start_with(message, "switch"):
return True, await self.switch(message)
elif self.command_start_with(message, "models"):
return True, await self.get_models()
return True, await self.print_models()
elif self.command_start_with(message, "model"):
return True, await self.set_model(message)
return False, None
async def get_models(self):
ret = "OpenAI GPT 类可用模型"
try:
models = await self.provider.client.models.list()
except NotFoundError as e:
@@ -73,23 +74,49 @@ class CommandOpenAIOfficial(Command):
self.provider.client.base_url = bu + "/v1"
models = await self.provider.client.models.list()
finally:
print(models.data)
i = 1
for model in models.data:
if str(model.id).startswith("gpt"):
ret += f"\n{i}. {model.id}"
i += 1
logger.debug(ret)
return filter(lambda x: x.id.startswith("gpt"), models.data)
async def print_models(self):
models = await self.get_models()
i = 1
ret = "OpenAI GPT 类可用模型"
for model in models:
ret += f"\n{i}. {model.id}"
i += 1
logger.debug(ret)
return True, ret, "models"
async def set_model(self, message: str):
l = message.split(" ")
if len(l) == 1:
return True, "请输入 /model 模型名/编号", "model"
model = str(l[1])
models = await self.get_models()
models = list(models)
if model.isdigit() and int(model) <= len(models) and int(model) >= 1:
model = models[int(model)-1]
else:
f = False
for m in models:
if model == m.id:
f = True
break
if not f:
return True, "模型不存在或输入非法", "model"
self.provider.set_model(model.id)
return True, f"模型已设置为 {model.id}", "model"
async def help(self):
commands = super().general_commands()
commands[''] = '调用 OpenAI DallE 模型生成图片'
commands['set'] = '人格设置面板'
commands['status'] = '查看 Api Key 状态和配置信息'
commands['token'] = '查看本轮会话 token'
commands['reset'] = '重置当前与 LLM 的会话'
commands['reset p'] = '重置当前与 LLM 的会话,但保留人格(system prompt'
commands['reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt'
commands['reset p'] = '重置当前与 LLM 的会话,并清除人格。'
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
@@ -98,10 +125,10 @@ class CommandOpenAIOfficial(Command):
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
await self.provider.forget(session_id)
await self.provider.forget(session_id, keep_system_prompt=True)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
await self.provider.forget(session_id, keep_system_prompt=True)
await self.provider.forget(session_id)
def his(self, message: str, session_id: str):
if self.provider is None:
@@ -114,21 +141,28 @@ class CommandOpenAIOfficial(Command):
page = int(l[1])
except BaseException as e:
return True, "页码不合法", "his"
contexts, total_num = self.provider.dump_contexts_page(size_per_page, page=page)
contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page)
t_pages = total_num // size_per_page + 1
return True, f"历史记录如下:\n{contexts}\n{page} 页 | 共 {t_pages}\n*输入 /his 2 跳转到第 2 页", "his"
def status(self):
def status(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "status"
keys_data = self.provider.get_keys_data()
ret = "OpenAI Key"
for k in keys_data:
status = "🟢" if keys_data[k]['status'] == 0 else "🔴"
status = "🟢" if keys_data[k] else "🔴"
ret += "\n|- " + k[:8] + " " + status
conf = self.provider.get_configs()
ret += "\n当前模型:" + conf['model']
if conf['model'] in MODELS:
ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens"
if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]):
ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens"
return True, ret, "status"
async def switch(self, message: str):
'''
+2
View File
@@ -5,6 +5,8 @@ import botpy.message
import re
import asyncio
import aiohttp
import botpy.types
import botpy.types.message
from util import general_utils as gu
from botpy.types.message import Reference
+64 -21
View File
@@ -78,11 +78,12 @@ class ProviderOpenAIOfficial(Provider):
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
self.curr_personality = {
self.DEFAULT_PERSONALITY = {
"name": "default",
"prompt": "你是一个很有帮助的 AI 助手。"
}
self.curr_personality = self.DEFAULT_PERSONALITY
self.session_personality = {} # 记录了某个session是否已设置人格。
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
@@ -121,13 +122,13 @@ class ProviderOpenAIOfficial(Provider):
def personality_set(self, default_personality: dict, session_id: str):
if not default_personality: return
if session_id not in self.session_memory: self.session_memory[session_id] = []
self.curr_personality = default_personality
self.session_personality = {} # 重置
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
if model in MODELS and tokens_num > MODELS[model] - 800:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 800])
if model in MODELS and tokens_num > MODELS[model] - 500:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
new_record = {
"user": {
@@ -160,13 +161,23 @@ class ProviderOpenAIOfficial(Provider):
# 转换为 openai 要求的格式
context = []
is_lvm = await self.is_lvm()
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
continue
context.append(record['user'])
if "AI" in record and record['AI']:
context.append(record['AI'])
return context
async def is_lvm(self):
'''
是否是 LVM
'''
return self.model_configs['model'].startswith("gpt-4")
async def get_models(self):
'''
@@ -201,7 +212,9 @@ class ProviderOpenAIOfficial(Provider):
},
{
"type": "image_url",
"url": await self.encode_image_bs64(image_url)
"image_url": {
"url": await self.encode_image_bs64(image_url)
}
}
]
}
@@ -249,7 +262,14 @@ class ProviderOpenAIOfficial(Provider):
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
continue
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
@@ -274,7 +294,10 @@ class ProviderOpenAIOfficial(Provider):
if session_id not in self.session_memory:
self.session_memory[session_id] = []
if session_id not in self.session_personality or not self.session_personality[session_id]:
self.personality_set(self.curr_personality, session_id)
self.session_personality[session_id] = True
# 如果 prompt 超过了最大窗口,截断。
# 1. 可以保证之后 pop 的时候不会出现问题
@@ -321,9 +344,13 @@ class ProviderOpenAIOfficial(Provider):
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except BadRequestError as e:
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
retry += 1
except RateLimitError as e:
if "You exceeded your current quota" in e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
@@ -417,25 +444,41 @@ class ProviderOpenAIOfficial(Provider):
self.session_memory[session_id] = []
if keep_system_prompt:
self.personality_set(self.curr_personality, session_id)
else:
self.curr_personality = self.DEFAULT_PERSONALITY
return True
def dump_contexts_page(self, size=5, page=1):
def dump_contexts_page(self, session_id: str, size=5, page=1,):
'''
获取缓存的会话
'''
# contexts_str = ""
# for i, key in enumerate(self.session_memory):
# if i < (page-1)*size or i >= page*size:
# continue
# contexts_str += f"Session ID: {key}\n"
# for record in self.session_memory[key]:
# if "user" in record:
# contexts_str += f"User: {record['user']['content']}\n"
# if "AI" in record:
# contexts_str += f"AI: {record['AI']['content']}\n"
# contexts_str += "---\n"
contexts_str = ""
for i, key in enumerate(self.session_memory):
if i < (page-1)*size or i >= page*size:
continue
contexts_str += f"Session ID: {key}\n"
for record in self.session_memory[key]:
if "user" in record:
contexts_str += f"User: {record['user']['content']}\n"
if "AI" in record:
contexts_str += f"AI: {record['AI']['content']}\n"
contexts_str += "---\n"
if session_id in self.session_memory:
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
contexts_str += f"User: {text}\n"
if "AI" in record and record['AI']:
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
contexts_str += f"Assistant: {text}\n"
else:
contexts_str = "会话 ID 不存在。"
return contexts_str, len(self.session_memory)
return contexts_str, len(self.session_memory[session_id])
def set_model(self, model: str):
self.model_configs['model'] = model
def get_configs(self):
return self.model_configs
+9 -4
View File
@@ -360,7 +360,13 @@ def save_temp_img(img: Image) -> str:
# 获得时间戳
timestamp = int(time.time())
p = f"temp/{timestamp}.jpg"
img.save(p)
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
logger.info(f"保存临时图片: {p}")
return p
async def download_image_by_url(url: str) -> str:
@@ -368,11 +374,10 @@ async def download_image_by_url(url: str) -> str:
下载图片
'''
try:
logger.info(f"下载图片: {url}")
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
img = Image.open(await resp.read())
p = save_temp_img(img)
return p
return save_temp_img(await resp.read())
except Exception as e:
raise e