From 1775327c2ec84176ee72344170879df251e67408 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 17 May 2024 09:07:11 +0800 Subject: [PATCH] chore: refact openai official --- cores/astrbot/types.py | 12 ++++- model/command/openai_official.py | 22 +++------ model/provider/openai_official_new.py | 66 +++++++++++++++++++++++++++ util/cmd_config.py | 6 +-- util/general_utils.py | 18 +++++--- 5 files changed, 97 insertions(+), 27 deletions(-) create mode 100644 model/provider/openai_official_new.py diff --git a/cores/astrbot/types.py b/cores/astrbot/types.py index c95ab33cf..c87cb3043 100644 --- a/cores/astrbot/types.py +++ b/cores/astrbot/types.py @@ -6,7 +6,7 @@ from nakuru import ( GuildMessage, ) from nakuru.entities.components import BaseMessageComponent -from typing import Union, List, ClassVar +from typing import Union, List, ClassVar, Callable from types import ModuleType from enum import Enum from dataclasses import dataclass @@ -167,6 +167,16 @@ class AstrMessageEvent(): self.role = role self.session_id = session_id +@dataclass +class CommandItem(): + ''' + 用来描述单个指令 + ''' + + command_name: Union[str, tuple] # 指令名 + callback: Callable # 回调函数 + description: str # 描述 + origin: str # 注册来源 class CommandResult(): ''' diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 6fa3fb4ad..27e8de3e7 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,7 +1,7 @@ from model.command.command import Command from model.provider.openai_official import ProviderOpenAIOfficial from util.personality import personalities -from cores.astrbot.types import GlobalObject +from cores.astrbot.types import GlobalObject, CommandItem from SparkleLogging.utils.core import LogManager from logging import Logger from openai._exceptions import NotFoundError, RateLimitError, APIError @@ -13,6 +13,12 @@ class CommandOpenAIOfficial(Command): self.provider = provider self.global_object = global_object self.personality_str = "" + self.commands = [ + CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"), + CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"), + CommandItem("status", self.gpt, "查看 GPT 配置信息和用量状态。", "内置"), + + ] super().__init__(provider, global_object) async def check_command(self, @@ -41,10 +47,6 @@ class CommandOpenAIOfficial(Command): return True, await self.reset(session_id, message) elif self.command_start_with(message, "his", "历史"): return True, self.his(message, session_id) - elif self.command_start_with(message, "token"): - return True, self.token(session_id) - elif self.command_start_with(message, "gpt"): - return True, self.gpt() elif self.command_start_with(message, "status"): return True, self.status() elif self.command_start_with(message, "help", "帮助"): @@ -126,16 +128,6 @@ class CommandOpenAIOfficial(Command): self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) return True, f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页", "his" - def token(self, session_id: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "token" - return True, f"会话的token数: {self.provider.get_user_usage_tokens(self.provider.session_dict[session_id])}\n系统最大缓存token数: {self.provider.max_tokens}", "token" - - def gpt(self): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "gpt" - return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt" - def status(self): if self.provider is None: return False, "未启用 OpenAI 官方 API", "status" diff --git a/model/provider/openai_official_new.py b/model/provider/openai_official_new.py new file mode 100644 index 000000000..6699c8e25 --- /dev/null +++ b/model/provider/openai_official_new.py @@ -0,0 +1,66 @@ +import os +import sys +import json +import time +import tiktoken +import threading +import traceback + +from openai import AsyncOpenAI +from openai.types.images_response import ImagesResponse +from openai.types.chat.chat_completion import ChatCompletion + +from cores.database.conn import dbConn +from model.provider.provider import Provider +from util import general_utils as gu +from util.cmd_config import CmdConfig +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot-core') + +class ProviderOpenAIOfficial(Provider): + def __init__(self, cfg) -> None: + super().__init__() + + os.makedirs("data/openai", exist_ok=True) + + self.cc = CmdConfig + self.key_data_path = "data/openai/keys.json" + self.api_keys = [] + self.chosen_api_key = None + self.base_url = None + self.keys_data = { + "keys": [] + } + + if cfg['key']: self.api_keys = cfg['key'] + if cfg['api_base']: self.base_url = cfg['api_base'] + if not self.api_keys: + logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。") + else: + self.chosen_api_key = self.api_keys[0] + + self.client = AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=self.base_url + ) + self.model_configs: dict = cfg['chatGPTConfigs'] + self.session_memory = {} # 会话记忆 + self.session_memory_lock = threading.Lock() + self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 + self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 + + # 从 SQLite DB 读取历史记录 + try: + db1 = dbConn() + for session in db1.get_all_session(): + self.session_memory_lock.acquire() + self.session_memory[session[0]] = json.loads(session[1])['data'] + self.session_memory_lock.release() + except BaseException as e: + logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。") + + # 定时保存历史记录 + threading.Thread(target=self.dump_history, daemon=True).start() + diff --git a/util/cmd_config.py b/util/cmd_config.py index ea05ed74d..928f3bb25 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -2,8 +2,7 @@ import os import json from typing import Union -cpath = "cmd_config.json" - +cpath = "data/cmd_config.json" def check_exist(): if not os.path.exists(cpath): @@ -89,8 +88,7 @@ def init_astrbot_config_items(): # 加载默认配置 cc = CmdConfig() cc.init_attributes("qq_forward_threshold", 200) - cc.init_attributes( - "qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n") + cc.init_attributes("qq_welcome", "") cc.init_attributes("qq_pic_mode", False) cc.init_attributes("gocq_host", "127.0.0.1") cc.init_attributes("gocq_http_port", 5700) diff --git a/util/general_utils.py b/util/general_utils.py index 7467342eb..33a772668 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -391,15 +391,19 @@ def create_markdown_image(text: str): raise e -def try_migrate_config(old_config: dict): +def try_migrate_config(): ''' - 迁移配置文件到 cmd_config.json + 将 cmd_config.json 迁移至 data/cmd_config.json ''' - cc = CmdConfig() - if cc.get("qqbot", None) is None: - # 未迁移过 - for k in old_config: - cc.put(k, old_config[k]) + if os.path.exists("cmd_config.json"): + with open("cmd_config.json", "r", encoding="utf-8-sig") as f: + data = json.load(f) + with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + try: + os.remove("cmd_config.json") + except Exception as e: + pass def get_local_ip_addresses():