fix: 修复弹出记录报错的问题 #272
This commit is contained in:
@@ -591,14 +591,14 @@ CONFIG_METADATA_2 = {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色之后按回车",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。对话需要成对(用户和助手),输入完一个角色之后按回车",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -72,31 +72,21 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
async def pop_record(self, session_id: str):
|
||||
'''
|
||||
弹出第一条记录
|
||||
弹出最早的一个对话
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
if len(self.session_memory[session_id]) < 2:
|
||||
return
|
||||
|
||||
try:
|
||||
self.session_memory[session_id].pop(0)
|
||||
self.session_memory[session_id].pop(0)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
@@ -188,11 +178,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
logger.warning("上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
await self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
if kwargs.get("persist", True):
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
@@ -223,6 +215,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user