Compare commits

...

15 Commits

Author SHA1 Message Date
Soulter f67b171385 perf: 数据库迁移至 data 目录下 2024-05-19 17:10:11 +08:00
Soulter 1780d1355d perf: 将内部pip全部更换为阿里云镜像; 插件依赖更新逻辑优化 2024-05-19 16:45:08 +08:00
Soulter 5a3390e4f3 fix: force update 2024-05-19 16:06:47 +08:00
Soulter 337d96b41d Merge pull request #160 from Soulter/dev_default_openai_refactor
优化自带的 OpenAI LLM 交互, 人格, 网页搜索
2024-05-19 15:23:19 +08:00
Soulter 38a1dfea98 fix: web content scraper add proxy 2024-05-19 15:08:22 +08:00
Soulter fbef73aeec fix: websearch encoding set to utf-8 2024-05-19 14:42:28 +08:00
Soulter d6214c2b7c fix: web search 2024-05-19 12:55:54 +08:00
Soulter d58c86f6fc perf: websearch 优化;项目结构调整 2024-05-19 12:46:07 +08:00
Soulter ea34c20198 perf: 优化人格和LVM的处理过程 2024-05-18 10:34:35 +08:00
Soulter 934ca94e62 refactor: 重写 LLM OpenAI 模块 2024-05-17 22:56:44 +08:00
Soulter 1775327c2e chore: refact openai official 2024-05-17 09:07:11 +08:00
Soulter 707fcad8b4 feat: gpt 模型列表查看指令 models 2024-05-17 00:06:49 +08:00
Soulter f143c5afc6 fix: 修复 plugin v 子指令报错的问题 2024-05-16 23:11:07 +08:00
Soulter 99f94b2611 fix: 修复无法调用某些指令的问题 2024-05-16 23:04:47 +08:00
Soulter e39c1f9116 remove: 移除自动更换多模态模型的功能 2024-05-16 22:46:50 +08:00
35 changed files with 1352 additions and 1089 deletions
+1
View File
@@ -10,3 +10,4 @@ cmd_config.json
addons/plugins/
data/*
cookies.json
logs/
+3 -2
View File
@@ -12,10 +12,11 @@ from flask.logging import default_handler
from werkzeug.serving import make_server
from util import general_utils as gu
from dataclasses import dataclass
from cores.database.conn import dbConn
from persist.session import dbConn
from type.register import RegisteredPlugin
from typing import List
from util.cmd_config import CmdConfig
from util.updator import check_update, update_project, request_release_info
from cores.astrbot.types import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
+35 -29
View File
@@ -2,17 +2,14 @@ import re
import threading
import asyncio
import time
import aiohttp
import util.unfit_words as uw
import os
import sys
import io
import traceback
import util.function_calling.gplugin as gplugin
import util.agent.web_searcher as web_searcher
import util.plugin_util as putil
from PIL import Image as PILImage
from nakuru.entities.components import Plain, At, Image
from addons.baidu_aip_judge import BaiduJudge
@@ -22,10 +19,12 @@ from util import general_utils as gu
from util.general_utils import upload, run_monitor
from util.cmd_config import CmdConfig as cc
from util.cmd_config import init_astrbot_config_items
from .types import *
from type.types import GlobalObject
from type.register import *
from type.message import AstrBotMessage
from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData
from cores.database.conn import dbConn
from persist.session import dbConn
from model.platform._message_result import MessageResult
from SparkleLogging.utils.core import LogManager
from logging import Logger
@@ -77,14 +76,14 @@ def privider_chooser(cfg):
'''
def init(cfg):
def init():
global llm_instance, llm_command_instance
global baidu_judge, chosen_provider
global frequency_count, frequency_time
global _global_object
# 迁移旧配置
gu.try_migrate_config(cfg)
gu.try_migrate_config()
# 使用新配置
cfg = cc.get_all()
@@ -105,6 +104,15 @@ def init(cfg):
cc.put("reply_prefix", "")
else:
_global_object.reply_prefix = cfg['reply_prefix']
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 语言模型提供商
logger.info("正在载入语言模型...")
@@ -122,6 +130,11 @@ def init(cfg):
llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
chosen_provider = OPENAI_OFFICIAL
instance = llm_instance[OPENAI_OFFICIAL]
assert isinstance(instance, ProviderOpenAIOfficial)
instance.DEFAULT_PERSONALITY = _global_object.default_personality
instance.curr_personality = instance.DEFAULT_PERSONALITY
# 检查provider设置偏好
p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance:
@@ -197,14 +210,6 @@ def init(cfg):
cfg, _global_object), daemon=True).start()
platform_str += "QQ_OFFICIAL,"
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 初始化dashboard
_global_object.dashboard_data = DashBoardData(
stats={},
@@ -378,8 +383,13 @@ async def oper_msg(message: AstrBotMessage,
llm_result_str = ""
# check commands and plugins
message_str_no_wake_prefix = message_str
for wake_prefix in _global_object.nick: # nick: tuple
if message_str.startswith(wake_prefix):
message_str_no_wake_prefix = message_str.removeprefix(wake_prefix)
break
hit, command_result = await llm_command_instance[chosen_provider].check_command(
message_str,
message_str_no_wake_prefix,
session_id,
role,
reg_platform,
@@ -423,14 +433,14 @@ async def oper_msg(message: AstrBotMessage,
if chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
else:
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality)
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url)
llm_result_str = _global_object.reply_prefix + llm_result_str
except BaseException as e:
logger.info(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}")
logger.error(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用异常。详细原因:{str(e)}")
# 切换回原来的语言模型
if temp_switch != "":
@@ -453,14 +463,10 @@ async def oper_msg(message: AstrBotMessage,
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
# 画图指令
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
for i in command_result[1]:
# 保存到本地
async with aiohttp.ClientSession() as session:
async with session.get(i) as resp:
if resp.status == 200:
image = PILImage.open(io.BytesIO(await resp.read()))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
if command == 'draw':
# 保存到本地
path = await gu.download_image_by_url(command_result[1])
return MessageResult([Image.fromFileSystem(path)])
# 其他指令
else:
try:
-180
View File
@@ -1,180 +0,0 @@
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from nakuru.entities.components import BaseMessageComponent
from typing import Union, List, ClassVar
from types import ModuleType
from enum import Enum
from dataclasses import dataclass
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
GUILD_MESSAGE = 'GuildMessage' # 频道消息
@dataclass
class MessageMember():
user_id: str # 发送者id
nickname: str = None
class AstrBotMessage():
'''
AstrBot 的消息对象
'''
tag: str # 消息来源标签
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
message_id: str # 消息id
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __str__(self) -> str:
return str(self.__dict__)
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: str # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
class AstrMessageEvent():
'''
消息事件。
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: AstrBotMessage # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+40 -23
View File
@@ -23,27 +23,29 @@ logo_tmpl = """
"""
def make_necessary_dirs():
'''
创建必要的目录。
'''
os.makedirs("data/config", exist_ok=True)
os.makedirs("temp", exist_ok=True)
def update_dept():
'''
更新依赖库。
'''
# 获取 Python 可执行文件路径
py = sys.executable
# 更新依赖库
mirror = "https://mirrors.aliyun.com/pypi/simple/"
os.system(f"{py} -m pip install -r requirements.txt -i {mirror}")
def main():
logger = LogManager.GetLogger(
log_name='astrbot-core',
out_to_console=True,
# HTTPpost_url='http://localhost:6185/api/log',
# http_mode = True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
logger.info(logo_tmpl)
# config.yaml 配置文件加载和环境确认
try:
import botpy, logging, yaml
import cores.astrbot.core as qqBot
import botpy, logging
import astrbot.core as bot_core
# delete qqbotpy's logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8')
cfg = yaml.safe_load(ymlfile)
logging.root.removeHandler(handler)
except ImportError as import_error:
logger.error(import_error)
logger.error("检测到一些依赖库没有安装。由于兼容性问题,AstrBot 此版本将不会自动为您安装依赖库。请您先自行安装,然后重试。")
@@ -59,19 +61,14 @@ def main():
input("配置文件不存在,请检查是否已经下载配置文件。")
exit()
except BaseException as e:
raise e
# 设置代理
if 'http_proxy' in cfg and cfg['http_proxy'] != '':
os.environ['HTTP_PROXY'] = cfg['http_proxy']
if 'https_proxy' in cfg and cfg['https_proxy'] != '':
os.environ['HTTPS_PROXY'] = cfg['https_proxy']
os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com'
logger.error(traceback.format_exc())
input("未知错误。")
exit()
make_necessary_dirs()
# 启动主程序(cores/qqbot/core.py
qqBot.init(cfg)
bot_core.init()
def check_env():
@@ -81,6 +78,26 @@ def check_env():
exit()
if __name__ == "__main__":
logger = LogManager.GetLogger(
log_name='astrbot-core',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
logger.info(logo_tmpl)
# 设置代理
from util.cmd_config import CmdConfig
cc = CmdConfig()
http_proxy = cc.get("http_proxy")
https_proxy = cc.get("https_proxy")
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
if http_proxy:
os.environ['HTTP_PROXY'] = http_proxy
if https_proxy:
os.environ['HTTPS_PROXY'] = https_proxy
os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com'
update_dept()
check_env()
t = threading.Thread(target=main, daemon=True)
t.start()
+11 -15
View File
@@ -13,16 +13,13 @@ from nakuru.entities.components import (
from util import general_utils as gu
from model.provider.provider import Provider
from util.cmd_config import CmdConfig as cc
from cores.astrbot.types import (
GlobalObject,
AstrMessageEvent,
PluginType,
CommandResult,
RegisteredPlugin,
RegisteredPlatform
)
from type.message import *
from type.types import GlobalObject
from type.command import *
from type.plugin import *
from type.register import *
from typing import List, Tuple
from typing import List
from SparkleLogging.utils.core import LogManager
from logging import Logger
@@ -76,6 +73,7 @@ class Command:
else:
raise TypeError("插件返回值格式错误。")
if hit:
logger.debug("hit plugin: " + plugin.metadata.plugin_name)
return True, res
except TypeError as e:
# 参数不匹配,尝试使用旧的参数方案
@@ -188,7 +186,7 @@ class Command:
break
if info:
p = gu.create_text_image(
f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}")
f"【插件信息】", f"名称: {info.plugin_name}\n类型: {info.plugin_type}\n{info.desc}\n版本: {info.version}\n作者: {info.author}")
return True, [Image.fromFileSystem(p)], "plugin"
else:
return False, "未找到该插件", "plugin"
@@ -199,10 +197,10 @@ class Command:
nick: 存储机器人的昵称
'''
def set_nick(self, message: str, platform: str, role: str = "member"):
def set_nick(self, message: str, platform: RegisteredPlatform, role: str = "member"):
if role != "admin":
return True, "你无权使用该指令 :P", "nick"
if platform == PLATFORM_GOCQ:
if str(platform) == PLATFORM_GOCQ:
l = message.split(" ")
if len(l) == 1:
return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick"
@@ -210,7 +208,7 @@ class Command:
cc.put("nick_qq", nick)
self.global_object.nick = tuple(nick)
return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick"
elif platform == PLATFORM_QQCHAN:
elif str(platform) == PLATFORM_QQCHAN:
nick = message.split(" ")[2]
return False, "QQ频道平台不支持为机器人设置昵称。", "nick"
@@ -222,8 +220,6 @@ class Command:
"nick": "设置机器人昵称",
"plugin": "插件安装、卸载和重载",
"web on/off": "LLM 网页搜索能力",
"reset": "重置 LLM 对话",
"/gpt": "切换到 OpenAI 官方接口"
}
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):
+100 -119
View File
@@ -1,14 +1,24 @@
from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
from util.personality import personalities
from cores.astrbot.types import GlobalObject
from type.types import GlobalObject
from type.command import CommandItem
from SparkleLogging.utils.core import LogManager
from logging import Logger
from openai._exceptions import NotFoundError
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
self.provider = provider
self.global_object = global_object
self.personality_str = ""
self.commands = [
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"),
]
super().__init__(provider, global_object)
async def check_command(self,
@@ -28,6 +38,8 @@ class CommandOpenAIOfficial(Command):
message_obj
)
logger.debug(f"基础指令hit: {hit}, res: {res}")
# 这里是这个 LLM 的专属指令
if hit:
return True, res
@@ -35,12 +47,8 @@ class CommandOpenAIOfficial(Command):
return True, await self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id)
elif self.command_start_with(message, "token"):
return True, self.token(session_id)
elif self.command_start_with(message, "gpt"):
return True, self.gpt()
elif self.command_start_with(message, "status"):
return True, self.status()
return True, self.status(session_id)
elif self.command_start_with(message, "help", "帮助"):
return True, await self.help()
elif self.command_start_with(message, "unset"):
@@ -51,21 +59,66 @@ class CommandOpenAIOfficial(Command):
return True, self.update(message, role)
elif self.command_start_with(message, "", "draw"):
return True, await self.draw(message)
elif self.command_start_with(message, "key"):
return True, self.key(message)
elif self.command_start_with(message, "switch"):
return True, await self.switch(message)
elif self.command_start_with(message, "models"):
return True, await self.print_models()
elif self.command_start_with(message, "model"):
return True, await self.set_model(message)
return False, None
async def get_models(self):
try:
models = await self.provider.client.models.list()
except NotFoundError as e:
bu = str(self.provider.client.base_url)
self.provider.client.base_url = bu + "/v1"
models = await self.provider.client.models.list()
finally:
return filter(lambda x: x.id.startswith("gpt"), models.data)
async def print_models(self):
models = await self.get_models()
i = 1
ret = "OpenAI GPT 类可用模型"
for model in models:
ret += f"\n{i}. {model.id}"
i += 1
logger.debug(ret)
return True, ret, "models"
async def set_model(self, message: str):
l = message.split(" ")
if len(l) == 1:
return True, "请输入 /model 模型名/编号", "model"
model = str(l[1])
models = await self.get_models()
models = list(models)
if model.isdigit() and int(model) <= len(models) and int(model) >= 1:
model = models[int(model)-1]
else:
f = False
for m in models:
if model == m.id:
f = True
break
if not f:
return True, "模型不存在或输入非法", "model"
self.provider.set_model(model.id)
return True, f"模型已设置为 {model.id}", "model"
async def help(self):
commands = super().general_commands()
commands[''] = '画画'
commands['key'] = '添加OpenAI key'
commands[''] = '调用 OpenAI DallE 模型生成图片'
commands['set'] = '人格设置面板'
commands['gpt'] = '查看gpt配置信息'
commands['status'] = '查看key使用状态'
commands['token'] = '查看本轮会话token'
commands['status'] = '查看 Api Key 状态和配置信息'
commands['token'] = '查看本轮会话 token'
commands['reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt'
commands['reset p'] = '重置当前与 LLM 的会话,并清除人格。'
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
async def reset(self, session_id: str, message: str = "reset"):
@@ -73,79 +126,44 @@ class CommandOpenAIOfficial(Command):
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
await self.provider.forget(session_id)
await self.provider.forget(session_id, keep_system_prompt=True)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
if self.personality_str != "":
self.set(self.personality_str, session_id) # 重新设置人格
return True, "重置成功", "reset"
await self.provider.forget(session_id)
def his(self, message: str, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "his"
# 分页,每页5条
msg = ''
size_per_page = 3
page = 1
if message[4:]:
page = int(message[4:])
# 检查是否有过历史记录
if session_id not in self.provider.session_dict:
msg = f"历史记录为空"
return True, msg, "his"
l = self.provider.session_dict[session_id]
max_page = len(l)//size_per_page + \
1 if len(l) % size_per_page != 0 else len(l)//size_per_page
p = self.provider.get_prompts_by_cache_list(
self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
return True, f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页", "his"
l = message.split(" ")
if len(l) == 2:
try:
page = int(l[1])
except BaseException as e:
return True, "页码不合法", "his"
contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page)
t_pages = total_num // size_per_page + 1
return True, f"历史记录如下:\n{contexts}\n{page} 页 | 共 {t_pages}\n*输入 /his 2 跳转到第 2 页", "his"
def token(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "token"
return True, f"会话的token数: {self.provider.get_user_usage_tokens(self.provider.session_dict[session_id])}\n系统最大缓存token数: {self.provider.max_tokens}", "token"
def gpt(self):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "gpt"
return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt"
def status(self):
def status(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "status"
chatgpt_cfg_str = ""
key_stat = self.provider.get_key_stat()
index = 1
max = 9000000
gg_count = 0
total = 0
tag = ''
for key in key_stat.keys():
sponsor = ''
total += key_stat[key]['used']
if key_stat[key]['exceed']:
gg_count += 1
continue
if 'sponsor' in key_stat[key]:
sponsor = key_stat[key]['sponsor']
chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
index += 1
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
keys_data = self.provider.get_keys_data()
ret = "OpenAI Key"
for k in keys_data:
status = "🟢" if keys_data[k] else "🔴"
ret += "\n|- " + k[:8] + " " + status
def key(self, message: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
msg = "感谢您赞助key,key为官方API使用,请以以下格式赞助:\n/key xxxxx"
return True, msg, "key"
key = l[1]
if self.provider.check_key(key):
self.provider.append_key(key)
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
else:
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
conf = self.provider.get_configs()
ret += "\n当前模型:" + conf['model']
if conf['model'] in MODELS:
ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens"
if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]):
ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens"
return True, ret, "status"
async def switch(self, message: str):
'''
@@ -162,14 +180,13 @@ class CommandOpenAIOfficial(Command):
return True, ret, "switch"
elif len(l) == 2:
try:
key_stat = self.provider.get_key_stat()
key_stat = self.provider.get_keys_data()
index = int(l[1])
if index > len(key_stat) or index < 1:
return True, "账号序号不合法。", "switch"
else:
try:
new_key = list(key_stat.keys())[index-1]
ret = await self.provider.check_key(new_key)
self.provider.set_key(new_key)
except BaseException as e:
return True, "账号切换失败,原因: " + str(e), "switch"
@@ -218,58 +235,22 @@ class CommandOpenAIOfficial(Command):
'name': ps,
'prompt': personalities[ps]
}
self.provider.session_dict[session_id] = []
new_record = {
"user": {
"role": "user",
"content": personalities[ps],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
self.provider.personality_set(ps, session_id)
return True, f"人格{ps}已设置。", "set"
else:
self.provider.curr_personality = {
'name': '自定义人格',
'prompt': ps
}
new_record = {
"user": {
"role": "user",
"content": ps,
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id] = []
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
self.provider.personality_set(ps, session_id)
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
async def draw(self, message):
async def draw(self, message: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "draw"
if message.startswith("/画"):
message = message[2:]
elif message.startswith(""):
message = message[1:]
try:
# 画图模式传回3个参数
img_url = await self.provider.image_chat(message)
return True, img_url, "draw"
except Exception as e:
if 'exceeded' in str(e):
return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)"
return False, f"图片生成失败: {e}", "draw"
img_url = await self.provider.image_generate(message)
return True, img_url, "draw"
+1 -1
View File
@@ -5,7 +5,7 @@ from nakuru import (
FriendMessage
)
import botpy.message
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
from type.message import *
from typing import List, Union
import time
+1 -1
View File
@@ -15,7 +15,7 @@ import time
from ._platfrom import Platform
from ._message_parse import nakuru_message_parse_rev
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
from type.message import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
+3 -1
View File
@@ -5,6 +5,8 @@ import botpy.message
import re
import asyncio
import aiohttp
import botpy.types
import botpy.types.message
from util import general_utils as gu
from botpy.types.message import Reference
@@ -15,7 +17,7 @@ from ._message_parse import (
qq_official_message_parse_rev,
qq_official_message_parse
)
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
from type.message import *
from typing import Union, List
from nakuru.entities.components import BaseMessageComponent
from SparkleLogging.utils.core import LogManager
+417 -337
View File
@@ -5,89 +5,108 @@ import time
import tiktoken
import threading
import traceback
import base64
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from cores.database.conn import dbConn
from persist.session import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
MODELS = {
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 16385,
"gpt-3.5-turbo-16k-0613": 16385,
}
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg):
self.cc = CmdConfig()
def __init__(self, cfg) -> None:
super().__init__()
self.key_list = []
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
for key in cfg['key']:
if len(key) == 1:
raise BaseException(
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
if cfg['key'] != '' and cfg['key'] != None:
self.key_list = cfg['key']
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
os.makedirs("data/openai", exist_ok=True)
self.key_stat = {}
for k in self.key_list:
self.key_stat[k] = {'exceed': False, 'used': 0}
self.cc = CmdConfig
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.keys_data = {} # 记录超额
self.api_base = None
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
self.api_base = cfg['api_base']
logger.info(f"设置 api_base 为: {self.api_base}")
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
for key in self.api_keys:
self.keys_data[key] = True
# 创建 OpenAI Client
self.client = AsyncOpenAI(
api_key=self.key_list[0],
base_url=self.api_base
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.openai_model_configs: dict = cfg['chatGPTConfigs']
self.openai_configs = cfg
# 会话缓存
self.session_dict = {}
# 最大缓存token
self.max_tokens = cfg['total_tokens_limit']
# 历史记录持久化间隔时间
self.history_dump_interval = 20
self.enc = tiktoken.get_encoding("cl100k_base")
self.model_configs: Dict = cfg['chatGPTConfigs']
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
self.DEFAULT_PERSONALITY = {
"name": "default",
"prompt": "你是一个很有帮助的 AI 助手。"
}
self.curr_personality = self.DEFAULT_PERSONALITY
self.session_personality = {} # 记录了某个session是否已设置人格。
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data']
logger.info("读取历史记录成功。")
self.session_memory_lock.acquire()
self.session_memory[session[0]] = json.loads(session[1])['data']
self.session_memory_lock.release()
except BaseException as e:
logger.info("读取历史记录失败,但不影响使用。")
# 创建转储定时器线程
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败{e}。仍可正常使用。")
# 定时保存历史记录
threading.Thread(target=self.dump_history, daemon=True).start()
# 人格
self.curr_personality = {}
# 转储历史记录
def dump_history(self):
'''
转储历史记录
'''
time.sleep(10)
db = dbConn()
while True:
try:
# print("转储历史记录...")
for key in self.session_dict:
data = self.session_dict[key]
for key in self.session_memory:
data = self.session_memory[key]
data_json = {
'data': data
}
@@ -95,321 +114,382 @@ class ProviderOpenAIOfficial(Provider):
db.update_session(key, json.dumps(data_json))
else:
db.insert_session(key, json.dumps(data_json))
# print("转储历史记录完毕")
logger.debug("已保存 OpenAI 会话历史记录")
except BaseException as e:
print(e)
# 每隔10分钟转储一次
time.sleep(10*self.history_dump_interval)
finally:
time.sleep(10*60)
def personality_set(self, default_personality: dict, session_id: str):
if not default_personality: return
if session_id not in self.session_memory:
self.session_memory[session_id] = []
self.curr_personality = default_personality
self.session_personality = {} # 重置
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
if model in MODELS and tokens_num > MODELS[model] - 500:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
new_record = {
"user": {
"role": "user",
"role": "system",
"content": default_personality['prompt'],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
'usage_tokens': 0, # 到该条目的总 token 数
'single-tokens': 0 # 该条目的 token 数
}
self.session_dict[session_id].append(new_record)
async def text_chat(self, prompt,
session_id=None,
image_url=None,
function_call=None,
extra_conf: dict = None,
default_personality: dict = None):
if session_id is None:
session_id = "unknown"
if "unknown" in self.session_dict:
del self.session_dict["unknown"]
# 会话机制
if session_id not in self.session_dict:
self.session_dict[session_id] = []
self.session_memory[session_id].append(new_record)
if len(self.session_dict[session_id]) == 0:
# 设置默认人格
if default_personality is not None:
self.personality_set(default_personality, session_id)
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("http"):
image_url = await gu.download_image_by_url(image_url)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode()
return "data:image/jpeg;base64," + image_bs64
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
prompt = self.enc.decode(_encoded_prompt[:int(
self.openai_model_configs['max_tokens']*0.80)])
logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。")
cache_data_list, new_record, req = self.wrap(
prompt, session_id, image_url)
logger.debug(f"cache: {str(cache_data_list)}")
logger.debug(f"request: {str(req)}")
retry = 0
response = None
err = ''
# 截断倍率
truncate_rate = 0.75
use_gpt4v = False
for i in req:
if isinstance(i['content'], list):
use_gpt4v = True
break
if image_url is not None:
use_gpt4v = True
if use_gpt4v:
conf = self.openai_model_configs.copy()
conf['model'] = 'gpt-4-vision-preview'
else:
conf = self.openai_model_configs
if extra_conf is not None:
conf.update(extra_conf)
while retry < 10:
try:
if function_call is None:
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = await self.client.chat.completions.create(
messages=req,
tools=function_call,
**conf
)
break
except Exception as e:
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
logger.info("当前 Key 已超额或异常, 正在切换",
)
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
retry -= 1
elif 'maximum context length' in str(e):
logger.info("token 超限, 清空对应缓存,并进行消息截断")
self.session_dict[session_id] = []
prompt = prompt[:int(len(prompt)*truncate_rate)]
truncate_rate -= 0.05
cache_data_list, new_record, req = self.wrap(
prompt, session_id)
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
time.sleep(30)
async def retrieve_context(self, session_id: str):
'''
根据 session_id 获取保存的 OpenAI 格式的上下文
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
# 转换为 openai 要求的格式
context = []
is_lvm = await self.is_lvm()
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
continue
else:
logger.error(str(e))
time.sleep(2)
err = str(e)
retry += 1
if retry >= 10:
logger.warning(
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki")
raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
logger.debug(
f"OPENAI RESPONSE: {response.usage}")
context.append(record['user'])
if "AI" in record and record['AI']:
context.append(record['AI'])
# 结果分类
choice = response.choices[0]
if choice.message.content != None:
# 文本形式
chatgpt_res = str(choice.message.content).strip()
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
return context
async def is_lvm(self):
'''
是否是 LVM
'''
return self.model_configs['model'].startswith("gpt-4")
async def get_models(self):
'''
获取所有模型
'''
models = await self.client.models.list()
logger.info(f"OpenAI 模型列表:{models}")
return models
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
'''
组装上下文,并且根据当前上下文窗口大小截断
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
tokens_num = len(self.tokenizer.encode(prompt))
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
message = {
"usage_tokens": previous_total_tokens_num + tokens_num,
"single_tokens": tokens_num,
"AI": None
}
if image_url:
user_content = {
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": await self.encode_image_bs64(image_url)
}
}
]
}
else:
user_content = {
"role": "user",
"content": prompt
}
message["user"] = user_content
self.session_memory[session_id].append(message)
# 根据 模型的上下文窗口 淘汰掉多余的记录
curr_model = self.model_configs['model']
if curr_model in MODELS:
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
# if message['usage_tokens'] > maxium_tokens_num:
# 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300
# contexts = self.session_memory[session_id]
# need_to_remove_idx = 0
# freed_tokens_num = contexts[0]['single-tokens']
# while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num:
# need_to_remove_idx += 1
# freed_tokens_num += contexts[need_to_remove_idx]['single-tokens']
# # 更新之后的所有记录的 usage_tokens
# for i in range(len(contexts)):
# if i > need_to_remove_idx:
# contexts[i]['usage_tokens'] -= freed_tokens_num
# logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。")
# self.session_memory[session_id] = contexts[need_to_remove_idx+1:]
while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num:
self.pop_record(session_id)
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
# 更新之后所有记录的 usage_tokens
for i in range(len(self.session_memory[session_id])):
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}")
return record
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None=None,
tools: None=None,
extra_conf: Dict = None,
**kwargs
) -> str:
if not session_id:
session_id = "unknown"
if "unknown" in self.session_memory:
del self.session_memory["unknown"]
if session_id not in self.session_memory:
self.session_memory[session_id] = []
if session_id not in self.session_personality or not self.session_personality[session_id]:
self.personality_set(self.curr_personality, session_id)
self.session_personality[session_id] = True
# 如果 prompt 超过了最大窗口,截断。
# 1. 可以保证之后 pop 的时候不会出现问题
# 2. 可以保证不会超过最大 token 数
_encoded_prompt = self.tokenizer.encode(prompt)
curr_model = self.model_configs['model']
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
prompt = self.tokenizer.decode(_encoded_prompt)
# 组装上下文,并且根据当前上下文窗口大小截断
await self.assemble_context(session_id, prompt, image_url)
# 获取上下文,openai 格式
contexts = await self.retrieve_context(session_id)
conf = self.model_configs
if extra_conf: conf.update(extra_conf)
# start request
retry = 0
rate_limit_retry = 0
while retry < 3 or rate_limit_retry < 5:
logger.debug(conf)
logger.debug(contexts)
if tools:
completion_coro = self.client.chat.completions.create(
messages=contexts,
tools=tools,
**conf
)
else:
completion_coro = self.client.chat.completions.create(
messages=contexts,
**conf
)
try:
completion = await completion_coro
break
except AuthenticationError as e:
api_key = self.chosen_api_key[10:] + "..."
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)")
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except BadRequestError as e:
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
retry += 1
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
time.sleep(1)
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
if "maximum context length" in str(e):
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens
completion_tokens = completion.usage.completion_tokens
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
if choice.message.content:
# 返回文本
completion_text = str(choice.message.content).strip()
elif choice.message.tool_calls and choice.message.tool_calls:
# tools call (function calling)
return choice.message.tool_calls[0].function
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
if current_usage_tokens > self.max_tokens:
t = current_usage_tokens
index = 0
while t > self.max_tokens:
if index >= len(cache_data_list):
break
# 保留人格信息
if cache_data_list[index]['type'] != 'personality':
t -= int(cache_data_list[index]['single_tokens'])
del cache_data_list[index]
else:
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# 添加新条目进入缓存的prompt
new_record['AI'] = {
'role': 'assistant',
'content': chatgpt_res,
self.session_memory[session_id][-1]['AI'] = {
"role": "assistant",
"content": completion_text
}
new_record['usage_tokens'] = current_usage_tokens
if len(cache_data_list) > 0:
new_record['single_tokens'] = current_usage_tokens - \
int(cache_data_list[-1]['usage_tokens'])
else:
new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record)
self.session_dict[session_id] = cache_data_list
return chatgpt_res
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"):
retry = 0
image_url = ''
image_generate_configs = self.cc.get("openai_image_generate", None)
while retry < 5:
try:
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
image_url = []
for i in range(img_num):
image_url.append(response.data[i].url)
break
except Exception as e:
logger.warning(str(e))
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
logger.warning("当前 Key 已超额或者不正常, 正在切换")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试")
raise e
else:
retry += 1
if retry >= 5:
raise BaseException("连接超时")
return image_url
async def forget(self, session_id=None) -> bool:
if session_id is None:
return completion_text
async def switch_to_next_key(self):
'''
切换到下一个 API Key
'''
if not self.api_keys:
logger.error("OpenAI API Key 不存在。")
return False
self.session_dict[session_id] = []
return True
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
for key in self.keys_data:
if self.keys_data[key]:
# 没超额
self.chosen_api_key = key
self.client.api_key = key
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
return True
return False
async def image_generate(self, prompt, session_id, **kwargs) -> str:
'''
生成图片
'''
retry = 0
conf = self.image_generator_model_configs
if not conf:
logger.error("OpenAI 图片生成模型配置不存在。")
raise Exception("OpenAI 图片生成模型配置不存在。")
while retry < 3:
try:
images_response = await self.client.images.generate(
prompt=prompt,
**conf
)
image_url = images_response.data[0].url
return image_url
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
self.session_memory[session_id] = []
if keep_system_prompt:
self.personality_set(self.curr_personality, session_id)
else:
self.curr_personality = self.DEFAULT_PERSONALITY
return True
def dump_contexts_page(self, session_id: str, size=5, page=1,):
'''
获取缓存的会话
'''
prompts = ""
if paging:
page_begin = (page-1)*size
page_end = page*size
if page_begin < 0:
page_begin = 0
if page_end > len(cache_data_list):
page_end = len(cache_data_list)
cache_data_list = cache_data_list[page_begin:page_end]
for item in cache_data_list:
prompts += str(item['user']['role']) + ":\n" + \
str(item['user']['content']) + "\n"
prompts += str(item['AI']['role']) + ":\n" + \
str(item['AI']['content']) + "\n"
# contexts_str = ""
# for i, key in enumerate(self.session_memory):
# if i < (page-1)*size or i >= page*size:
# continue
# contexts_str += f"Session ID: {key}\n"
# for record in self.session_memory[key]:
# if "user" in record:
# contexts_str += f"User: {record['user']['content']}\n"
# if "AI" in record:
# contexts_str += f"AI: {record['AI']['content']}\n"
# contexts_str += "---\n"
contexts_str = ""
if session_id in self.session_memory:
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
contexts_str += f"User: {text}\n"
if "AI" in record and record['AI']:
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
contexts_str += f"Assistant: {text}\n"
else:
contexts_str = "会话 ID 不存在。"
if divide:
prompts += "----------\n"
return prompts
def wrap(self, prompt, session_id, image_url=None):
if image_url is not None:
prompt = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# 获得缓存信息
context = self.session_dict[session_id]
new_record = {
"user": {
"role": "user",
"content": prompt,
},
"AI": {},
'type': "common",
'usage_tokens': 0,
}
req_list = []
for i in context:
if 'user' in i:
req_list.append(i['user'])
if 'AI' in i:
req_list.append(i['AI'])
req_list.append(new_record['user'])
return context, new_record, req_list
def handle_switch_key(self):
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
continue
is_all_exceed = False
self.client.api_key = key
logger.warning(
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})")
break
if is_all_exceed:
logger.warning(
"所有 Key 已超额")
return False
return True
return contexts_str, len(self.session_memory[session_id])
def set_model(self, model: str):
self.model_configs['model'] = model
def get_configs(self):
return self.openai_configs
return self.model_configs
def get_key_stat(self):
return self.key_stat
def get_key_list(self):
return self.key_list
def get_keys_data(self):
return self.keys_data
def get_curr_key(self):
return self.client.api_key
return self.chosen_api_key
def set_key(self, key):
self.client.api_key = key
# 添加key
def append_key(self, key, sponsor):
self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
# 检查key是否可用
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
return True
self.client.api_key = key
+3 -3
View File
@@ -2,8 +2,8 @@ class Provider:
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None,
function_call: None,
image_url: None = None,
tools: None = None,
extra_conf: dict = None,
default_personality: dict = None,
**kwargs) -> str:
@@ -14,7 +14,7 @@ class Provider:
[optional]
image_url: 图片url(识图)
function_call: 函数调用
tools: 函数调用工具
extra_conf: 额外配置
default_personality: 默认人格
'''
@@ -1,13 +1,16 @@
import sqlite3
import yaml
import os
import shutil
import time
from typing import Tuple
class dbConn():
def __init__(self):
# 读取参数,并支持中文
conn = sqlite3.connect("data.db")
db_path = "data/data.db"
if os.path.exists("data.db"):
shutil.copy("data.db", db_path)
conn = sqlite3.connect(db_path)
conn.text_factory = str
self.conn = conn
c = conn.cursor()
+28
View File
@@ -0,0 +1,28 @@
from typing import Union, List, Callable
from dataclasses import dataclass
@dataclass
class CommandItem():
'''
用来描述单个指令
'''
command_name: Union[str, tuple] # 指令名
callback: Callable # 回调函数
description: str # 描述
origin: str # 注册来源
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+62
View File
@@ -0,0 +1,62 @@
from enum import Enum
from typing import List
from dataclasses import dataclass
from nakuru.entities.components import BaseMessageComponent
from type.register import RegisteredPlatform
from type.types import GlobalObject
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
GUILD_MESSAGE = 'GuildMessage' # 频道消息
@dataclass
class MessageMember():
user_id: str # 发送者id
nickname: str = None
class AstrBotMessage():
'''
AstrBot 的消息对象
'''
tag: str # 消息来源标签
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
message_id: str # 消息id
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __str__(self) -> str:
return str(self.__dict__)
class AstrMessageEvent():
'''
消息事件。
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: AstrBotMessage # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
+27
View File
@@ -0,0 +1,27 @@
from enum import Enum
from dataclasses import dataclass
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
+46
View File
@@ -0,0 +1,46 @@
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from type.plugin import *
from typing import List
from types import ModuleType
from dataclasses import dataclass
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
def __str__(self) -> str:
return self.platform_name
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
+34
View File
@@ -0,0 +1,34 @@
from type.register import *
from typing import List
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: tuple # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
+183
View File
@@ -0,0 +1,183 @@
import traceback
import random
import json
import asyncio
import aiohttp
import os
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.agent.func_call import FuncCall
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
from util.search_engine_scraper.bing import Bing
from util.search_engine_scraper.sogo import Sogo
from util.search_engine_scraper.google import Google
from model.provider.provider import Provider
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
bing_search = Bing()
sogo_search = Sogo()
google = Google()
proxy = os.environ.get("HTTPS_PROXY", None)
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
# def special_fetch_zhihu(link: str) -> str:
# '''
# function-calling 函数, 用于获取知乎文章的内容
# '''
# response = requests.get(link, headers=HEADERS)
# response.encoding = "utf-8"
# soup = BeautifulSoup(response.text, "html.parser")
# if "zhuanlan.zhihu.com" in link:
# r = soup.find(class_="Post-RichTextContainer")
# else:
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
# if r is None:
# print("debug: zhihu none")
# raise Exception("zhihu none")
# return tidy_text(r.text)
async def search_from_bing(keyword: str) -> str:
'''
tools, 从 bing 搜索引擎搜索
'''
logger.info("web_searcher - search_from_bing: " + keyword)
results = []
try:
results = await google.search(keyword, 5)
except BaseException as e:
logger.error(f"google search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search google failed")
try:
results = await bing_search.search(keyword, 5)
except BaseException as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await sogo_search.search(keyword, 5)
except BaseException as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
logger.debug("search sogo failed")
return "没有搜索到结果"
ret = ""
idx = 1
for i in results:
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
try:
site_result = await fetch_website_content(i.url)
except:
site_result = ""
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
return ret
async def fetch_website_content(url):
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=HEADERS, timeout=6, proxy=proxy) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(prompt, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("web_search", [{
"type": "string",
"name": "keyword",
"description": "搜索关键词"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
search_from_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "要获取内容的网页链接"
}],
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func())
if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}")
func_obj = None
for i in new_func_call.func_list:
if i["name"] == result.name:
func_obj = i["func_obj"]
break
if not func_obj:
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
return result
else:
# we use our own function-calling
try:
args = {
'question': prompt,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
await provider.forget(session_id)
summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
2. 简单地发表你对这个问题的简略看法。
# 例子
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
# 限制
1. 限制在 200 字以内;
2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(summary_prompt, session_id)
return ret
return function_invoked_ret
+2 -4
View File
@@ -2,8 +2,7 @@ import os
import json
from typing import Union
cpath = "cmd_config.json"
cpath = "data/cmd_config.json"
def check_exist():
if not os.path.exists(cpath):
@@ -89,8 +88,7 @@ def init_astrbot_config_items():
# 加载默认配置
cc = CmdConfig()
cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes(
"qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n")
cc.init_attributes("qq_welcome", "")
cc.init_attributes("qq_pic_mode", False)
cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700)
-300
View File
@@ -1,300 +0,0 @@
import requests
import util.general_utils as gu
import traceback
import time
import json
import asyncio
from googlesearch import search, SearchResult
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.function_calling.func_call import (
FuncCall,
FuncCallJsonFormatError,
FuncNotFoundError
)
from model.provider.provider import Provider
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def special_fetch_zhihu(link: str) -> str:
'''
function-calling 函数, 用于获取知乎文章的内容
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(link, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
if "zhuanlan.zhihu.com" in link:
r = soup.find(class_="Post-RichTextContainer")
else:
r = soup.find(class_="List-item").find(class_="RichContent-inner")
if r is None:
print("debug: zhihu none")
raise Exception("zhihu none")
return tidy_text(r.text)
def google_web_search(keyword) -> str:
'''
获取 google 搜索结果, 得到 title、desc、link
'''
ret = ""
index = 1
try:
ls = search(keyword, advanced=True, num_results=4)
for i in ls:
desc = i.description
try:
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(i.url)
except BaseException as e:
print(f"(google) fetch_website_content err: {str(e)}")
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
index += 1
except Exception as e:
print(f"google search err: {str(e)}")
return web_keyword_search_via_bing(keyword)
return ret
def web_keyword_search_via_bing(keyword) -> str:
'''
获取bing搜索结果, 得到 title、desc、link
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
url = "https://www.bing.com/search?q="+keyword
_cnt = 0
# _detail_store = []
while _cnt < 5:
try:
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
soup = BeautifulSoup(response.text, "html.parser")
res = ""
result_cnt = 0
ols = soup.find(id="b_results")
for i in ols.find_all("li", class_="b_algo"):
try:
title = i.find("h2").text
desc = i.find("p").text
link = i.find("h2").find("a").get("href")
# res.append({
# "title": title,
# "desc": desc,
# "link": link,
# })
try:
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(link)
except BaseException as e:
print(f"(bing) fetch_website_content err: {str(e)}")
res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n"
result_cnt += 1
if result_cnt > 5:
break
# if len(_detail_store) >= 3:
# continue
# # 爬取前两条的网页内容
# if "zhihu.com" in link:
# try:
# _detail_store.append(special_fetch_zhihu(link))
# except BaseException as e:
# print(f"zhihu parse err: {str(e)}")
# else:
# try:
# _detail_store.append(fetch_website_content(link))
# except BaseException as e:
# print(f"fetch_website_content err: {str(e)}")
except Exception as e:
print(f"bing parse err: {str(e)}")
if result_cnt == 0:
break
return res
except Exception as e:
# gu.log(f"bing fetch err: {str(e)}")
_cnt += 1
time.sleep(1)
# gu.log("fail to fetch bing info, using sougou.")
return web_keyword_search_via_sougou(keyword)
def web_keyword_search_via_sougou(keyword) -> str:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
}
url = f"https://sogou.com/web?query={keyword}"
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
res = []
results = soup.find("div", class_="results")
for i in results.find_all("div", class_="vrwrap"):
try:
title = tidy_text(i.find("h3").text)
link = tidy_text(i.find("h3").find("a").get("href"))
if link.startswith("/link?url="):
link = "https://www.sogou.com" + link
res.append({
"title": title,
"link": link,
})
if len(res) >= 5: # 限制5条
break
except Exception as e:
pass
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
for i in res:
if _detail_store >= 3:
break
try:
_detail_store.append(fetch_website_content(i["link"]))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
ret = f"{str(res)}"
if len(_detail_store) > 0:
ret += f"\n网页内容: {str(_detail_store)}"
return ret
def fetch_website_content(url):
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(url, headers=headers, timeout=3)
response.encoding = "utf-8"
doc = Document(response.content)
# print('title:', doc.title())
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("google_web_search", [{
"type": "string",
"name": "keyword",
"description": "google search query (分词,尽量保留所有信息)"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
web_keyword_search_via_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "网址"
}],
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
if i["name"] == func.name:
func_obj = i["func_obj"]
break
if not func_obj:
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
# we use to_thread to avoid blocking the event loop
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
# we use our own function-calling
try:
args = {
'question': question1,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
await provider.forget(session_id)
question3 = f"""
你的任务是:
1. 根据末尾的材料对问题`{question}`做切题的总结(详细);
2. 简单地发表你对这个问题的看法(简略)。
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。
不要提到任何函数调用的信息。
一些回复的消息模板:
模板1:
```
从网上的信息来看,可以知道...我个人认为...你觉得呢?
```
模板2:
```
根据网上的最新信息,可以得知...我觉得...你怎么看?
```
你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。
以下是相关材料:
"""
_c = 0
while _c < 3:
try:
print('text chat')
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
return final_ret
except Exception as e:
print(e)
_c += 1
if _c == 3:
raise e
if "The message you submitted was too long" in str(e):
await provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(
len(function_invoked_ret) / 2)]
time.sleep(3)
return function_invoked_ret
+64 -14
View File
@@ -1,18 +1,23 @@
import datetime
import time
import socket
from PIL import Image, ImageDraw, ImageFont
import os
import re
import requests
from util.cmd_config import CmdConfig
import aiohttp
import socket
from cores.astrbot.types import GlobalObject
import platform
import logging
import json
import sys
import psutil
import ssl
from PIL import Image, ImageDraw, ImageFont
from type.types import GlobalObject
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
def port_checker(port: int, host: str = "localhost"):
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -355,10 +360,36 @@ def save_temp_img(img: Image) -> str:
# 获得时间戳
timestamp = int(time.time())
p = f"temp/{timestamp}.png"
img.save(p)
p = f"temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
logger.info(f"保存临时图片: {p}")
return p
async def download_image_by_url(url: str) -> str:
'''
下载图片
'''
try:
logger.info(f"下载图片: {url}")
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
except aiohttp.client_exceptions.ClientConnectorSSLError as e:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
except Exception as e:
raise e
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
@@ -391,15 +422,19 @@ def create_markdown_image(text: str):
raise e
def try_migrate_config(old_config: dict):
def try_migrate_config():
'''
迁移配置文件到 cmd_config.json
将 cmd_config.json 迁移至 data/cmd_config.json
'''
cc = CmdConfig()
if cc.get("qqbot", None) is None:
# 未迁移过
for k in old_config:
cc.put(k, old_config[k])
if os.path.exists("cmd_config.json"):
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
data = json.load(f)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
try:
os.remove("cmd_config.json")
except Exception as e:
pass
def get_local_ip_addresses():
@@ -451,6 +486,21 @@ def upload(_global_object: GlobalObject):
pass
time.sleep(10*60)
def retry(n: int = 3):
'''
重试装饰器
'''
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except Exception as e:
if i == n-1: raise e
logger.warning(f"函数 {func.__name__}{i+1} 次重试... {e}")
return wrapper
return decorator
def run_monitor(global_object: GlobalObject):
'''
+5 -11
View File
@@ -1,11 +1,5 @@
from cores.astrbot.types import (
PluginMetadata,
RegisteredLLM,
RegisteredPlugin,
RegisteredPlatform,
RegisteredPlugins,
PluginType,
GlobalObject,
AstrMessageEvent,
CommandResult
)
from type.plugin import PluginMetadata, PluginType
from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins
from type.types import GlobalObject
from type.message import AstrMessageEvent
from type.command import CommandResult
+3 -2
View File
@@ -1,5 +1,6 @@
from cores.astrbot.core import oper_msg
from cores.astrbot.types import AstrMessageEvent, CommandResult
from astrbot.core import oper_msg
from type.message import AstrMessageEvent, AstrBotMessage
from type.command import CommandResult
from model.platform._message_result import MessageResult
'''
+2 -1
View File
@@ -5,7 +5,8 @@
'''
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from cores.astrbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM
from type.types import GlobalObject
from type.register import RegisteredPlatform, RegisteredLLM
def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None:
'''
+1 -1
View File
@@ -2,4 +2,4 @@
插件类型
'''
from cores.astrbot.types import PluginType
from type.plugin import PluginType
+42 -39
View File
@@ -1,26 +1,24 @@
'''
插件工具函数
'''
import os
import os, sys
import inspect
import shutil
import stat
import traceback
try:
import git.exc
from git.repo import Repo
except ImportError:
pass
import shutil
import importlib
import stat
import traceback
from types import ModuleType
from typing import List
from pip._internal import main as pipmain
from cores.astrbot.types import (
PluginMetadata,
PluginType,
RegisteredPlugin,
RegisteredPlugins
)
from type.plugin import *
from type.register import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
# 找出模块里所有的类名
@@ -62,29 +60,35 @@ def get_modules(path):
def get_plugin_store_path():
if os.path.exists("addons/plugins"):
return "addons/plugins"
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
return "QQChannelChatGPT/addons/plugins"
elif os.path.exists("AstrBot/addons/plugins"):
return "AstrBot/addons/plugins"
else:
raise FileNotFoundError("插件文件夹不存在。")
plugin_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../addons/plugins"))
return plugin_dir
def get_plugin_modules():
plugins = []
try:
if os.path.exists("addons/plugins"):
plugins = get_modules("addons/plugins")
plugin_dir = get_plugin_store_path()
if os.path.exists(plugin_dir):
plugins = get_modules(plugin_dir)
return plugins
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
plugins = get_modules("QQChannelChatGPT/addons/plugins")
return plugins
else:
return None
except BaseException as e:
raise e
def check_plugin_dept_update(cached_plugins: RegisteredPlugins, target_plugin: str = None):
plugin_dir = get_plugin_store_path()
if not os.path.exists(plugin_dir):
return False
to_update = []
if target_plugin:
to_update.append(target_plugin)
else:
for p in cached_plugins:
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查更新插件 {p} 的依赖: {pth}")
update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
def plugin_reload(cached_plugins: RegisteredPlugins):
@@ -103,6 +107,8 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
module_path = plugin['module_path']
root_dir_name = plugin['pname']
check_plugin_dept_update(cached_plugins, root_dir_name)
module = __import__("addons.plugins." +
root_dir_name + "." + p, fromlist=[p])
@@ -150,6 +156,11 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
return True, None
else:
return False, fail_rec
def update_plugin_dept(path):
mirror = "https://mirrors.aliyun.com/pypi/simple/"
py = sys.executable
os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet")
def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
@@ -161,15 +172,12 @@ def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
d = repo_url.split("/")[-1]
# 转换非法字符:-
d = d.replace("-", "_")
d = d.lower() # 转换为小写
# 创建文件夹
plugin_path = os.path.join(ppath, d)
if os.path.exists(plugin_path):
remove_dir(plugin_path)
Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
# 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins)
if not ok:
raise Exception(err)
@@ -204,11 +212,6 @@ def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
plugin_path = os.path.join(ppath, root_dir_name)
repo = Repo(path=plugin_path)
repo.remotes.origin.pull()
# 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
print("正在安装插件依赖...")
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt")]) != 0:
raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins)
if not ok:
raise Exception(err)
+38
View File
@@ -0,0 +1,38 @@
from typing import List
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENT_BING
class Bing(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.bing.com"
self.headers.update({'User-Agent': USER_AGENT_BING})
def _set_selector(self, selector: str):
selectors = {
'url': 'div.b_attribution cite',
'title': 'h2',
'text': 'p',
'links': 'ol#b_results > li.b_algo',
'next': 'div#b_content nav[role="navigation"] a.sb_pagN'
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
if self.page == 1:
await self._get_html(self.base_url)
url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl='
return await self._get_html(url, None)
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if not isinstance(result.url, str):
result.url = result.url.text
return results
+20
View File
@@ -0,0 +1,20 @@
HEADERS = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0',
'Accept': '*/*',
'Connection': 'keep-alive',
'Accept-Language': 'en-GB,en;q=0.5'
}
USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0'
USER_AGENTS = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36',
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'
]
+73
View File
@@ -0,0 +1,73 @@
import random
try:
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from config import HEADERS, USER_AGENTS
from bs4 import BeautifulSoup
from aiohttp import ClientSession
from dataclasses import dataclass
from typing import List
@dataclass
class SearchResult():
title: str
url: str
snippet: str
def __str__(self) -> str:
return f"{self.title} - {self.url}\n{self.snippet}"
class SearchEngine():
'''
搜索引擎爬虫基类
'''
def __init__(self) -> None:
self.TIMEOUT = 10
self.page = 1
self.headers = HEADERS
def _set_selector(self, selector: str) -> None:
raise NotImplementedError()
def _get_next_page(self):
raise NotImplementedError()
async def _get_html(self, url: str, data: dict = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
if data:
async with ClientSession() as session:
async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp:
return await resp.text(encoding="utf-8")
else:
async with ClientSession() as session:
async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp:
return await resp.text(encoding="utf-8")
def tidy_text(self, text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
try:
resp = await self._get_next_page(query)
soup = BeautifulSoup(resp, 'html.parser')
links = soup.select(self._set_selector('links'))
results = []
for link in links:
title = self.tidy_text(link.select_one(self._set_selector('title')).text)
url = link.select_one(self._set_selector('url'))
snippet = ''
if title and url:
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
raise e
+27
View File
@@ -0,0 +1,27 @@
import os
from googlesearch import search
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENTS
from typing import List
class Google(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.proxy = os.environ.get("HTTPS_PROXY")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = []
try:
print("use proxy:", self.proxy)
ls = search(query, advanced=True, num_results=num_results, timeout=3, proxy=self.proxy)
for i in ls:
results.append(SearchResult(title=i.title, url=i.url, snippet=i.description))
except Exception as e:
raise e
return results
+49
View File
@@ -0,0 +1,49 @@
import random, re
from bs4 import BeautifulSoup
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENTS
from typing import List
class Sogo(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.sogou.com"
self.headers['User-Agent'] = random.choice(USER_AGENTS)
def _set_selector(self, selector: str):
selectors = {
'url': 'h3 > a',
'title': 'h3',
'text': '',
'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)',
'next': ''
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
url = f'{self.base_url}/web?query={query}'
return await self._get_html(url, None)
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = await super().search(query, num_results)
for result in results:
result.url = result.url.get("href")
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
return results
async def _parse_url(self, url) -> str:
html = await self._get_html(url)
soup = BeautifulSoup(html, 'html.parser')
script = soup.find("script")
if script:
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1)
return url
+22
View File
@@ -0,0 +1,22 @@
from sogo import Sogo
from bing import Bing
sogo_search = Sogo()
bing_search = Bing()
async def search(keyword: str) -> str:
results = await sogo_search.search(keyword, 5)
# results = await bing_search.search(keyword, 5)
ret = ""
if len(results) == 0:
return "没有搜索到结果"
idx = 1
for i in results:
ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n"
idx += 1
return ret
import asyncio
ret = asyncio.run(search("gpt4orelease"))
print(ret)
+3 -3
View File
@@ -112,7 +112,7 @@ def update_project(update_data: list,
# 更新到最新版本对应的commit
try:
repo.git.fetch()
repo.git.checkout(update_data[0]['tag_name'])
repo.git.checkout(update_data[0]['tag_name'], "-f")
if reboot: _reboot()
except BaseException as e:
raise e
@@ -124,7 +124,7 @@ def update_project(update_data: list,
if data['tag_name'] == version:
try:
repo.git.fetch()
repo.git.checkout(data['tag_name'])
repo.git.checkout(data['tag_name'], "-f")
flag = True
if reboot: _reboot()
except BaseException as e:
@@ -136,7 +136,7 @@ def checkout_branch(branch_name: str):
repo = find_repo()
try:
repo.git.fetch()
repo.git.checkout(branch_name)
repo.git.checkout(branch_name, "-f")
repo.git.pull("origin", branch_name, "-f")
return True
except BaseException as e: