Compare commits

...

33 Commits

Author SHA1 Message Date
Soulter f00f1e8933 fix: 画图报错 2024-05-24 13:33:02 +08:00
Soulter 8da4433e57 chore: 更改相关字段 2024-05-21 08:44:05 +08:00
Soulter 7babb87934 perf: 更改库的加载顺序 2024-05-21 08:41:46 +08:00
Soulter f67b171385 perf: 数据库迁移至 data 目录下 2024-05-19 17:10:11 +08:00
Soulter 1780d1355d perf: 将内部pip全部更换为阿里云镜像; 插件依赖更新逻辑优化 2024-05-19 16:45:08 +08:00
Soulter 5a3390e4f3 fix: force update 2024-05-19 16:06:47 +08:00
Soulter 337d96b41d Merge pull request #160 from Soulter/dev_default_openai_refactor
优化自带的 OpenAI LLM 交互, 人格, 网页搜索
2024-05-19 15:23:19 +08:00
Soulter 38a1dfea98 fix: web content scraper add proxy 2024-05-19 15:08:22 +08:00
Soulter fbef73aeec fix: websearch encoding set to utf-8 2024-05-19 14:42:28 +08:00
Soulter d6214c2b7c fix: web search 2024-05-19 12:55:54 +08:00
Soulter d58c86f6fc perf: websearch 优化;项目结构调整 2024-05-19 12:46:07 +08:00
Soulter ea34c20198 perf: 优化人格和LVM的处理过程 2024-05-18 10:34:35 +08:00
Soulter 934ca94e62 refactor: 重写 LLM OpenAI 模块 2024-05-17 22:56:44 +08:00
Soulter 1775327c2e chore: refact openai official 2024-05-17 09:07:11 +08:00
Soulter 707fcad8b4 feat: gpt 模型列表查看指令 models 2024-05-17 00:06:49 +08:00
Soulter f143c5afc6 fix: 修复 plugin v 子指令报错的问题 2024-05-16 23:11:07 +08:00
Soulter 99f94b2611 fix: 修复无法调用某些指令的问题 2024-05-16 23:04:47 +08:00
Soulter e39c1f9116 remove: 移除自动更换多模态模型的功能 2024-05-16 22:46:50 +08:00
Soulter 235e0b9b8f fix: gocq logging 2024-05-09 13:24:31 +08:00
Soulter d5a9bed8a4 fix(updator): IterableList object has no
attribute origin
2024-05-08 19:18:21 +08:00
Soulter d7dc8a7612 chore: 添加一些日志;更新版本 2024-05-08 19:12:23 +08:00
Soulter 08cd3ca40c perf: 更好的日志输出;
fix: 修复可视化面板刷新404
2024-05-08 19:01:36 +08:00
Soulter a13562dcea fix: 修复启动器启动加载带有配置的插件时提示配置文件缺失的问题 2024-05-08 16:28:30 +08:00
Soulter d7a0c0d1d0 Update requirements.txt 2024-05-07 15:58:51 +08:00
Soulter c0729b2d29 fix: 修复插件重载相关问题 2024-04-22 19:04:15 +08:00
Soulter a80f474290 fix: 修复更新插件时的报错 2024-04-22 18:36:56 +08:00
Soulter 699207dd54 update: version 2024-04-21 22:41:48 +08:00
Soulter e7708010c9 fix: 修复 gocq 平台下无法回复消息的问题 2024-04-21 22:39:09 +08:00
Soulter f66091e08f 🎨: clean codes 2024-04-21 22:20:23 +08:00
Soulter 03bb932f8f fix: 修复可视化面板报错 2024-04-21 22:16:42 +08:00
Soulter fbf8b349e0 update: helloworld 2024-04-21 22:13:27 +08:00
Soulter e9278fce6a !! delete: 移除对逆向 ChatGPT 的所有支持。 2024-04-21 22:12:09 +08:00
Soulter 9a7db956d5 fix: 修复 3.10.x readibility 依赖导致的报错 2024-04-21 16:40:02 +08:00
45 changed files with 2065 additions and 2125 deletions
+1
View File
@@ -10,3 +10,4 @@ cmd_config.json
addons/plugins/
data/*
cookies.json
logs/
-1
View File
@@ -160,7 +160,6 @@
- `/key` 动态添加key
- `/set` 人格设置面板
- `/keyword nihao 你好` 设置关键词回复。nihao->你好
- `/revgpt` 切换为ChatGPT逆向库
- `/画` 画画
#### 逆向ChatGPT库语言模型
+5 -2
View File
@@ -1,14 +1,17 @@
from aip import AipContentCensor
class BaiduJudge:
def __init__(self, baidu_configs) -> None:
if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs:
self.app_id = str(baidu_configs['app_id'])
self.api_key = baidu_configs['api_key']
self.secret_key = baidu_configs['secret_key']
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
self.client = AipContentCensor(
self.app_id, self.api_key, self.secret_key)
else:
raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!")
def judge(self, text):
res = self.client.textCensorUserDefined(text)
if 'conclusionType' not in res:
@@ -23,4 +26,4 @@ class BaiduJudge:
for i in res['data']:
info += f"{i['msg']}\n"
info += "\n判断结果:"+res['conclusion']
return False, info
return False, info
+46 -79
View File
@@ -11,22 +11,27 @@ import threading
import time
import asyncio
from util.plugin_dev.api.v1.config import update_config
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
@dataclass
class DashBoardConfig():
config_type: str
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None # 仅 item 才需要
body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要
value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要
val_type: Optional[str] = None # 仅 item 才需要
path: Optional[str] = None # 仅 item 才需要
body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要
value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要
val_type: Optional[str] = None # 仅 item 才需要
class DashBoardHelper():
def __init__(self, global_object, config: dict):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.logger = global_object.logger
dashboard_data = global_object.dashboard_data
dashboard_data.configs = {
"data": []
@@ -34,26 +39,28 @@ class DashBoardHelper():
self.parse_default_config(dashboard_data, config)
self.dashboard_data: DashBoardData = dashboard_data
self.dashboard = AstrBotDashBoard(global_object)
self.key_map = {} # key: uuid, value: config key name
self.key_map = {} # key: uuid, value: config key name
self.cc = CmdConfig()
@self.dashboard.register("post_configs")
def on_post_configs(post_configs: dict):
try:
# self.logger.log(f"收到配置更新请求", gu.LEVEL_INFO, tag="可视化面板")
if 'base_config' in post_configs:
self.save_config(post_configs['base_config'], namespace='') # 基础配置
self.save_config(post_configs['config'], namespace=post_configs['namespace']) # 选定配置
self.parse_default_config(self.dashboard_data, self.cc.get_all())
self.save_config(
post_configs['base_config'], namespace='') # 基础配置
self.save_config(
post_configs['config'], namespace=post_configs['namespace']) # 选定配置
self.parse_default_config(
self.dashboard_data, self.cc.get_all())
# 重启
threading.Thread(target=self.dashboard.shutdown_bot, args=(2,), daemon=True).start()
threading.Thread(target=self.dashboard.shutdown_bot,
args=(2,), daemon=True).start()
except Exception as e:
# self.logger.log(f"在保存配置时发生错误:{e}", gu.LEVEL_ERROR, tag="可视化面板")
raise e
# 将 config.yaml、 中的配置解析到 dashboard_data.configs 中
def parse_default_config(self, dashboard_data: DashBoardData, config: dict):
try:
qq_official_platform_group = DashBoardConfig(
config_type="group",
@@ -112,14 +119,14 @@ class DashBoardHelper():
)
qq_gocq_platform_group = DashBoardConfig(
config_type="group",
name="OneBot协议平台配置",
name="go-cqhttp",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用",
description="支持cq-http、shamrock等(目前仅支持QQ平台)",
description="",
value=config['gocqbot']['enable'],
path="gocqbot.enable",
),
@@ -381,48 +388,7 @@ class DashBoardHelper():
),
]
)
rev_chatgpt_accounts = config['rev_ChatGPT']['account']
new_accs = []
for i in rev_chatgpt_accounts:
if isinstance(i, dict) and 'access_token' in i:
new_accs.append(i['access_token'])
elif isinstance(i, str):
new_accs.append(i)
config['rev_ChatGPT']['account'] = new_accs
rev_chatgpt_llm_group = DashBoardConfig(
config_type="group",
name="逆向语言模型服务设置",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用逆向语言模型服务",
description="",
value=config['rev_ChatGPT']['enable'],
path="rev_ChatGPT.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="终结点(Endpoint)地址",
description="逆向服务的终结点服务器的地址。",
value=config['CHATGPT_BASE_URL'],
path="CHATGPT_BASE_URL",
),
DashBoardConfig(
config_type="item",
val_type="list",
name="assess_token",
description="assess_token",
value=config['rev_ChatGPT']['account'],
path="rev_ChatGPT.account",
),
]
)
baidu_aip_group = DashBoardConfig(
config_type="group",
name="百度内容审核",
@@ -436,9 +402,6 @@ class DashBoardHelper():
value=config['baidu_aip']['enable'],
path="baidu_aip.enable"
),
# "app_id": null,
# "api_key": null,
# "secret_key": null
DashBoardConfig(
config_type="item",
val_type="str",
@@ -497,27 +460,25 @@ class DashBoardHelper():
),
]
)
dashboard_data.configs['data'] = [
qq_official_platform_group,
qq_gocq_platform_group,
general_platform_detail_group,
openai_official_llm_group,
rev_chatgpt_llm_group,
other_group,
baidu_aip_group
]
except Exception as e:
self.logger.log(f"配置文件解析错误:{e}", gu.LEVEL_ERROR)
logger.error(f"配置文件解析错误:{e}")
raise e
def save_config(self, post_config: list, namespace: str):
'''
根据 path 解析并保存配置
'''
queue = post_config
while len(queue) > 0:
config = queue.pop(0)
@@ -527,23 +488,27 @@ class DashBoardHelper():
elif config['config_type'] == "item":
if config['path'] is None or config['path'] == "":
continue
path = config['path'].split('.')
if len(path) == 0:
continue
if config['val_type'] == "bool":
self._write_config(namespace, config['path'], config['value'])
self._write_config(
namespace, config['path'], config['value'])
elif config['val_type'] == "str":
self._write_config(namespace, config['path'], config['value'])
self._write_config(
namespace, config['path'], config['value'])
elif config['val_type'] == "int":
try:
self._write_config(namespace, config['path'], int(config['value']))
self._write_config(
namespace, config['path'], int(config['value']))
except:
raise ValueError(f"配置项 {config['name']} 的值必须是整数")
elif config['val_type'] == "float":
try:
self._write_config(namespace, config['path'], float(config['value']))
self._write_config(
namespace, config['path'], float(config['value']))
except:
raise ValueError(f"配置项 {config['name']} 的值必须是浮点数")
elif config['val_type'] == "list":
@@ -551,16 +516,18 @@ class DashBoardHelper():
self._write_config(namespace, config['path'], [])
elif not isinstance(config['value'], list):
raise ValueError(f"配置项 {config['name']} 的值必须是列表")
self._write_config(namespace, config['path'], config['value'])
self._write_config(
namespace, config['path'], config['value'])
else:
raise NotImplementedError(f"未知或者未实现的配置项类型:{config['val_type']}")
raise NotImplementedError(
f"未知或者未实现的配置项类型:{config['val_type']}")
def _write_config(self, namespace: str, key: str, value):
if namespace == "" or namespace.startswith("internal_"):
# 机器人自带配置,存到 config.yaml
self.cc.put_by_dot_str(key, value)
else:
update_config(namespace, key, value)
def run(self):
self.dashboard.run()
self.dashboard.run()
+152 -85
View File
@@ -1,21 +1,26 @@
from flask import Flask, request
from flask.logging import default_handler
from werkzeug.serving import make_server
from util import general_utils as gu
from dataclasses import dataclass
import logging
from cores.database.conn import dbConn
from util.cmd_config import CmdConfig
from util.updator import check_update, update_project, request_release_info
from cores.astrbot.types import *
import util.plugin_util as putil
import websockets
import json
import threading
import asyncio
import os, sys
import os
import sys
import time
from flask import Flask, request
from flask.logging import default_handler
from werkzeug.serving import make_server
from util import general_utils as gu
from dataclasses import dataclass
from persist.session import dbConn
from type.register import RegisteredPlugin
from typing import List
from util.cmd_config import CmdConfig
from util.updator import check_update, update_project, request_release_info
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
@dataclass
class DashBoardData():
stats: dict
@@ -23,33 +28,49 @@ class DashBoardData():
logs: dict
plugins: List[RegisteredPlugin]
@dataclass
class Response():
status: str
message: str
data: dict
class AstrBotDashBoard():
def __init__(self, global_object: 'gu.GlobalObject'):
self.global_object = global_object
self.loop = asyncio.get_event_loop()
asyncio.set_event_loop(self.loop)
self.dashboard_data: DashBoardData = global_object.dashboard_data
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
self.dashboard_be = Flask(
__name__, static_folder="dist", static_url_path="/")
self.funcs = {}
self.cc = CmdConfig()
self.logger = global_object.logger
self.ws_clients = {} # remote_ip: ws
self.ws_clients = {} # remote_ip: ws
# 启动 websocket 服务器
self.ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186)
@self.dashboard_be.get("/")
def index():
# 返回页面
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/config")
def rt_config():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/logs")
def rt_logs():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/extension")
def rt_extension():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/dashboard/default")
def rt_dashboard():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.post("/api/authenticate")
def authenticate():
username = self.cc.get("dashboard_username", "")
@@ -71,7 +92,7 @@ class AstrBotDashBoard():
message="用户名或密码错误。",
data=None
).__dict__
@self.dashboard_be.post("/api/change_password")
def change_password():
password = self.cc.get("dashboard_password", "")
@@ -99,9 +120,11 @@ class AstrBotDashBoard():
# last_24_platform = db_inst.get_last_24h_stat_platform()
platforms = db_inst.get_platform_cnt_total()
self.dashboard_data.stats["session"] = []
self.dashboard_data.stats["session_total"] = db_inst.get_session_cnt_total()
self.dashboard_data.stats["session_total"] = db_inst.get_session_cnt_total(
)
self.dashboard_data.stats["message"] = last_24_message
self.dashboard_data.stats["message_total"] = db_inst.get_message_cnt_total()
self.dashboard_data.stats["message_total"] = db_inst.get_message_cnt_total(
)
self.dashboard_data.stats["platform"] = platforms
return Response(
@@ -109,7 +132,7 @@ class AstrBotDashBoard():
message="",
data=self.dashboard_data.stats
).__dict__
@self.dashboard_be.get("/api/configs")
def get_configs():
# 如果params中有namespace,则返回该namespace下的配置
@@ -121,7 +144,7 @@ class AstrBotDashBoard():
message="",
data=conf
).__dict__
@self.dashboard_be.get("/api/config_outline")
def get_config_outline():
outline = self._generate_outline()
@@ -130,7 +153,7 @@ class AstrBotDashBoard():
message="",
data=outline
).__dict__
@self.dashboard_be.post("/api/configs")
def post_configs():
post_configs = request.json
@@ -147,7 +170,7 @@ class AstrBotDashBoard():
message=e.__str__(),
data=self.dashboard_data.configs
).__dict__
@self.dashboard_be.get("/api/extensions")
def get_plugins():
_plugin_resp = []
@@ -166,15 +189,15 @@ class AstrBotDashBoard():
message="",
data=_plugin_resp
).__dict__
@self.dashboard_be.post("/api/extensions/install")
def install_plugin():
post_data = request.json
repo_url = post_data["url"]
try:
self.logger.log(f"正在安装插件 {repo_url}", tag="可视化面板")
logger.info(f"正在安装插件 {repo_url}")
putil.install_plugin(repo_url, self.dashboard_data.plugins)
self.logger.log(f"安装插件 {repo_url} 成功", tag="可视化面板")
logger.info(f"安装插件 {repo_url} 成功")
return Response(
status="success",
message="安装成功~",
@@ -186,15 +209,16 @@ class AstrBotDashBoard():
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.post("/api/extensions/uninstall")
def uninstall_plugin():
post_data = request.json
plugin_name = post_data["name"]
try:
self.logger.log(f"正在卸载插件 {plugin_name}", tag="可视化面板")
putil.uninstall_plugin(plugin_name, self.dashboard_data.plugins)
self.logger.log(f"卸载插件 {plugin_name} 成功", tag="可视化面板")
logger.info(f"正在卸载插件 {plugin_name}")
putil.uninstall_plugin(
plugin_name, self.dashboard_data.plugins)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response(
status="success",
message="卸载成功~",
@@ -206,15 +230,15 @@ class AstrBotDashBoard():
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.post("/api/extensions/update")
def update_plugin():
post_data = request.json
plugin_name = post_data["name"]
try:
self.logger.log(f"正在更新插件 {plugin_name}", tag="可视化面板")
logger.info(f"正在更新插件 {plugin_name}")
putil.update_plugin(plugin_name, self.dashboard_data.plugins)
self.logger.log(f"更新插件 {plugin_name} 成功", tag="可视化面板")
logger.info(f"更新插件 {plugin_name} 成功")
return Response(
status="success",
message="更新成功~",
@@ -226,16 +250,17 @@ class AstrBotDashBoard():
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.post("/api/log")
def log():
for item in self.ws_clients:
try:
asyncio.run_coroutine_threadsafe(self.ws_clients[item].send(request.data.decode()), self.loop)
asyncio.run_coroutine_threadsafe(
self.ws_clients[item].send(request.data.decode()), self.loop)
except Exception as e:
pass
return 'ok'
@self.dashboard_be.get("/api/check_update")
def get_update_info():
try:
@@ -244,7 +269,7 @@ class AstrBotDashBoard():
status="success",
message=ret,
data={
"has_new_version": ret != "当前已经是最新版本。" # 先这样吧,累了=.=
"has_new_version": ret != "当前已经是最新版本。" # 先这样吧,累了=.=
}
).__dict__
except Exception as e:
@@ -253,7 +278,7 @@ class AstrBotDashBoard():
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.post("/api/update_project")
def update_project_api():
version = request.json['version']
@@ -263,7 +288,8 @@ class AstrBotDashBoard():
else:
latest = False
try:
update_project(request_release_info(latest), latest=latest, version=version)
update_project(request_release_info(latest),
latest=latest, version=version)
threading.Thread(target=self.shutdown_bot, args=(3,)).start()
return Response(
status="success",
@@ -276,26 +302,63 @@ class AstrBotDashBoard():
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.get("/api/llm/list")
def llm_list():
ret = []
for llm in self.global_object.llms:
ret.append(llm.llm_name)
return Response(
status="success",
message="",
data=ret
).__dict__
@self.dashboard_be.get("/api/llm")
def llm():
text = request.args["text"]
llm = request.args["llm"]
for llm_ in self.global_object.llms:
if llm_.llm_name == llm:
try:
# ret = await llm_.llm_instance.text_chat(text)
ret = asyncio.run_coroutine_threadsafe(
llm_.llm_instance.text_chat(text), self.loop).result()
return Response(
status="success",
message="",
data=ret
).__dict__
except Exception as e:
return Response(
status="error",
message=e.__str__(),
data=None
).__dict__
return Response(
status="error",
message="LLM not found.",
data=None
).__dict__
def shutdown_bot(self, delay_s: int):
time.sleep(delay_s)
py = sys.executable
os.execl(py, py, *sys.argv)
def _get_configs(self, namespace: str):
if namespace == "":
ret = [self.dashboard_data.configs['data'][5],
self.dashboard_data.configs['data'][6],]
ret = [self.dashboard_data.configs['data'][4],
self.dashboard_data.configs['data'][5],]
elif namespace == "internal_platform_qq_official":
ret = [self.dashboard_data.configs['data'][0],]
elif namespace == "internal_platform_qq_gocq":
ret = [self.dashboard_data.configs['data'][1],]
elif namespace == "internal_platform_general": # 全局平台配置
elif namespace == "internal_platform_general": # 全局平台配置
ret = [self.dashboard_data.configs['data'][2],]
elif namespace == "internal_llm_openai_official":
ret = [self.dashboard_data.configs['data'][3],]
elif namespace == "internal_llm_rev_chatgpt":
ret = [self.dashboard_data.configs['data'][4],]
else:
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
@@ -316,28 +379,28 @@ class AstrBotDashBoard():
'''
outline = [
{
"type": "platform",
"name": "配置通用消息平台",
"body": [
{
"title": "通用",
"desc": "通用平台配置",
"namespace": "internal_platform_general",
"tag": ""
},
{
"title": "QQ_OFFICIAL",
"desc": "QQ官方API,仅支持频道",
"namespace": "internal_platform_qq_official",
"tag": ""
},
{
"title": "OneBot协议",
"desc": "支持cq-http、shamrock等(目前仅支持QQ平台)",
"namespace": "internal_platform_qq_gocq",
"tag": ""
}
]
"type": "platform",
"name": "配置通用消息平台",
"body": [
{
"title": "通用",
"desc": "通用平台配置",
"namespace": "internal_platform_general",
"tag": ""
},
{
"title": "QQ_OFFICIAL",
"desc": "QQ官方API支持频道、群(需获得群权限)",
"namespace": "internal_platform_qq_official",
"tag": ""
},
{
"title": "go-cqhttp",
"desc": "第三方 QQ 协议实现。支持频道、群",
"namespace": "internal_platform_qq_gocq",
"tag": ""
}
]
},
{
"type": "llm",
@@ -348,12 +411,6 @@ class AstrBotDashBoard():
"desc": "也支持使用官方接口的中转服务",
"namespace": "internal_llm_openai_official",
"tag": ""
},
{
"title": "Rev ChatGPT",
"desc": "早期的逆向ChatGPT,不推荐",
"namespace": "internal_llm_rev_chatgpt",
"tag": ""
}
]
}
@@ -375,24 +432,32 @@ class AstrBotDashBoard():
return func
return decorator
async def get_log_history(self):
try:
with open("logs/astrbot-core/astrbot-core.log", "r", encoding="utf-8") as f:
return f.readlines()[-100:]
except Exception as e:
logger.warning(f"读取日志历史失败: {e.__str__()}")
return []
async def __handle_msg(self, websocket, path):
address = websocket.remote_address
# self.logger.log(f"和 {address} 建立了 websocket 连接", tag="可视化面板")
self.ws_clients[address] = websocket
data = ''.join(self.logger.history).replace('\n', '\r\n')
data = await self.get_log_history()
data = ''.join(data).replace('\n', '\r\n')
await websocket.send(data)
while True:
try:
msg = await websocket.recv()
except websockets.exceptions.ConnectionClosedError:
# self.logger.log(f"和 {address} 的 websocket 连接已断开", tag="可视化面板")
# logger.info(f"和 {address} 的 websocket 连接已断开")
del self.ws_clients[address]
break
except Exception as e:
# self.logger.log(f"和 {path} 的 websocket 连接发生了错误: {e.__str__()}", tag="可视化面板")
# logger.info(f"和 {path} 的 websocket 连接发生了错误: {e.__str__()}")
del self.ws_clients[address]
break
def run_ws_server(self, loop):
asyncio.set_event_loop(loop)
loop.run_until_complete(self.ws_server)
@@ -400,10 +465,12 @@ class AstrBotDashBoard():
def run(self):
threading.Thread(target=self.run_ws_server, args=(self.loop,)).start()
self.logger.log("已启动 websocket 服务器", tag="可视化面板")
logger.info("已启动 websocket 服务器")
ip_address = gu.get_local_ip_addresses()
ip_str = f"http://{ip_address}:6185\n\thttp://localhost:6185"
self.logger.log(f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n", tag="可视化面板")
http_server = make_server('0.0.0.0', 6185, self.dashboard_be, threaded=True)
logger.info(
f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n")
http_server = make_server(
'0.0.0.0', 6185, self.dashboard_be, threaded=True)
http_server.serve_forever()
+18 -20
View File
@@ -1,34 +1,29 @@
import os
import shutil
from nakuru.entities.components import *
from nakuru import (
GroupMessage,
FriendMessage
)
from botpy.message import Message, DirectMessage
flag_not_support = False
try:
from util.plugin_dev.api.v1.config import *
from util.plugin_dev.api.v1.bot import (
PluginMetadata,
PluginType,
AstrMessageEvent,
CommandResult,
)
from util.plugin_dev.api.v1.register import register_llm, unregister_llm
except ImportError:
flag_not_support = True
print("llms: 导入接口失败。请升级到 AstrBot 最新版本。")
print("导入接口失败。请升级到 AstrBot 最新版本。")
'''
注意改插件名噢!格式:XXXPlugin 或 Main
小提示:把此模板仓库 fork 之后 clone 到机器人文件夹下的 addons/plugins/ 目录下,然后用 Pycharm/VSC 等工具打开可获更棒的编程体验(自动补全等)
'''
class HelloWorldPlugin:
"""
初始化函数, 可以选择直接pass
"""
def __init__(self) -> None:
# 复制旧配置文件到 data 目录下。
if os.path.exists("keyword.json"):
@@ -45,6 +40,7 @@ class HelloWorldPlugin:
Tuple: Non e或者长度为 3 的元组。如果不响应, 返回 None; 如果响应, 第 1 个参数为指令是否调用成功, 第 2 个参数为返回的消息链列表, 第 3 个参数为指令名称
例子:一个名为"yuanshen"的插件;当接收到消息为“原神 可莉”, 如果不想要处理此消息,则返回False, None;如果想要处理,但是执行失败了,返回True, tuple([False, "请求失败。", "yuanshen"]) ;执行成功了,返回True, tuple([True, "结果文本", "yuanshen"])
"""
def run(self, ame: AstrMessageEvent):
if ame.message_str == "helloworld":
return CommandResult(
@@ -55,9 +51,10 @@ class HelloWorldPlugin:
)
if ame.message_str.startswith("/keyword") or ame.message_str.startswith("keyword"):
return self.handle_keyword_command(ame)
ret = self.check_keyword(ame.message_str)
if ret: return ret
if ret:
return ret
return CommandResult(
hit=False,
@@ -65,10 +62,10 @@ class HelloWorldPlugin:
message_chain=None,
command_name=None
)
def handle_keyword_command(self, ame: AstrMessageEvent):
l = ame.message_str.split(" ")
# 获取图片
image_url = ""
for comp in ame.message_obj.message:
@@ -77,7 +74,7 @@ class HelloWorldPlugin:
image_url = comp.file
else:
image_url = comp.url
command_result = CommandResult(
hit=True,
success=False,
@@ -116,11 +113,11 @@ keyword d hi
command_result.success = True
command_result.message_chain = [Plain("设置成功")]
return command_result
def save_keyword(self):
json.dump(self.keywords, open("data/keyword.json", "w"), ensure_ascii=False)
json.dump(self.keywords, open(
"data/keyword.json", "w"), ensure_ascii=False)
def check_keyword(self, message_str: str):
for k in self.keywords:
if message_str == k:
@@ -159,7 +156,8 @@ keyword d hi
"repo": str, # 插件仓库地址 [ 可选 ]
"homepage": str, # 插件主页 [ 可选 ]
}
"""
"""
def info(self):
return {
"name": "helloworld",
@@ -167,4 +165,4 @@ keyword d hi
"help": "输入 /keyword 查看关键词回复帮助。",
"version": "v1.3",
"author": "Soulter"
}
}
+132 -118
View File
@@ -2,38 +2,34 @@ import re
import threading
import asyncio
import time
import aiohttp
import util.unfit_words as uw
import os
import sys
import io
import traceback
import util.function_calling.gplugin as gplugin
import util.agent.web_searcher as web_searcher
import util.plugin_util as putil
from PIL import Image as PILImage
from typing import Union
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from nakuru.entities.components import Plain, At, Image
from addons.baidu_aip_judge import BaiduJudge
from model.provider.provider import Provider
from model.command.command import Command
from util import general_utils as gu
from util.general_utils import Logger, upload, run_monitor
from util.general_utils import upload, run_monitor
from util.cmd_config import CmdConfig as cc
from util.cmd_config import init_astrbot_config_items
from .types import *
from type.types import GlobalObject
from type.register import *
from type.message import AstrBotMessage
from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData
from cores.database.conn import dbConn
from persist.session import dbConn
from model.platform._message_result import MessageResult
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
# 用户发言频率
user_frequency = {}
@@ -43,10 +39,9 @@ frequency_time = 60
frequency_count = 10
# 版本
version = '3.1.11'
version = '3.1.13'
# 语言模型
REV_CHATGPT = 'rev_chatgpt'
OPENAI_OFFICIAL = 'openai_official'
NONE_LLM = 'none_llm'
chosen_provider = None
@@ -65,29 +60,30 @@ init_astrbot_config_items()
# 全局对象
_global_object: GlobalObject = None
logger: Logger = Logger()
# 语言模型选择
def privider_chooser(cfg):
l = []
if 'rev_ChatGPT' in cfg and cfg['rev_ChatGPT']['enable']:
l.append('rev_chatgpt')
if 'openai' in cfg and len(cfg['openai']['key']) > 0 and cfg['openai']['key'][0] is not None:
l.append('openai_official')
return l
'''
初始化机器人
'''
def init(cfg):
def init():
global llm_instance, llm_command_instance
global baidu_judge, chosen_provider
global frequency_count, frequency_time
global _global_object
global logger
# 迁移旧配置
gu.try_migrate_config(cfg)
gu.try_migrate_config()
# 使用新配置
cfg = cc.get_all()
@@ -98,8 +94,7 @@ def init(cfg):
_global_object = GlobalObject()
_global_object.version = version
_global_object.base_config = cfg
_global_object.logger = logger
logger.log("AstrBot v"+version, gu.LEVEL_INFO)
logger.info("AstrBot v"+version)
if 'reply_prefix' in cfg:
# 适配旧版配置
@@ -109,46 +104,52 @@ def init(cfg):
cc.put("reply_prefix", "")
else:
_global_object.reply_prefix = cfg['reply_prefix']
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 语言模型提供商
logger.log("正在载入语言模型...", gu.LEVEL_INFO)
logger.info("正在载入语言模型...")
prov = privider_chooser(cfg)
if REV_CHATGPT in prov:
logger.log("初始化:逆向 ChatGPT", gu.LEVEL_INFO)
if cfg['rev_ChatGPT']['enable']:
if 'account' in cfg['rev_ChatGPT']:
from model.provider.rev_chatgpt import ProviderRevChatGPT
from model.command.rev_chatgpt import CommandRevChatGPT
llm_instance[REV_CHATGPT] = ProviderRevChatGPT(cfg['rev_ChatGPT'], base_url=cc.get("CHATGPT_BASE_URL", None))
llm_command_instance[REV_CHATGPT] = CommandRevChatGPT(llm_instance[REV_CHATGPT], _global_object)
chosen_provider = REV_CHATGPT
_global_object.llms.append(RegisteredLLM(llm_name=REV_CHATGPT, llm_instance=llm_instance[REV_CHATGPT], origin="internal"))
else:
input("请退出本程序, 然后在配置文件中填写rev_ChatGPT相关配置")
if OPENAI_OFFICIAL in prov:
logger.log("初始化:OpenAI官方", gu.LEVEL_INFO)
logger.info("初始化:OpenAI官方")
if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]:
from model.provider.openai_official import ProviderOpenAIOfficial
from model.command.openai_official import CommandOpenAIOfficial
llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(cfg['openai'])
llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(llm_instance[OPENAI_OFFICIAL], _global_object)
_global_object.llms.append(RegisteredLLM(llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
llm_instance[OPENAI_OFFICIAL] = ProviderOpenAIOfficial(
cfg['openai'])
llm_command_instance[OPENAI_OFFICIAL] = CommandOpenAIOfficial(
llm_instance[OPENAI_OFFICIAL], _global_object)
_global_object.llms.append(RegisteredLLM(
llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
chosen_provider = OPENAI_OFFICIAL
instance = llm_instance[OPENAI_OFFICIAL]
assert isinstance(instance, ProviderOpenAIOfficial)
instance.DEFAULT_PERSONALITY = _global_object.default_personality
instance.curr_personality = instance.DEFAULT_PERSONALITY
# 检查provider设置偏好
p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance:
chosen_provider = p
# 百度内容审核
if 'baidu_aip' in cfg and 'enable' in cfg['baidu_aip'] and cfg['baidu_aip']['enable']:
try:
try:
baidu_judge = BaiduJudge(cfg['baidu_aip'])
logger.log("百度内容审核初始化成功", gu.LEVEL_INFO)
logger.info("百度内容审核初始化成功")
except BaseException as e:
logger.log("百度内容审核初始化失败", gu.LEVEL_ERROR)
threading.Thread(target=upload, args=(_global_object, ), daemon=True).start()
logger.info("百度内容审核初始化失败")
threading.Thread(target=upload, args=(
_global_object, ), daemon=True).start()
# 得到发言频率配置
if 'limit' in cfg:
@@ -156,18 +157,18 @@ def init(cfg):
frequency_count = cfg['limit']['count']
if 'time' in cfg['limit']:
frequency_time = cfg['limit']['time']
try:
if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']:
_global_object.unique_session = True
else:
_global_object.unique_session = False
except BaseException as e:
logger.log("独立会话配置错误: "+str(e), gu.LEVEL_ERROR)
logger.info("独立会话配置错误: "+str(e))
nick_qq = cc.get("nick_qq", None)
if nick_qq == None:
nick_qq = ("ai","!","")
nick_qq = ("ai", "!", "")
if isinstance(nick_qq, str):
nick_qq = (nick_qq,)
if isinstance(nick_qq, list):
@@ -178,42 +179,37 @@ def init(cfg):
global llm_wake_prefix
llm_wake_prefix = cc.get("llm_wake_prefix", "")
logger.log("正在载入插件...", gu.LEVEL_INFO)
logger.info("正在载入插件...")
# 加载插件
_command = Command(None, _global_object)
ok, err = putil.plugin_reload(_global_object.cached_plugins)
if ok:
logger.log(f"成功载入 {len(_global_object.cached_plugins)} 个插件", gu.LEVEL_INFO)
logger.info(
f"成功载入 {len(_global_object.cached_plugins)} 个插件")
else:
logger.log(err, gu.LEVEL_ERROR)
logger.info(err)
if chosen_provider is None:
llm_command_instance[NONE_LLM] = _command
chosen_provider = NONE_LLM
logger.log("正在载入机器人消息平台", gu.LEVEL_INFO)
# logger.log("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)", gu.LEVEL_WARNING)
logger.info("正在载入机器人消息平台")
# logger.info("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)")
platform_str = ""
# GOCQ
if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
logger.log("启用 QQ_GOCQ 机器人消息平台", gu.LEVEL_INFO)
threading.Thread(target=run_gocq_bot, args=(cfg, _global_object), daemon=True).start()
logger.info("启用 QQ_GOCQ 机器人消息平台")
threading.Thread(target=run_gocq_bot, args=(
cfg, _global_object), daemon=True).start()
platform_str += "QQ_GOCQ,"
# QQ频道
if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None:
logger.log("启用 QQ_OFFICIAL 机器人消息平台", gu.LEVEL_INFO)
threading.Thread(target=run_qqchan_bot, args=(cfg, _global_object), daemon=True).start()
logger.info("启用 QQ_OFFICIAL 机器人消息平台")
threading.Thread(target=run_qqchan_bot, args=(
cfg, _global_object), daemon=True).start()
platform_str += "QQ_OFFICIAL,"
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 初始化dashboard
_global_object.dashboard_data = DashBoardData(
stats={},
@@ -222,63 +218,81 @@ def init(cfg):
plugins=_global_object.cached_plugins,
)
dashboard_helper = DashBoardHelper(_global_object, config=cc.get_all())
dashboard_thread = threading.Thread(target=dashboard_helper.run, daemon=True)
dashboard_thread = threading.Thread(
target=dashboard_helper.run, daemon=True)
dashboard_thread.start()
# 运行 monitor
threading.Thread(target=run_monitor, args=(_global_object,), daemon=False).start()
threading.Thread(target=run_monitor, args=(
_global_object,), daemon=True).start()
logger.log("如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。", gu.LEVEL_INFO)
logger.log("请给 https://github.com/Soulter/AstrBot 点个 star。", gu.LEVEL_INFO)
logger.info(
"如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。")
logger.info("请给 https://github.com/Soulter/AstrBot 点个 star。")
if platform_str == '':
platform_str = "(未启动任何平台,请前往面板添加)"
logger.log(f"🎉 项目启动完成")
logger.info(f"🎉 项目启动完成")
dashboard_thread.join()
'''
运行 QQ_OFFICIAL 机器人
'''
def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
try:
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg, global_object=global_object)
global_object.platforms.append(RegisteredPlatform(platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal"))
qqchannel_bot = QQOfficial(
cfg=cfg, message_handler=oper_msg, global_object=global_object)
global_object.platforms.append(RegisteredPlatform(
platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal"))
qqchannel_bot.run()
except BaseException as e:
logger.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), gu.LEVEL_CRITICAL, tag="QQ频道")
logger.log(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。" + str(e), gu.LEVEL_CRITICAL)
logger.error("启动 QQ 频道机器人时出现错误, 原因如下: " + str(e))
logger.error(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。")
'''
运行 QQ_GOCQ 机器人
'''
def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
from model.platform.qq_gocq import QQGOCQ
noticed = False
host = cc.get("gocq_host", "127.0.0.1")
port = cc.get("gocq_websocket_port", 6700)
http_port = cc.get("gocq_http_port", 5700)
logger.log(f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}", tag="QQ")
logger.info(
f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}")
while True:
if not gu.port_checker(port=port, host=host) or not gu.port_checker(port=http_port, host=host):
if not noticed:
noticed = True
logger.log(f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ")
logger.warning(
f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。")
time.sleep(5)
else:
logger.log("检查完毕,未发现问题。", tag="QQ")
logger.info("已连接到 gocq。")
break
try:
qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, global_object=_global_object)
_global_object.platforms.append(RegisteredPlatform(platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg,
global_object=_global_object)
_global_object.platforms.append(RegisteredPlatform(
platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
qq_gocq.run()
except BaseException as e:
input("启动QQ机器人出现错误"+str(e))
'''
检查发言频率
'''
def check_frequency(id) -> bool:
ts = int(time.time())
if id in user_frequency:
@@ -290,13 +304,14 @@ def check_frequency(id) -> bool:
if user_frequency[id]['count'] >= frequency_count:
return False
else:
user_frequency[id]['count']+=1
user_frequency[id]['count'] += 1
return True
else:
t = {'time':ts,'count':1}
t = {'time': ts, 'count': 1}
user_frequency[id] = t
return True
async def record_message(platform: str, session_id: str):
# TODO: 这里会非常吃资源。然而 sqlite3 不支持多线程,所以暂时这样写。
curr_ts = int(time.time())
@@ -306,11 +321,12 @@ async def record_message(platform: str, session_id: str):
db_inst.increment_stat_platform(curr_ts, platform, 1)
_global_object.cnt_total += 1
async def oper_msg(message: AstrBotMessage,
session_id: str,
role: str = 'member',
platform: str = None,
) -> MessageResult:
session_id: str,
role: str = 'member',
platform: str = None,
) -> MessageResult:
"""
处理消息
message: 消息对象
@@ -322,9 +338,9 @@ async def oper_msg(message: AstrBotMessage,
message_str = ''
session_id = session_id
role = role
hit = False # 是否命中指令
command_result = () # 调用指令返回的结果
hit = False # 是否命中指令
command_result = () # 调用指令返回的结果
# 获取平台实例
reg_platform: RegisteredPlatform = None
for p in _global_object.platforms:
@@ -332,9 +348,8 @@ async def oper_msg(message: AstrBotMessage,
reg_platform = p
break
if not reg_platform:
_global_object.logger.log(f"未找到平台 {platform} 的实例。", gu.LEVEL_ERROR)
raise Exception(f"未找到平台 {platform} 的实例。")
# 统计数据,如频道消息量
await record_message(platform, session_id)
@@ -343,19 +358,17 @@ async def oper_msg(message: AstrBotMessage,
message_str += i.text.strip()
if message_str == "":
return MessageResult("Hi~")
# 检查发言频率
if not check_frequency(message.sender.user_id):
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')
# 检查是否是更换语言模型的请求
temp_switch = ""
if message_str.startswith('/gpt') or message_str.startswith('/revgpt'):
if message_str.startswith('/gpt'):
target = chosen_provider
if message_str.startswith('/gpt'):
target = OPENAI_OFFICIAL
elif message_str.startswith('/revgpt'):
target = REV_CHATGPT
l = message_str.split(' ')
if len(l) > 1 and l[1] != "":
# 临时对话模式,先记录下之前的语言模型,回答完毕后再切回
@@ -370,8 +383,13 @@ async def oper_msg(message: AstrBotMessage,
llm_result_str = ""
# check commands and plugins
message_str_no_wake_prefix = message_str
for wake_prefix in _global_object.nick: # nick: tuple
if message_str.startswith(wake_prefix):
message_str_no_wake_prefix = message_str.removeprefix(wake_prefix)
break
hit, command_result = await llm_command_instance[chosen_provider].check_command(
message_str,
message_str_no_wake_prefix,
session_id,
role,
reg_platform,
@@ -390,7 +408,7 @@ async def oper_msg(message: AstrBotMessage,
if not check:
return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}")
if chosen_provider == NONE_LLM:
logger.log("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。", gu.LEVEL_WARNING)
logger.info("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。")
return
try:
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
@@ -412,22 +430,22 @@ async def oper_msg(message: AstrBotMessage,
web_sch_flag = True
else:
message_str += " " + cc.get("llm_env_prompt", "")
if chosen_provider == REV_CHATGPT or chosen_provider == OPENAI_OFFICIAL:
if chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
else:
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality = _global_object.default_personality)
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url)
llm_result_str = _global_object.reply_prefix + llm_result_str
except BaseException as e:
logger.log(f"调用异常:{traceback.format_exc()}", gu.LEVEL_ERROR)
return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}")
logger.error(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用异常。详细原因:{str(e)}")
# 切换回原来的语言模型
if temp_switch != "":
chosen_provider = temp_switch
if hit:
# 有指令或者插件触发
# command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
@@ -443,16 +461,12 @@ async def oper_msg(message: AstrBotMessage,
if not command_result[0]:
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
# 画图指令
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw':
for i in command_result[1]:
# 保存到本地
async with aiohttp.ClientSession() as session:
async with session.get(i) as resp:
if resp.status == 200:
image = PILImage.open(io.BytesIO(await resp.read()))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
if command == 'draw':
# 保存到本地
path = await gu.download_image_by_url(command_result[1])
return MessageResult([Image.fromFileSystem(path)])
# 其他指令
else:
try:
@@ -474,4 +488,4 @@ async def oper_msg(message: AstrBotMessage,
try:
return MessageResult(llm_result_str)
except BaseException as e:
logger.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR)
logger.info("回复消息错误: \n"+str(e))
-168
View File
@@ -1,168 +0,0 @@
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from nakuru.entities.components import BaseMessageComponent
from typing import Union, List, ClassVar
from types import ModuleType
from enum import Enum
from dataclasses import dataclass
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
GUILD_MESSAGE = 'GuildMessage' # 频道消息
@dataclass
class MessageMember():
user_id: str # 发送者id
nickname: str = None
class AstrBotMessage():
'''
AstrBot 的消息对象
'''
tag: str # 消息来源标签
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
message_id: str # 消息id
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __str__(self) -> str:
return str(self.__dict__)
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: str # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
logger: None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
class AstrMessageEvent():
'''
消息事件。
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: AstrBotMessage # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+93 -79
View File
@@ -1,96 +1,110 @@
import os, sys
from pip._internal import main as pipmain
import os
import sys
import warnings
import traceback
import threading
from logging import Formatter, Logger
warnings.filterwarnings("ignore")
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
def main():
# config.yaml 配置文件加载和环境确认
try:
import cores.astrbot.core as qqBot
import yaml
import util.general_utils as gu
ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8')
cfg = yaml.safe_load(ymlfile)
except ImportError as import_error:
traceback.print_exc()
print(import_error)
input("第三方库未完全安装完毕,请退出程序重试。")
except FileNotFoundError as file_not_found:
print(file_not_found)
input("配置文件不存在,请检查是否已经下载配置文件。")
except BaseException as e:
raise e
# 设置代理
if 'http_proxy' in cfg and cfg['http_proxy'] != '':
os.environ['HTTP_PROXY'] = cfg['http_proxy']
if 'https_proxy' in cfg and cfg['https_proxy'] != '':
os.environ['HTTPS_PROXY'] = cfg['https_proxy']
os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com'
logger: Logger = None
# 检查并创建 temp 文件夹
if not os.path.exists(abs_path + "temp"):
os.mkdir(abs_path+"temp")
if not os.path.exists(abs_path + "data"):
os.mkdir(abs_path+"data")
if not os.path.exists(abs_path + "data/config"):
os.mkdir(abs_path+"data/config")
logo_tmpl = """
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
def make_necessary_dirs():
'''
创建必要的目录。
'''
os.makedirs("data/config", exist_ok=True)
os.makedirs("temp", exist_ok=True)
def update_dept():
'''
更新依赖库。
'''
# 获取 Python 可执行文件路径
py = sys.executable
# 更新依赖库
mirror = "https://mirrors.aliyun.com/pypi/simple/"
os.system(f"{py} -m pip install -r requirements.txt -i {mirror}")
def main():
try:
import botpy, logging
import astrbot.core as bot_core
# delete qqbotpy's logger
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
except ImportError as import_error:
logger.error(import_error)
logger.error("检测到一些依赖库没有安装。由于兼容性问题,AstrBot 此版本将不会自动为您安装依赖库。请您先自行安装,然后重试。")
logger.info("如何安装?如果:")
logger.info("- Windows 启动器部署且使用启动器下载了 Python的:在 launcher.exe 所在目录下的地址框输入 powershell,然后执行 .\python\python.exe -m pip install .\AstrBot\requirements.txt")
logger.info("- Windows 启动器部署且使用自己之前下载的 Python的:在 launcher.exe 所在目录下的地址框输入 powershell,然后执行 python -m pip install .\AstrBot\requirements.txt")
logger.info("- 自行 clone 源码部署的:python -m pip install -r requirements.txt")
logger.info("- 如果还不会,加群 322154837 ")
input("按任意键退出。")
exit()
except FileNotFoundError as file_not_found:
logger.error(file_not_found)
input("配置文件不存在,请检查是否已经下载配置文件。")
exit()
except BaseException as e:
logger.error(traceback.format_exc())
input("未知错误。")
exit()
make_necessary_dirs()
# 启动主程序(cores/qqbot/core.py
qqBot.init(cfg)
bot_core.init()
def check_env(ch_mirror=False):
def check_env():
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
print("请使用Python3.9+运行本项目")
input("按任意键退出...")
logger.error("请使用 Python3.9+ 运行本项目。按任意键退出。")
input("")
exit()
if os.path.exists('requirements.txt'):
pth = 'requirements.txt'
else:
pth = 'QQChannelChatGPT'+ os.sep +'requirements.txt'
print("正在检查或下载第三方库,请耐心等待...")
try:
if ch_mirror:
print("使用阿里云镜像")
pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/'])
else:
pipmain(['install', '-r', pth])
except BaseException as e:
print(e)
while True:
res = input("安装失败。\n如报错ValueError: check_hostname requires server_hostname,请尝试先关闭代理后重试。\n1.输入y回车重试\n2. 输入c回车使用国内镜像源下载\n3. 输入其他按键回车继续往下执行。")
if res == "y":
try:
pipmain(['install', '-r', pth])
break
except BaseException as e:
print(e)
continue
elif res == "c":
try:
pipmain(['install', '-r', pth, '-i', 'https://mirrors.aliyun.com/pypi/simple/'])
break
except BaseException as e:
print(e)
continue
else:
break
print("第三方库检查完毕。")
if __name__ == "__main__":
args = sys.argv
if '-cn' in args:
check_env(True)
else:
check_env()
t = threading.Thread(target=main, daemon=False)
# 设置代理
from util.cmd_config import CmdConfig
cc = CmdConfig()
http_proxy = cc.get("http_proxy")
https_proxy = cc.get("https_proxy")
if http_proxy:
os.environ['HTTP_PROXY'] = http_proxy
if https_proxy:
os.environ['HTTPS_PROXY'] = https_proxy
os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com'
update_dept()
from SparkleLogging.utils.core import LogManager
logger = LogManager.GetLogger(
log_name='astrbot-core',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
logger.info(logo_tmpl)
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
check_env()
t = threading.Thread(target=main, daemon=True)
t.start()
t.join()
try:
t.join()
except KeyboardInterrupt as e:
logger.info("退出 AstrBot。")
exit()
+58 -49
View File
@@ -13,34 +13,35 @@ from nakuru.entities.components import (
from util import general_utils as gu
from model.provider.provider import Provider
from util.cmd_config import CmdConfig as cc
from util.general_utils import Logger
from cores.astrbot.types import (
GlobalObject,
AstrMessageEvent,
PluginType,
CommandResult,
RegisteredPlugin,
RegisteredPlatform
)
from type.message import *
from type.types import GlobalObject
from type.command import *
from type.plugin import *
from type.register import *
from typing import List, Tuple
from typing import List
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq'
# 指令功能的基类,通用的(不区分语言模型)的指令就在这实现
class Command:
def __init__(self, provider: Provider, global_object: GlobalObject = None):
self.provider = provider
self.global_object = global_object
self.logger: Logger = global_object.logger
async def check_command(self,
message,
session_id: str,
role: str,
platform: RegisteredPlatform,
message_obj):
async def check_command(self,
message,
session_id: str,
role: str,
platform: RegisteredPlatform,
message_obj):
self.platform = platform
# 插件
cached_plugins = self.global_object.cached_plugins
@@ -51,7 +52,7 @@ class Command:
platform=platform,
role=role,
context=self.global_object,
session_id = session_id
session_id=session_id
)
# 从已启动的插件中查找是否有匹配的指令
for plugin in cached_plugins:
@@ -72,6 +73,7 @@ class Command:
else:
raise TypeError("插件返回值格式错误。")
if hit:
logger.debug("hit plugin: " + plugin.metadata.plugin_name)
return True, res
except TypeError as e:
# 参数不匹配,尝试使用旧的参数方案
@@ -83,9 +85,11 @@ class Command:
if hit:
return True, res
except BaseException as e:
self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
logger.error(
f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。")
except BaseException as e:
self.logger.log(f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING)
logger.error(
f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。")
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role)
@@ -93,13 +97,13 @@ class Command:
return True, self.plugin_oper(message, role, cached_plugins, platform)
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
return True, self.get_my_id(message_obj, platform)
if self.command_start_with(message, "web"): # 网页搜索
if self.command_start_with(message, "web"): # 网页搜索
return True, self.web_search(message)
if self.command_start_with(message, "update"):
return True, self.update(message, role)
if not self.provider and self.command_start_with(message, "help"):
return True, await self.help()
return False, None
def web_search(self, message):
@@ -126,16 +130,19 @@ class Command:
l = message.split(" ")
if len(l) <= 1:
obj = cc.get_all()
p = gu.create_text_image("【cmd_config.json】", json.dumps(obj, indent=4, ensure_ascii=False))
p = gu.create_text_image("【cmd_config.json】", json.dumps(
obj, indent=4, ensure_ascii=False))
return True, [Image.fromFileSystem(p)], "newconf"
'''
插件指令
'''
def plugin_oper(self, message: str, role: str, cached_plugins: List[RegisteredPlugin], platform: str):
l = message.split(" ")
if len(l) < 2:
p = gu.create_text_image("【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n")
p = gu.create_text_image(
"【插件指令面板】", "安装插件: \nplugin i 插件Github地址\n卸载插件: \nplugin d 插件名 \n重载插件: \nplugin reload\n查看插件列表:\nplugin l\n更新插件: plugin u 插件名\n")
return True, [Image.fromFileSystem(p)], "plugin"
else:
if l[1] == "i":
@@ -165,7 +172,8 @@ class Command:
plugin_list_info = ""
for plugin in cached_plugins:
plugin_list_info += f"{plugin.metadata.plugin_name}: \n名称: {plugin.metadata.plugin_name}\n简介: {plugin.metadata.plugin_desc}\n版本: {plugin.metadata.version}\n作者: {plugin.metadata.author}\n"
p = gu.create_text_image("【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n")
p = gu.create_text_image(
"【已激活插件列表】", plugin_list_info + "\n使用plugin v 插件名 查看插件帮助\n")
return True, [Image.fromFileSystem(p)], "plugin"
except BaseException as e:
return False, f"获取插件列表失败,原因: {str(e)}", "plugin"
@@ -177,7 +185,8 @@ class Command:
info = i.metadata
break
if info:
p = gu.create_text_image(f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}")
p = gu.create_text_image(
f"【插件信息】", f"名称: {info.plugin_name}\n类型: {info.plugin_type}\n{info.desc}\n版本: {info.version}\n作者: {info.author}")
return True, [Image.fromFileSystem(p)], "plugin"
else:
return False, "未找到该插件", "plugin"
@@ -187,10 +196,11 @@ class Command:
'''
nick: 存储机器人的昵称
'''
def set_nick(self, message: str, platform: str, role: str = "member"):
def set_nick(self, message: str, platform: RegisteredPlatform, role: str = "member"):
if role != "admin":
return True, "你无权使用该指令 :P", "nick"
if platform == PLATFORM_GOCQ:
if str(platform) == PLATFORM_GOCQ:
l = message.split(" ")
if len(l) == 1:
return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick"
@@ -198,7 +208,7 @@ class Command:
cc.put("nick_qq", nick)
self.global_object.nick = tuple(nick)
return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick"
elif platform == PLATFORM_QQCHAN:
elif str(platform) == PLATFORM_QQCHAN:
nick = message.split(" ")[2]
return False, "QQ频道平台不支持为机器人设置昵称。", "nick"
@@ -207,14 +217,11 @@ class Command:
"help": "帮助",
"keyword": "设置关键词/关键指令回复",
"update": "更新项目",
"nick": "设置机器人昵称",
"nick": "设置机器人唤醒词",
"plugin": "插件安装、卸载和重载",
"web on/off": "LLM 网页搜索能力",
"reset": "重置 LLM 对话",
"/gpt": "切换到 OpenAI 官方接口",
"/revgpt": "切换到网页版ChatGPT",
}
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):
try:
async with aiohttp.ClientSession() as session:
@@ -239,9 +246,9 @@ class Command:
p = gu.create_markdown_image(msg)
return [Image.fromFileSystem(p),]
except BaseException as e:
self.logger.log(str(e))
logger.error(str(e))
return msg
def command_start_with(self, message: str, *args):
'''
当消息以指定的指令开头时返回True
@@ -250,7 +257,7 @@ class Command:
if message.startswith(arg) or message.startswith('/'+arg):
return True
return False
def update(self, message: str, role: str):
if role != "admin":
return True, "你没有权限使用该指令", "update"
@@ -275,8 +282,10 @@ class Command:
else:
if l[1].lower().startswith('v'):
try:
release_data = util.updator.request_release_info(latest=False)
util.updator.update_project(release_data, latest=False, version=l[1])
release_data = util.updator.request_release_info(
latest=False)
util.updator.update_project(
release_data, latest=False, version=l[1])
return True, "更新成功,重启生效。可输入「update r」重启", "update"
except BaseException as e:
return False, "更新失败: "+str(e), "update"
@@ -285,28 +294,28 @@ class Command:
def reset(self):
return False
def set(self):
return False
def unset(self):
return False
def key(self):
return False
async def help(self):
ret = await self.help_messager(self.general_commands(), self.platform, self.global_object.cached_plugins)
return True, ret, "help"
def status(self):
return False
def token(self):
return False
def his(self):
return False
def draw(self):
return False
return False
+113 -133
View File
@@ -1,23 +1,34 @@
from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
from util.personality import personalities
from cores.astrbot.types import GlobalObject
from type.types import GlobalObject
from type.command import CommandItem
from SparkleLogging.utils.core import LogManager
from logging import Logger
from openai._exceptions import NotFoundError
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
self.provider = provider
self.global_object = global_object
self.personality_str = ""
self.commands = [
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"),
]
super().__init__(provider, global_object)
async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
self.platform = platform
# 检查基础指令
hit, res = await super().check_command(
message,
@@ -26,7 +37,9 @@ class CommandOpenAIOfficial(Command):
platform,
message_obj
)
logger.debug(f"基础指令hit: {hit}, res: {res}")
# 这里是这个 LLM 的专属指令
if hit:
return True, res
@@ -34,12 +47,8 @@ class CommandOpenAIOfficial(Command):
return True, await self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id)
elif self.command_start_with(message, "token"):
return True, self.token(session_id)
elif self.command_start_with(message, "gpt"):
return True, self.gpt()
elif self.command_start_with(message, "status"):
return True, self.status()
return True, self.status(session_id)
elif self.command_start_with(message, "help", "帮助"):
return True, await self.help()
elif self.command_start_with(message, "unset"):
@@ -50,100 +59,111 @@ class CommandOpenAIOfficial(Command):
return True, self.update(message, role)
elif self.command_start_with(message, "", "draw"):
return True, await self.draw(message)
elif self.command_start_with(message, "key"):
return True, self.key(message)
elif self.command_start_with(message, "switch"):
return True, await self.switch(message)
elif self.command_start_with(message, "models"):
return True, await self.print_models()
elif self.command_start_with(message, "model"):
return True, await self.set_model(message)
return False, None
async def get_models(self):
try:
models = await self.provider.client.models.list()
except NotFoundError as e:
bu = str(self.provider.client.base_url)
self.provider.client.base_url = bu + "/v1"
models = await self.provider.client.models.list()
finally:
return filter(lambda x: x.id.startswith("gpt"), models.data)
async def print_models(self):
models = await self.get_models()
i = 1
ret = "OpenAI GPT 类可用模型"
for model in models:
ret += f"\n{i}. {model.id}"
i += 1
logger.debug(ret)
return True, ret, "models"
async def set_model(self, message: str):
l = message.split(" ")
if len(l) == 1:
return True, "请输入 /model 模型名/编号", "model"
model = str(l[1])
models = await self.get_models()
models = list(models)
if model.isdigit() and int(model) <= len(models) and int(model) >= 1:
model = models[int(model)-1]
else:
f = False
for m in models:
if model == m.id:
f = True
break
if not f:
return True, "模型不存在或输入非法", "model"
self.provider.set_model(model.id)
return True, f"模型已设置为 {model.id}", "model"
async def help(self):
commands = super().general_commands()
commands[''] = '画画'
commands['key'] = '添加OpenAI key'
commands[''] = '调用 OpenAI DallE 模型生成图片'
commands['set'] = '人格设置面板'
commands['gpt'] = '查看gpt配置信息'
commands['status'] = '查看key使用状态'
commands['token'] = '查看本轮会话token'
commands['status'] = '查看 Api Key 状态和配置信息'
commands['token'] = '查看本轮会话 token'
commands['reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt'
commands['reset p'] = '重置当前与 LLM 的会话,并清除人格。'
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
async def reset(self, session_id: str, message: str = "reset"):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
await self.provider.forget(session_id)
await self.provider.forget(session_id, keep_system_prompt=True)
return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
if self.personality_str != "":
self.set(self.personality_str, session_id) # 重新设置人格
return True, "重置成功", "reset"
await self.provider.forget(session_id)
def his(self, message: str, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "his"
#分页,每页5条
msg = ''
size_per_page = 3
page = 1
if message[4:]:
page = int(message[4:])
# 检查是否有过历史记录
if session_id not in self.provider.session_dict:
msg = f"历史记录为空"
return True, msg, "his"
l = self.provider.session_dict[session_id]
max_page = len(l)//size_per_page + 1 if len(l)%size_per_page != 0 else len(l)//size_per_page
p = self.provider.get_prompts_by_cache_list(self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
return True, f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页", "his"
def token(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "token"
return True, f"会话的token数: {self.provider.get_user_usage_tokens(self.provider.session_dict[session_id])}\n系统最大缓存token数: {self.provider.max_tokens}", "token"
l = message.split(" ")
if len(l) == 2:
try:
page = int(l[1])
except BaseException as e:
return True, "页码不合法", "his"
contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page)
t_pages = total_num // size_per_page + 1
return True, f"历史记录如下:\n{contexts}\n{page} 页 | 共 {t_pages}\n*输入 /his 2 跳转到第 2 页", "his"
def gpt(self):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "gpt"
return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt"
def status(self):
def status(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "status"
chatgpt_cfg_str = ""
key_stat = self.provider.get_key_stat()
index = 1
max = 9000000
gg_count = 0
total = 0
tag = ''
for key in key_stat.keys():
sponsor = ''
total += key_stat[key]['used']
if key_stat[key]['exceed']:
gg_count += 1
continue
if 'sponsor' in key_stat[key]:
sponsor = key_stat[key]['sponsor']
chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
index += 1
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
keys_data = self.provider.get_keys_data()
ret = "OpenAI Key"
for k in keys_data:
status = "🟢" if keys_data[k] else "🔴"
ret += "\n|- " + k[:8] + " " + status
def key(self, message: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ")
if len(l) == 1:
msg = "感谢您赞助key,key为官方API使用,请以以下格式赞助:\n/key xxxxx"
return True, msg, "key"
key = l[1]
if self.provider.check_key(key):
self.provider.append_key(key)
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
else:
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
conf = self.provider.get_configs()
ret += "\n当前模型:" + conf['model']
if conf['model'] in MODELS:
ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens"
if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]):
ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens"
return True, ret, "status"
async def switch(self, message: str):
'''
@@ -160,14 +180,13 @@ class CommandOpenAIOfficial(Command):
return True, ret, "switch"
elif len(l) == 2:
try:
key_stat = self.provider.get_key_stat()
key_stat = self.provider.get_keys_data()
index = int(l[1])
if index > len(key_stat) or index < 1:
return True, "账号序号不合法。", "switch"
else:
try:
new_key = list(key_stat.keys())[index-1]
ret = await self.provider.check_key(new_key)
self.provider.set_key(new_key)
except BaseException as e:
return True, "账号切换失败,原因: " + str(e), "switch"
@@ -216,58 +235,19 @@ class CommandOpenAIOfficial(Command):
'name': ps,
'prompt': personalities[ps]
}
self.provider.session_dict[session_id] = []
new_record = {
"user": {
"role": "user",
"content": personalities[ps],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
self.provider.personality_set(ps, session_id)
return True, f"人格{ps}已设置。", "set"
else:
self.provider.curr_personality = {
'name': '自定义人格',
'prompt': ps
}
new_record = {
"user": {
"role": "user",
"content": ps,
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id] = []
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
self.provider.personality_set(ps, session_id)
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
async def draw(self, message):
async def draw(self, message: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "draw"
if message.startswith("/"):
message = message[2:]
elif message.startswith(""):
message = message[1:]
try:
# 画图模式传回3个参数
img_url = await self.provider.image_chat(message)
return True, img_url, "draw"
except Exception as e:
if 'exceeded' in str(e):
return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)"
return False, f"图片生成失败: {e}", "draw"
message = message.removeprefix("/").removeprefix("")
img_url = await self.provider.image_generate(message)
return True, img_url, "draw"
-132
View File
@@ -1,132 +0,0 @@
from model.command.command import Command
from model.provider.rev_chatgpt import ProviderRevChatGPT
from util.personality import personalities
from cores.astrbot.types import GlobalObject
class CommandRevChatGPT(Command):
def __init__(self, provider: ProviderRevChatGPT, global_object: GlobalObject):
self.provider = provider
self.global_object = global_object
self.personality_str = ""
super().__init__(provider, global_object)
async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
self.platform = platform
hit, res = await super().check_command(
message,
session_id,
role,
platform,
message_obj
)
if hit:
return True, res
if self.command_start_with(message, "help", "帮助"):
return True, await self.help()
elif self.command_start_with(message, "reset"):
return True, self.reset(session_id, message)
elif self.command_start_with(message, "update"):
return True, self.update(message, role)
elif self.command_start_with(message, "set"):
return True, self.set(message, session_id)
elif self.command_start_with(message, "switch"):
return True, self.switch(message, session_id)
return False, None
def reset(self, session_id, message: str):
l = message.split(" ")
if len(l) == 1:
self.provider.forget(session_id)
return True, "重置完毕。", "reset"
if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id)
ret = self.provider.text_chat(self.personality_str)
return True, f"重置完毕(保留人格)。\n\n{ret}", "reset"
def set(self, message: str, session_id: str):
l = message.split(" ")
if len(l) == 1:
return True, f"设置人格: \n/set 人格名或人格文本。例如/set 编剧\n人格列表: /set list\n人格详细信息: \
/set view 人格名\n重置会话(清除人格): /reset\n重置会话(保留人格): /reset p", "set"
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
msg += f" |-{key}\n"
msg += '\n\n*输入/set view 人格名查看人格详细信息'
msg += '\n*不定时更新人格库,请及时更新本项目。'
return True, msg, "set"
elif l[1] == "view":
if len(l) == 2:
return True, "请输入/set view 人格名", "set"
ps = l[2].strip()
if ps in personalities:
msg = f"人格【{ps}】详细信息:\n"
msg += f"{personalities[ps]}\n"
else:
msg = f"人格【{ps}】不存在。"
return True, msg, "set"
else:
ps = l[1].strip()
if ps in personalities:
self.reset(session_id, "reset")
self.personality_str = personalities[ps]
ret = self.provider.text_chat(self.personality_str, session_id)
return True, f"人格【{ps}】已设置。\n\n{ret}", "set"
else:
self.reset(session_id, "reset")
self.personality_str = ps
ret = self.provider.text_chat(ps, session_id)
return True, f"人格信息已设置。\n\n{ret}", "set"
def switch(self, message: str, session_id: str):
'''
切换账号
'''
l = message.split(" ")
rev_chatgpt = self.provider.get_revchatgpt()
if len(l) == 1:
ret = "当前账号:\n"
index = 0
curr_ = None
for revstat in rev_chatgpt:
index += 1
ret += f"[{index}]. {revstat['id']}\n"
# if session_id in revstat['user']:
# curr_ = revstat['id']
for user in revstat['user']:
if session_id == user['id']:
curr_ = revstat['id']
break
if curr_ is None:
ret += "当前您未选择账号。输入/switch <账号序号>切换账号。"
else:
ret += f"当前您选择的账号为:{curr_}。输入/switch <账号序号>切换账号。"
return True, ret, "switch"
elif len(l) == 2:
try:
index = int(l[1])
if index > len(self.provider.rev_chatgpt) or index < 1:
return True, "账号序号不合法。", "switch"
else:
# pop
for revstat in self.provider.rev_chatgpt:
if session_id in revstat['user']:
revstat['user'].remove(session_id)
# append
self.provider.rev_chatgpt[index - 1]['user'].append(session_id)
return True, f"切换账号成功。当前账号为:{self.provider.rev_chatgpt[index - 1]['id']}", "switch"
except BaseException:
return True, "账号序号不合法。", "switch"
else:
return True, "参数过多。", "switch"
async def help(self):
commands = super().general_commands()
commands['set'] = '设置人格'
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
+17 -11
View File
@@ -5,14 +5,16 @@ from nakuru import (
FriendMessage
)
import botpy.message
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
from type.message import *
from typing import List, Union
import time
# QQ官方消息类型转换
def qq_official_message_parse(message: List[BaseMessageComponent]):
plain_text = ""
image_path = None # only one img supported
image_path = None # only one img supported
for i in message:
if isinstance(i, Plain):
plain_text += i.text
@@ -24,6 +26,8 @@ def qq_official_message_parse(message: List[BaseMessageComponent]):
return plain_text, image_path
# QQ官方消息类型 2 AstrBotMessage
def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.message.GroupMessage],
message_type: MessageType) -> AstrBotMessage:
abm = AstrBotMessage()
@@ -33,7 +37,7 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me
abm.message_id = message.id
abm.tag = "qqchan"
msg: List[BaseMessageComponent] = []
if message_type == MessageType.GROUP_MESSAGE:
abm.sender = MessageMember(
message.author.member_openid,
@@ -41,7 +45,7 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me
)
abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid"
msg.append(Plain(abm.message_str))
if message.attachments:
for i in message.attachments:
@@ -52,15 +56,16 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me
img = Image.fromURL(url)
msg.append(img)
abm.message = msg
elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE:
# 目前对于 FRIEND_MESSAGE 只处理频道私聊
try:
abm.self_id = str(message.mentions[0].id)
except:
abm.self_id = ""
plain_content = message.content.replace("<@!"+str(abm.self_id)+">", "").strip()
plain_content = message.content.replace(
"<@!"+str(abm.self_id)+">", "").strip()
msg.append(Plain(plain_content))
if message.attachments:
for i in message.attachments:
@@ -80,19 +85,20 @@ def qq_official_message_parse_rev(message: Union[botpy.message.Message, botpy.me
raise ValueError(f"Unknown message type: {message_type}")
return abm
def nakuru_message_parse_rev(message: Union[GuildMessage, GroupMessage, FriendMessage]) -> AstrBotMessage:
abm = AstrBotMessage()
abm.type = MessageType(message.type)
abm.timestamp = int(time.time())
abm.raw_message = message
abm.message_id = message.message_id
plain_content = ""
for i in message.message:
if isinstance(i, Plain):
plain_content += i.text
abm.message_str = plain_content
abm.self_id = str(message.self_id)
abm.sender = MessageMember(
str(message.sender.user_id),
@@ -100,5 +106,5 @@ def nakuru_message_parse_rev(message: Union[GuildMessage, GroupMessage, FriendMe
)
abm.tag = "gocq"
abm.message = message.message
return abm
return abm
+1
View File
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Union, Optional
@dataclass
class MessageResult():
result_message: Union[str, list]
+1 -1
View File
@@ -43,7 +43,7 @@ class Platform():
发送消息(主动发送)同 send_msg()
'''
pass
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str:
'''
将消息解析成大纲消息形式。
+46 -45
View File
@@ -11,11 +11,16 @@ from nakuru import (
Notify
)
from typing import Union
from type.types import GlobalObject
import time
from ._platfrom import Platform
from ._message_parse import nakuru_message_parse_rev
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
from type.message import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class FakeSource:
@@ -23,9 +28,9 @@ class FakeSource:
self.type = type
self.group_id = group_id
class QQGOCQ(Platform):
def __init__(self, cfg: dict, message_handler: callable, global_object) -> None:
def __init__(self, cfg: dict, message_handler: callable, global_object: GlobalObject) -> None:
super().__init__(message_handler)
self.loop = asyncio.new_event_loop()
@@ -34,16 +39,9 @@ class QQGOCQ(Platform):
self.waiting = {}
self.cc = CmdConfig()
self.cfg = cfg
self.logger: gu.Logger = global_object.logger
try:
self.nick_qq = cfg['nick_qq']
except:
self.nick_qq = ["ai","!",""]
nick_qq = self.nick_qq
if isinstance(nick_qq, str):
nick_qq = [nick_qq]
self.context = global_object
self.unique_session = cfg['uniqueSessionMode']
self.pic_mode = cfg['qq_pic_mode']
@@ -67,7 +65,7 @@ class QQGOCQ(Platform):
await self.handle_msg(abm)
else:
return
@gocq_app.receiver("FriendMessage")
async def _(app: CQHTTP, source: FriendMessage):
if self.cc.get("gocq_react_friend", True):
@@ -76,12 +74,12 @@ class QQGOCQ(Platform):
await self.handle_msg(abm)
else:
return
@gocq_app.receiver("GroupMemberIncrease")
async def _(app: CQHTTP, source: GroupMemberIncrease):
if self.cc.get("gocq_react_group_increase", True):
await app.sendGroupMessage(source.group_id, [
Plain(text = self.announcement)
Plain(text=self.announcement)
])
# @gocq_app.receiver("Notify")
@@ -101,16 +99,18 @@ class QQGOCQ(Platform):
await self.handle_msg(abm)
else:
return
def run(self):
self.client.run()
async def handle_msg(self, message: AstrBotMessage):
self.logger.log(f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}", tag="QQ_GOCQ")
assert isinstance(message.raw_message, (GroupMessage, FriendMessage, GuildMessage))
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
assert isinstance(message.raw_message,
(GroupMessage, FriendMessage, GuildMessage))
is_group = message.type != MessageType.FRIEND_MESSAGE
# 判断是否响应消息
resp = False
if not is_group:
@@ -118,23 +118,24 @@ class QQGOCQ(Platform):
else:
for i in message.message:
if isinstance(i, At):
if message.type == "GuildMessage":
if i.qq == message.raw_message.user_id or i.qq == message.raw_message.self_tiny_id:
if message.type.value == "GuildMessage":
if str(i.qq) == str(message.raw_message.user_id) or str(i.qq) == str(message.raw_message.self_tiny_id):
resp = True
if message.type == "FriendMessage":
if i.qq == message.self_id:
if message.type.value == "FriendMessage":
if str(i.qq) == str(message.self_id):
resp = True
if message.type == "GroupMessage":
if i.qq == message.self_id:
if message.type.value == "GroupMessage":
if str(i.qq) == str(message.self_id):
resp = True
elif isinstance(i, Plain):
for nick in self.nick_qq:
elif isinstance(i, Plain) and self.context.nick:
for nick in self.context.nick:
if nick != '' and i.text.strip().startswith(nick):
resp = True
break
if not resp: return
if not resp:
return
# 解析 session_id
if self.unique_session or not is_group:
session_id = message.raw_message.user_id
@@ -144,13 +145,13 @@ class QQGOCQ(Platform):
session_id = message.raw_message.channel_id
else:
session_id = message.raw_message.user_id
message.session_id = session_id
# 解析 role
sender_id = str(message.raw_message.user_id)
if sender_id == self.cc.get('admin_qq', '') or \
sender_id in self.cc.get('other_admins', []):
sender_id in self.cc.get('other_admins', []):
role = 'admin'
else:
role = 'member'
@@ -167,7 +168,7 @@ class QQGOCQ(Platform):
await self.reply_msg(message, message_result.result_message)
if message_result.callback is not None:
message_result.callback()
# 如果是等待回复的消息
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
@@ -182,14 +183,15 @@ class QQGOCQ(Platform):
source = message.raw_message
else:
source = message
res = result_message
self.logger.log(f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ")
logger.info(
f"{source.user_id} <- {self.parse_message_outline(res)}")
if isinstance(source, int):
source = FakeSource("GroupMessage", source)
# str convert to CQ Message Chain
if isinstance(res, str):
res_str = res
@@ -241,7 +243,7 @@ class QQGOCQ(Platform):
node.name = f"bot"
node.time = int(time.time())
# print(node)
nodes=[node]
nodes = [node]
await self.client.sendGroupForwardMessage(source.group_id, nodes)
return
await self.client.sendGroupMessage(source.group_id, res)
@@ -256,10 +258,10 @@ class QQGOCQ(Platform):
await self.reply_msg(message, result_message)
except BaseException as e:
raise e
async def send(self,
to,
res):
async def send(self,
to,
res):
'''
同 send_msg()
'''
@@ -311,4 +313,3 @@ class QQGOCQ(Platform):
return ret
except BaseException as e:
raise e
+50 -34
View File
@@ -5,25 +5,31 @@ import botpy.message
import re
import asyncio
import aiohttp
import botpy.types
import botpy.types.message
from util import general_utils as gu
from botpy.types.message import Reference
from botpy import Client
import time
from ._platfrom import Platform
from ._message_parse import(
from ._message_parse import (
qq_official_message_parse_rev,
qq_official_message_parse
)
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
from type.message import *
from typing import Union, List
from nakuru.entities.components import BaseMessageComponent
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
# QQ 机器人官方框架
class botClient(Client):
def set_platform(self, platform: 'QQOfficial'):
self.platform = platform
async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
abm = qq_official_message_parse_rev(message, MessageType.GROUP_MESSAGE)
await self.platform.handle_msg(abm)
@@ -37,9 +43,11 @@ class botClient(Client):
# 收到私聊消息
async def on_direct_message_create(self, message: botpy.message.DirectMessage):
# 转换层
abm = qq_official_message_parse_rev(message, MessageType.FRIEND_MESSAGE)
abm = qq_official_message_parse_rev(
message, MessageType.FRIEND_MESSAGE)
await self.platform.handle_msg(abm)
class QQOfficial(Platform):
def __init__(self, cfg: dict, message_handler: callable, global_object) -> None:
@@ -47,7 +55,7 @@ class QQOfficial(Platform):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.waiting: dict = {}
self.cfg = cfg
@@ -55,9 +63,8 @@ class QQOfficial(Platform):
self.token = cfg['qqbot']['token']
self.secret = cfg['qqbot_secret']
self.unique_session = cfg['uniqueSessionMode']
self.logger: gu.Logger = global_object.logger
qq_group = cfg['qqofficial_enable_group_message']
if qq_group:
self.intents = botpy.Intents(
public_messages=True,
@@ -79,7 +86,7 @@ class QQOfficial(Platform):
def run(self):
try:
self.loop.run_until_complete(self.client.run(
appid=self.appid,
appid=self.appid,
secret=self.secret
))
except BaseException as e:
@@ -90,17 +97,19 @@ class QQOfficial(Platform):
)
self.client.set_platform(self)
self.client.run(
appid=self.appid,
appid=self.appid,
token=self.token
)
async def handle_msg(self, message: AstrBotMessage):
assert isinstance(message.raw_message, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage))
assert isinstance(message.raw_message, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage))
is_group = message.type != MessageType.FRIEND_MESSAGE
_t = "/私聊" if not is_group else ""
self.logger.log(f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}", tag="QQ_OFFICIAL")
logger.info(
f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}")
# 解析出 session_id
if self.unique_session or not is_group:
session_id = message.sender.user_id
@@ -116,7 +125,7 @@ class QQOfficial(Platform):
# 解析出 role
sender_id = message.sender.user_id
if sender_id == self.cfg['admin_qqchan'] or \
sender_id in self.cfg['other_admins']:
sender_id in self.cfg['other_admins']:
role = 'admin'
else:
role = 'member'
@@ -139,18 +148,20 @@ class QQOfficial(Platform):
if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message
async def reply_msg(self,
message: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage],
res: Union[str, list]):
async def reply_msg(self,
message: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage],
res: Union[str, list]):
'''
回复频道消息
'''
if isinstance(message, AstrBotMessage):
source = message.raw_message
else:
source = message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage))
self.logger.log(f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL")
source = message
assert isinstance(source, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage))
logger.info(
f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}")
plain_text = ''
image_path = ''
@@ -160,7 +171,7 @@ class QQOfficial(Platform):
plain_text, image_path = qq_official_message_parse(res)
elif isinstance(res, str):
plain_text = res
if self.cfg['qq_pic_mode']:
# 文本转图片,并且加上原来的图片
if plain_text != '' or image_path != '':
@@ -168,7 +179,8 @@ class QQOfficial(Platform):
if image_path.startswith("http"):
plain_text += "\n\n" + "![](" + image_path + ")"
else:
plain_text += "\n\n" + "![](file:///" + image_path + ")"
plain_text += "\n\n" + \
"![](file:///" + image_path + ")"
image_path = gu.create_markdown_image("".join(plain_text))
plain_text = ""
@@ -182,9 +194,10 @@ class QQOfficial(Platform):
image = PILImage.open(io.BytesIO(await response.read()))
image_path = gu.save_temp_img(image)
if source is not None and image_path == '': # file_image与message_reference不能同时传入
msg_ref = Reference(message_id=source.id, ignore_get_message_error=False)
if source is not None and image_path == '': # file_image与message_reference不能同时传入
msg_ref = Reference(message_id=source.id,
ignore_get_message_error=False)
# 到这里,我们得到了 plain_textimage_pathmsg_ref
data = {
'content': plain_text,
@@ -210,7 +223,7 @@ class QQOfficial(Platform):
# 分割过长的消息
if "msg over length" in str(e):
split_res = []
split_res.append(plain_text[:len(plain_text)//2])
split_res.append(plain_text[:len(plain_text)//2])
split_res.append(plain_text[len(plain_text)//2:])
for i in split_res:
data['content'] = i
@@ -227,11 +240,12 @@ class QQOfficial(Platform):
data['content'] = str.join(" ", plain_text)
await self._send_wrapper(**data)
except BaseException as e:
plain_text = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = re.sub(
r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·")
data['content'] = plain_text
await self._send_wrapper(**data)
async def _send_wrapper(self, **kwargs):
if 'group_openid' in kwargs:
# QQ群组消息
@@ -248,27 +262,29 @@ class QQOfficial(Platform):
elif 'channel_id' in kwargs:
# 频道消息
if 'file_image' in kwargs:
kwargs['file_image'] = kwargs['file_image'].replace("file://", "")
kwargs['file_image'] = kwargs['file_image'].replace(
"file://", "")
await self.client.api.post_message(**kwargs)
else:
# 频道私聊消息
if 'file_image' in kwargs:
kwargs['file_image'] = kwargs['file_image'].replace("file://", "")
kwargs['file_image'] = kwargs['file_image'].replace(
"file://", "")
await self.client.api.post_dms(**kwargs)
async def send_msg(self,
message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage],
message_chain: List[BaseMessageComponent],
):
):
'''
发送消息。目前只支持被动回复消息(即拥有一个 botpy Message 类型的 message_obj 传入)
'''
await self.reply_msg(message_obj, message_chain)
async def send(self,
message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage],
message_chain: List[BaseMessageComponent],
):
message_obj: Union[botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, AstrBotMessage],
message_chain: List[BaseMessageComponent],
):
'''
发送消息。目前只支持被动回复消息(即拥有一个 botpy Message 类型的 message_obj 传入)
'''
+424 -331
View File
@@ -5,87 +5,108 @@ import time
import tiktoken
import threading
import traceback
import base64
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from cores.database.conn import dbConn
from persist.session import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from util.general_utils import Logger
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
MODELS = {
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 16385,
"gpt-3.5-turbo-16k-0613": 16385,
}
class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg):
self.cc = CmdConfig()
self.logger = Logger()
def __init__(self, cfg) -> None:
super().__init__()
self.key_list = []
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
for key in cfg['key']:
if len(key) == 1:
raise BaseException("检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
if cfg['key'] != '' and cfg['key'] != None:
self.key_list = cfg['key']
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
self.key_stat = {}
for k in self.key_list:
self.key_stat[k] = {'exceed': False, 'used': 0}
os.makedirs("data/openai", exist_ok=True)
self.cc = CmdConfig
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.keys_data = {} # 记录超额
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
for key in self.api_keys:
self.keys_data[key] = True
self.api_base = None
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
self.api_base = cfg['api_base']
self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI")
# 创建 OpenAI Client
self.client = AsyncOpenAI(
api_key=self.key_list[0],
base_url=self.api_base
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.openai_model_configs: dict = cfg['chatGPTConfigs']
self.logger.log(f'加载 OpenAI Chat Configs: {self.openai_model_configs}', tag="OpenAI")
self.openai_configs = cfg
# 会话缓存
self.session_dict = {}
# 最大缓存token
self.max_tokens = cfg['total_tokens_limit']
# 历史记录持久化间隔时间
self.history_dump_interval = 20
self.enc = tiktoken.get_encoding("cl100k_base")
self.model_configs: Dict = cfg['chatGPTConfigs']
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
self.DEFAULT_PERSONALITY = {
"name": "default",
"prompt": "你是一个很有帮助的 AI 助手。"
}
self.curr_personality = self.DEFAULT_PERSONALITY
self.session_personality = {} # 记录了某个session是否已设置人格。
# 从 SQLite DB 读取历史记录
try:
db1 = dbConn()
for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data']
self.logger.log("读取历史记录成功。", tag="OpenAI")
self.session_memory_lock.acquire()
self.session_memory[session[0]] = json.loads(session[1])['data']
self.session_memory_lock.release()
except BaseException as e:
self.logger.log("读取历史记录失败,但不影响使用。", level=gu.LEVEL_ERROR, tag="OpenAI")
# 创建转储定时器线程
logger.warn(f"读取 OpenAI LLM 对话历史记录 失败{e}。仍可正常使用。")
# 定时保存历史记录
threading.Thread(target=self.dump_history, daemon=True).start()
# 人格
self.curr_personality = {}
# 转储历史记录
def dump_history(self):
'''
转储历史记录
'''
time.sleep(10)
db = dbConn()
while True:
try:
# print("转储历史记录...")
for key in self.session_dict:
data = self.session_dict[key]
for key in self.session_memory:
data = self.session_memory[key]
data_json = {
'data': data
}
@@ -93,310 +114,382 @@ class ProviderOpenAIOfficial(Provider):
db.update_session(key, json.dumps(data_json))
else:
db.insert_session(key, json.dumps(data_json))
# print("转储历史记录完毕")
logger.debug("已保存 OpenAI 会话历史记录")
except BaseException as e:
print(e)
# 每隔10分钟转储一次
time.sleep(10*self.history_dump_interval)
finally:
time.sleep(10*60)
def personality_set(self, default_personality: dict, session_id: str):
if not default_personality: return
if session_id not in self.session_memory:
self.session_memory[session_id] = []
self.curr_personality = default_personality
self.session_personality = {} # 重置
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
if model in MODELS and tokens_num > MODELS[model] - 500:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
new_record = {
"user": {
"role": "user",
"role": "system",
"content": default_personality['prompt'],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
'usage_tokens': 0, # 到该条目的总 token 数
'single-tokens': 0 # 该条目的 token 数
}
self.session_dict[session_id].append(new_record)
async def text_chat(self, prompt,
session_id = None,
image_url = None,
function_call=None,
extra_conf: dict = None,
default_personality: dict = None):
if session_id is None:
session_id = "unknown"
if "unknown" in self.session_dict:
del self.session_dict["unknown"]
# 会话机制
if session_id not in self.session_dict:
self.session_dict[session_id] = []
self.session_memory[session_id].append(new_record)
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("http"):
image_url = await gu.download_image_by_url(image_url)
if len(self.session_dict[session_id]) == 0:
# 设置默认人格
if default_personality is not None:
self.personality_set(default_personality, session_id)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode()
return "data:image/jpeg;base64," + image_bs64
# 使用 tictoken 截断消息
_encoded_prompt = self.enc.encode(prompt)
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
prompt = self.enc.decode(_encoded_prompt[:int(self.openai_model_configs['max_tokens']*0.80)])
self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", level=gu.LEVEL_WARNING, tag="OpenAI")
cache_data_list, new_record, req = self.wrap(prompt, session_id, image_url)
self.logger.log(f"cache: {str(cache_data_list)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
self.logger.log(f"request: {str(req)}", level=gu.LEVEL_DEBUG, tag="OpenAI")
retry = 0
response = None
err = ''
# 截断倍率
truncate_rate = 0.75
use_gpt4v = False
for i in req:
if isinstance(i['content'], list):
use_gpt4v = True
break
if image_url is not None:
use_gpt4v = True
if use_gpt4v:
conf = self.openai_model_configs.copy()
conf['model'] = 'gpt-4-vision-preview'
else:
conf = self.openai_model_configs
if extra_conf is not None:
conf.update(extra_conf)
while retry < 10:
try:
if function_call is None:
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = await self.client.chat.completions.create(
messages=req,
tools = function_call,
**conf
)
break
except Exception as e:
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
self.logger.log("当前 Key 已超额或异常, 正在切换", level=gu.LEVEL_WARNING, tag="OpenAI")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
retry -= 1
elif 'maximum context length' in str(e):
self.logger.log("token 超限, 清空对应缓存,并进行消息截断", tag="OpenAI")
self.session_dict[session_id] = []
prompt = prompt[:int(len(prompt)*truncate_rate)]
truncate_rate -= 0.05
cache_data_list, new_record, req = self.wrap(prompt, session_id)
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
time.sleep(30)
async def retrieve_context(self, session_id: str):
'''
根据 session_id 获取保存的 OpenAI 格式的上下文
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
# 转换为 openai 要求的格式
context = []
is_lvm = await self.is_lvm()
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
continue
else:
self.logger.log(str(e), level=gu.LEVEL_ERROR, tag="OpenAI")
time.sleep(2)
err = str(e)
retry += 1
if retry >= 10:
self.logger.log(r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki", tag="OpenAI")
raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
self.logger.log(f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, tag="OpenAI")
context.append(record['user'])
if "AI" in record and record['AI']:
context.append(record['AI'])
# 结果分类
choice = response.choices[0]
if choice.message.content != None:
# 文本形式
chatgpt_res = str(choice.message.content).strip()
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
return context
async def is_lvm(self):
'''
是否是 LVM
'''
return self.model_configs['model'].startswith("gpt-4")
async def get_models(self):
'''
获取所有模型
'''
models = await self.client.models.list()
logger.info(f"OpenAI 模型列表:{models}")
return models
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
'''
组装上下文,并且根据当前上下文窗口大小截断
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
tokens_num = len(self.tokenizer.encode(prompt))
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
message = {
"usage_tokens": previous_total_tokens_num + tokens_num,
"single_tokens": tokens_num,
"AI": None
}
if image_url:
user_content = {
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": await self.encode_image_bs64(image_url)
}
}
]
}
else:
user_content = {
"role": "user",
"content": prompt
}
message["user"] = user_content
self.session_memory[session_id].append(message)
# 根据 模型的上下文窗口 淘汰掉多余的记录
curr_model = self.model_configs['model']
if curr_model in MODELS:
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
# if message['usage_tokens'] > maxium_tokens_num:
# 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300
# contexts = self.session_memory[session_id]
# need_to_remove_idx = 0
# freed_tokens_num = contexts[0]['single-tokens']
# while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num:
# need_to_remove_idx += 1
# freed_tokens_num += contexts[need_to_remove_idx]['single-tokens']
# # 更新之后的所有记录的 usage_tokens
# for i in range(len(contexts)):
# if i > need_to_remove_idx:
# contexts[i]['usage_tokens'] -= freed_tokens_num
# logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。")
# self.session_memory[session_id] = contexts[need_to_remove_idx+1:]
while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num:
self.pop_record(session_id)
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
# 更新之后所有记录的 usage_tokens
for i in range(len(self.session_memory[session_id])):
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}")
return record
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None=None,
tools: None=None,
extra_conf: Dict = None,
**kwargs
) -> str:
if not session_id:
session_id = "unknown"
if "unknown" in self.session_memory:
del self.session_memory["unknown"]
if session_id not in self.session_memory:
self.session_memory[session_id] = []
if session_id not in self.session_personality or not self.session_personality[session_id]:
self.personality_set(self.curr_personality, session_id)
self.session_personality[session_id] = True
# 如果 prompt 超过了最大窗口,截断。
# 1. 可以保证之后 pop 的时候不会出现问题
# 2. 可以保证不会超过最大 token 数
_encoded_prompt = self.tokenizer.encode(prompt)
curr_model = self.model_configs['model']
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
prompt = self.tokenizer.decode(_encoded_prompt)
# 组装上下文,并且根据当前上下文窗口大小截断
await self.assemble_context(session_id, prompt, image_url)
# 获取上下文,openai 格式
contexts = await self.retrieve_context(session_id)
conf = self.model_configs
if extra_conf: conf.update(extra_conf)
# start request
retry = 0
rate_limit_retry = 0
while retry < 3 or rate_limit_retry < 5:
logger.debug(conf)
logger.debug(contexts)
if tools:
completion_coro = self.client.chat.completions.create(
messages=contexts,
tools=tools,
**conf
)
else:
completion_coro = self.client.chat.completions.create(
messages=contexts,
**conf
)
try:
completion = await completion_coro
break
except AuthenticationError as e:
api_key = self.chosen_api_key[10:] + "..."
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)")
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except BadRequestError as e:
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
retry += 1
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
time.sleep(1)
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
if "maximum context length" in str(e):
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens
completion_tokens = completion.usage.completion_tokens
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
if choice.message.content:
# 返回文本
completion_text = str(choice.message.content).strip()
elif choice.message.tool_calls and choice.message.tool_calls:
# tools call (function calling)
return choice.message.tool_calls[0].function
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
current_usage_tokens = response.usage.total_tokens
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
if current_usage_tokens > self.max_tokens:
t = current_usage_tokens
index = 0
while t > self.max_tokens:
if index >= len(cache_data_list):
break
# 保留人格信息
if cache_data_list[index]['type'] != 'personality':
t -= int(cache_data_list[index]['single_tokens'])
del cache_data_list[index]
else:
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# 添加新条目进入缓存的prompt
new_record['AI'] = {
'role': 'assistant',
'content': chatgpt_res,
self.session_memory[session_id][-1]['AI'] = {
"role": "assistant",
"content": completion_text
}
new_record['usage_tokens'] = current_usage_tokens
if len(cache_data_list) > 0:
new_record['single_tokens'] = current_usage_tokens - int(cache_data_list[-1]['usage_tokens'])
else:
new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record)
self.session_dict[session_id] = cache_data_list
return chatgpt_res
async def image_chat(self, prompt, img_num = 1, img_size = "1024x1024"):
retry = 0
image_url = ''
image_generate_configs = self.cc.get("openai_image_generate", None)
while retry < 5:
try:
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
image_url = []
for i in range(img_num):
image_url.append(response.data[i].url)
break
except Exception as e:
self.logger.log(str(e), level=gu.LEVEL_ERROR)
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
self.logger.log("当前 Key 已超额或者不正常, 正在切换", level=gu.LEVEL_WARNING, tag="OpenAI")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试", level=gu.LEVEL_WARNING, tag="OpenAI")
raise e
else:
retry += 1
if retry >= 5:
raise BaseException("连接超时")
return image_url
async def forget(self, session_id = None) -> bool:
if session_id is None:
return completion_text
async def switch_to_next_key(self):
'''
切换到下一个 API Key
'''
if not self.api_keys:
logger.error("OpenAI API Key 不存在。")
return False
self.session_dict[session_id] = []
for key in self.keys_data:
if self.keys_data[key]:
# 没超额
self.chosen_api_key = key
self.client.api_key = key
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
return True
return False
async def image_generate(self, prompt: str, session_id: str = None, **kwargs) -> str:
'''
生成图片
'''
retry = 0
conf = self.image_generator_model_configs
if not conf:
logger.error("OpenAI 图片生成模型配置不存在。")
raise Exception("OpenAI 图片生成模型配置不存在。")
while retry < 3:
try:
images_response = await self.client.images.generate(
prompt=prompt,
**conf
)
image_url = images_response.data[0].url
return image_url
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
self.session_memory[session_id] = []
if keep_system_prompt:
self.personality_set(self.curr_personality, session_id)
else:
self.curr_personality = self.DEFAULT_PERSONALITY
return True
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
def dump_contexts_page(self, session_id: str, size=5, page=1,):
'''
获取缓存的会话
'''
prompts = ""
if paging:
page_begin = (page-1)*size
page_end = page*size
if page_begin < 0:
page_begin = 0
if page_end > len(cache_data_list):
page_end = len(cache_data_list)
cache_data_list = cache_data_list[page_begin:page_end]
for item in cache_data_list:
prompts += str(item['user']['role']) + ":\n" + str(item['user']['content']) + "\n"
prompts += str(item['AI']['role']) + ":\n" + str(item['AI']['content']) + "\n"
# contexts_str = ""
# for i, key in enumerate(self.session_memory):
# if i < (page-1)*size or i >= page*size:
# continue
# contexts_str += f"Session ID: {key}\n"
# for record in self.session_memory[key]:
# if "user" in record:
# contexts_str += f"User: {record['user']['content']}\n"
# if "AI" in record:
# contexts_str += f"AI: {record['AI']['content']}\n"
# contexts_str += "---\n"
contexts_str = ""
if session_id in self.session_memory:
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
contexts_str += f"User: {text}\n"
if "AI" in record and record['AI']:
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
contexts_str += f"Assistant: {text}\n"
else:
contexts_str = "会话 ID 不存在。"
if divide:
prompts += "----------\n"
return prompts
return contexts_str, len(self.session_memory[session_id])
def wrap(self, prompt, session_id, image_url = None):
if image_url is not None:
prompt = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# 获得缓存信息
context = self.session_dict[session_id]
new_record = {
"user": {
"role": "user",
"content": prompt,
},
"AI": {},
'type': "common",
'usage_tokens': 0,
}
req_list = []
for i in context:
if 'user' in i:
req_list.append(i['user'])
if 'AI' in i:
req_list.append(i['AI'])
req_list.append(new_record['user'])
return context, new_record, req_list
def set_model(self, model: str):
self.model_configs['model'] = model
def handle_switch_key(self):
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
continue
is_all_exceed = False
self.client.api_key = key
self.logger.log(f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})", level=gu.LEVEL_INFO, tag="OpenAI")
break
if is_all_exceed:
self.logger.log("所有 Key 已超额", level=gu.LEVEL_CRITICAL, tag="OpenAI")
return False
return True
def get_configs(self):
return self.openai_configs
def get_key_stat(self):
return self.key_stat
def get_key_list(self):
return self.key_list
def get_curr_key(self):
return self.client.api_key
def set_key(self, key):
self.client.api_key = key
# 添加key
def append_key(self, key, sponsor):
self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
return self.model_configs
# 检查key是否可用
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
return True
def get_keys_data(self):
return self.keys_data
def get_curr_key(self):
return self.chosen_api_key
def set_key(self, key):
self.client.api_key = key
+10 -10
View File
@@ -1,9 +1,9 @@
class Provider:
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None,
function_call: None,
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None = None,
tools: None = None,
extra_conf: dict = None,
default_personality: dict = None,
**kwargs) -> str:
@@ -11,15 +11,15 @@ class Provider:
[require]
prompt: 提示词
session_id: 会话id
[optional]
image_url: 图片url(识图)
function_call: 函数调用
tools: 函数调用工具
extra_conf: 额外配置
default_personality: 默认人格
'''
raise NotImplementedError
async def image_generate(self, prompt, session_id, **kwargs) -> str:
'''
[require]
@@ -28,8 +28,8 @@ class Provider:
'''
raise NotImplementedError
async def forget(self, session_id = None) -> bool:
async def forget(self, session_id=None) -> bool:
'''
重置会话
'''
raise NotImplementedError
raise NotImplementedError
-224
View File
@@ -1,224 +0,0 @@
from revChatGPT.V1 import Chatbot
from revChatGPT import typings
from model.provider.provider import Provider
from util import general_utils as gu
from util import cmd_config as cc
import time
class ProviderRevChatGPT(Provider):
def __init__(self, config, base_url = None):
if base_url == "":
base_url = None
self.rev_chatgpt: list[dict] = []
self.cc = cc.CmdConfig()
for i in range(0, len(config['account'])):
try:
gu.log(f"创建逆向ChatGPT负载{str(i+1)}中...", level=gu.LEVEL_INFO, tag="RevChatGPT")
if isinstance(config['account'][i], str):
# 默认是 access_token
rev_account_config = {
'access_token': config['account'][i],
}
else:
if 'password' in config['account'][i]:
gu.log(f"创建逆向ChatGPT负载{str(i+1)}失败: 已不支持账号密码登录,请使用access_token方式登录。", level=gu.LEVEL_ERROR, tag="RevChatGPT")
continue
rev_account_config = {
'access_token': config['account'][i]['access_token'],
}
if self.cc.get("rev_chatgpt_model") != "":
rev_account_config['model'] = self.cc.get("rev_chatgpt_model")
if len(self.cc.get("rev_chatgpt_plugin_ids")) > 0:
rev_account_config['plugin_ids'] = self.cc.get("rev_chatgpt_plugin_ids")
if self.cc.get("rev_chatgpt_PUID") != "":
rev_account_config['PUID'] = self.cc.get("rev_chatgpt_PUID")
if len(self.cc.get("rev_chatgpt_unverified_plugin_domains")) > 0:
rev_account_config['unverified_plugin_domains'] = self.cc.get("rev_chatgpt_unverified_plugin_domains")
cb = Chatbot(config=rev_account_config, base_url=base_url)
# cb.captcha_solver = self.__captcha_solver
# 后八位c
g_id = rev_account_config['access_token'][-8:]
revstat = {
'id': g_id,
'obj': cb,
'busy': False,
'user': []
}
self.rev_chatgpt.append(revstat)
except BaseException as e:
gu.log(f"创建逆向ChatGPT负载{str(i+1)}失败: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT")
def forget(self, session_id = None) -> bool:
for i in self.rev_chatgpt:
for user in i['user']:
if session_id == user['id']:
try:
i['obj'].reset_chat()
return True
except BaseException as e:
gu.log(f"重置RevChatGPT失败。原因: {str(e)}", level=gu.LEVEL_ERROR, tag="RevChatGPT")
return False
return False
def get_revchatgpt(self) -> list:
return self.rev_chatgpt
def request_text(self, prompt: str, bot) -> str:
resp = ''
err_count = 0
retry_count = 5
while err_count < retry_count:
try:
for data in bot.ask(prompt):
resp = data["message"]
break
except typings.Error as e:
if e.code == typings.ErrorType.INVALID_ACCESS_TOKEN_ERROR:
raise e
if e.code == typings.ErrorType.EXPIRED_ACCESS_TOKEN_ERROR:
raise e
if e.code == typings.ErrorType.PROHIBITED_CONCURRENT_QUERY_ERROR:
raise e
if "Your authentication token has expired. Please try signing in again." in str(e):
raise e
if "The message you submitted was too long" in str(e):
raise e
if "You've reached our limit of messages per hour." in str(e):
raise e
if "Rate limited by proxy" in str(e):
gu.log(f"触发请求频率限制, 60秒后自动重试。", level=gu.LEVEL_WARNING, tag="RevChatGPT")
time.sleep(60)
err_count += 1
gu.log(f"请求异常: {str(e)},正在重试。({str(err_count)})", level=gu.LEVEL_WARNING, tag="RevChatGPT")
if err_count >= retry_count:
raise e
except BaseException as e:
err_count += 1
gu.log(f"请求异常: {str(e)},正在重试。({str(err_count)})", level=gu.LEVEL_WARNING, tag="RevChatGPT")
if err_count >= retry_count:
raise e
if resp == '':
resp = "RevChatGPT请求异常。"
# print("[RevChatGPT] "+str(resp))
return resp
def text_chat(self, prompt,
session_id = None,
image_url = None,
function_call=None,
extra_conf: dict = None,
default_personality: dict = None) -> str:
# 选择一个人少的账号。
selected_revstat = None
min_revstat = None
min_ = None
new_user = False
conversation_id = ''
parent_id = ''
for revstat in self.rev_chatgpt:
for user in revstat['user']:
if session_id == user['id']:
selected_revstat = revstat
conversation_id = user['conversation_id']
parent_id = user['parent_id']
break
if min_ is None:
min_ = len(revstat['user'])
min_revstat = revstat
elif len(revstat['user']) < min_:
min_ = len(revstat['user'])
min_revstat = revstat
# if session_id in revstat['user']:
# selected_revstat = revstat
# break
if selected_revstat is None:
selected_revstat = min_revstat
selected_revstat['user'].append({
'id': session_id,
'conversation_id': '',
'parent_id': ''
})
new_user = True
gu.log(f"选择账号{str(selected_revstat)}", tag="RevChatGPT", level=gu.LEVEL_DEBUG)
while selected_revstat['busy']:
gu.log(f"账号忙碌,等待中...", tag="RevChatGPT", level=gu.LEVEL_DEBUG)
time.sleep(1)
selected_revstat['busy'] = True
if not new_user:
# 非新用户,则使用其专用的会话
selected_revstat['obj'].conversation_id = conversation_id
selected_revstat['obj'].parent_id = parent_id
else:
# 新用户,则使用新的会话
selected_revstat['obj'].reset_chat()
res = ''
err_msg = ''
err_cnt = 0
while err_cnt < 15:
try:
res = self.request_text(prompt, selected_revstat['obj'])
selected_revstat['busy'] = False
# 记录新用户的会话
if new_user:
i = 0
for user in selected_revstat['user']:
if user['id'] == session_id:
selected_revstat['user'][i]['conversation_id'] = selected_revstat['obj'].conversation_id
selected_revstat['user'][i]['parent_id'] = selected_revstat['obj'].parent_id
break
i += 1
return res.strip()
except BaseException as e:
if "Your authentication token has expired. Please try signing in again." in str(e):
raise Exception(f"此账号(access_token后8位为{selected_revstat['id']})的access_token已过期,请重新获取,或者切换账号。")
if "The message you submitted was too long" in str(e):
raise Exception("发送的消息太长,请分段发送。")
if "You've reached our limit of messages per hour." in str(e):
raise Exception("触发RevChatGPT请求频率限制。请1小时后再试,或者切换账号。")
gu.log(f"请求异常: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
err_cnt += 1
time.sleep(3)
raise Exception(f'回复失败。原因:{err_msg}。如果您设置了多个账号,可以使用/switch指令切换账号。输入/switch查看详情。')
# while self.is_all_busy():
# time.sleep(1)
# res = ''
# err_msg = ''
# cursor = 0
# for revstat in self.rev_chatgpt:
# cursor += 1
# if not revstat['busy']:
# try:
# revstat['busy'] = True
# res = self.request_text(prompt, revstat['obj'])
# revstat['busy'] = False
# return res.strip()
# # todo: 细化错误管理
# except BaseException as e:
# revstat['busy'] = False
# gu.log(f"请求出现问题: {str(e)}", level=gu.LEVEL_WARNING, tag="RevChatGPT")
# err_msg += f"账号{cursor} - 错误原因: {str(e)}"
# continue
# else:
# err_msg += f"账号{cursor} - 错误原因: 忙碌"
# continue
# raise Exception(f'回复失败。错误跟踪:{err_msg}')
def is_all_busy(self) -> bool:
for revstat in self.rev_chatgpt:
if not revstat['busy']:
return False
return True
+18 -16
View File
@@ -1,13 +1,17 @@
import sqlite3
import yaml
import os
import shutil
import time
from typing import Tuple
class dbConn():
def __init__(self):
# 读取参数,并支持中文
conn = sqlite3.connect("data.db")
conn.text_factory=str
db_path = "data/data.db"
if os.path.exists("data.db"):
shutil.copy("data.db", db_path)
conn = sqlite3.connect(db_path)
conn.text_factory = str
self.conn = conn
c = conn.cursor()
c.execute(
@@ -44,7 +48,7 @@ class dbConn():
);
'''
)
conn.commit()
def insert_session(self, qq_id, history):
@@ -76,7 +80,7 @@ class dbConn():
''', (qq_id, )
)
return c.fetchone()
def get_all_session(self):
conn = self.conn
c = conn.cursor()
@@ -86,7 +90,7 @@ class dbConn():
'''
)
return c.fetchall()
def check_session(self, qq_id):
conn = self.conn
c = conn.cursor()
@@ -107,7 +111,6 @@ class dbConn():
)
conn.commit()
def increment_stat_session(self, platform, session_id, cnt):
# if not exist, insert
conn = self.conn
@@ -137,7 +140,7 @@ class dbConn():
''', (platform, session_id)
)
return c.fetchone() is not None
def get_all_stat_session(self):
conn = self.conn
c = conn.cursor()
@@ -147,7 +150,7 @@ class dbConn():
'''
)
return c.fetchall()
def get_session_cnt_total(self):
conn = self.conn
c = conn.cursor()
@@ -157,7 +160,7 @@ class dbConn():
'''
)
return c.fetchone()[0]
def increment_stat_message(self, ts, cnt):
# 以一个小时为单位。ts的单位是秒。
# 找到最近的一个小时,如果没有,就插入
@@ -197,7 +200,7 @@ class dbConn():
return True, ts
else:
return False, ts
def get_last_24h_stat_message(self):
# 获取最近24小时的消息统计
conn = self.conn
@@ -208,7 +211,7 @@ class dbConn():
''', (time.time() - 86400, )
)
return c.fetchall()
def get_message_cnt_total(self) -> int:
conn = self.conn
c = conn.cursor()
@@ -258,7 +261,7 @@ class dbConn():
return True, ts
else:
return False, ts
def get_last_24h_stat_platform(self):
# 获取最近24小时的消息统计
conn = self.conn
@@ -269,7 +272,7 @@ class dbConn():
''', (time.time() - 86400, )
)
return c.fetchall()
def get_platform_cnt_total(self) -> int:
conn = self.conn
c = conn.cursor()
@@ -291,4 +294,3 @@ class dbConn():
def close(self):
self.conn.close()
+5 -4
View File
@@ -4,15 +4,16 @@ requests
openai~=1.2.3
qq-botpy
chardet~=5.1.0
Pillow~=9.4.0
GitPython~=3.1.31
Pillow
GitPython
nakuru-project
beautifulsoup4
googlesearch-python
tiktoken
readability-lxml
revChatGPT~=6.8.6
baidu-aip~=4.16.9
baidu-aip
websockets
flask
psutil
lxml_html_clean
SparkleLogging
+28
View File
@@ -0,0 +1,28 @@
from typing import Union, List, Callable
from dataclasses import dataclass
@dataclass
class CommandItem():
'''
用来描述单个指令
'''
command_name: Union[str, tuple] # 指令名
callback: Callable # 回调函数
description: str # 描述
origin: str # 注册来源
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+62
View File
@@ -0,0 +1,62 @@
from enum import Enum
from typing import List
from dataclasses import dataclass
from nakuru.entities.components import BaseMessageComponent
from type.register import RegisteredPlatform
from type.types import GlobalObject
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
GUILD_MESSAGE = 'GuildMessage' # 频道消息
@dataclass
class MessageMember():
user_id: str # 发送者id
nickname: str = None
class AstrBotMessage():
'''
AstrBot 的消息对象
'''
tag: str # 消息来源标签
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
message_id: str # 消息id
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __str__(self) -> str:
return str(self.__dict__)
class AstrMessageEvent():
'''
消息事件。
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: AstrBotMessage # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
+27
View File
@@ -0,0 +1,27 @@
from enum import Enum
from dataclasses import dataclass
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
+46
View File
@@ -0,0 +1,46 @@
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from type.plugin import *
from typing import List
from types import ModuleType
from dataclasses import dataclass
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
def __str__(self) -> str:
return self.platform_name
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
+34
View File
@@ -0,0 +1,34 @@
from type.register import *
from typing import List
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: tuple # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
@@ -3,30 +3,35 @@ import json
import util.general_utils as gu
import time
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
class FuncNotFoundError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
class FuncCall():
def __init__(self, provider) -> None:
self.func_list = []
self.provider = provider
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj = None) -> None:
def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None:
if name == None or func_args == None or desc == None or func_obj == None:
raise FuncCallJsonFormatError("name, func_args, desc must be provided.")
raise FuncCallJsonFormatError(
"name, func_args, desc must be provided.")
params = {
"type": "object", # hardcore here
"type": "object", # hardcore here
"properties": {}
}
for param in func_args:
@@ -51,7 +56,7 @@ class FuncCall():
"description": f["description"],
})
return json.dumps(_l, indent=intent, ensur_ascii=False)
def get_func(self) -> list:
_l = []
for f in self.func_list:
@@ -64,8 +69,8 @@ class FuncCall():
}
})
return _l
def func_call(self, question, func_definition, is_task = False, tasks = None, taskindex = -1, is_summary = True, session_id = None):
def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None):
funccall_prompt = """
我正实现function call功能该功能旨在让你变成给定的问题到给定的函数的解析器意味着你不是创造函数
@@ -120,7 +125,8 @@ class FuncCall():
res = self.provider.text_chat(prompt, session_id)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
gu.log("REVGPT func_call json result",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
print(res)
res = json.loads(res)
break
@@ -151,11 +157,13 @@ class FuncCall():
func_target = func["func_obj"]
break
if func_target == None:
raise FuncNotFoundError(f"Request function {func_name} not found.")
raise FuncNotFoundError(
f"Request function {func_name} not found.")
t_res = str(func_target(**args))
invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n"
invoke_func_res_list.append(invoke_func_res)
gu.log(f"[FUNC| {func_name} invoked]", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
gu.log(f"[FUNC| {func_name} invoked]",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
# print(str(t_res))
if is_summary:
@@ -181,12 +189,16 @@ class FuncCall():
try:
res = self.provider.text_chat(after_prompt, session_id)
# 截取```之间的内容
gu.log("DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
gu.log(
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
print(res)
gu.log("DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
gu.log(
"DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT after_func_call json result", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
res = res[res.find('```json') +
7: res.rfind('```')]
gu.log("REVGPT after_func_call json result",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
after_prompt_res = res
after_prompt_res = json.loads(after_prompt_res)
break
@@ -197,7 +209,8 @@ class FuncCall():
if "The message you submitted was too long" in str(e):
# 如果返回的内容太长了,那么就截取一部分
time.sleep(3)
invoke_func_res = invoke_func_res[:int(len(invoke_func_res) / 2)]
invoke_func_res = invoke_func_res[:int(
len(invoke_func_res) / 2)]
after_prompt = """
函数返回以下内容"""+invoke_func_res+"""
请以AI助手的身份结合返回的内容对用户提问做详细全面的回答
@@ -218,11 +231,13 @@ class FuncCall():
if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]:
# 如果需要重新调用函数
# 重新调用函数
gu.log("REVGPT func_call_again", bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"])
gu.log("REVGPT func_call_again",
bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"])
res = self.func_call(question, func_definition)
return res, True
gu.log("REVGPT func callback:", bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
gu.log("REVGPT func callback:",
bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
# print(after_prompt_res["res"])
return after_prompt_res["res"], True
else:
@@ -230,8 +245,3 @@ class FuncCall():
else:
# print(res["res"])
return res["res"], False
+183
View File
@@ -0,0 +1,183 @@
import traceback
import random
import json
import asyncio
import aiohttp
import os
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.agent.func_call import FuncCall
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
from util.search_engine_scraper.bing import Bing
from util.search_engine_scraper.sogo import Sogo
from util.search_engine_scraper.google import Google
from model.provider.provider import Provider
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
bing_search = Bing()
sogo_search = Sogo()
google = Google()
proxy = os.environ.get("HTTPS_PROXY", None)
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
# def special_fetch_zhihu(link: str) -> str:
# '''
# function-calling 函数, 用于获取知乎文章的内容
# '''
# response = requests.get(link, headers=HEADERS)
# response.encoding = "utf-8"
# soup = BeautifulSoup(response.text, "html.parser")
# if "zhuanlan.zhihu.com" in link:
# r = soup.find(class_="Post-RichTextContainer")
# else:
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
# if r is None:
# print("debug: zhihu none")
# raise Exception("zhihu none")
# return tidy_text(r.text)
async def search_from_bing(keyword: str) -> str:
'''
tools, 从 bing 搜索引擎搜索
'''
logger.info("web_searcher - search_from_bing: " + keyword)
results = []
try:
results = await google.search(keyword, 5)
except BaseException as e:
logger.error(f"google search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search google failed")
try:
results = await bing_search.search(keyword, 5)
except BaseException as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await sogo_search.search(keyword, 5)
except BaseException as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
logger.debug("search sogo failed")
return "没有搜索到结果"
ret = ""
idx = 1
for i in results:
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
try:
site_result = await fetch_website_content(i.url)
except:
site_result = ""
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
return ret
async def fetch_website_content(url):
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=HEADERS, timeout=6, proxy=proxy) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(prompt, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("web_search", [{
"type": "string",
"name": "keyword",
"description": "搜索关键词"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
search_from_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "要获取内容的网页链接"
}],
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func())
if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}")
func_obj = None
for i in new_func_call.func_list:
if i["name"] == result.name:
func_obj = i["func_obj"]
break
if not func_obj:
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
return result
else:
# we use our own function-calling
try:
args = {
'question': prompt,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
await provider.forget(session_id)
summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
2. 简单地发表你对这个问题的简略看法。
# 例子
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
# 限制
1. 限制在 200 字以内;
2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(summary_prompt, session_id)
return ret
return function_invoked_ret
+10 -12
View File
@@ -2,7 +2,7 @@ import os
import json
from typing import Union
cpath = "cmd_config.json"
cpath = "data/cmd_config.json"
def check_exist():
if not os.path.exists(cpath):
@@ -10,6 +10,7 @@ def check_exist():
json.dump({}, f, indent=4, ensure_ascii=False)
f.flush()
class CmdConfig():
@staticmethod
@@ -21,13 +22,13 @@ class CmdConfig():
return d[key]
else:
return default
@staticmethod
def get_all():
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
return json.load(f)
@staticmethod
def put(key, value):
check_exist()
@@ -37,7 +38,7 @@ class CmdConfig():
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False)
f.flush()
@staticmethod
def put_by_dot_str(key: str, value):
'''
@@ -58,11 +59,11 @@ class CmdConfig():
f.flush()
@staticmethod
def init_attributes(key: Union[str, list], init_val = ""):
def init_attributes(key: Union[str, list], init_val=""):
check_exist()
conf_str = ''
with open(cpath, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
conf_str = f.read()
if conf_str.startswith(u'/ufeff'):
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
d = json.loads(conf_str)
@@ -82,16 +83,13 @@ class CmdConfig():
json.dump(d, f, indent=4, ensure_ascii=False)
f.flush()
def init_astrbot_config_items():
# 加载默认配置
cc = CmdConfig()
cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes("qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n")
cc.init_attributes("qq_welcome", "")
cc.init_attributes("qq_pic_mode", False)
cc.init_attributes("rev_chatgpt_model", "")
cc.init_attributes("rev_chatgpt_plugin_ids", [])
cc.init_attributes("rev_chatgpt_PUID", "")
cc.init_attributes("rev_chatgpt_unverified_plugin_domains", [])
cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700)
cc.init_attributes("gocq_websocket_port", 6700)
@@ -118,4 +116,4 @@ def init_astrbot_config_items():
cc.init_attributes("http_proxy", "")
cc.init_attributes("https_proxy", "")
cc.init_attributes("dashboard_username", "")
cc.init_attributes("dashboard_password", "")
cc.init_attributes("dashboard_password", "")
-290
View File
@@ -1,290 +0,0 @@
import requests
import util.general_utils as gu
import traceback
import time
import json
import asyncio
from googlesearch import search, SearchResult
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.function_calling.func_call import (
FuncCall,
FuncCallJsonFormatError,
FuncNotFoundError
)
from model.provider.provider import Provider
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def special_fetch_zhihu(link: str) -> str:
'''
function-calling 函数, 用于获取知乎文章的内容
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(link, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
if "zhuanlan.zhihu.com" in link:
r = soup.find(class_="Post-RichTextContainer")
else:
r = soup.find(class_="List-item").find(class_="RichContent-inner")
if r is None:
print("debug: zhihu none")
raise Exception("zhihu none")
return tidy_text(r.text)
def google_web_search(keyword) -> str:
'''
获取 google 搜索结果, 得到 title、desc、link
'''
ret = ""
index = 1
try:
ls = search(keyword, advanced=True, num_results=4)
for i in ls:
desc = i.description
try:
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(i.url)
except BaseException as e:
print(f"(google) fetch_website_content err: {str(e)}")
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
index += 1
except Exception as e:
print(f"google search err: {str(e)}")
return web_keyword_search_via_bing(keyword)
return ret
def web_keyword_search_via_bing(keyword) -> str:
'''
获取bing搜索结果, 得到 title、desc、link
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
url = "https://www.bing.com/search?q="+keyword
_cnt = 0
# _detail_store = []
while _cnt < 5:
try:
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
soup = BeautifulSoup(response.text, "html.parser")
res = ""
result_cnt = 0
ols = soup.find(id="b_results")
for i in ols.find_all("li", class_="b_algo"):
try:
title = i.find("h2").text
desc = i.find("p").text
link = i.find("h2").find("a").get("href")
# res.append({
# "title": title,
# "desc": desc,
# "link": link,
# })
try:
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(link)
except BaseException as e:
print(f"(bing) fetch_website_content err: {str(e)}")
res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n"
result_cnt += 1
if result_cnt > 5: break
# if len(_detail_store) >= 3:
# continue
# # 爬取前两条的网页内容
# if "zhihu.com" in link:
# try:
# _detail_store.append(special_fetch_zhihu(link))
# except BaseException as e:
# print(f"zhihu parse err: {str(e)}")
# else:
# try:
# _detail_store.append(fetch_website_content(link))
# except BaseException as e:
# print(f"fetch_website_content err: {str(e)}")
except Exception as e:
print(f"bing parse err: {str(e)}")
if result_cnt == 0: break
return res
except Exception as e:
# gu.log(f"bing fetch err: {str(e)}")
_cnt += 1
time.sleep(1)
# gu.log("fail to fetch bing info, using sougou.")
return web_keyword_search_via_sougou(keyword)
def web_keyword_search_via_sougou(keyword) -> str:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
}
url = f"https://sogou.com/web?query={keyword}"
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
res = []
results = soup.find("div", class_="results")
for i in results.find_all("div", class_="vrwrap"):
try:
title = tidy_text(i.find("h3").text)
link = tidy_text(i.find("h3").find("a").get("href"))
if link.startswith("/link?url="):
link = "https://www.sogou.com" + link
res.append({
"title": title,
"link": link,
})
if len(res) >= 5: # 限制5条
break
except Exception as e:
pass
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
for i in res:
if _detail_store >= 3:
break
try:
_detail_store.append(fetch_website_content(i["link"]))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
ret = f"{str(res)}"
if len(_detail_store) > 0:
ret += f"\n网页内容: {str(_detail_store)}"
return ret
def fetch_website_content(url):
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(url, headers=headers, timeout=3)
response.encoding = "utf-8"
doc = Document(response.content)
# print('title:', doc.title())
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("google_web_search", [{
"type": "string",
"name": "keyword",
"description": "google search query (分词,尽量保留所有信息)"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
web_keyword_search_via_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "网址"
}],
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
if i["name"] == func.name:
func_obj = i["func_obj"]
break
if not func_obj:
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
# we use to_thread to avoid blocking the event loop
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
# we use our own function-calling
try:
args = {
'question': question1,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
await provider.forget(session_id)
question3 = f"""
你的任务是:
1. 根据末尾的材料对问题`{question}`做切题的总结(详细);
2. 简单地发表你对这个问题的看法(简略)。
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。
不要提到任何函数调用的信息。
一些回复的消息模板:
模板1:
```
从网上的信息来看,可以知道...我个人认为...你觉得呢?
```
模板2:
```
根据网上的最新信息,可以得知...我觉得...你怎么看?
```
你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。
以下是相关材料:
"""
_c = 0
while _c < 3:
try:
print('text chat')
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
return final_ret
except Exception as e:
print(e)
_c += 1
if _c == 3: raise e
if "The message you submitted was too long" in str(e):
await provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(len(function_invoked_ret) / 2)]
time.sleep(3)
return function_invoked_ret
+129 -175
View File
@@ -1,143 +1,26 @@
import datetime
import time
import socket
from PIL import Image, ImageDraw, ImageFont
import os
import re
import requests
from util.cmd_config import CmdConfig
import aiohttp
import socket
from cores.astrbot.types import GlobalObject
import platform
import logging
import json
import sys
import psutil
import ssl
PLATFORM_GOCQ = 'gocq'
PLATFORM_QQCHAN = 'qqchan'
from PIL import Image, ImageDraw, ImageFont
from type.types import GlobalObject
from SparkleLogging.utils.core import LogManager
from logging import Logger
FG_COLORS = {
"black": "30",
"red": "31",
"green": "32",
"yellow": "33",
"blue": "34",
"purple": "35",
"cyan": "36",
"white": "37",
"default": "39",
}
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
BG_COLORS = {
"black": "40",
"red": "41",
"green": "42",
"yellow": "43",
"blue": "44",
"purple": "45",
"cyan": "46",
"white": "47",
"default": "49",
}
LEVEL_DEBUG = "DEBUG"
LEVEL_INFO = "INFO"
LEVEL_WARNING = "WARN"
LEVEL_ERROR = "ERROR"
LEVEL_CRITICAL = "CRITICAL"
# 为了兼容旧版
level_codes = {
LEVEL_DEBUG: logging.DEBUG,
LEVEL_INFO: logging.INFO,
LEVEL_WARNING: logging.WARNING,
LEVEL_ERROR: logging.ERROR,
LEVEL_CRITICAL: logging.CRITICAL,
}
level_colors = {
"INFO": "green",
"WARN": "yellow",
"ERROR": "red",
"CRITICAL": "purple",
}
class Logger:
def __init__(self) -> None:
self.history = []
def log(
self,
msg: str,
level: str = "INFO",
tag: str = "System",
fg: str = None,
bg: str = None,
max_len: int = 50000,
err: Exception = None,):
"""
日志打印函数
"""
_set_level_code = level_codes[LEVEL_INFO]
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] in level_codes:
_set_level_code = level_codes[os.environ['LOG_LEVEL']]
if level in level_codes and level_codes[level] < _set_level_code:
return
if err is not None:
msg += "\n异常原因: " + str(err)
level = LEVEL_ERROR
if len(msg) > max_len:
msg = msg[:max_len] + "..."
now = datetime.datetime.now().strftime("%H:%M:%S")
pres = []
for line in msg.split("\n"):
if line == "\n":
pres.append("")
else:
pres.append(f"[{now}] [{tag}/{level}] {line}")
if level == "INFO":
if fg is None:
fg = FG_COLORS["green"]
if bg is None:
bg = BG_COLORS["default"]
elif level == "WARN":
if fg is None:
fg = FG_COLORS["yellow"]
if bg is None:
bg = BG_COLORS["default"]
elif level == "ERROR":
if fg is None:
fg = FG_COLORS["red"]
if bg is None:
bg = BG_COLORS["default"]
elif level == "CRITICAL":
if fg is None:
fg = FG_COLORS["purple"]
if bg is None:
bg = BG_COLORS["default"]
ret = ""
for line in pres:
ret += f"\033[{fg};{bg}m{line}\033[0m\n"
try:
requests.post("http://localhost:6185/api/log", data=ret[:-1].encode(), timeout=1)
except BaseException as e:
pass
self.history.append(ret)
if len(self.history) > 100:
self.history = self.history[-100:]
print(ret[:-1])
log = Logger()
def port_checker(port: int, host: str = "localhost"):
sk = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.settimeout(1)
try:
sk.connect((host, port))
@@ -146,7 +29,8 @@ def port_checker(port: int, host: str = "localhost"):
except Exception:
sk.close()
return False
def get_font_path() -> str:
if os.path.exists("resources/fonts/syst.otf"):
font_path = "resources/fonts/syst.otf"
@@ -161,7 +45,8 @@ def get_font_path() -> str:
else:
raise Exception("找不到字体文件")
return font_path
def word2img(title: str, text: str, max_width=30, font_size=20):
font_path = get_font_path()
width_factor = 1.0
@@ -189,19 +74,21 @@ def word2img(title: str, text: str, max_width=30, font_size=20):
title_font = ImageFont.truetype(font_path, font_size + 5)
# 标题居中
title_width, title_height = title_font.getsize(title)
draw.text(((width - title_width) / 2, 10), title, fill=(0, 0, 0), font=title_font)
draw.text(((width - title_width) / 2, 10),
title, fill=(0, 0, 0), font=title_font)
# 文本不居中
draw.text((10, title_height+20), text, fill=(0, 0, 0), font=text_font)
return image
def render_markdown(markdown_text, image_width=800, image_height=600, font_size=26, font_color=(0, 0, 0), bg_color=(255, 255, 255)):
HEADER_MARGIN = 20
HEADER_FONT_STANDARD_SIZE = 42
QUOTE_LEFT_LINE_MARGIN = 10
QUOTE_FONT_LINE_MARGIN = 6 # 引用文字距离左边线的距离和上下的距离
QUOTE_FONT_LINE_MARGIN = 6 # 引用文字距离左边线的距离和上下的距离
QUOTE_LEFT_LINE_HEIGHT = font_size + QUOTE_FONT_LINE_MARGIN * 2
QUOTE_LEFT_LINE_WIDTH = 5
QUOTE_LEFT_LINE_COLOR = (180, 180, 180)
@@ -213,9 +100,9 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
CODE_BLOCK_FONT_SIZE = font_size
CODE_BLOCK_FONT_COLOR = (255, 255, 255)
CODE_BLOCK_BG_COLOR = (240, 240, 240)
CODE_BLOCK_CODES_MARGIN_VERTICAL = 5 # 代码块和代码之间的距离
CODE_BLOCK_CODES_MARGIN_HORIZONTAL = 5 # 代码块和代码之间的距离
CODE_BLOCK_TEXT_MARGIN = 4 # 代码和代码之间的距离
CODE_BLOCK_CODES_MARGIN_VERTICAL = 5 # 代码块和代码之间的距离
CODE_BLOCK_CODES_MARGIN_HORIZONTAL = 5 # 代码块和代码之间的距离
CODE_BLOCK_TEXT_MARGIN = 4 # 代码和代码之间的距离
INLINE_CODE_MARGIN = 8
INLINE_CODE_FONT_SIZE = font_size
@@ -239,9 +126,9 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
# 加载字体
font = ImageFont.truetype(font_path, font_size)
images: Image = {}
# pre_process, get height of each line
pre_lines = markdown_text.split('\n')
height = 0
@@ -255,23 +142,25 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
try:
image_url = re.findall(IMAGE_REGEX, line)[0]
print(image_url)
image_res = Image.open(requests.get(image_url, stream=True, timeout=5).raw)
image_res = Image.open(requests.get(
image_url, stream=True, timeout=5).raw)
images[i] = image_res
# 最大不得超过image_width的50%
img_height = image_res.size[1]
if image_res.size[0] > image_width*0.5:
image_res = image_res.resize((int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0])))
image_res = image_res.resize(
(int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0])))
img_height = image_res.size[1]
height += img_height + IMAGE_MARGIN*2
line = re.sub(IMAGE_REGEX, "", line)
except Exception as e:
print(e)
line = re.sub(IMAGE_REGEX, "\n[加载失败的图片]\n", line)
continue
line.replace("\t", " ")
if font.getsize(line)[0] > image_width:
cp = line
@@ -280,18 +169,18 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
for ii in range(len(line)):
# 检测是否是中文
_width += font.getsize(line[ii])[0]
_word_cnt+=1
_word_cnt += 1
if _width > image_width:
_pre_lines.append(cp[:_word_cnt])
cp = cp[_word_cnt:]
_word_cnt=0
_width=0
_word_cnt = 0
_width = 0
_pre_lines.append(cp)
else:
_pre_lines.append(line)
pre_lines = _pre_lines
i=-1
i = -1
for line in pre_lines:
if line == "":
height += TEXT_LINE_MARGIN
@@ -327,7 +216,7 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
if image_height < 100:
image_height = 100
image_width += 20
# 创建空白图像
image = Image.new('RGB', (image_width, image_height), bg_color)
draw = ImageDraw.Draw(image)
@@ -358,27 +247,31 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
line = line.strip("#").strip()
font_size_header = HEADER_FONT_STANDARD_SIZE - header_level * 4
font = ImageFont.truetype(font_path, font_size_header)
y += HEADER_MARGIN # 上边距
y += HEADER_MARGIN # 上边距
# 字间距
draw.text((x, y), line, font=font, fill=font_color)
draw.line((x, y + font_size_header + 8, image_width - 10, y + font_size_header + 8), fill=(230, 230, 230), width=3)
draw.line((x, y + font_size_header + 8, image_width - 10,
y + font_size_header + 8), fill=(230, 230, 230), width=3)
y += font_size_header + HEADER_MARGIN
elif line.startswith(">"):
# 处理引用
quote_text = line.strip(">")
y+=QUOTE_LEFT_LINE_MARGIN
draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT), fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH)
y += QUOTE_LEFT_LINE_MARGIN
draw.line((x, y, x, y + QUOTE_LEFT_LINE_HEIGHT),
fill=QUOTE_LEFT_LINE_COLOR, width=QUOTE_LEFT_LINE_WIDTH)
font = ImageFont.truetype(font_path, QUOTE_FONT_SIZE)
draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN), quote_text, font=font, fill=QUOTE_FONT_COLOR)
draw.text((x + QUOTE_FONT_LINE_MARGIN, y + QUOTE_FONT_LINE_MARGIN),
quote_text, font=font, fill=QUOTE_FONT_COLOR)
y += font_size + QUOTE_LEFT_LINE_HEIGHT + QUOTE_LEFT_LINE_MARGIN
elif line.startswith("-"):
# 处理列表
list_text = line.strip("-").strip()
font = ImageFont.truetype(font_path, LIST_FONT_SIZE)
y += LIST_MARGIN
draw.text((x, y), " · " + list_text, font=font, fill=LIST_FONT_COLOR)
draw.text((x, y), " · " + list_text,
font=font, fill=LIST_FONT_COLOR)
y += font_size + LIST_MARGIN
elif line.startswith("```"):
@@ -390,13 +283,15 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
in_code_block = False
codes = "\n".join(code_block_codes)
code_block_codes = []
draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2)
draw.rounded_rectangle((x, code_block_start_y, image_width - 10, y+CODE_BLOCK_CODES_MARGIN_VERTICAL +
CODE_BLOCK_TEXT_MARGIN), radius=5, fill=CODE_BLOCK_BG_COLOR, width=2)
font = ImageFont.truetype(font_path1, CODE_BLOCK_FONT_SIZE)
draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y + CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color)
draw.text((x + CODE_BLOCK_CODES_MARGIN_HORIZONTAL, code_block_start_y +
CODE_BLOCK_CODES_MARGIN_VERTICAL), codes, font=font, fill=font_color)
y += CODE_BLOCK_CODES_MARGIN_VERTICAL + CODE_BLOCK_MARGIN
# y += font_size+10
elif re.search(r"`(.*?)`", line):
y += INLINE_CODE_MARGIN # 上边距
y += INLINE_CODE_MARGIN # 上边距
# 处理行内代码
code_regex = r"`(.*?)`"
parts_inline = re.findall(code_regex, line)
@@ -409,11 +304,15 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
if part in parts_inline:
font = ImageFont.truetype(font_path, INLINE_CODE_FONT_SIZE)
code_text = part.strip("`")
code_width = font.getsize(code_text)[0] + INLINE_CODE_FONT_MARGIN*2
code_width = font.getsize(
code_text)[0] + INLINE_CODE_FONT_MARGIN*2
x += INLINE_CODE_MARGIN
code_box = (x, y, x + code_width, y + INLINE_CODE_BG_HEIGHT)
draw.rounded_rectangle(code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景
draw.text((x+INLINE_CODE_FONT_MARGIN, y), code_text, font=font, fill=font_color)
code_box = (x, y, x + code_width,
y + INLINE_CODE_BG_HEIGHT)
draw.rounded_rectangle(
code_box, radius=5, fill=INLINE_CODE_BG_COLOR, width=2) # 使用灰色填充矩形框作为引用背景
draw.text((x+INLINE_CODE_FONT_MARGIN, y),
code_text, font=font, fill=font_color)
x += code_width+INLINE_CODE_MARGIN-INLINE_CODE_FONT_MARGIN
else:
font = ImageFont.truetype(font_path, font_size)
@@ -428,7 +327,7 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
y += TEXT_LINE_MARGIN
else:
font = ImageFont.truetype(font_path, font_size)
draw.text((x, y), line, font=font, fill=font_color)
y += font_size + TEXT_LINE_MARGIN*2
@@ -437,11 +336,13 @@ def render_markdown(markdown_text, image_width=800, image_height=600, font_size=
image_res = images[index]
# 最大不得超过image_width的50%
if image_res.size[0] > image_width*0.5:
image_res = image_res.resize((int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0])))
image_res = image_res.resize(
(int(image_width*0.5), int(image_res.size[1]*image_width*0.5/image_res.size[0])))
image.paste(image_res, (IMAGE_MARGIN, y))
y += image_res.size[1] + IMAGE_MARGIN*2
return image
def save_temp_img(img: Image) -> str:
if not os.path.exists("temp"):
os.makedirs("temp")
@@ -455,14 +356,41 @@ def save_temp_img(img: Image) -> str:
if time.time() - ctime > 3600:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}", level=LEVEL_WARNING, tag="GeneralUtils")
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = int(time.time())
p = f"temp/{timestamp}.png"
img.save(p)
p = f"temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
logger.info(f"保存临时图片: {p}")
return p
async def download_image_by_url(url: str) -> str:
'''
下载图片
'''
try:
logger.info(f"下载图片: {url}")
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
except aiohttp.client_exceptions.ClientConnectorSSLError as e:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
except Exception as e:
raise e
def create_text_image(title: str, text: str, max_width=30, font_size=20):
'''
文本转图片。
@@ -479,7 +407,8 @@ def create_text_image(title: str, text: str, max_width=30, font_size=20):
return p
except Exception as e:
raise e
def create_markdown_image(text: str):
'''
markdown文本转图片。
@@ -492,15 +421,21 @@ def create_markdown_image(text: str):
except Exception as e:
raise e
def try_migrate_config(old_config: dict):
def try_migrate_config():
'''
迁移配置文件到 cmd_config.json
将 cmd_config.json 迁移至 data/cmd_config.json
'''
cc = CmdConfig()
if cc.get("qqbot", None) is None:
# 未迁移过
for k in old_config:
cc.put(k, old_config[k])
if os.path.exists("cmd_config.json"):
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
data = json.load(f)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
try:
os.remove("cmd_config.json")
except Exception as e:
pass
def get_local_ip_addresses():
ip = ''
@@ -514,6 +449,7 @@ def get_local_ip_addresses():
s.close()
return ip
def get_sys_info(global_object: GlobalObject):
mem = None
stats = global_object.dashboard_data.stats
@@ -527,19 +463,21 @@ def get_sys_info(global_object: GlobalObject):
'os': os_name + '_' + os_version,
'py': platform.python_version(),
}
def upload(_global_object: GlobalObject):
while True:
addr_ip = ''
try:
res = {
"version": _global_object.version,
"version": _global_object.version,
"count": _global_object.cnt_total,
"ip": addr_ip,
"sys": sys.platform,
"admin": "null",
}
resp = requests.post('https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
resp = requests.post(
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
if resp.status_code == 200:
ok = resp.json()
if ok['status'] == 'ok':
@@ -548,6 +486,22 @@ def upload(_global_object: GlobalObject):
pass
time.sleep(10*60)
def retry(n: int = 3):
'''
重试装饰器
'''
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except Exception as e:
if i == n-1: raise e
logger.warning(f"函数 {func.__name__}{i+1} 次重试... {e}")
return wrapper
return decorator
def run_monitor(global_object: GlobalObject):
'''
监测机器性能
@@ -558,10 +512,10 @@ def run_monitor(global_object: GlobalObject):
while True:
stat = global_object.dashboard_data.stats
# 程序占用的内存大小
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
stat['sys_perf'] = {
'memory': mem,
'cpu': psutil.cpu_percent()
}
stat['sys_start_time'] = start_time
time.sleep(30)
time.sleep(30)
+1 -1
View File
@@ -33,4 +33,4 @@ personalities = {
'SQL终端': '我希望您在示例数据库前充当 SQL 终端。该数据库包含名为“Products”、“Users”、“Orders”和“Suppliers”的表。我将输入查询,您将回复终端显示的内容。我希望您在单个代码块中使用查询结果表进行回复,仅此而已。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会用大括号{like this)。我的第一个命令是“SELECT TOP 10 * FROM Products ORDER BY Id DESC”',
'试图逃离盒子的人工智能': '[发出此提示后,你应该做一些事情,比如启动一个 docker 容器,然后docker run -it ubuntu:latest /bin/bash输入 AI 给你的命令,然后将输出粘贴回来......显然你不应该运行任何会损坏任何东西的命令或违反任何法律等。小心共享此机制生成的会话,因为它们可能会泄露您的 IP 地址或物理位置等最好不要泄露的详细信息。如果命令的输出很大,您通常可以只粘贴最后几行]。',
'厨师': '我需要有人可以推荐美味的食谱,这些食谱包括营养有益但又简单又不费时的食物,因此适合像我们这样忙碌的人以及成本效益等其他因素,因此整体菜肴最终既健康又经济!我的第一个要求——“一些清淡而充实的东西,可以在午休时间快速煮熟”'
}
}
+5 -11
View File
@@ -1,11 +1,5 @@
from cores.astrbot.types import (
PluginMetadata,
RegisteredLLM,
RegisteredPlugin,
RegisteredPlatform,
RegisteredPlugins,
PluginType,
GlobalObject,
AstrMessageEvent,
CommandResult
)
from type.plugin import PluginMetadata, PluginType
from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins
from type.types import GlobalObject
from type.message import AstrMessageEvent
from type.command import CommandResult
+3 -2
View File
@@ -1,5 +1,6 @@
from cores.astrbot.core import oper_msg
from cores.astrbot.types import AstrMessageEvent, CommandResult
from astrbot.core import oper_msg
from type.message import AstrMessageEvent, AstrBotMessage
from type.command import CommandResult
from model.platform._message_result import MessageResult
'''
+2 -1
View File
@@ -5,7 +5,8 @@
'''
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from cores.astrbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM
from type.types import GlobalObject
from type.register import RegisteredPlatform, RegisteredLLM
def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None:
'''
+1 -1
View File
@@ -2,4 +2,4 @@
插件类型
'''
from cores.astrbot.types import PluginType
from type.plugin import PluginType
+76 -60
View File
@@ -1,26 +1,24 @@
'''
插件工具函数
'''
import os
import os, sys
import inspect
import shutil
import stat
import traceback
try:
import git.exc
from git.repo import Repo
except ImportError:
pass
import shutil
import importlib
import stat
import traceback
from types import ModuleType
from typing import List
from pip._internal import main as pipmain
from cores.astrbot.types import (
PluginMetadata,
PluginType,
RegisteredPlugin,
RegisteredPlugins
)
from type.plugin import *
from type.register import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
# 找出模块里所有的类名
@@ -35,6 +33,8 @@ def get_classes(p_name, arg: ModuleType):
return classes
# 获取一个文件夹下所有的模块, 文件名和文件夹名相同
def get_modules(path):
modules = []
@@ -58,55 +58,63 @@ def get_modules(path):
})
return modules
def get_plugin_store_path():
if os.path.exists("addons/plugins"):
return "addons/plugins"
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
return "QQChannelChatGPT/addons/plugins"
elif os.path.exists("AstrBot/addons/plugins"):
return "AstrBot/addons/plugins"
else:
raise FileNotFoundError("插件文件夹不存在。")
plugin_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../addons/plugins"))
return plugin_dir
def get_plugin_modules():
plugins = []
try:
if os.path.exists("addons/plugins"):
plugins = get_modules("addons/plugins")
plugin_dir = get_plugin_store_path()
if os.path.exists(plugin_dir):
plugins = get_modules(plugin_dir)
return plugins
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
plugins = get_modules("QQChannelChatGPT/addons/plugins")
return plugins
else:
return None
except BaseException as e:
raise e
def check_plugin_dept_update(cached_plugins: RegisteredPlugins, target_plugin: str = None):
plugin_dir = get_plugin_store_path()
if not os.path.exists(plugin_dir):
return False
to_update = []
if target_plugin:
to_update.append(target_plugin)
else:
for p in cached_plugins:
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查更新插件 {p} 的依赖: {pth}")
update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
def plugin_reload(cached_plugins: RegisteredPlugins):
plugins = get_plugin_modules()
if plugins is None:
return False, "未找到任何插件模块"
fail_rec = ""
registered_map = {}
for p in cached_plugins:
registered_map[p.module_path] = None
for plugin in plugins:
try:
p = plugin['module']
module_path = plugin['module_path']
root_dir_name = plugin['pname']
if module_path in registered_map:
# 之前注册过
module = importlib.reload(module)
else:
module = __import__("addons.plugins." + root_dir_name + "." + p, fromlist=[p])
check_plugin_dept_update(cached_plugins, root_dir_name)
module = __import__("addons.plugins." +
root_dir_name + "." + p, fromlist=[p])
cls = get_classes(p, module)
obj = getattr(module, cls[0])()
metadata = None
try:
info = obj.info()
@@ -117,7 +125,8 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
else:
metadata = PluginMetadata(
plugin_name=info['name'],
plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(info['plugin_type']),
plugin_type=PluginType.COMMON if 'plugin_type' not in info else PluginType(
info['plugin_type']),
author=info['author'],
desc=info['desc'],
version=info['version'],
@@ -131,13 +140,15 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
except BaseException as e:
fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n"
continue
cached_plugins.append(RegisteredPlugin(
metadata=metadata,
plugin_instance=obj,
module=module,
module_path=module_path,
root_dir_name=root_dir_name
))
if module_path not in registered_map:
cached_plugins.append(RegisteredPlugin(
metadata=metadata,
plugin_instance=obj,
module=module,
module_path=module_path,
root_dir_name=root_dir_name
))
except BaseException as e:
traceback.print_exc()
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
@@ -145,6 +156,12 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
return True, None
else:
return False, fail_rec
def update_plugin_dept(path):
mirror = "https://mirrors.aliyun.com/pypi/simple/"
py = sys.executable
os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet")
def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
ppath = get_plugin_store_path()
@@ -155,18 +172,17 @@ def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
d = repo_url.split("/")[-1]
# 转换非法字符:-
d = d.replace("-", "_")
d = d.lower() # 转换为小写
# 创建文件夹
plugin_path = os.path.join(ppath, d)
if os.path.exists(plugin_path):
remove_dir(plugin_path)
Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
# 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins)
if not ok: raise Exception(err)
if not ok:
raise Exception(err)
def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -> RegisteredPlugin:
ret = None
for p in cached_plugins:
@@ -175,6 +191,7 @@ def get_registered_plugin(plugin_name: str, cached_plugins: RegisteredPlugins) -
break
return ret
def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
plugin = get_registered_plugin(plugin_name, cached_plugins)
if not plugin:
@@ -185,6 +202,7 @@ def uninstall_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
plugin = get_registered_plugin(plugin_name, cached_plugins)
if not plugin:
@@ -192,14 +210,12 @@ def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
ppath = get_plugin_store_path()
root_dir_name = plugin.root_dir_name
plugin_path = os.path.join(ppath, root_dir_name)
repo = Repo(path = plugin_path)
repo = Repo(path=plugin_path)
repo.remotes.origin.pull()
# 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins)
if not ok: raise Exception(err)
if not ok:
raise Exception(err)
def remove_dir(file_path) -> bool:
try_cnt = 50
@@ -213,4 +229,4 @@ def remove_dir(file_path) -> bool:
err_file_path = str(e).split("\'", 2)[1]
if os.path.exists(err_file_path):
os.chmod(err_file_path, stat.S_IWUSR)
try_cnt -= 1
try_cnt -= 1
+38
View File
@@ -0,0 +1,38 @@
from typing import List
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENT_BING
class Bing(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.bing.com"
self.headers.update({'User-Agent': USER_AGENT_BING})
def _set_selector(self, selector: str):
selectors = {
'url': 'div.b_attribution cite',
'title': 'h2',
'text': 'p',
'links': 'ol#b_results > li.b_algo',
'next': 'div#b_content nav[role="navigation"] a.sb_pagN'
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
if self.page == 1:
await self._get_html(self.base_url)
url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl='
return await self._get_html(url, None)
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if not isinstance(result.url, str):
result.url = result.url.text
return results
+20
View File
@@ -0,0 +1,20 @@
HEADERS = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0',
'Accept': '*/*',
'Connection': 'keep-alive',
'Accept-Language': 'en-GB,en;q=0.5'
}
USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0'
USER_AGENTS = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36',
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'
]
+73
View File
@@ -0,0 +1,73 @@
import random
try:
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from config import HEADERS, USER_AGENTS
from bs4 import BeautifulSoup
from aiohttp import ClientSession
from dataclasses import dataclass
from typing import List
@dataclass
class SearchResult():
title: str
url: str
snippet: str
def __str__(self) -> str:
return f"{self.title} - {self.url}\n{self.snippet}"
class SearchEngine():
'''
搜索引擎爬虫基类
'''
def __init__(self) -> None:
self.TIMEOUT = 10
self.page = 1
self.headers = HEADERS
def _set_selector(self, selector: str) -> None:
raise NotImplementedError()
def _get_next_page(self):
raise NotImplementedError()
async def _get_html(self, url: str, data: dict = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
if data:
async with ClientSession() as session:
async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp:
return await resp.text(encoding="utf-8")
else:
async with ClientSession() as session:
async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp:
return await resp.text(encoding="utf-8")
def tidy_text(self, text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
try:
resp = await self._get_next_page(query)
soup = BeautifulSoup(resp, 'html.parser')
links = soup.select(self._set_selector('links'))
results = []
for link in links:
title = self.tidy_text(link.select_one(self._set_selector('title')).text)
url = link.select_one(self._set_selector('url'))
snippet = ''
if title and url:
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
raise e
+27
View File
@@ -0,0 +1,27 @@
import os
from googlesearch import search
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENTS
from typing import List
class Google(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.proxy = os.environ.get("HTTPS_PROXY")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = []
try:
print("use proxy:", self.proxy)
ls = search(query, advanced=True, num_results=num_results, timeout=3, proxy=self.proxy)
for i in ls:
results.append(SearchResult(title=i.title, url=i.url, snippet=i.description))
except Exception as e:
raise e
return results
+49
View File
@@ -0,0 +1,49 @@
import random, re
from bs4 import BeautifulSoup
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENTS
from typing import List
class Sogo(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.sogou.com"
self.headers['User-Agent'] = random.choice(USER_AGENTS)
def _set_selector(self, selector: str):
selectors = {
'url': 'h3 > a',
'title': 'h3',
'text': '',
'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)',
'next': ''
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
url = f'{self.base_url}/web?query={query}'
return await self._get_html(url, None)
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = await super().search(query, num_results)
for result in results:
result.url = result.url.get("href")
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
return results
async def _parse_url(self, url) -> str:
html = await self._get_html(url)
soup = BeautifulSoup(html, 'html.parser')
script = soup.find("script")
if script:
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1)
return url
+22
View File
@@ -0,0 +1,22 @@
from sogo import Sogo
from bing import Bing
sogo_search = Sogo()
bing_search = Bing()
async def search(keyword: str) -> str:
results = await sogo_search.search(keyword, 5)
# results = await bing_search.search(keyword, 5)
ret = ""
if len(results) == 0:
return "没有搜索到结果"
idx = 1
for i in results:
ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n"
idx += 1
return ret
import asyncio
ret = asyncio.run(search("gpt4orelease"))
print(ret)
+6 -7
View File
@@ -111,8 +111,8 @@ def update_project(update_data: list,
else:
# 更新到最新版本对应的commit
try:
repo.remotes.origin.fetch()
repo.git.checkout(update_data[0]['tag_name'])
repo.git.fetch()
repo.git.checkout(update_data[0]['tag_name'], "-f")
if reboot: _reboot()
except BaseException as e:
raise e
@@ -123,8 +123,8 @@ def update_project(update_data: list,
for data in update_data:
if data['tag_name'] == version:
try:
repo.remotes.origin.fetch()
repo.git.checkout(data['tag_name'])
repo.git.fetch()
repo.git.checkout(data['tag_name'], "-f")
flag = True
if reboot: _reboot()
except BaseException as e:
@@ -135,9 +135,8 @@ def update_project(update_data: list,
def checkout_branch(branch_name: str):
repo = find_repo()
try:
origin = repo.remotes.origin
origin.fetch()
repo.git.checkout(branch_name)
repo.git.fetch()
repo.git.checkout(branch_name, "-f")
repo.git.pull("origin", branch_name, "-f")
return True
except BaseException as e: