diff --git a/.gitignore b/.gitignore index 33d25a74c..71c6b2ef9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ cmd_config.json addons/plugins/ data/* cookies.json +logs/ diff --git a/addons/dashboard/server.py b/addons/dashboard/server.py index 2bb775aec..7ec4ffe21 100644 --- a/addons/dashboard/server.py +++ b/addons/dashboard/server.py @@ -12,10 +12,11 @@ from flask.logging import default_handler from werkzeug.serving import make_server from util import general_utils as gu from dataclasses import dataclass -from cores.database.conn import dbConn +from persist.session import dbConn +from type.register import RegisteredPlugin +from typing import List from util.cmd_config import CmdConfig from util.updator import check_update, update_project, request_release_info -from cores.astrbot.types import * from SparkleLogging.utils.core import LogManager from logging import Logger logger: Logger = LogManager.GetLogger(log_name='astrbot-core') diff --git a/cores/astrbot/core.py b/astrbot/core.py similarity index 93% rename from cores/astrbot/core.py rename to astrbot/core.py index 611ade77b..2b8470452 100644 --- a/cores/astrbot/core.py +++ b/astrbot/core.py @@ -2,17 +2,14 @@ import re import threading import asyncio import time -import aiohttp import util.unfit_words as uw import os import sys -import io import traceback -import util.function_calling.gplugin as gplugin +import util.agent.web_searcher as web_searcher import util.plugin_util as putil -from PIL import Image as PILImage from nakuru.entities.components import Plain, At, Image from addons.baidu_aip_judge import BaiduJudge @@ -22,10 +19,12 @@ from util import general_utils as gu from util.general_utils import upload, run_monitor from util.cmd_config import CmdConfig as cc from util.cmd_config import init_astrbot_config_items -from .types import * +from type.types import GlobalObject +from type.register import * +from type.message import AstrBotMessage from addons.dashboard.helper import DashBoardHelper from addons.dashboard.server import DashBoardData -from cores.database.conn import dbConn +from persist.session import dbConn from model.platform._message_result import MessageResult from SparkleLogging.utils.core import LogManager from logging import Logger @@ -77,14 +76,14 @@ def privider_chooser(cfg): ''' -def init(cfg): +def init(): global llm_instance, llm_command_instance global baidu_judge, chosen_provider global frequency_count, frequency_time global _global_object # 迁移旧配置 - gu.try_migrate_config(cfg) + gu.try_migrate_config() # 使用新配置 cfg = cc.get_all() @@ -105,6 +104,15 @@ def init(cfg): cc.put("reply_prefix", "") else: _global_object.reply_prefix = cfg['reply_prefix'] + + default_personality_str = cc.get("default_personality_str", "") + if default_personality_str == "": + _global_object.default_personality = None + else: + _global_object.default_personality = { + "name": "default", + "prompt": default_personality_str, + } # 语言模型提供商 logger.info("正在载入语言模型...") @@ -122,6 +130,11 @@ def init(cfg): llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) chosen_provider = OPENAI_OFFICIAL + instance = llm_instance[OPENAI_OFFICIAL] + assert isinstance(instance, ProviderOpenAIOfficial) + instance.DEFAULT_PERSONALITY = _global_object.default_personality + instance.curr_personality = instance.DEFAULT_PERSONALITY + # 检查provider设置偏好 p = cc.get("chosen_provider", None) if p is not None and p in llm_instance: @@ -197,14 +210,6 @@ def init(cfg): cfg, _global_object), daemon=True).start() platform_str += "QQ_OFFICIAL," - default_personality_str = cc.get("default_personality_str", "") - if default_personality_str == "": - _global_object.default_personality = None - else: - _global_object.default_personality = { - "name": "default", - "prompt": default_personality_str, - } # 初始化dashboard _global_object.dashboard_data = DashBoardData( stats={}, @@ -428,14 +433,14 @@ async def oper_msg(message: AstrBotMessage, if chosen_provider == OPENAI_OFFICIAL: if _global_object.web_search or web_sch_flag: official_fc = chosen_provider == OPENAI_OFFICIAL - llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) + llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) else: - llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality) + llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url) llm_result_str = _global_object.reply_prefix + llm_result_str except BaseException as e: - logger.info(f"调用异常:{traceback.format_exc()}") - return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}") + logger.error(f"调用异常:{traceback.format_exc()}") + return MessageResult(f"调用异常。详细原因:{str(e)}") # 切换回原来的语言模型 if temp_switch != "": @@ -458,14 +463,10 @@ async def oper_msg(message: AstrBotMessage, return MessageResult(f"指令调用错误: \n{str(command_result[1])}") # 画图指令 - if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw': - for i in command_result[1]: - # 保存到本地 - async with aiohttp.ClientSession() as session: - async with session.get(i) as resp: - if resp.status == 200: - image = PILImage.open(io.BytesIO(await resp.read())) - return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))]) + if command == 'draw': + # 保存到本地 + path = await gu.download_image_by_url(command_result[1]) + return MessageResult([Image.fromFileSystem(path)]) # 其他指令 else: try: diff --git a/cores/astrbot/types.py b/cores/astrbot/types.py deleted file mode 100644 index c95ab33cf..000000000 --- a/cores/astrbot/types.py +++ /dev/null @@ -1,183 +0,0 @@ -from model.provider.provider import Provider as LLMProvider -from model.platform._platfrom import Platform -from nakuru import ( - GroupMessage, - FriendMessage, - GuildMessage, -) -from nakuru.entities.components import BaseMessageComponent -from typing import Union, List, ClassVar -from types import ModuleType -from enum import Enum -from dataclasses import dataclass - - -class MessageType(Enum): - GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 - FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 - GUILD_MESSAGE = 'GuildMessage' # 频道消息 - - -@dataclass -class MessageMember(): - user_id: str # 发送者id - nickname: str = None - - -class AstrBotMessage(): - ''' - AstrBot 的消息对象 - ''' - tag: str # 消息来源标签 - type: MessageType # 消息类型 - self_id: str # 机器人的识别id - session_id: str # 会话id - message_id: str # 消息id - sender: MessageMember # 发送者 - message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 - message_str: str # 最直观的纯文本消息字符串 - raw_message: object - timestamp: int # 消息时间戳 - - def __str__(self) -> str: - return str(self.__dict__) - - -class PluginType(Enum): - PLATFORM = 'platfrom' # 平台类插件。 - LLM = 'llm' # 大语言模型类插件 - COMMON = 'common' # 其他插件 - - -@dataclass -class PluginMetadata: - ''' - 插件的元数据。 - ''' - # required - plugin_name: str - plugin_type: PluginType - author: str # 插件作者 - desc: str # 插件简介 - version: str # 插件版本 - - # optional - repo: str = None # 插件仓库地址 - - def __str__(self) -> str: - return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" - - -@dataclass -class RegisteredPlugin: - ''' - 注册在 AstrBot 中的插件。 - ''' - metadata: PluginMetadata - plugin_instance: object - module_path: str - module: ModuleType - root_dir_name: str - - def __str__(self) -> str: - return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" - - -RegisteredPlugins = List[RegisteredPlugin] - - -@dataclass -class RegisteredPlatform: - ''' - 注册在 AstrBot 中的平台。平台应当实现 Platform 接口。 - ''' - platform_name: str - platform_instance: Platform - origin: str = None # 注册来源 - - def __str__(self) -> str: - return self.platform_name - - -@dataclass -class RegisteredLLM: - ''' - 注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。 - ''' - llm_name: str - llm_instance: LLMProvider - origin: str = None # 注册来源 - - -class GlobalObject: - ''' - 存放一些公用的数据,用于在不同模块(如core与command)之间传递 - ''' - version: str # 机器人版本 - nick: tuple # 用户定义的机器人的别名 - base_config: dict # config.json 中导出的配置 - cached_plugins: List[RegisteredPlugin] # 加载的插件 - platforms: List[RegisteredPlatform] - llms: List[RegisteredLLM] - - web_search: bool # 是否开启了网页搜索 - reply_prefix: str # 回复前缀 - unique_session: bool # 是否开启了独立会话 - cnt_total: int # 总消息数 - default_personality: dict - dashboard_data = None - - def __init__(self): - self.nick = None # gocq 的昵称 - self.base_config = None # config.yaml - self.cached_plugins = [] # 缓存的插件 - self.web_search = False # 是否开启了网页搜索 - self.reply_prefix = None - self.unique_session = False - self.cnt_total = 0 - self.platforms = [] - self.llms = [] - self.default_personality = None - self.dashboard_data = None - self.stat = {} - - -class AstrMessageEvent(): - ''' - 消息事件。 - ''' - context: GlobalObject # 一些公用数据 - message_str: str # 纯消息字符串 - message_obj: AstrBotMessage # 消息对象 - platform: RegisteredPlatform # 来源平台 - role: str # 基本身份。`admin` 或 `member` - session_id: int # 会话 id - - def __init__(self, - message_str: str, - message_obj: AstrBotMessage, - platform: RegisteredPlatform, - role: str, - context: GlobalObject, - session_id: str = None): - self.context = context - self.message_str = message_str - self.message_obj = message_obj - self.platform = platform - self.role = role - self.session_id = session_id - - -class CommandResult(): - ''' - 用于在Command中返回多个值 - ''' - - def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: - self.hit = hit - self.success = success - self.message_chain = message_chain - self.command_name = command_name - - def _result_tuple(self): - return (self.success, self.message_chain, self.command_name) diff --git a/main.py b/main.py index d8a9622c3..60125cf9a 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ def make_necessary_dirs(): os.makedirs("temp", exist_ok=True) def main(): + logger = LogManager.GetLogger( log_name='astrbot-core', out_to_console=True, @@ -35,15 +36,25 @@ def main(): custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") ) logger.info(logo_tmpl) - # config.yaml 配置文件加载和环境确认 + + # 设置代理 + from util.cmd_config import CmdConfig + cc = CmdConfig() + http_proxy = cc.get("http_proxy") + https_proxy = cc.get("https_proxy") + logger.info(f"使用代理: {http_proxy}, {https_proxy}") + if http_proxy: + os.environ['HTTP_PROXY'] = http_proxy + if https_proxy: + os.environ['HTTPS_PROXY'] = https_proxy + os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com' + try: - import botpy, logging, yaml - import cores.astrbot.core as qqBot + import botpy, logging + import astrbot.core as bot_core # delete qqbotpy's logger for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8') - cfg = yaml.safe_load(ymlfile) + logging.root.removeHandler(handler) except ImportError as import_error: logger.error(import_error) logger.error("检测到一些依赖库没有安装。由于兼容性问题,AstrBot 此版本将不会自动为您安装依赖库。请您先自行安装,然后重试。") @@ -59,19 +70,14 @@ def main(): input("配置文件不存在,请检查是否已经下载配置文件。") exit() except BaseException as e: - raise e - - # 设置代理 - if 'http_proxy' in cfg and cfg['http_proxy'] != '': - os.environ['HTTP_PROXY'] = cfg['http_proxy'] - if 'https_proxy' in cfg and cfg['https_proxy'] != '': - os.environ['HTTPS_PROXY'] = cfg['https_proxy'] - os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com' + logger.error(traceback.format_exc()) + input("未知错误。") + exit() make_necessary_dirs() # 启动主程序(cores/qqbot/core.py) - qqBot.init(cfg) + bot_core.init() def check_env(): diff --git a/model/command/command.py b/model/command/command.py index 362cc54c1..bc9132e0b 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -13,16 +13,13 @@ from nakuru.entities.components import ( from util import general_utils as gu from model.provider.provider import Provider from util.cmd_config import CmdConfig as cc -from cores.astrbot.types import ( - GlobalObject, - AstrMessageEvent, - PluginType, - CommandResult, - RegisteredPlugin, - RegisteredPlatform -) +from type.message import * +from type.types import GlobalObject +from type.command import * +from type.plugin import * +from type.register import * -from typing import List, Tuple +from typing import List from SparkleLogging.utils.core import LogManager from logging import Logger @@ -223,8 +220,6 @@ class Command: "nick": "设置机器人昵称", "plugin": "插件安装、卸载和重载", "web on/off": "LLM 网页搜索能力", - "reset": "重置 LLM 对话", - "/gpt": "切换到 OpenAI 官方接口" } async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None): diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 6fa3fb4ad..10d02fd6e 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,10 +1,11 @@ from model.command.command import Command -from model.provider.openai_official import ProviderOpenAIOfficial +from model.provider.openai_official import ProviderOpenAIOfficial, MODELS from util.personality import personalities -from cores.astrbot.types import GlobalObject +from type.types import GlobalObject +from type.command import CommandItem from SparkleLogging.utils.core import LogManager from logging import Logger -from openai._exceptions import NotFoundError, RateLimitError, APIError +from openai._exceptions import NotFoundError logger: Logger = LogManager.GetLogger(log_name='astrbot-core') @@ -13,6 +14,11 @@ class CommandOpenAIOfficial(Command): self.provider = provider self.global_object = global_object self.personality_str = "" + self.commands = [ + CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"), + CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"), + CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"), + ] super().__init__(provider, global_object) async def check_command(self, @@ -41,12 +47,8 @@ class CommandOpenAIOfficial(Command): return True, await self.reset(session_id, message) elif self.command_start_with(message, "his", "历史"): return True, self.his(message, session_id) - elif self.command_start_with(message, "token"): - return True, self.token(session_id) - elif self.command_start_with(message, "gpt"): - return True, self.gpt() elif self.command_start_with(message, "status"): - return True, self.status() + return True, self.status(session_id) elif self.command_start_with(message, "help", "帮助"): return True, await self.help() elif self.command_start_with(message, "unset"): @@ -57,16 +59,15 @@ class CommandOpenAIOfficial(Command): return True, self.update(message, role) elif self.command_start_with(message, "画", "draw"): return True, await self.draw(message) - elif self.command_start_with(message, "key"): - return True, self.key(message) elif self.command_start_with(message, "switch"): return True, await self.switch(message) elif self.command_start_with(message, "models"): - return True, await self.get_models() + return True, await self.print_models() + elif self.command_start_with(message, "model"): + return True, await self.set_model(message) return False, None async def get_models(self): - ret = "OpenAI GPT 类可用模型" try: models = await self.provider.client.models.list() except NotFoundError as e: @@ -74,23 +75,50 @@ class CommandOpenAIOfficial(Command): self.provider.client.base_url = bu + "/v1" models = await self.provider.client.models.list() finally: - print(models.data) - i = 1 - for model in models.data: - if str(model.id).startswith("gpt"): - ret += f"\n{i}. {model.id}" - i += 1 - logger.debug(ret) + return filter(lambda x: x.id.startswith("gpt"), models.data) + + async def print_models(self): + models = await self.get_models() + i = 1 + ret = "OpenAI GPT 类可用模型" + for model in models: + ret += f"\n{i}. {model.id}" + i += 1 + logger.debug(ret) return True, ret, "models" + + async def set_model(self, message: str): + l = message.split(" ") + if len(l) == 1: + return True, "请输入 /model 模型名/编号", "model" + model = str(l[1]) + models = await self.get_models() + models = list(models) + if model.isdigit() and int(model) <= len(models) and int(model) >= 1: + model = models[int(model)-1] + else: + f = False + for m in models: + if model == m.id: + f = True + break + if not f: + return True, "模型不存在或输入非法", "model" + + self.provider.set_model(model.id) + return True, f"模型已设置为 {model.id}", "model" + + async def help(self): commands = super().general_commands() - commands['画'] = '画画' - commands['key'] = '添加OpenAI key' + commands['画'] = '调用 OpenAI DallE 模型生成图片' commands['set'] = '人格设置面板' - commands['gpt'] = '查看gpt配置信息' - commands['status'] = '查看key使用状态' - commands['token'] = '查看本轮会话token' + commands['status'] = '查看 Api Key 状态和配置信息' + commands['token'] = '查看本轮会话 token' + commands['reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt)' + commands['reset p'] = '重置当前与 LLM 的会话,并清除人格。' + return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" async def reset(self, session_id: str, message: str = "reset"): @@ -98,79 +126,44 @@ class CommandOpenAIOfficial(Command): return False, "未启用 OpenAI 官方 API", "reset" l = message.split(" ") if len(l) == 1: - await self.provider.forget(session_id) + await self.provider.forget(session_id, keep_system_prompt=True) return True, "重置成功", "reset" if len(l) == 2 and l[1] == "p": - self.provider.forget(session_id) - if self.personality_str != "": - self.set(self.personality_str, session_id) # 重新设置人格 - return True, "重置成功", "reset" + await self.provider.forget(session_id) def his(self, message: str, session_id: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "his" - # 分页,每页5条 - msg = '' size_per_page = 3 page = 1 - if message[4:]: - page = int(message[4:]) - # 检查是否有过历史记录 - if session_id not in self.provider.session_dict: - msg = f"历史记录为空" - return True, msg, "his" - l = self.provider.session_dict[session_id] - max_page = len(l)//size_per_page + \ - 1 if len(l) % size_per_page != 0 else len(l)//size_per_page - p = self.provider.get_prompts_by_cache_list( - self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page) - return True, f"历史记录如下:\n{p}\n第{page}页 | 共{max_page}页\n*输入/his 2跳转到第2页", "his" + l = message.split(" ") + if len(l) == 2: + try: + page = int(l[1]) + except BaseException as e: + return True, "页码不合法", "his" + contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page) + t_pages = total_num // size_per_page + 1 + return True, f"历史记录如下:\n{contexts}\n第 {page} 页 | 共 {t_pages} 页\n*输入 /his 2 跳转到第 2 页", "his" - def token(self, session_id: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "token" - return True, f"会话的token数: {self.provider.get_user_usage_tokens(self.provider.session_dict[session_id])}\n系统最大缓存token数: {self.provider.max_tokens}", "token" - - def gpt(self): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "gpt" - return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt" - - def status(self): + def status(self, session_id: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "status" - chatgpt_cfg_str = "" - key_stat = self.provider.get_key_stat() - index = 1 - max = 9000000 - gg_count = 0 - total = 0 - tag = '' - for key in key_stat.keys(): - sponsor = '' - total += key_stat[key]['used'] - if key_stat[key]['exceed']: - gg_count += 1 - continue - if 'sponsor' in key_stat[key]: - sponsor = key_stat[key]['sponsor'] - chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n" - index += 1 - return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status" + keys_data = self.provider.get_keys_data() + ret = "OpenAI Key" + for k in keys_data: + status = "🟢" if keys_data[k] else "🔴" + ret += "\n|- " + k[:8] + " " + status - def key(self, message: str): - if self.provider is None: - return False, "未启用 OpenAI 官方 API", "reset" - l = message.split(" ") - if len(l) == 1: - msg = "感谢您赞助key,key为官方API使用,请以以下格式赞助:\n/key xxxxx" - return True, msg, "key" - key = l[1] - if self.provider.check_key(key): - self.provider.append_key(key) - return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~" - else: - return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key" + conf = self.provider.get_configs() + ret += "\n当前模型:" + conf['model'] + if conf['model'] in MODELS: + ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens" + + if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]): + ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens" + + return True, ret, "status" async def switch(self, message: str): ''' @@ -187,14 +180,13 @@ class CommandOpenAIOfficial(Command): return True, ret, "switch" elif len(l) == 2: try: - key_stat = self.provider.get_key_stat() + key_stat = self.provider.get_keys_data() index = int(l[1]) if index > len(key_stat) or index < 1: return True, "账号序号不合法。", "switch" else: try: new_key = list(key_stat.keys())[index-1] - ret = await self.provider.check_key(new_key) self.provider.set_key(new_key) except BaseException as e: return True, "账号切换失败,原因: " + str(e), "switch" @@ -243,58 +235,22 @@ class CommandOpenAIOfficial(Command): 'name': ps, 'prompt': personalities[ps] } - self.provider.session_dict[session_id] = [] - new_record = { - "user": { - "role": "user", - "content": personalities[ps], - }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 - } - self.provider.session_dict[session_id].append(new_record) - self.personality_str = message + self.provider.personality_set(ps, session_id) return True, f"人格{ps}已设置。", "set" else: self.provider.curr_personality = { 'name': '自定义人格', 'prompt': ps } - new_record = { - "user": { - "role": "user", - "content": ps, - }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 - } - self.provider.session_dict[session_id] = [] - self.provider.session_dict[session_id].append(new_record) - self.personality_str = message + self.provider.personality_set(ps, session_id) return True, f"自定义人格已设置。 \n人格信息: {ps}", "set" - async def draw(self, message): + async def draw(self, message: str): if self.provider is None: return False, "未启用 OpenAI 官方 API", "draw" if message.startswith("/画"): message = message[2:] elif message.startswith("画"): message = message[1:] - try: - # 画图模式传回3个参数 - img_url = await self.provider.image_chat(message) - return True, img_url, "draw" - except Exception as e: - if 'exceeded' in str(e): - return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)" - return False, f"图片生成失败: {e}", "draw" + img_url = await self.provider.image_generate(message) + return True, img_url, "draw" \ No newline at end of file diff --git a/model/platform/_message_parse.py b/model/platform/_message_parse.py index a55f0beb0..af25d5955 100644 --- a/model/platform/_message_parse.py +++ b/model/platform/_message_parse.py @@ -5,7 +5,7 @@ from nakuru import ( FriendMessage ) import botpy.message -from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember +from type.message import * from typing import List, Union import time diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py index 20955b101..68bd04557 100644 --- a/model/platform/qq_gocq.py +++ b/model/platform/qq_gocq.py @@ -15,7 +15,7 @@ import time from ._platfrom import Platform from ._message_parse import nakuru_message_parse_rev -from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember +from type.message import * from SparkleLogging.utils.core import LogManager from logging import Logger diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 520d3d320..daa5a0f5f 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -5,6 +5,8 @@ import botpy.message import re import asyncio import aiohttp +import botpy.types +import botpy.types.message from util import general_utils as gu from botpy.types.message import Reference @@ -15,7 +17,7 @@ from ._message_parse import ( qq_official_message_parse_rev, qq_official_message_parse ) -from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember +from type.message import * from typing import Union, List from nakuru.entities.components import BaseMessageComponent from SparkleLogging.utils.core import LogManager diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 72f1c2f89..fff43ca15 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -5,95 +5,108 @@ import time import tiktoken import threading import traceback +import base64 from openai import AsyncOpenAI from openai.types.images_response import ImagesResponse from openai.types.chat.chat_completion import ChatCompletion +from openai._exceptions import * -from cores.database.conn import dbConn +from persist.session import dbConn from model.provider.provider import Provider from util import general_utils as gu from util.cmd_config import CmdConfig from SparkleLogging.utils.core import LogManager from logging import Logger +from typing import List, Dict logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - -abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' - +MODELS = { + "gpt-4o": 128000, + "gpt-4o-2024-05-13": 128000, + "gpt-4-turbo": 128000, + "gpt-4-turbo-2024-04-09": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-0125-preview": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4-1106-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 16385, + "gpt-3.5-turbo-16k-0613": 16385, +} class ProviderOpenAIOfficial(Provider): - def __init__(self, cfg): - self.cc = CmdConfig() + def __init__(self, cfg) -> None: + super().__init__() - self.key_list = [] - # 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错 - for key in cfg['key']: - if len(key) == 1: - raise BaseException( - "检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。") - if cfg['key'] != '' and cfg['key'] != None: - self.key_list = cfg['key'] - if len(self.key_list) == 0: - raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。") + os.makedirs("data/openai", exist_ok=True) - self.key_stat = {} - for k in self.key_list: - self.key_stat[k] = {'exceed': False, 'used': 0} + self.cc = CmdConfig + self.key_data_path = "data/openai/keys.json" + self.api_keys = [] + self.chosen_api_key = None + self.base_url = None + self.keys_data = {} # 记录超额 - self.api_base = None - if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': - self.api_base = cfg['api_base'] - logger.info(f"设置 api_base 为: {self.api_base}") + if cfg['key']: self.api_keys = cfg['key'] + if cfg['api_base']: self.base_url = cfg['api_base'] + if not self.api_keys: + logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。") + else: + self.chosen_api_key = self.api_keys[0] + + for key in self.api_keys: + self.keys_data[key] = True - # 创建 OpenAI Client self.client = AsyncOpenAI( - api_key=self.key_list[0], - base_url=self.api_base + api_key=self.chosen_api_key, + base_url=self.base_url ) - - self.openai_model_configs: dict = cfg['chatGPTConfigs'] - self.openai_configs = cfg - # 会话缓存 - self.session_dict = {} - # 最大缓存token - self.max_tokens = cfg['total_tokens_limit'] - # 历史记录持久化间隔时间 - self.history_dump_interval = 20 - - self.enc = tiktoken.get_encoding("cl100k_base") - + self.model_configs: Dict = cfg['chatGPTConfigs'] + self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None) + self.session_memory: Dict[str, List] = {} # 会话记忆 + self.session_memory_lock = threading.Lock() + self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 + self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 + self.DEFAULT_PERSONALITY = { + "name": "default", + "prompt": "你是一个很有帮助的 AI 助手。" + } + self.curr_personality = self.DEFAULT_PERSONALITY + self.session_personality = {} # 记录了某个session是否已设置人格。 # 从 SQLite DB 读取历史记录 try: db1 = dbConn() for session in db1.get_all_session(): - self.session_dict[session[0]] = json.loads(session[1])['data'] - logger.info("读取历史记录成功。") + self.session_memory_lock.acquire() + self.session_memory[session[0]] = json.loads(session[1])['data'] + self.session_memory_lock.release() except BaseException as e: - logger.info("读取历史记录失败,但不影响使用。") - - # 创建转储定时器线程 + logger.warn(f"读取 OpenAI LLM 对话历史记录 失败:{e}。仍可正常使用。") + + # 定时保存历史记录 threading.Thread(target=self.dump_history, daemon=True).start() - # 人格 - self.curr_personality = {} - - def make_tmp_client(self, api_key: str, base_url: str): - return AsyncOpenAI( - api_key=api_key, - base_url=base_url - ) - - # 转储历史记录 def dump_history(self): + ''' + 转储历史记录 + ''' time.sleep(10) db = dbConn() while True: try: - # print("转储历史记录...") - for key in self.session_dict: - data = self.session_dict[key] + for key in self.session_memory: + data = self.session_memory[key] data_json = { 'data': data } @@ -101,309 +114,382 @@ class ProviderOpenAIOfficial(Provider): db.update_session(key, json.dumps(data_json)) else: db.insert_session(key, json.dumps(data_json)) - # print("转储历史记录完毕") + logger.debug("已保存 OpenAI 会话历史记录") except BaseException as e: print(e) - # 每隔10分钟转储一次 - time.sleep(10*self.history_dump_interval) - + finally: + time.sleep(10*60) + def personality_set(self, default_personality: dict, session_id: str): + if not default_personality: return + if session_id not in self.session_memory: + self.session_memory[session_id] = [] self.curr_personality = default_personality + self.session_personality = {} # 重置 + encoded_prompt = self.tokenizer.encode(default_personality['prompt']) + tokens_num = len(encoded_prompt) + model = self.model_configs['model'] + if model in MODELS and tokens_num > MODELS[model] - 500: + default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500]) + new_record = { "user": { - "role": "user", + "role": "system", "content": default_personality['prompt'], }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 + 'usage_tokens': 0, # 到该条目的总 token 数 + 'single-tokens': 0 # 该条目的 token 数 } - self.session_dict[session_id].append(new_record) - async def text_chat(self, prompt, - session_id=None, - image_url=None, - function_call=None, - extra_conf: dict = None, - default_personality: dict = None): - if session_id is None: - session_id = "unknown" - if "unknown" in self.session_dict: - del self.session_dict["unknown"] - # 会话机制 - if session_id not in self.session_dict: - self.session_dict[session_id] = [] + self.session_memory[session_id].append(new_record) - if len(self.session_dict[session_id]) == 0: - # 设置默认人格 - if default_personality is not None: - self.personality_set(default_personality, session_id) + async def encode_image_bs64(self, image_url: str) -> str: + ''' + 将图片转换为 base64 + ''' + if image_url.startswith("http"): + image_url = await gu.download_image_by_url(image_url) + + with open(image_url, "rb") as f: + image_bs64 = base64.b64encode(f.read()).decode() + return "data:image/jpeg;base64," + image_bs64 - # 使用 tictoken 截断消息 - _encoded_prompt = self.enc.encode(prompt) - if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): - prompt = self.enc.decode(_encoded_prompt[:int( - self.openai_model_configs['max_tokens']*0.80)]) - logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。") - - cache_data_list, new_record, req = self.wrap( - prompt, session_id, image_url) - logger.debug(f"cache: {str(cache_data_list)}") - logger.debug(f"request: {str(req)}") - retry = 0 - response = None - err = '' - - # 截断倍率 - truncate_rate = 0.75 - - conf = self.openai_model_configs - if extra_conf is not None: - conf.update(extra_conf) - - while retry < 10: - try: - if function_call is None: - response = await self.client.chat.completions.create( - messages=req, - **conf - ) - else: - response = await self.client.chat.completions.create( - messages=req, - tools=function_call, - **conf - ) - break - except Exception as e: - traceback.print_exc() - if 'Invalid content type. image_url is only supported by certain models.' in str(e): - raise e - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - logger.info("当前 Key 已超额或异常, 正在切换", - ) - self.key_stat[self.client.api_key]['exceed'] = True - is_switched = self.handle_switch_key() - if not is_switched: - raise e - retry -= 1 - elif 'maximum context length' in str(e): - logger.info("token 超限, 清空对应缓存,并进行消息截断") - self.session_dict[session_id] = [] - prompt = prompt[:int(len(prompt)*truncate_rate)] - truncate_rate -= 0.05 - cache_data_list, new_record, req = self.wrap( - prompt, session_id) - - elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e): - time.sleep(30) + async def retrieve_context(self, session_id: str): + ''' + 根据 session_id 获取保存的 OpenAI 格式的上下文 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + # 转换为 openai 要求的格式 + context = [] + is_lvm = await self.is_lvm() + for record in self.session_memory[session_id]: + if "user" in record and record['user']: + if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list): + logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。") continue - else: - logger.error(str(e)) - time.sleep(2) - err = str(e) - retry += 1 - if retry >= 10: - logger.warning( - r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki") - raise BaseException("连接出错: "+str(err)) - assert isinstance(response, ChatCompletion) - logger.debug( - f"OPENAI RESPONSE: {response.usage}") + context.append(record['user']) + if "AI" in record and record['AI']: + context.append(record['AI']) - # 结果分类 - choice = response.choices[0] - if choice.message.content != None: - # 文本形式 - chatgpt_res = str(choice.message.content).strip() - elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0: + return context + + async def is_lvm(self): + ''' + 是否是 LVM + ''' + return self.model_configs['model'].startswith("gpt-4") + + async def get_models(self): + ''' + 获取所有模型 + ''' + models = await self.client.models.list() + logger.info(f"OpenAI 模型列表:{models}") + return models + + async def assemble_context(self, session_id: str, prompt: str, image_url: str = None): + ''' + 组装上下文,并且根据当前上下文窗口大小截断 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + tokens_num = len(self.tokenizer.encode(prompt)) + previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens'] + + message = { + "usage_tokens": previous_total_tokens_num + tokens_num, + "single_tokens": tokens_num, + "AI": None + } + if image_url: + user_content = { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": await self.encode_image_bs64(image_url) + } + } + ] + } + else: + user_content = { + "role": "user", + "content": prompt + } + + message["user"] = user_content + self.session_memory[session_id].append(message) + + # 根据 模型的上下文窗口 淘汰掉多余的记录 + curr_model = self.model_configs['model'] + if curr_model in MODELS: + maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion + # if message['usage_tokens'] > maxium_tokens_num: + # 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300 + # contexts = self.session_memory[session_id] + # need_to_remove_idx = 0 + # freed_tokens_num = contexts[0]['single-tokens'] + # while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num: + # need_to_remove_idx += 1 + # freed_tokens_num += contexts[need_to_remove_idx]['single-tokens'] + # # 更新之后的所有记录的 usage_tokens + # for i in range(len(contexts)): + # if i > need_to_remove_idx: + # contexts[i]['usage_tokens'] -= freed_tokens_num + # logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。") + # self.session_memory[session_id] = contexts[need_to_remove_idx+1:] + while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num: + self.pop_record(session_id) + + + async def pop_record(self, session_id: str, pop_system_prompt: bool = False): + ''' + 弹出第一条记录 + ''' + if session_id not in self.session_memory: + raise Exception("会话 ID 不存在") + + if len(self.session_memory[session_id]) == 0: + return None + + for i in range(len(self.session_memory[session_id])): + # 检查是否是 system prompt + if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system": + # 如果只有一个 system prompt,才不删掉 + f = False + for j in range(i+1, len(self.session_memory[session_id])): + if self.session_memory[session_id][j]['user']['role'] == "system": + f = True + break + if not f: + continue + record = self.session_memory[session_id].pop(i) + break + + # 更新之后所有记录的 usage_tokens + for i in range(len(self.session_memory[session_id])): + self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens'] + logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}。") + return record + + async def text_chat(self, + prompt: str, + session_id: str, + image_url: None=None, + tools: None=None, + extra_conf: Dict = None, + **kwargs + ) -> str: + if not session_id: + session_id = "unknown" + if "unknown" in self.session_memory: + del self.session_memory["unknown"] + + if session_id not in self.session_memory: + self.session_memory[session_id] = [] + + if session_id not in self.session_personality or not self.session_personality[session_id]: + self.personality_set(self.curr_personality, session_id) + self.session_personality[session_id] = True + + # 如果 prompt 超过了最大窗口,截断。 + # 1. 可以保证之后 pop 的时候不会出现问题 + # 2. 可以保证不会超过最大 token 数 + _encoded_prompt = self.tokenizer.encode(prompt) + curr_model = self.model_configs['model'] + if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300: + _encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300] + prompt = self.tokenizer.decode(_encoded_prompt) + + # 组装上下文,并且根据当前上下文窗口大小截断 + await self.assemble_context(session_id, prompt, image_url) + + # 获取上下文,openai 格式 + contexts = await self.retrieve_context(session_id) + + conf = self.model_configs + if extra_conf: conf.update(extra_conf) + + # start request + retry = 0 + rate_limit_retry = 0 + while retry < 3 or rate_limit_retry < 5: + logger.debug(conf) + logger.debug(contexts) + if tools: + completion_coro = self.client.chat.completions.create( + messages=contexts, + tools=tools, + **conf + ) + else: + completion_coro = self.client.chat.completions.create( + messages=contexts, + **conf + ) + try: + completion = await completion_coro + break + except AuthenticationError as e: + api_key = self.chosen_api_key[10:] + "..." + logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)") + self.keys_data[self.chosen_api_key] = False + ok = await self.switch_to_next_key() + if ok: continue + else: raise Exception("所有 OpenAI API Key 目前都不可用。") + except BadRequestError as e: + logger.warn(f"OpenAI 请求异常:{e}。") + if "image_url is only supported by certain models." in str(e): + raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。") + retry += 1 + except RateLimitError as e: + if "You exceeded your current quota" in str(e): + self.keys_data[self.chosen_api_key] = False + ok = await self.switch_to_next_key() + if ok: continue + else: raise Exception("所有 OpenAI API Key 目前都不可用。") + logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}") + await self.switch_to_next_key() + rate_limit_retry += 1 + time.sleep(1) + except Exception as e: + retry += 1 + if retry >= 3: + logger.error(traceback.format_exc()) + raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。") + if "maximum context length" in str(e): + logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") + self.pop_record(session_id) + + logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。") + time.sleep(1) + + assert isinstance(completion, ChatCompletion) + logger.debug(f"openai completion: {completion.usage}") + + choice = completion.choices[0] + + usage_tokens = completion.usage.total_tokens + completion_tokens = completion.usage.completion_tokens + self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens + self.session_memory[session_id][-1]['single_tokens'] += completion_tokens + + if choice.message.content: + # 返回文本 + completion_text = str(choice.message.content).strip() + elif choice.message.tool_calls and choice.message.tool_calls: # tools call (function calling) return choice.message.tool_calls[0].function - - self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens - current_usage_tokens = response.usage.total_tokens - - # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens - if current_usage_tokens > self.max_tokens: - t = current_usage_tokens - index = 0 - while t > self.max_tokens: - if index >= len(cache_data_list): - break - # 保留人格信息 - if cache_data_list[index]['type'] != 'personality': - t -= int(cache_data_list[index]['single_tokens']) - del cache_data_list[index] - else: - index += 1 - # 删除完后更新相关字段 - self.session_dict[session_id] = cache_data_list - - # 添加新条目进入缓存的prompt - new_record['AI'] = { - 'role': 'assistant', - 'content': chatgpt_res, + + self.session_memory[session_id][-1]['AI'] = { + "role": "assistant", + "content": completion_text } - new_record['usage_tokens'] = current_usage_tokens - if len(cache_data_list) > 0: - new_record['single_tokens'] = current_usage_tokens - \ - int(cache_data_list[-1]['usage_tokens']) - else: - new_record['single_tokens'] = current_usage_tokens - cache_data_list.append(new_record) - - self.session_dict[session_id] = cache_data_list - - return chatgpt_res - - async def image_chat(self, prompt, img_num=1, img_size="1024x1024"): - retry = 0 - image_url = '' - - image_generate_configs = self.cc.get("openai_image_generate", None) - - while retry < 5: - try: - response: ImagesResponse = await self.client.images.generate( - prompt=prompt, - **image_generate_configs - ) - image_url = [] - for i in range(img_num): - image_url.append(response.data[i].url) - break - except Exception as e: - logger.warning(str(e)) - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( - e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - logger.warning("当前 Key 已超额或者不正常, 正在切换") - self.key_stat[self.client.api_key]['exceed'] = True - is_switched = self.handle_switch_key() - if not is_switched: - raise e - elif 'Your request was rejected as a result of our safety system.' in str(e): - logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试") - raise e - else: - retry += 1 - if retry >= 5: - raise BaseException("连接超时") - - return image_url - - async def forget(self, session_id=None) -> bool: - if session_id is None: + return completion_text + + async def switch_to_next_key(self): + ''' + 切换到下一个 API Key + ''' + if not self.api_keys: + logger.error("OpenAI API Key 不存在。") return False - self.session_dict[session_id] = [] - return True - def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): + for key in self.keys_data: + if self.keys_data[key]: + # 没超额 + self.chosen_api_key = key + self.client.api_key = key + logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。") + return True + + return False + + async def image_generate(self, prompt, session_id, **kwargs) -> str: + ''' + 生成图片 + ''' + retry = 0 + conf = self.image_generator_model_configs + if not conf: + logger.error("OpenAI 图片生成模型配置不存在。") + raise Exception("OpenAI 图片生成模型配置不存在。") + + while retry < 3: + try: + images_response = await self.client.images.generate( + prompt=prompt, + **conf + ) + image_url = images_response.data[0].url + return image_url + except Exception as e: + retry += 1 + if retry >= 3: + logger.error(traceback.format_exc()) + raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。") + logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。") + time.sleep(1) + + async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool: + if session_id is None: return False + self.session_memory[session_id] = [] + if keep_system_prompt: + self.personality_set(self.curr_personality, session_id) + else: + self.curr_personality = self.DEFAULT_PERSONALITY + return True + + def dump_contexts_page(self, session_id: str, size=5, page=1,): ''' 获取缓存的会话 ''' - prompts = "" - if paging: - page_begin = (page-1)*size - page_end = page*size - if page_begin < 0: - page_begin = 0 - if page_end > len(cache_data_list): - page_end = len(cache_data_list) - cache_data_list = cache_data_list[page_begin:page_end] - for item in cache_data_list: - prompts += str(item['user']['role']) + ":\n" + \ - str(item['user']['content']) + "\n" - prompts += str(item['AI']['role']) + ":\n" + \ - str(item['AI']['content']) + "\n" + # contexts_str = "" + # for i, key in enumerate(self.session_memory): + # if i < (page-1)*size or i >= page*size: + # continue + # contexts_str += f"Session ID: {key}\n" + # for record in self.session_memory[key]: + # if "user" in record: + # contexts_str += f"User: {record['user']['content']}\n" + # if "AI" in record: + # contexts_str += f"AI: {record['AI']['content']}\n" + # contexts_str += "---\n" + contexts_str = "" + if session_id in self.session_memory: + for record in self.session_memory[session_id]: + if "user" in record and record['user']: + text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content'] + contexts_str += f"User: {text}\n" + if "AI" in record and record['AI']: + text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content'] + contexts_str += f"Assistant: {text}\n" + else: + contexts_str = "会话 ID 不存在。" - if divide: - prompts += "----------\n" - return prompts - - def wrap(self, prompt, session_id, image_url=None): - if image_url is not None: - prompt = [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - } - ] - # 获得缓存信息 - context = self.session_dict[session_id] - new_record = { - "user": { - "role": "user", - "content": prompt, - }, - "AI": {}, - 'type': "common", - 'usage_tokens': 0, - } - req_list = [] - for i in context: - if 'user' in i: - req_list.append(i['user']) - if 'AI' in i: - req_list.append(i['AI']) - req_list.append(new_record['user']) - return context, new_record, req_list - - def handle_switch_key(self): - is_all_exceed = True - for key in self.key_stat: - if key == None or self.key_stat[key]['exceed']: - continue - is_all_exceed = False - self.client.api_key = key - logger.warning( - f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})") - break - if is_all_exceed: - logger.warning( - "所有 Key 已超额") - return False - return True + return contexts_str, len(self.session_memory[session_id]) + def set_model(self, model: str): + self.model_configs['model'] = model + def get_configs(self): - return self.openai_configs + return self.model_configs - def get_key_stat(self): - return self.key_stat - - def get_key_list(self): - return self.key_list + def get_keys_data(self): + return self.keys_data def get_curr_key(self): - return self.client.api_key + return self.chosen_api_key def set_key(self, key): - self.client.api_key = key - - # 添加key - def append_key(self, key, sponsor): - self.key_list.append(key) - self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor} - - # 检查key是否可用 - async def check_key(self, key): - client_ = AsyncOpenAI( - api_key=key, - base_url=self.api_base - ) - messages = [{"role": "user", "content": "please just echo `test`"}] - await client_.chat.completions.create( - messages=messages, - **self.openai_model_configs - ) - return True + self.client.api_key = key \ No newline at end of file diff --git a/model/provider/provider.py b/model/provider/provider.py index f2a80202c..8bf437ae8 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -2,8 +2,8 @@ class Provider: async def text_chat(self, prompt: str, session_id: str, - image_url: None, - function_call: None, + image_url: None = None, + tools: None = None, extra_conf: dict = None, default_personality: dict = None, **kwargs) -> str: @@ -14,7 +14,7 @@ class Provider: [optional] image_url: 图片url(识图) - function_call: 函数调用 + tools: 函数调用工具 extra_conf: 额外配置 default_personality: 默认人格 ''' diff --git a/cores/database/conn.py b/persist/session.py similarity index 100% rename from cores/database/conn.py rename to persist/session.py diff --git a/type/command.py b/type/command.py new file mode 100644 index 000000000..0504f947d --- /dev/null +++ b/type/command.py @@ -0,0 +1,28 @@ +from typing import Union, List, Callable +from dataclasses import dataclass + + +@dataclass +class CommandItem(): + ''' + 用来描述单个指令 + ''' + + command_name: Union[str, tuple] # 指令名 + callback: Callable # 回调函数 + description: str # 描述 + origin: str # 注册来源 + +class CommandResult(): + ''' + 用于在Command中返回多个值 + ''' + + def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: + self.hit = hit + self.success = success + self.message_chain = message_chain + self.command_name = command_name + + def _result_tuple(self): + return (self.success, self.message_chain, self.command_name) diff --git a/type/message.py b/type/message.py new file mode 100644 index 000000000..ca49d51d6 --- /dev/null +++ b/type/message.py @@ -0,0 +1,62 @@ +from enum import Enum +from typing import List +from dataclasses import dataclass +from nakuru.entities.components import BaseMessageComponent + +from type.register import RegisteredPlatform +from type.types import GlobalObject + +class MessageType(Enum): + GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 + FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 + GUILD_MESSAGE = 'GuildMessage' # 频道消息 + + +@dataclass +class MessageMember(): + user_id: str # 发送者id + nickname: str = None + + +class AstrBotMessage(): + ''' + AstrBot 的消息对象 + ''' + tag: str # 消息来源标签 + type: MessageType # 消息类型 + self_id: str # 机器人的识别id + session_id: str # 会话id + message_id: str # 消息id + sender: MessageMember # 发送者 + message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 + message_str: str # 最直观的纯文本消息字符串 + raw_message: object + timestamp: int # 消息时间戳 + + def __str__(self) -> str: + return str(self.__dict__) + +class AstrMessageEvent(): + ''' + 消息事件。 + ''' + context: GlobalObject # 一些公用数据 + message_str: str # 纯消息字符串 + message_obj: AstrBotMessage # 消息对象 + platform: RegisteredPlatform # 来源平台 + role: str # 基本身份。`admin` 或 `member` + session_id: int # 会话 id + + def __init__(self, + message_str: str, + message_obj: AstrBotMessage, + platform: RegisteredPlatform, + role: str, + context: GlobalObject, + session_id: str = None): + self.context = context + self.message_str = message_str + self.message_obj = message_obj + self.platform = platform + self.role = role + self.session_id = session_id diff --git a/type/plugin.py b/type/plugin.py new file mode 100644 index 000000000..cbc878ed3 --- /dev/null +++ b/type/plugin.py @@ -0,0 +1,27 @@ +from enum import Enum +from dataclasses import dataclass + +class PluginType(Enum): + PLATFORM = 'platfrom' # 平台类插件。 + LLM = 'llm' # 大语言模型类插件 + COMMON = 'common' # 其他插件 + + +@dataclass +class PluginMetadata: + ''' + 插件的元数据。 + ''' + # required + plugin_name: str + plugin_type: PluginType + author: str # 插件作者 + desc: str # 插件简介 + version: str # 插件版本 + + # optional + repo: str = None # 插件仓库地址 + + def __str__(self) -> str: + return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" + diff --git a/type/register.py b/type/register.py new file mode 100644 index 000000000..8d99fc666 --- /dev/null +++ b/type/register.py @@ -0,0 +1,46 @@ +from model.provider.provider import Provider as LLMProvider +from model.platform._platfrom import Platform +from type.plugin import * +from typing import List +from types import ModuleType +from dataclasses import dataclass + +@dataclass +class RegisteredPlugin: + ''' + 注册在 AstrBot 中的插件。 + ''' + metadata: PluginMetadata + plugin_instance: object + module_path: str + module: ModuleType + root_dir_name: str + + def __str__(self) -> str: + return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" + + +RegisteredPlugins = List[RegisteredPlugin] + + +@dataclass +class RegisteredPlatform: + ''' + 注册在 AstrBot 中的平台。平台应当实现 Platform 接口。 + ''' + platform_name: str + platform_instance: Platform + origin: str = None # 注册来源 + + def __str__(self) -> str: + return self.platform_name + + +@dataclass +class RegisteredLLM: + ''' + 注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。 + ''' + llm_name: str + llm_instance: LLMProvider + origin: str = None # 注册来源 diff --git a/type/types.py b/type/types.py new file mode 100644 index 000000000..2f9fd4928 --- /dev/null +++ b/type/types.py @@ -0,0 +1,34 @@ +from type.register import * +from typing import List + +class GlobalObject: + ''' + 存放一些公用的数据,用于在不同模块(如core与command)之间传递 + ''' + version: str # 机器人版本 + nick: tuple # 用户定义的机器人的别名 + base_config: dict # config.json 中导出的配置 + cached_plugins: List[RegisteredPlugin] # 加载的插件 + platforms: List[RegisteredPlatform] + llms: List[RegisteredLLM] + + web_search: bool # 是否开启了网页搜索 + reply_prefix: str # 回复前缀 + unique_session: bool # 是否开启了独立会话 + cnt_total: int # 总消息数 + default_personality: dict + dashboard_data = None + + def __init__(self): + self.nick = None # gocq 的昵称 + self.base_config = None # config.yaml + self.cached_plugins = [] # 缓存的插件 + self.web_search = False # 是否开启了网页搜索 + self.reply_prefix = None + self.unique_session = False + self.cnt_total = 0 + self.platforms = [] + self.llms = [] + self.default_personality = None + self.dashboard_data = None + self.stat = {} diff --git a/util/function_calling/func_call.py b/util/agent/func_call.py similarity index 100% rename from util/function_calling/func_call.py rename to util/agent/func_call.py diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py new file mode 100644 index 000000000..e07f80847 --- /dev/null +++ b/util/agent/web_searcher.py @@ -0,0 +1,183 @@ +import traceback +import random +import json +import asyncio +import aiohttp +import os + +from readability import Document +from bs4 import BeautifulSoup +from openai.types.chat.chat_completion_message_tool_call import Function +from util.agent.func_call import FuncCall +from util.search_engine_scraper.config import HEADERS, USER_AGENTS +from util.search_engine_scraper.bing import Bing +from util.search_engine_scraper.sogo import Sogo +from util.search_engine_scraper.google import Google +from model.provider.provider import Provider +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot-core') + + +bing_search = Bing() +sogo_search = Sogo() +google = Google() +proxy = os.environ.get("HTTPS_PROXY", None) + +def tidy_text(text: str) -> str: + ''' + 清理文本,去除空格、换行符等 + ''' + return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + +# def special_fetch_zhihu(link: str) -> str: +# ''' +# function-calling 函数, 用于获取知乎文章的内容 +# ''' +# response = requests.get(link, headers=HEADERS) +# response.encoding = "utf-8" +# soup = BeautifulSoup(response.text, "html.parser") + +# if "zhuanlan.zhihu.com" in link: +# r = soup.find(class_="Post-RichTextContainer") +# else: +# r = soup.find(class_="List-item").find(class_="RichContent-inner") +# if r is None: +# print("debug: zhihu none") +# raise Exception("zhihu none") +# return tidy_text(r.text) + +async def search_from_bing(keyword: str) -> str: + ''' + tools, 从 bing 搜索引擎搜索 + ''' + logger.info("web_searcher - search_from_bing: " + keyword) + results = [] + try: + results = await google.search(keyword, 5) + except BaseException as e: + logger.error(f"google search error: {e}, try the next one...") + if len(results) == 0: + logger.debug("search google failed") + try: + results = await bing_search.search(keyword, 5) + except BaseException as e: + logger.error(f"bing search error: {e}, try the next one...") + if len(results) == 0: + logger.debug("search bing failed") + try: + results = await sogo_search.search(keyword, 5) + except BaseException as e: + logger.error(f"sogo search error: {e}") + if len(results) == 0: + logger.debug("search sogo failed") + return "没有搜索到结果" + ret = "" + idx = 1 + for i in results: + logger.info(f"web_searcher - scraping web: {i.title} - {i.url}") + try: + site_result = await fetch_website_content(i.url) + except: + site_result = "" + site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result + ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n" + idx += 1 + return ret + + +async def fetch_website_content(url): + header = HEADERS + header.update({'User-Agent': random.choice(USER_AGENTS)}) + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=HEADERS, timeout=6, proxy=proxy) as response: + html = await response.text(encoding="utf-8") + doc = Document(html) + ret = doc.summary(html_partial=True) + soup = BeautifulSoup(ret, 'html.parser') + ret = tidy_text(soup.get_text()) + return ret + + +async def web_search(prompt, provider: Provider, session_id, official_fc=False): + ''' + official_fc: 使用官方 function-calling + ''' + new_func_call = FuncCall(provider) + + new_func_call.add_func("web_search", [{ + "type": "string", + "name": "keyword", + "description": "搜索关键词" + }], + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + search_from_bing + ) + new_func_call.add_func("fetch_website_content", [{ + "type": "string", + "name": "url", + "description": "要获取内容的网页链接" + }], + "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", + fetch_website_content + ) + + has_func = False + function_invoked_ret = "" + if official_fc: + # we use official function-calling + result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func()) + if isinstance(result, Function): + logger.debug(f"web_searcher - function-calling: {result}") + func_obj = None + for i in new_func_call.func_list: + if i["name"] == result.name: + func_obj = i["func_obj"] + break + if not func_obj: + return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" + try: + args = json.loads(result.arguments) + function_invoked_ret = await func_obj(**args) + has_func = True + except BaseException as e: + traceback.print_exc() + return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" + else: + return result + else: + # we use our own function-calling + try: + args = { + 'question': prompt, + 'func_definition': new_func_call.func_dump(), + 'is_task': False, + 'is_summary': False, + } + function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) + except BaseException as e: + res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)" + return res + has_func = True + + if has_func: + await provider.forget(session_id) + summary_prompt = f""" +你是一个专业且高效的助手,你的任务是 +1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; +2. 简单地发表你对这个问题的简略看法。 + +# 例子 +1. 从网上的信息来看,可以知道...我个人认为...你觉得呢? +2. 根据网上的最新信息,可以得知...我觉得...你怎么看? + +# 限制 +1. 限制在 200 字以内; +2. 请**直接输出总结**,不要输出多余的内容和提示语。 + +# 相关材料 +{function_invoked_ret}""" + ret = await provider.text_chat(summary_prompt, session_id) + return ret + return function_invoked_ret diff --git a/util/cmd_config.py b/util/cmd_config.py index ea05ed74d..928f3bb25 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -2,8 +2,7 @@ import os import json from typing import Union -cpath = "cmd_config.json" - +cpath = "data/cmd_config.json" def check_exist(): if not os.path.exists(cpath): @@ -89,8 +88,7 @@ def init_astrbot_config_items(): # 加载默认配置 cc = CmdConfig() cc.init_attributes("qq_forward_threshold", 200) - cc.init_attributes( - "qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n") + cc.init_attributes("qq_welcome", "") cc.init_attributes("qq_pic_mode", False) cc.init_attributes("gocq_host", "127.0.0.1") cc.init_attributes("gocq_http_port", 5700) diff --git a/util/function_calling/gplugin.py b/util/function_calling/gplugin.py deleted file mode 100644 index 5ddab894e..000000000 --- a/util/function_calling/gplugin.py +++ /dev/null @@ -1,300 +0,0 @@ -import requests -import util.general_utils as gu -import traceback -import time -import json -import asyncio -from googlesearch import search, SearchResult -from readability import Document -from bs4 import BeautifulSoup -from openai.types.chat.chat_completion_message_tool_call import Function -from util.function_calling.func_call import ( - FuncCall, - FuncCallJsonFormatError, - FuncNotFoundError -) -from model.provider.provider import Provider - - -def tidy_text(text: str) -> str: - ''' - 清理文本,去除空格、换行符等 - ''' - return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") - - -def special_fetch_zhihu(link: str) -> str: - ''' - function-calling 函数, 用于获取知乎文章的内容 - ''' - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - response = requests.get(link, headers=headers) - response.encoding = "utf-8" - soup = BeautifulSoup(response.text, "html.parser") - - if "zhuanlan.zhihu.com" in link: - r = soup.find(class_="Post-RichTextContainer") - else: - r = soup.find(class_="List-item").find(class_="RichContent-inner") - if r is None: - print("debug: zhihu none") - raise Exception("zhihu none") - return tidy_text(r.text) - - -def google_web_search(keyword) -> str: - ''' - 获取 google 搜索结果, 得到 title、desc、link - ''' - ret = "" - index = 1 - try: - ls = search(keyword, advanced=True, num_results=4) - for i in ls: - desc = i.description - try: - # gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO) - desc = fetch_website_content(i.url) - except BaseException as e: - print(f"(google) fetch_website_content err: {str(e)}") - # gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999) - ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n" - index += 1 - except Exception as e: - print(f"google search err: {str(e)}") - return web_keyword_search_via_bing(keyword) - return ret - - -def web_keyword_search_via_bing(keyword) -> str: - ''' - 获取bing搜索结果, 得到 title、desc、link - ''' - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - url = "https://www.bing.com/search?q="+keyword - _cnt = 0 - # _detail_store = [] - while _cnt < 5: - try: - response = requests.get(url, headers=headers) - response.encoding = "utf-8" - # gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999) - soup = BeautifulSoup(response.text, "html.parser") - res = "" - result_cnt = 0 - ols = soup.find(id="b_results") - for i in ols.find_all("li", class_="b_algo"): - try: - title = i.find("h2").text - desc = i.find("p").text - link = i.find("h2").find("a").get("href") - # res.append({ - # "title": title, - # "desc": desc, - # "link": link, - # }) - try: - # gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO) - desc = fetch_website_content(link) - except BaseException as e: - print(f"(bing) fetch_website_content err: {str(e)}") - - res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n" - result_cnt += 1 - if result_cnt > 5: - break - - # if len(_detail_store) >= 3: - # continue - # # 爬取前两条的网页内容 - # if "zhihu.com" in link: - # try: - # _detail_store.append(special_fetch_zhihu(link)) - # except BaseException as e: - # print(f"zhihu parse err: {str(e)}") - # else: - # try: - # _detail_store.append(fetch_website_content(link)) - # except BaseException as e: - # print(f"fetch_website_content err: {str(e)}") - - except Exception as e: - print(f"bing parse err: {str(e)}") - if result_cnt == 0: - break - return res - except Exception as e: - # gu.log(f"bing fetch err: {str(e)}") - _cnt += 1 - time.sleep(1) - - # gu.log("fail to fetch bing info, using sougou.") - return web_keyword_search_via_sougou(keyword) - - -def web_keyword_search_via_sougou(keyword) -> str: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", - } - url = f"https://sogou.com/web?query={keyword}" - response = requests.get(url, headers=headers) - response.encoding = "utf-8" - soup = BeautifulSoup(response.text, "html.parser") - - res = [] - results = soup.find("div", class_="results") - for i in results.find_all("div", class_="vrwrap"): - try: - title = tidy_text(i.find("h3").text) - link = tidy_text(i.find("h3").find("a").get("href")) - if link.startswith("/link?url="): - link = "https://www.sogou.com" + link - res.append({ - "title": title, - "link": link, - }) - if len(res) >= 5: # 限制5条 - break - except Exception as e: - pass - # gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR) - # 爬取网页内容 - _detail_store = [] - for i in res: - if _detail_store >= 3: - break - try: - _detail_store.append(fetch_website_content(i["link"])) - except BaseException as e: - print(f"fetch_website_content err: {str(e)}") - ret = f"{str(res)}" - if len(_detail_store) > 0: - ret += f"\n网页内容: {str(_detail_store)}" - return ret - - -def fetch_website_content(url): - # gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG) - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - response = requests.get(url, headers=headers, timeout=3) - response.encoding = "utf-8" - doc = Document(response.content) - # print('title:', doc.title()) - ret = doc.summary(html_partial=True) - soup = BeautifulSoup(ret, 'html.parser') - ret = tidy_text(soup.get_text()) - return ret - - -async def web_search(question, provider: Provider, session_id, official_fc=False): - ''' - official_fc: 使用官方 function-calling - ''' - new_func_call = FuncCall(provider) - new_func_call.add_func("google_web_search", [{ - "type": "string", - "name": "keyword", - "description": "google search query (分词,尽量保留所有信息)" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - web_keyword_search_via_bing - ) - new_func_call.add_func("fetch_website_content", [{ - "type": "string", - "name": "url", - "description": "网址" - }], - "获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content - ) - question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。" - has_func = False - function_invoked_ret = "" - if official_fc: - # we use official function-calling - func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func()) - if isinstance(func, Function): - # 执行对应的结果: - func_obj = None - for i in new_func_call.func_list: - if i["name"] == func.name: - func_obj = i["func_obj"] - break - if not func_obj: - # gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR) - return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" - try: - args = json.loads(func.arguments) - # we use to_thread to avoid blocking the event loop - function_invoked_ret = await asyncio.to_thread(func_obj, **args) - has_func = True - except BaseException as e: - traceback.print_exc() - return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" - else: - # now func is a string - return func - else: - # we use our own function-calling - try: - args = { - 'question': question1, - 'func_definition': new_func_call.func_dump(), - 'is_task': False, - 'is_summary': False, - } - function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) - except BaseException as e: - res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)" - return res - has_func = True - - if has_func: - await provider.forget(session_id) - question3 = f""" -你的任务是: -1. 根据末尾的材料对问题`{question}`做切题的总结(详细); -2. 简单地发表你对这个问题的看法(简略)。 -你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。 -不要提到任何函数调用的信息。 - -一些回复的消息模板: -模板1: -``` -从网上的信息来看,可以知道...我个人认为...你觉得呢? -``` -模板2: -``` -根据网上的最新信息,可以得知...我觉得...你怎么看? -``` -你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。 - -以下是相关材料: -""" - _c = 0 - while _c < 3: - try: - print('text chat') - final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id) - return final_ret - except Exception as e: - print(e) - _c += 1 - if _c == 3: - raise e - if "The message you submitted was too long" in str(e): - await provider.forget(session_id) - function_invoked_ret = function_invoked_ret[:int( - len(function_invoked_ret) / 2)] - time.sleep(3) - return function_invoked_ret diff --git a/util/general_utils.py b/util/general_utils.py index 7467342eb..d9f2d851c 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -1,18 +1,23 @@ -import datetime import time import socket -from PIL import Image, ImageDraw, ImageFont import os import re import requests -from util.cmd_config import CmdConfig +import aiohttp import socket -from cores.astrbot.types import GlobalObject import platform -import logging import json import sys import psutil +import ssl + +from PIL import Image, ImageDraw, ImageFont +from type.types import GlobalObject +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot-core') + def port_checker(port: int, host: str = "localhost"): sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -355,10 +360,36 @@ def save_temp_img(img: Image) -> str: # 获得时间戳 timestamp = int(time.time()) - p = f"temp/{timestamp}.png" - img.save(p) + p = f"temp/{timestamp}.jpg" + + if isinstance(img, Image.Image): + img.save(p) + else: + with open(p, "wb") as f: + f.write(img) + logger.info(f"保存临时图片: {p}") return p +async def download_image_by_url(url: str) -> str: + ''' + 下载图片 + ''' + try: + logger.info(f"下载图片: {url}") + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + return save_temp_img(await resp.read()) + except aiohttp.client_exceptions.ClientConnectorSSLError as e: + # 关闭SSL验证 + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get(url, ssl=ssl_context) as resp: + return save_temp_img(await resp.read()) + except Exception as e: + raise e + def create_text_image(title: str, text: str, max_width=30, font_size=20): ''' @@ -391,15 +422,19 @@ def create_markdown_image(text: str): raise e -def try_migrate_config(old_config: dict): +def try_migrate_config(): ''' - 迁移配置文件到 cmd_config.json + 将 cmd_config.json 迁移至 data/cmd_config.json ''' - cc = CmdConfig() - if cc.get("qqbot", None) is None: - # 未迁移过 - for k in old_config: - cc.put(k, old_config[k]) + if os.path.exists("cmd_config.json"): + with open("cmd_config.json", "r", encoding="utf-8-sig") as f: + data = json.load(f) + with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + try: + os.remove("cmd_config.json") + except Exception as e: + pass def get_local_ip_addresses(): @@ -451,6 +486,21 @@ def upload(_global_object: GlobalObject): pass time.sleep(10*60) +def retry(n: int = 3): + ''' + 重试装饰器 + ''' + def decorator(func): + def wrapper(*args, **kwargs): + for i in range(n): + try: + return func(*args, **kwargs) + except Exception as e: + if i == n-1: raise e + logger.warning(f"函数 {func.__name__} 第 {i+1} 次重试... {e}") + return wrapper + return decorator + def run_monitor(global_object: GlobalObject): ''' diff --git a/util/plugin_dev/api/v1/bot.py b/util/plugin_dev/api/v1/bot.py index f855b6132..46a3b7fce 100644 --- a/util/plugin_dev/api/v1/bot.py +++ b/util/plugin_dev/api/v1/bot.py @@ -1,11 +1,5 @@ -from cores.astrbot.types import ( - PluginMetadata, - RegisteredLLM, - RegisteredPlugin, - RegisteredPlatform, - RegisteredPlugins, - PluginType, - GlobalObject, - AstrMessageEvent, - CommandResult -) \ No newline at end of file +from type.plugin import PluginMetadata, PluginType +from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins +from type.types import GlobalObject +from type.message import AstrMessageEvent +from type.command import CommandResult \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py index a0f2853e9..39722f709 100644 --- a/util/plugin_dev/api/v1/message.py +++ b/util/plugin_dev/api/v1/message.py @@ -1,5 +1,6 @@ -from cores.astrbot.core import oper_msg -from cores.astrbot.types import AstrMessageEvent, CommandResult +from astrbot.core import oper_msg +from type.message import AstrMessageEvent, AstrBotMessage +from type.command import CommandResult from model.platform._message_result import MessageResult ''' diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py index bf28cdd68..83ccd5460 100644 --- a/util/plugin_dev/api/v1/register.py +++ b/util/plugin_dev/api/v1/register.py @@ -5,7 +5,8 @@ ''' from model.provider.provider import Provider as LLMProvider from model.platform._platfrom import Platform -from cores.astrbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM +from type.types import GlobalObject +from type.register import RegisteredPlatform, RegisteredLLM def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None: ''' diff --git a/util/plugin_dev/api/v1/types.py b/util/plugin_dev/api/v1/types.py index 3c42be69b..96d07a775 100644 --- a/util/plugin_dev/api/v1/types.py +++ b/util/plugin_dev/api/v1/types.py @@ -2,4 +2,4 @@ 插件类型 ''' -from cores.astrbot.types import PluginType \ No newline at end of file +from type.plugin import PluginType \ No newline at end of file diff --git a/util/plugin_util.py b/util/plugin_util.py index 1045c108a..d8bf57a60 100644 --- a/util/plugin_util.py +++ b/util/plugin_util.py @@ -3,24 +3,18 @@ ''' import os import inspect +import shutil +import stat +import traceback + try: - import git.exc from git.repo import Repo except ImportError: pass -import shutil -import importlib -import stat -import traceback from types import ModuleType -from typing import List from pip._internal import main as pipmain -from cores.astrbot.types import ( - PluginMetadata, - PluginType, - RegisteredPlugin, - RegisteredPlugins -) +from type.plugin import * +from type.register import * # 找出模块里所有的类名 diff --git a/util/search_engine_scraper/bing.py b/util/search_engine_scraper/bing.py new file mode 100644 index 000000000..f51a2a042 --- /dev/null +++ b/util/search_engine_scraper/bing.py @@ -0,0 +1,38 @@ +from typing import List + +try: + from util.search_engine_scraper.engine import SearchEngine, SearchResult + from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING +except ImportError: + from engine import SearchEngine, SearchResult + from config import HEADERS, USER_AGENT_BING + +class Bing(SearchEngine): + def __init__(self) -> None: + super().__init__() + self.base_url = "https://www.bing.com" + self.headers.update({'User-Agent': USER_AGENT_BING}) + + def _set_selector(self, selector: str): + selectors = { + 'url': 'div.b_attribution cite', + 'title': 'h2', + 'text': 'p', + 'links': 'ol#b_results > li.b_algo', + 'next': 'div#b_content nav[role="navigation"] a.sb_pagN' + } + return selectors[selector] + + async def _get_next_page(self, query) -> str: + if self.page == 1: + await self._get_html(self.base_url) + url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl=' + return await self._get_html(url, None) + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + results = await super().search(query, num_results) + for result in results: + if not isinstance(result.url, str): + result.url = result.url.text + + return results \ No newline at end of file diff --git a/util/search_engine_scraper/config.py b/util/search_engine_scraper/config.py new file mode 100644 index 000000000..ab9cec6f8 --- /dev/null +++ b/util/search_engine_scraper/config.py @@ -0,0 +1,20 @@ +HEADERS = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0', + 'Accept': '*/*', + 'Connection': 'keep-alive', + 'Accept-Language': 'en-GB,en;q=0.5' +} + +USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0' +USER_AGENTS = [ + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0', + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36', + 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0', + 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0' +] \ No newline at end of file diff --git a/util/search_engine_scraper/engine.py b/util/search_engine_scraper/engine.py new file mode 100644 index 000000000..d78f5bff3 --- /dev/null +++ b/util/search_engine_scraper/engine.py @@ -0,0 +1,73 @@ +import random +try: + from util.search_engine_scraper.config import HEADERS, USER_AGENTS +except ImportError: + from config import HEADERS, USER_AGENTS + +from bs4 import BeautifulSoup +from aiohttp import ClientSession +from dataclasses import dataclass +from typing import List + + +@dataclass +class SearchResult(): + title: str + url: str + snippet: str + + def __str__(self) -> str: + return f"{self.title} - {self.url}\n{self.snippet}" + +class SearchEngine(): + ''' + 搜索引擎爬虫基类 + ''' + + def __init__(self) -> None: + self.TIMEOUT = 10 + self.page = 1 + self.headers = HEADERS + + def _set_selector(self, selector: str) -> None: + raise NotImplementedError() + + def _get_next_page(self): + raise NotImplementedError() + + async def _get_html(self, url: str, data: dict = None) -> str: + headers = self.headers + headers["Referer"] = url + headers["User-Agent"] = random.choice(USER_AGENTS) + if data: + async with ClientSession() as session: + async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp: + return await resp.text(encoding="utf-8") + else: + async with ClientSession() as session: + async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp: + return await resp.text(encoding="utf-8") + + + def tidy_text(self, text: str) -> str: + ''' + 清理文本,去除空格、换行符等 + ''' + return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + try: + resp = await self._get_next_page(query) + soup = BeautifulSoup(resp, 'html.parser') + links = soup.select(self._set_selector('links')) + results = [] + for link in links: + title = self.tidy_text(link.select_one(self._set_selector('title')).text) + url = link.select_one(self._set_selector('url')) + snippet = '' + if title and url: + results.append(SearchResult(title=title, url=url, snippet=snippet)) + return results[:num_results] if len(results) > num_results else results + except Exception as e: + raise e \ No newline at end of file diff --git a/util/search_engine_scraper/google.py b/util/search_engine_scraper/google.py new file mode 100644 index 000000000..9eff69816 --- /dev/null +++ b/util/search_engine_scraper/google.py @@ -0,0 +1,27 @@ +import os +from googlesearch import search + +try: + from util.search_engine_scraper.engine import SearchEngine, SearchResult + from util.search_engine_scraper.config import HEADERS, USER_AGENTS +except ImportError: + from engine import SearchEngine, SearchResult + from config import HEADERS, USER_AGENTS + +from typing import List + +class Google(SearchEngine): + def __init__(self) -> None: + super().__init__() + self.proxy = os.environ.get("HTTPS_PROXY") + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + results = [] + try: + print("use proxy:", self.proxy) + ls = search(query, advanced=True, num_results=num_results, timeout=3, proxy=self.proxy) + for i in ls: + results.append(SearchResult(title=i.title, url=i.url, snippet=i.description)) + except Exception as e: + raise e + return results \ No newline at end of file diff --git a/util/search_engine_scraper/sogo.py b/util/search_engine_scraper/sogo.py new file mode 100644 index 000000000..4ddf88dfc --- /dev/null +++ b/util/search_engine_scraper/sogo.py @@ -0,0 +1,49 @@ +import random, re +from bs4 import BeautifulSoup + +try: + from util.search_engine_scraper.engine import SearchEngine, SearchResult + from util.search_engine_scraper.config import HEADERS, USER_AGENTS +except ImportError: + from engine import SearchEngine, SearchResult + from config import HEADERS, USER_AGENTS + +from typing import List + +class Sogo(SearchEngine): + def __init__(self) -> None: + super().__init__() + self.base_url = "https://www.sogou.com" + self.headers['User-Agent'] = random.choice(USER_AGENTS) + + + def _set_selector(self, selector: str): + selectors = { + 'url': 'h3 > a', + 'title': 'h3', + 'text': '', + 'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)', + 'next': '' + } + return selectors[selector] + + async def _get_next_page(self, query) -> str: + url = f'{self.base_url}/web?query={query}' + return await self._get_html(url, None) + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + results = await super().search(query, num_results) + for result in results: + result.url = result.url.get("href") + if result.url.startswith("/link?"): + result.url = self.base_url + result.url + result.url = await self._parse_url(result.url) + return results + + async def _parse_url(self, url) -> str: + html = await self._get_html(url) + soup = BeautifulSoup(html, 'html.parser') + script = soup.find("script") + if script: + url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1) + return url \ No newline at end of file diff --git a/util/search_engine_scraper/test.py b/util/search_engine_scraper/test.py new file mode 100644 index 000000000..a1be25ee1 --- /dev/null +++ b/util/search_engine_scraper/test.py @@ -0,0 +1,22 @@ +from sogo import Sogo +from bing import Bing + +sogo_search = Sogo() +bing_search = Bing() +async def search(keyword: str) -> str: + results = await sogo_search.search(keyword, 5) + # results = await bing_search.search(keyword, 5) + ret = "" + if len(results) == 0: + return "没有搜索到结果" + + idx = 1 + for i in results: + ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n" + idx += 1 + + return ret + +import asyncio +ret = asyncio.run(search("gpt4orelease")) +print(ret) \ No newline at end of file