chore: refact openai official

This commit is contained in:
Soulter
2024-05-17 09:07:11 +08:00
parent 707fcad8b4
commit 1775327c2e
5 changed files with 97 additions and 27 deletions
+11 -1
View File
@@ -6,7 +6,7 @@ from nakuru import (
GuildMessage,
)
from nakuru.entities.components import BaseMessageComponent
from typing import Union, List, ClassVar
from typing import Union, List, ClassVar, Callable
from types import ModuleType
from enum import Enum
from dataclasses import dataclass
@@ -167,6 +167,16 @@ class AstrMessageEvent():
self.role = role
self.session_id = session_id
@dataclass
class CommandItem():
'''
用来描述单个指令
'''
command_name: Union[str, tuple] # 指令名
callback: Callable # 回调函数
description: str # 描述
origin: str # 注册来源
class CommandResult():
'''
+7 -15
View File
@@ -1,7 +1,7 @@
from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial
from util.personality import personalities
from cores.astrbot.types import GlobalObject
from cores.astrbot.types import GlobalObject, CommandItem
from SparkleLogging.utils.core import LogManager
from logging import Logger
from openai._exceptions import NotFoundError, RateLimitError, APIError
@@ -13,6 +13,12 @@ class CommandOpenAIOfficial(Command):
self.provider = provider
self.global_object = global_object
self.personality_str = ""
self.commands = [
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
CommandItem("status", self.gpt, "查看 GPT 配置信息和用量状态。", "内置"),
]
super().__init__(provider, global_object)
async def check_command(self,
@@ -41,10 +47,6 @@ class CommandOpenAIOfficial(Command):
return True, await self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id)
elif self.command_start_with(message, "token"):
return True, self.token(session_id)
elif self.command_start_with(message, "gpt"):
return True, self.gpt()
elif self.command_start_with(message, "status"):
return True, self.status()
elif self.command_start_with(message, "help", "帮助"):
@@ -126,16 +128,6 @@ class CommandOpenAIOfficial(Command):
self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
return True, f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页", "his"
def token(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "token"
return True, f"会话的token数: {self.provider.get_user_usage_tokens(self.provider.session_dict[session_id])}\n系统最大缓存token数: {self.provider.max_tokens}", "token"
def gpt(self):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "gpt"
return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt"
def status(self):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "status"
+66
View File
@@ -0,0 +1,66 @@
import os
import sys
import json
import time
import tiktoken
import threading
import traceback
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from cores.database.conn import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg) -> None:
super().__init__()
os.makedirs("data/openai", exist_ok=True)
self.cc = CmdConfig
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.keys_data = {
"keys": []
}
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.model_configs: dict = cfg['chatGPTConfigs']
self.session_memory = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
for session in db1.get_all_session():
self.session_memory_lock.acquire()
self.session_memory[session[0]] = json.loads(session[1])['data']
self.session_memory_lock.release()
except BaseException as e:
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。")
# 定时保存历史记录
threading.Thread(target=self.dump_history, daemon=True).start()
+2 -4
View File
@@ -2,8 +2,7 @@ import os
import json
from typing import Union
cpath = "cmd_config.json"
cpath = "data/cmd_config.json"
def check_exist():
if not os.path.exists(cpath):
@@ -89,8 +88,7 @@ def init_astrbot_config_items():
# 加载默认配置
cc = CmdConfig()
cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes(
"qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n")
cc.init_attributes("qq_welcome", "")
cc.init_attributes("qq_pic_mode", False)
cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700)
+11 -7
View File
@@ -391,15 +391,19 @@ def create_markdown_image(text: str):
raise e
def try_migrate_config(old_config: dict):
def try_migrate_config():
'''
迁移配置文件到 cmd_config.json
将 cmd_config.json 迁移至 data/cmd_config.json
'''
cc = CmdConfig()
if cc.get("qqbot", None) is None:
# 未迁移过
for k in old_config:
cc.put(k, old_config[k])
if os.path.exists("cmd_config.json"):
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
data = json.load(f)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
try:
os.remove("cmd_config.json")
except Exception as e:
pass
def get_local_ip_addresses():