perf: 优化人格和LVM的处理过程
This commit is contained in:
@@ -133,6 +133,7 @@ def init(cfg):
|
||||
|
||||
instance = llm_instance[OPENAI_OFFICIAL]
|
||||
assert isinstance(instance, ProviderOpenAIOfficial)
|
||||
instance.DEFAULT_PERSONALITY = _global_object.default_personality
|
||||
instance.personality_set(_global_object.default_personality, session_id=None)
|
||||
|
||||
# 检查provider设置偏好
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from model.command.command import Command
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
|
||||
from util.personality import personalities
|
||||
from cores.astrbot.types import GlobalObject, CommandItem
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
@@ -47,7 +47,7 @@ class CommandOpenAIOfficial(Command):
|
||||
elif self.command_start_with(message, "his", "历史"):
|
||||
return True, self.his(message, session_id)
|
||||
elif self.command_start_with(message, "status"):
|
||||
return True, self.status()
|
||||
return True, self.status(session_id)
|
||||
elif self.command_start_with(message, "help", "帮助"):
|
||||
return True, await self.help()
|
||||
elif self.command_start_with(message, "unset"):
|
||||
@@ -61,11 +61,12 @@ class CommandOpenAIOfficial(Command):
|
||||
elif self.command_start_with(message, "switch"):
|
||||
return True, await self.switch(message)
|
||||
elif self.command_start_with(message, "models"):
|
||||
return True, await self.get_models()
|
||||
return True, await self.print_models()
|
||||
elif self.command_start_with(message, "model"):
|
||||
return True, await self.set_model(message)
|
||||
return False, None
|
||||
|
||||
async def get_models(self):
|
||||
ret = "OpenAI GPT 类可用模型"
|
||||
try:
|
||||
models = await self.provider.client.models.list()
|
||||
except NotFoundError as e:
|
||||
@@ -73,23 +74,49 @@ class CommandOpenAIOfficial(Command):
|
||||
self.provider.client.base_url = bu + "/v1"
|
||||
models = await self.provider.client.models.list()
|
||||
finally:
|
||||
print(models.data)
|
||||
i = 1
|
||||
for model in models.data:
|
||||
if str(model.id).startswith("gpt"):
|
||||
ret += f"\n{i}. {model.id}"
|
||||
i += 1
|
||||
logger.debug(ret)
|
||||
return filter(lambda x: x.id.startswith("gpt"), models.data)
|
||||
|
||||
async def print_models(self):
|
||||
models = await self.get_models()
|
||||
i = 1
|
||||
ret = "OpenAI GPT 类可用模型"
|
||||
for model in models:
|
||||
ret += f"\n{i}. {model.id}"
|
||||
i += 1
|
||||
logger.debug(ret)
|
||||
return True, ret, "models"
|
||||
|
||||
|
||||
async def set_model(self, message: str):
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
return True, "请输入 /model 模型名/编号", "model"
|
||||
model = str(l[1])
|
||||
models = await self.get_models()
|
||||
models = list(models)
|
||||
if model.isdigit() and int(model) <= len(models) and int(model) >= 1:
|
||||
model = models[int(model)-1]
|
||||
else:
|
||||
f = False
|
||||
for m in models:
|
||||
if model == m.id:
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
return True, "模型不存在或输入非法", "model"
|
||||
|
||||
self.provider.set_model(model.id)
|
||||
return True, f"模型已设置为 {model.id}", "model"
|
||||
|
||||
|
||||
async def help(self):
|
||||
commands = super().general_commands()
|
||||
commands['画'] = '调用 OpenAI DallE 模型生成图片'
|
||||
commands['set'] = '人格设置面板'
|
||||
commands['status'] = '查看 Api Key 状态和配置信息'
|
||||
commands['token'] = '查看本轮会话 token'
|
||||
commands['reset'] = '重置当前与 LLM 的会话'
|
||||
commands['reset p'] = '重置当前与 LLM 的会话,但保留人格(system prompt)'
|
||||
commands['reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt)'
|
||||
commands['reset p'] = '重置当前与 LLM 的会话,并清除人格。'
|
||||
|
||||
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
|
||||
|
||||
@@ -98,10 +125,10 @@ class CommandOpenAIOfficial(Command):
|
||||
return False, "未启用 OpenAI 官方 API", "reset"
|
||||
l = message.split(" ")
|
||||
if len(l) == 1:
|
||||
await self.provider.forget(session_id)
|
||||
await self.provider.forget(session_id, keep_system_prompt=True)
|
||||
return True, "重置成功", "reset"
|
||||
if len(l) == 2 and l[1] == "p":
|
||||
await self.provider.forget(session_id, keep_system_prompt=True)
|
||||
await self.provider.forget(session_id)
|
||||
|
||||
def his(self, message: str, session_id: str):
|
||||
if self.provider is None:
|
||||
@@ -114,21 +141,28 @@ class CommandOpenAIOfficial(Command):
|
||||
page = int(l[1])
|
||||
except BaseException as e:
|
||||
return True, "页码不合法", "his"
|
||||
contexts, total_num = self.provider.dump_contexts_page(size_per_page, page=page)
|
||||
contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page)
|
||||
t_pages = total_num // size_per_page + 1
|
||||
return True, f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页", "his"
|
||||
|
||||
def status(self):
|
||||
def status(self, session_id: str):
|
||||
if self.provider is None:
|
||||
return False, "未启用 OpenAI 官方 API", "status"
|
||||
keys_data = self.provider.get_keys_data()
|
||||
ret = "OpenAI Key"
|
||||
for k in keys_data:
|
||||
status = "🟢" if keys_data[k]['status'] == 0 else "🔴"
|
||||
status = "🟢" if keys_data[k] else "🔴"
|
||||
ret += "\n|- " + k[:8] + " " + status
|
||||
|
||||
conf = self.provider.get_configs()
|
||||
ret += "\n当前模型:" + conf['model']
|
||||
if conf['model'] in MODELS:
|
||||
ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens"
|
||||
|
||||
if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]):
|
||||
ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens"
|
||||
|
||||
return True, ret, "status"
|
||||
|
||||
async def switch(self, message: str):
|
||||
'''
|
||||
|
||||
@@ -5,6 +5,8 @@ import botpy.message
|
||||
import re
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
from util import general_utils as gu
|
||||
|
||||
from botpy.types.message import Reference
|
||||
|
||||
@@ -78,11 +78,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.session_memory_lock = threading.Lock()
|
||||
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
|
||||
self.curr_personality = {
|
||||
self.DEFAULT_PERSONALITY = {
|
||||
"name": "default",
|
||||
"prompt": "你是一个很有帮助的 AI 助手。"
|
||||
}
|
||||
|
||||
self.curr_personality = self.DEFAULT_PERSONALITY
|
||||
self.session_personality = {} # 记录了某个session是否已设置人格。
|
||||
# 从 SQLite DB 读取历史记录
|
||||
try:
|
||||
db1 = dbConn()
|
||||
@@ -121,13 +122,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
def personality_set(self, default_personality: dict, session_id: str):
|
||||
if not default_personality: return
|
||||
if session_id not in self.session_memory: self.session_memory[session_id] = []
|
||||
self.curr_personality = default_personality
|
||||
self.session_personality = {} # 重置
|
||||
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
|
||||
tokens_num = len(encoded_prompt)
|
||||
model = self.model_configs['model']
|
||||
if model in MODELS and tokens_num > MODELS[model] - 800:
|
||||
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 800])
|
||||
if model in MODELS and tokens_num > MODELS[model] - 500:
|
||||
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
|
||||
|
||||
new_record = {
|
||||
"user": {
|
||||
@@ -160,13 +161,23 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
# 转换为 openai 要求的格式
|
||||
context = []
|
||||
is_lvm = await self.is_lvm()
|
||||
for record in self.session_memory[session_id]:
|
||||
if "user" in record and record['user']:
|
||||
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
|
||||
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
|
||||
continue
|
||||
context.append(record['user'])
|
||||
if "AI" in record and record['AI']:
|
||||
context.append(record['AI'])
|
||||
|
||||
return context
|
||||
|
||||
async def is_lvm(self):
|
||||
'''
|
||||
是否是 LVM
|
||||
'''
|
||||
return self.model_configs['model'].startswith("gpt-4")
|
||||
|
||||
async def get_models(self):
|
||||
'''
|
||||
@@ -201,7 +212,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"url": await self.encode_image_bs64(image_url)
|
||||
"image_url": {
|
||||
"url": await self.encode_image_bs64(image_url)
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -249,7 +262,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
continue
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
@@ -274,7 +294,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
if session_id not in self.session_memory:
|
||||
self.session_memory[session_id] = []
|
||||
|
||||
if session_id not in self.session_personality or not self.session_personality[session_id]:
|
||||
self.personality_set(self.curr_personality, session_id)
|
||||
self.session_personality[session_id] = True
|
||||
|
||||
# 如果 prompt 超过了最大窗口,截断。
|
||||
# 1. 可以保证之后 pop 的时候不会出现问题
|
||||
@@ -321,9 +344,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
ok = await self.switch_to_next_key()
|
||||
if ok: continue
|
||||
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||
|
||||
except BadRequestError as e:
|
||||
logger.warn(f"OpenAI 请求异常:{e}。")
|
||||
if "image_url is only supported by certain models." in str(e):
|
||||
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
|
||||
retry += 1
|
||||
except RateLimitError as e:
|
||||
if "You exceeded your current quota" in e:
|
||||
if "You exceeded your current quota" in str(e):
|
||||
self.keys_data[self.chosen_api_key] = False
|
||||
ok = await self.switch_to_next_key()
|
||||
if ok: continue
|
||||
@@ -417,25 +444,41 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.session_memory[session_id] = []
|
||||
if keep_system_prompt:
|
||||
self.personality_set(self.curr_personality, session_id)
|
||||
else:
|
||||
self.curr_personality = self.DEFAULT_PERSONALITY
|
||||
return True
|
||||
|
||||
def dump_contexts_page(self, size=5, page=1):
|
||||
def dump_contexts_page(self, session_id: str, size=5, page=1,):
|
||||
'''
|
||||
获取缓存的会话
|
||||
'''
|
||||
# contexts_str = ""
|
||||
# for i, key in enumerate(self.session_memory):
|
||||
# if i < (page-1)*size or i >= page*size:
|
||||
# continue
|
||||
# contexts_str += f"Session ID: {key}\n"
|
||||
# for record in self.session_memory[key]:
|
||||
# if "user" in record:
|
||||
# contexts_str += f"User: {record['user']['content']}\n"
|
||||
# if "AI" in record:
|
||||
# contexts_str += f"AI: {record['AI']['content']}\n"
|
||||
# contexts_str += "---\n"
|
||||
contexts_str = ""
|
||||
for i, key in enumerate(self.session_memory):
|
||||
if i < (page-1)*size or i >= page*size:
|
||||
continue
|
||||
contexts_str += f"Session ID: {key}\n"
|
||||
for record in self.session_memory[key]:
|
||||
if "user" in record:
|
||||
contexts_str += f"User: {record['user']['content']}\n"
|
||||
if "AI" in record:
|
||||
contexts_str += f"AI: {record['AI']['content']}\n"
|
||||
contexts_str += "---\n"
|
||||
if session_id in self.session_memory:
|
||||
for record in self.session_memory[session_id]:
|
||||
if "user" in record and record['user']:
|
||||
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
|
||||
contexts_str += f"User: {text}\n"
|
||||
if "AI" in record and record['AI']:
|
||||
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
|
||||
contexts_str += f"Assistant: {text}\n"
|
||||
else:
|
||||
contexts_str = "会话 ID 不存在。"
|
||||
|
||||
return contexts_str, len(self.session_memory)
|
||||
return contexts_str, len(self.session_memory[session_id])
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model_configs['model'] = model
|
||||
|
||||
def get_configs(self):
|
||||
return self.model_configs
|
||||
|
||||
@@ -360,7 +360,13 @@ def save_temp_img(img: Image) -> str:
|
||||
# 获得时间戳
|
||||
timestamp = int(time.time())
|
||||
p = f"temp/{timestamp}.jpg"
|
||||
img.save(p)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
else:
|
||||
with open(p, "wb") as f:
|
||||
f.write(img)
|
||||
logger.info(f"保存临时图片: {p}")
|
||||
return p
|
||||
|
||||
async def download_image_by_url(url: str) -> str:
|
||||
@@ -368,11 +374,10 @@ async def download_image_by_url(url: str) -> str:
|
||||
下载图片
|
||||
'''
|
||||
try:
|
||||
logger.info(f"下载图片: {url}")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
img = Image.open(await resp.read())
|
||||
p = save_temp_img(img)
|
||||
return p
|
||||
return save_temp_img(await resp.read())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
Reference in New Issue
Block a user