Compare commits

...

34 Commits

Author SHA1 Message Date
Soulter e8773cea7f fix: 修复配置文件没有有效迁移的问题 2024-05-25 20:59:37 +08:00
Soulter 4d36ffcb08 fix: 优化插件的结果处理 2024-05-25 18:46:38 +08:00
Soulter c653e492c4 Merge pull request #164 from Soulter/stat-upload-perf
/models 指令优化
2024-05-25 18:35:56 +08:00
Soulter f08de1f404 perf: 添加 models 指令到帮助中 2024-05-25 18:34:08 +08:00
Soulter 1218691b61 perf: model 指令放宽限制,支持输入自定义模型。设置模型后持久化保存。 2024-05-25 18:29:01 +08:00
Soulter 61fc27ff79 Merge pull request #163 from Soulter/stat-upload-perf
优化统计记录数据结构
2024-05-25 18:28:08 +08:00
Soulter 123ee24f7e fix: stat perf 2024-05-25 18:01:16 +08:00
Soulter 52c9045a28 feat: 优化了统计信息数据结构 2024-05-25 17:47:41 +08:00
Soulter f00f1e8933 fix: 画图报错 2024-05-24 13:33:02 +08:00
Soulter 8da4433e57 chore: 更改相关字段 2024-05-21 08:44:05 +08:00
Soulter 7babb87934 perf: 更改库的加载顺序 2024-05-21 08:41:46 +08:00
Soulter f67b171385 perf: 数据库迁移至 data 目录下 2024-05-19 17:10:11 +08:00
Soulter 1780d1355d perf: 将内部pip全部更换为阿里云镜像; 插件依赖更新逻辑优化 2024-05-19 16:45:08 +08:00
Soulter 5a3390e4f3 fix: force update 2024-05-19 16:06:47 +08:00
Soulter 337d96b41d Merge pull request #160 from Soulter/dev_default_openai_refactor
优化自带的 OpenAI LLM 交互, 人格, 网页搜索
2024-05-19 15:23:19 +08:00
Soulter 38a1dfea98 fix: web content scraper add proxy 2024-05-19 15:08:22 +08:00
Soulter fbef73aeec fix: websearch encoding set to utf-8 2024-05-19 14:42:28 +08:00
Soulter d6214c2b7c fix: web search 2024-05-19 12:55:54 +08:00
Soulter d58c86f6fc perf: websearch 优化;项目结构调整 2024-05-19 12:46:07 +08:00
Soulter ea34c20198 perf: 优化人格和LVM的处理过程 2024-05-18 10:34:35 +08:00
Soulter 934ca94e62 refactor: 重写 LLM OpenAI 模块 2024-05-17 22:56:44 +08:00
Soulter 1775327c2e chore: refact openai official 2024-05-17 09:07:11 +08:00
Soulter 707fcad8b4 feat: gpt 模型列表查看指令 models 2024-05-17 00:06:49 +08:00
Soulter f143c5afc6 fix: 修复 plugin v 子指令报错的问题 2024-05-16 23:11:07 +08:00
Soulter 99f94b2611 fix: 修复无法调用某些指令的问题 2024-05-16 23:04:47 +08:00
Soulter e39c1f9116 remove: 移除自动更换多模态模型的功能 2024-05-16 22:46:50 +08:00
Soulter 235e0b9b8f fix: gocq logging 2024-05-09 13:24:31 +08:00
Soulter d5a9bed8a4 fix(updator): IterableList object has no
attribute origin
2024-05-08 19:18:21 +08:00
Soulter d7dc8a7612 chore: 添加一些日志;更新版本 2024-05-08 19:12:23 +08:00
Soulter 08cd3ca40c perf: 更好的日志输出;
fix: 修复可视化面板刷新404
2024-05-08 19:01:36 +08:00
Soulter a13562dcea fix: 修复启动器启动加载带有配置的插件时提示配置文件缺失的问题 2024-05-08 16:28:30 +08:00
Soulter d7a0c0d1d0 Update requirements.txt 2024-05-07 15:58:51 +08:00
Soulter c0729b2d29 fix: 修复插件重载相关问题 2024-04-22 19:04:15 +08:00
Soulter a80f474290 fix: 修复更新插件时的报错 2024-04-22 18:36:56 +08:00
39 changed files with 1643 additions and 1438 deletions
+1
View File
@@ -10,3 +10,4 @@ cmd_config.json
addons/plugins/ addons/plugins/
data/* data/*
cookies.json cookies.json
logs/
+7 -6
View File
@@ -11,6 +11,10 @@ import threading
import time import time
import asyncio import asyncio
from util.plugin_dev.api.v1.config import update_config from util.plugin_dev.api.v1.config import update_config
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
@dataclass @dataclass
@@ -28,7 +32,6 @@ class DashBoardHelper():
def __init__(self, global_object, config: dict): def __init__(self, global_object, config: dict):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
self.logger = global_object.logger
dashboard_data = global_object.dashboard_data dashboard_data = global_object.dashboard_data
dashboard_data.configs = { dashboard_data.configs = {
"data": [] "data": []
@@ -42,7 +45,6 @@ class DashBoardHelper():
@self.dashboard.register("post_configs") @self.dashboard.register("post_configs")
def on_post_configs(post_configs: dict): def on_post_configs(post_configs: dict):
try: try:
# self.logger.log(f"收到配置更新请求", gu.LEVEL_INFO, tag="可视化面板")
if 'base_config' in post_configs: if 'base_config' in post_configs:
self.save_config( self.save_config(
post_configs['base_config'], namespace='') # 基础配置 post_configs['base_config'], namespace='') # 基础配置
@@ -54,7 +56,6 @@ class DashBoardHelper():
threading.Thread(target=self.dashboard.shutdown_bot, threading.Thread(target=self.dashboard.shutdown_bot,
args=(2,), daemon=True).start() args=(2,), daemon=True).start()
except Exception as e: except Exception as e:
# self.logger.log(f"在保存配置时发生错误:{e}", gu.LEVEL_ERROR, tag="可视化面板")
raise e raise e
# 将 config.yaml、 中的配置解析到 dashboard_data.configs 中 # 将 config.yaml、 中的配置解析到 dashboard_data.configs 中
@@ -118,14 +119,14 @@ class DashBoardHelper():
) )
qq_gocq_platform_group = DashBoardConfig( qq_gocq_platform_group = DashBoardConfig(
config_type="group", config_type="group",
name="OneBot协议平台配置", name="go-cqhttp",
description="", description="",
body=[ body=[
DashBoardConfig( DashBoardConfig(
config_type="item", config_type="item",
val_type="bool", val_type="bool",
name="启用", name="启用",
description="支持cq-http、shamrock等(目前仅支持QQ平台)", description="",
value=config['gocqbot']['enable'], value=config['gocqbot']['enable'],
path="gocqbot.enable", path="gocqbot.enable",
), ),
@@ -470,7 +471,7 @@ class DashBoardHelper():
] ]
except Exception as e: except Exception as e:
self.logger.log(f"配置文件解析错误:{e}", gu.LEVEL_ERROR) logger.error(f"配置文件解析错误:{e}")
raise e raise e
def save_config(self, post_config: list, namespace: str): def save_config(self, post_config: list, namespace: str):
+54 -29
View File
@@ -1,13 +1,3 @@
from flask import Flask, request
from flask.logging import default_handler
from werkzeug.serving import make_server
from util import general_utils as gu
from dataclasses import dataclass
import logging
from cores.database.conn import dbConn
from util.cmd_config import CmdConfig
from util.updator import check_update, update_project, request_release_info
from cores.astrbot.types import *
import util.plugin_util as putil import util.plugin_util as putil
import websockets import websockets
import json import json
@@ -17,6 +7,19 @@ import os
import sys import sys
import time import time
from flask import Flask, request
from flask.logging import default_handler
from werkzeug.serving import make_server
from util import general_utils as gu
from dataclasses import dataclass
from persist.session import dbConn
from type.register import RegisteredPlugin
from typing import List
from util.cmd_config import CmdConfig
from util.updator import check_update, update_project, request_release_info
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
@dataclass @dataclass
class DashBoardData(): class DashBoardData():
@@ -41,11 +44,8 @@ class AstrBotDashBoard():
self.dashboard_data: DashBoardData = global_object.dashboard_data self.dashboard_data: DashBoardData = global_object.dashboard_data
self.dashboard_be = Flask( self.dashboard_be = Flask(
__name__, static_folder="dist", static_url_path="/") __name__, static_folder="dist", static_url_path="/")
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
self.funcs = {} self.funcs = {}
self.cc = CmdConfig() self.cc = CmdConfig()
self.logger = global_object.logger
self.ws_clients = {} # remote_ip: ws self.ws_clients = {} # remote_ip: ws
# 启动 websocket 服务器 # 启动 websocket 服务器
self.ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186) self.ws_server = websockets.serve(self.__handle_msg, "0.0.0.0", 6186)
@@ -55,6 +55,22 @@ class AstrBotDashBoard():
# 返回页面 # 返回页面
return self.dashboard_be.send_static_file("index.html") return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/config")
def rt_config():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/logs")
def rt_logs():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/extension")
def rt_extension():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/dashboard/default")
def rt_dashboard():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.post("/api/authenticate") @self.dashboard_be.post("/api/authenticate")
def authenticate(): def authenticate():
username = self.cc.get("dashboard_username", "") username = self.cc.get("dashboard_username", "")
@@ -179,9 +195,9 @@ class AstrBotDashBoard():
post_data = request.json post_data = request.json
repo_url = post_data["url"] repo_url = post_data["url"]
try: try:
self.logger.log(f"正在安装插件 {repo_url}", tag="可视化面板") logger.info(f"正在安装插件 {repo_url}")
putil.install_plugin(repo_url, self.dashboard_data.plugins) putil.install_plugin(repo_url, self.dashboard_data.plugins)
self.logger.log(f"安装插件 {repo_url} 成功", tag="可视化面板") logger.info(f"安装插件 {repo_url} 成功")
return Response( return Response(
status="success", status="success",
message="安装成功~", message="安装成功~",
@@ -199,10 +215,10 @@ class AstrBotDashBoard():
post_data = request.json post_data = request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
try: try:
self.logger.log(f"正在卸载插件 {plugin_name}", tag="可视化面板") logger.info(f"正在卸载插件 {plugin_name}")
putil.uninstall_plugin( putil.uninstall_plugin(
plugin_name, self.dashboard_data.plugins) plugin_name, self.dashboard_data.plugins)
self.logger.log(f"卸载插件 {plugin_name} 成功", tag="可视化面板") logger.info(f"卸载插件 {plugin_name} 成功")
return Response( return Response(
status="success", status="success",
message="卸载成功~", message="卸载成功~",
@@ -220,9 +236,9 @@ class AstrBotDashBoard():
post_data = request.json post_data = request.json
plugin_name = post_data["name"] plugin_name = post_data["name"]
try: try:
self.logger.log(f"正在更新插件 {plugin_name}", tag="可视化面板") logger.info(f"正在更新插件 {plugin_name}")
putil.update_plugin(plugin_name, self.dashboard_data.plugins) putil.update_plugin(plugin_name, self.dashboard_data.plugins)
self.logger.log(f"更新插件 {plugin_name} 成功", tag="可视化面板") logger.info(f"更新插件 {plugin_name} 成功")
return Response( return Response(
status="success", status="success",
message="更新成功~", message="更新成功~",
@@ -374,13 +390,13 @@ class AstrBotDashBoard():
}, },
{ {
"title": "QQ_OFFICIAL", "title": "QQ_OFFICIAL",
"desc": "QQ官方API,仅支持频道", "desc": "QQ官方API支持频道、群(需获得群权限)",
"namespace": "internal_platform_qq_official", "namespace": "internal_platform_qq_official",
"tag": "" "tag": ""
}, },
{ {
"title": "OneBot协议", "title": "go-cqhttp",
"desc": "支持cq-http、shamrock等(目前仅支持QQ平台)", "desc": "第三方 QQ 协议实现。支持频道、群",
"namespace": "internal_platform_qq_gocq", "namespace": "internal_platform_qq_gocq",
"tag": "" "tag": ""
} }
@@ -416,21 +432,29 @@ class AstrBotDashBoard():
return func return func
return decorator return decorator
async def get_log_history(self):
try:
with open("logs/astrbot-core/astrbot-core.log", "r", encoding="utf-8") as f:
return f.readlines()[-100:]
except Exception as e:
logger.warning(f"读取日志历史失败: {e.__str__()}")
return []
async def __handle_msg(self, websocket, path): async def __handle_msg(self, websocket, path):
address = websocket.remote_address address = websocket.remote_address
# self.logger.log(f"和 {address} 建立了 websocket 连接", tag="可视化面板")
self.ws_clients[address] = websocket self.ws_clients[address] = websocket
data = ''.join(self.logger.history).replace('\n', '\r\n') data = await self.get_log_history()
data = ''.join(data).replace('\n', '\r\n')
await websocket.send(data) await websocket.send(data)
while True: while True:
try: try:
msg = await websocket.recv() msg = await websocket.recv()
except websockets.exceptions.ConnectionClosedError: except websockets.exceptions.ConnectionClosedError:
# self.logger.log(f"和 {address} 的 websocket 连接已断开", tag="可视化面板") # logger.info(f"和 {address} 的 websocket 连接已断开")
del self.ws_clients[address] del self.ws_clients[address]
break break
except Exception as e: except Exception as e:
# self.logger.log(f"和 {path} 的 websocket 连接发生了错误: {e.__str__()}", tag="可视化面板") # logger.info(f"和 {path} 的 websocket 连接发生了错误: {e.__str__()}")
del self.ws_clients[address] del self.ws_clients[address]
break break
@@ -441,11 +465,12 @@ class AstrBotDashBoard():
def run(self): def run(self):
threading.Thread(target=self.run_ws_server, args=(self.loop,)).start() threading.Thread(target=self.run_ws_server, args=(self.loop,)).start()
self.logger.log("已启动 websocket 服务器", tag="可视化面板") logger.info("已启动 websocket 服务器")
ip_address = gu.get_local_ip_addresses() ip_address = gu.get_local_ip_addresses()
ip_str = f"http://{ip_address}:6185\n\thttp://localhost:6185" ip_str = f"http://{ip_address}:6185\n\thttp://localhost:6185"
self.logger.log( logger.info(
f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n", tag="可视化面板") f"\n==================\n您可访问:\n\n\t{ip_str}\n\n来登录可视化面板,默认账号密码为空。\n注意: 所有配置项现已全量迁移至 cmd_config.json 文件下,可登录可视化面板在线修改配置。\n==================\n")
http_server = make_server( http_server = make_server(
'0.0.0.0', 6185, self.dashboard_be, threaded=True) '0.0.0.0', 6185, self.dashboard_be, threaded=True)
http_server.serve_forever() http_server.serve_forever()
+71 -80
View File
@@ -2,32 +2,35 @@ import re
import threading import threading
import asyncio import asyncio
import time import time
import aiohttp
import util.unfit_words as uw import util.unfit_words as uw
import os import os
import sys import sys
import io
import traceback import traceback
import util.function_calling.gplugin as gplugin import util.agent.web_searcher as web_searcher
import util.plugin_util as putil import util.plugin_util as putil
from PIL import Image as PILImage
from nakuru.entities.components import Plain, At, Image from nakuru.entities.components import Plain, At, Image
from addons.baidu_aip_judge import BaiduJudge from addons.baidu_aip_judge import BaiduJudge
from model.provider.provider import Provider from model.provider.provider import Provider
from model.command.command import Command from model.command.command import Command
from util import general_utils as gu from util import general_utils as gu
from util.general_utils import Logger, upload, run_monitor from util.general_utils import upload, run_monitor
from util.cmd_config import CmdConfig as cc from util.cmd_config import CmdConfig as cc
from util.cmd_config import init_astrbot_config_items from util.cmd_config import init_astrbot_config_items
from .types import * from type.types import GlobalObject
from type.register import *
from type.message import AstrBotMessage
from type.config import *
from addons.dashboard.helper import DashBoardHelper from addons.dashboard.helper import DashBoardHelper
from addons.dashboard.server import DashBoardData from addons.dashboard.server import DashBoardData
from cores.database.conn import dbConn from persist.session import dbConn
from model.platform._message_result import MessageResult from model.platform._message_result import MessageResult
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
# 用户发言频率 # 用户发言频率
user_frequency = {} user_frequency = {}
@@ -36,9 +39,6 @@ frequency_time = 60
# 计数默认值 # 计数默认值
frequency_count = 10 frequency_count = 10
# 版本
version = '3.1.12'
# 语言模型 # 语言模型
OPENAI_OFFICIAL = 'openai_official' OPENAI_OFFICIAL = 'openai_official'
NONE_LLM = 'none_llm' NONE_LLM = 'none_llm'
@@ -54,13 +54,8 @@ baidu_judge = None
# CLI # CLI
PLATFORM_CLI = 'cli' PLATFORM_CLI = 'cli'
init_astrbot_config_items()
# 全局对象 # 全局对象
_global_object: GlobalObject = None _global_object: GlobalObject = None
logger: Logger = Logger()
# 语言模型选择
def privider_chooser(cfg): def privider_chooser(cfg):
@@ -69,22 +64,16 @@ def privider_chooser(cfg):
l.append('openai_official') l.append('openai_official')
return l return l
def init():
''' '''
初始化机器人 初始化机器人
''' '''
def init(cfg):
global llm_instance, llm_command_instance global llm_instance, llm_command_instance
global baidu_judge, chosen_provider global baidu_judge, chosen_provider
global frequency_count, frequency_time global frequency_count, frequency_time
global _global_object global _global_object
global logger
# 迁移旧配置 init_astrbot_config_items()
gu.try_migrate_config(cfg)
# 使用新配置
cfg = cc.get_all() cfg = cc.get_all()
_event_loop = asyncio.new_event_loop() _event_loop = asyncio.new_event_loop()
@@ -92,10 +81,9 @@ def init(cfg):
# 初始化 global_object # 初始化 global_object
_global_object = GlobalObject() _global_object = GlobalObject()
_global_object.version = version _global_object.version = VERSION
_global_object.base_config = cfg _global_object.base_config = cfg
_global_object.logger = logger logger.info("AstrBot v" + VERSION)
logger.log("AstrBot v"+version, gu.LEVEL_INFO)
if 'reply_prefix' in cfg: if 'reply_prefix' in cfg:
# 适配旧版配置 # 适配旧版配置
@@ -105,12 +93,21 @@ def init(cfg):
cc.put("reply_prefix", "") cc.put("reply_prefix", "")
else: else:
_global_object.reply_prefix = cfg['reply_prefix'] _global_object.reply_prefix = cfg['reply_prefix']
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 语言模型提供商 # 语言模型提供商
logger.log("正在载入语言模型...", gu.LEVEL_INFO) logger.info("正在载入语言模型...")
prov = privider_chooser(cfg) prov = privider_chooser(cfg)
if OPENAI_OFFICIAL in prov: if OPENAI_OFFICIAL in prov:
logger.log("初始化:OpenAI官方", gu.LEVEL_INFO) logger.info("初始化:OpenAI官方")
if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]: if cfg['openai']['key'] is not None and cfg['openai']['key'] != [None]:
from model.provider.openai_official import ProviderOpenAIOfficial from model.provider.openai_official import ProviderOpenAIOfficial
from model.command.openai_official import CommandOpenAIOfficial from model.command.openai_official import CommandOpenAIOfficial
@@ -122,6 +119,11 @@ def init(cfg):
llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal")) llm_name=OPENAI_OFFICIAL, llm_instance=llm_instance[OPENAI_OFFICIAL], origin="internal"))
chosen_provider = OPENAI_OFFICIAL chosen_provider = OPENAI_OFFICIAL
instance = llm_instance[OPENAI_OFFICIAL]
assert isinstance(instance, ProviderOpenAIOfficial)
instance.DEFAULT_PERSONALITY = _global_object.default_personality
instance.curr_personality = instance.DEFAULT_PERSONALITY
# 检查provider设置偏好 # 检查provider设置偏好
p = cc.get("chosen_provider", None) p = cc.get("chosen_provider", None)
if p is not None and p in llm_instance: if p is not None and p in llm_instance:
@@ -131,9 +133,9 @@ def init(cfg):
if 'baidu_aip' in cfg and 'enable' in cfg['baidu_aip'] and cfg['baidu_aip']['enable']: if 'baidu_aip' in cfg and 'enable' in cfg['baidu_aip'] and cfg['baidu_aip']['enable']:
try: try:
baidu_judge = BaiduJudge(cfg['baidu_aip']) baidu_judge = BaiduJudge(cfg['baidu_aip'])
logger.log("百度内容审核初始化成功", gu.LEVEL_INFO) logger.info("百度内容审核初始化成功")
except BaseException as e: except BaseException as e:
logger.log("百度内容审核初始化失败", gu.LEVEL_ERROR) logger.info("百度内容审核初始化失败")
threading.Thread(target=upload, args=( threading.Thread(target=upload, args=(
_global_object, ), daemon=True).start() _global_object, ), daemon=True).start()
@@ -151,7 +153,7 @@ def init(cfg):
else: else:
_global_object.unique_session = False _global_object.unique_session = False
except BaseException as e: except BaseException as e:
logger.log("独立会话配置错误: "+str(e), gu.LEVEL_ERROR) logger.info("独立会话配置错误: "+str(e))
nick_qq = cc.get("nick_qq", None) nick_qq = cc.get("nick_qq", None)
if nick_qq == None: if nick_qq == None:
@@ -166,45 +168,37 @@ def init(cfg):
global llm_wake_prefix global llm_wake_prefix
llm_wake_prefix = cc.get("llm_wake_prefix", "") llm_wake_prefix = cc.get("llm_wake_prefix", "")
logger.log("正在载入插件...", gu.LEVEL_INFO) logger.info("正在载入插件...")
# 加载插件 # 加载插件
_command = Command(None, _global_object) _command = Command(None, _global_object)
ok, err = putil.plugin_reload(_global_object.cached_plugins) ok, err = putil.plugin_reload(_global_object.cached_plugins)
if ok: if ok:
logger.log( logger.info(
f"成功载入 {len(_global_object.cached_plugins)} 个插件", gu.LEVEL_INFO) f"成功载入 {len(_global_object.cached_plugins)} 个插件")
else: else:
logger.log(err, gu.LEVEL_ERROR) logger.info(err)
if chosen_provider is None: if chosen_provider is None:
llm_command_instance[NONE_LLM] = _command llm_command_instance[NONE_LLM] = _command
chosen_provider = NONE_LLM chosen_provider = NONE_LLM
logger.log("正在载入机器人消息平台", gu.LEVEL_INFO) logger.info("正在载入机器人消息平台")
# logger.log("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)", gu.LEVEL_WARNING) # logger.info("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)")
platform_str = "" platform_str = ""
# GOCQ # GOCQ
if 'gocqbot' in cfg and cfg['gocqbot']['enable']: if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
logger.log("启用 QQ_GOCQ 机器人消息平台", gu.LEVEL_INFO) logger.info("启用 QQ_GOCQ 机器人消息平台")
threading.Thread(target=run_gocq_bot, args=( threading.Thread(target=run_gocq_bot, args=(
cfg, _global_object), daemon=True).start() cfg, _global_object), daemon=True).start()
platform_str += "QQ_GOCQ," platform_str += "QQ_GOCQ,"
# QQ频道 # QQ频道
if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None: if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None:
logger.log("启用 QQ_OFFICIAL 机器人消息平台", gu.LEVEL_INFO) logger.info("启用 QQ_OFFICIAL 机器人消息平台")
threading.Thread(target=run_qqchan_bot, args=( threading.Thread(target=run_qqchan_bot, args=(
cfg, _global_object), daemon=True).start() cfg, _global_object), daemon=True).start()
platform_str += "QQ_OFFICIAL," platform_str += "QQ_OFFICIAL,"
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
_global_object.default_personality = None
else:
_global_object.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
# 初始化dashboard # 初始化dashboard
_global_object.dashboard_data = DashBoardData( _global_object.dashboard_data = DashBoardData(
stats={}, stats={},
@@ -221,12 +215,12 @@ def init(cfg):
threading.Thread(target=run_monitor, args=( threading.Thread(target=run_monitor, args=(
_global_object,), daemon=True).start() _global_object,), daemon=True).start()
logger.log( logger.info(
"如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。", gu.LEVEL_INFO) "如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。")
logger.log("请给 https://github.com/Soulter/AstrBot 点个 star。", gu.LEVEL_INFO) logger.info("请给 https://github.com/Soulter/AstrBot 点个 star。")
if platform_str == '': if platform_str == '':
platform_str = "(未启动任何平台,请前往面板添加)" platform_str = "(未启动任何平台,请前往面板添加)"
logger.log(f"🎉 项目启动完成") logger.info(f"🎉 项目启动完成")
dashboard_thread.join() dashboard_thread.join()
@@ -245,10 +239,8 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal"))
qqchannel_bot.run() qqchannel_bot.run()
except BaseException as e: except BaseException as e:
logger.log("启动QQ频道机器人时出现错误, 原因如下: " + str(e), logger.error("启动 QQ 频道机器人时出现错误, 原因如下: " + str(e))
gu.LEVEL_CRITICAL, tag="QQ频道") logger.error(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。")
logger.log(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。" +
str(e), gu.LEVEL_CRITICAL)
''' '''
@@ -263,17 +255,17 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
host = cc.get("gocq_host", "127.0.0.1") host = cc.get("gocq_host", "127.0.0.1")
port = cc.get("gocq_websocket_port", 6700) port = cc.get("gocq_websocket_port", 6700)
http_port = cc.get("gocq_http_port", 5700) http_port = cc.get("gocq_http_port", 5700)
logger.log( logger.info(
f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}", tag="QQ") f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}")
while True: while True:
if not gu.port_checker(port=port, host=host) or not gu.port_checker(port=http_port, host=host): if not gu.port_checker(port=port, host=host) or not gu.port_checker(port=http_port, host=host):
if not noticed: if not noticed:
noticed = True noticed = True
logger.log( logger.warning(
f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。", gu.LEVEL_CRITICAL, tag="QQ") f"连接到{host}:{port}(或{http_port})失败。程序会每隔 5s 自动重试。")
time.sleep(5) time.sleep(5)
else: else:
logger.log("检查完毕,未发现问题。", tag="QQ") logger.info("已连接到 gocq。")
break break
try: try:
qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg,
@@ -316,7 +308,6 @@ async def record_message(platform: str, session_id: str):
db_inst.increment_stat_session(platform, session_id, 1) db_inst.increment_stat_session(platform, session_id, 1)
db_inst.increment_stat_message(curr_ts, 1) db_inst.increment_stat_message(curr_ts, 1)
db_inst.increment_stat_platform(curr_ts, platform, 1) db_inst.increment_stat_platform(curr_ts, platform, 1)
_global_object.cnt_total += 1
async def oper_msg(message: AstrBotMessage, async def oper_msg(message: AstrBotMessage,
@@ -345,7 +336,6 @@ async def oper_msg(message: AstrBotMessage,
reg_platform = p reg_platform = p
break break
if not reg_platform: if not reg_platform:
_global_object.logger.log(f"未找到平台 {platform} 的实例。", gu.LEVEL_ERROR)
raise Exception(f"未找到平台 {platform} 的实例。") raise Exception(f"未找到平台 {platform} 的实例。")
# 统计数据,如频道消息量 # 统计数据,如频道消息量
@@ -381,8 +371,13 @@ async def oper_msg(message: AstrBotMessage,
llm_result_str = "" llm_result_str = ""
# check commands and plugins # check commands and plugins
message_str_no_wake_prefix = message_str
for wake_prefix in _global_object.nick: # nick: tuple
if message_str.startswith(wake_prefix):
message_str_no_wake_prefix = message_str.removeprefix(wake_prefix)
break
hit, command_result = await llm_command_instance[chosen_provider].check_command( hit, command_result = await llm_command_instance[chosen_provider].check_command(
message_str, message_str_no_wake_prefix,
session_id, session_id,
role, role,
reg_platform, reg_platform,
@@ -401,7 +396,7 @@ async def oper_msg(message: AstrBotMessage,
if not check: if not check:
return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}") return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}")
if chosen_provider == NONE_LLM: if chosen_provider == NONE_LLM:
logger.log("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。", gu.LEVEL_WARNING) logger.info("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。")
return return
try: try:
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix): if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
@@ -426,14 +421,14 @@ async def oper_msg(message: AstrBotMessage,
if chosen_provider == OPENAI_OFFICIAL: if chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag: if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL official_fc = chosen_provider == OPENAI_OFFICIAL
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
else: else:
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url, default_personality=_global_object.default_personality) llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url)
llm_result_str = _global_object.reply_prefix + llm_result_str llm_result_str = _global_object.reply_prefix + llm_result_str
except BaseException as e: except BaseException as e:
logger.log(f"调用异常:{traceback.format_exc()}", gu.LEVEL_ERROR) logger.error(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用语言模型例程时出现异常。原因: {str(e)}") return MessageResult(f"调用异常。详细原因:{str(e)}")
# 切换回原来的语言模型 # 切换回原来的语言模型
if temp_switch != "": if temp_switch != "":
@@ -456,14 +451,10 @@ async def oper_msg(message: AstrBotMessage,
return MessageResult(f"指令调用错误: \n{str(command_result[1])}") return MessageResult(f"指令调用错误: \n{str(command_result[1])}")
# 画图指令 # 画图指令
if isinstance(command_result[1], list) and len(command_result) == 3 and command == 'draw': if command == 'draw':
for i in command_result[1]: # 保存到本地
# 保存到本地 path = await gu.download_image_by_url(command_result[1])
async with aiohttp.ClientSession() as session: return MessageResult([Image.fromFileSystem(path)])
async with session.get(i) as resp:
if resp.status == 200:
image = PILImage.open(io.BytesIO(await resp.read()))
return MessageResult([Image.fromFileSystem(gu.save_temp_img(image))])
# 其他指令 # 其他指令
else: else:
try: try:
@@ -485,4 +476,4 @@ async def oper_msg(message: AstrBotMessage,
try: try:
return MessageResult(llm_result_str) return MessageResult(llm_result_str)
except BaseException as e: except BaseException as e:
logger.log("回复消息错误: \n"+str(e), gu.LEVEL_ERROR) logger.info("回复消息错误: \n"+str(e))
-181
View File
@@ -1,181 +0,0 @@
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from nakuru import (
GroupMessage,
FriendMessage,
GuildMessage,
)
from nakuru.entities.components import BaseMessageComponent
from typing import Union, List, ClassVar
from types import ModuleType
from enum import Enum
from dataclasses import dataclass
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
GUILD_MESSAGE = 'GuildMessage' # 频道消息
@dataclass
class MessageMember():
user_id: str # 发送者id
nickname: str = None
class AstrBotMessage():
'''
AstrBot 的消息对象
'''
tag: str # 消息来源标签
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
message_id: str # 消息id
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __str__(self) -> str:
return str(self.__dict__)
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据。
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件。
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
class GlobalObject:
'''
存放一些公用的数据,用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: str # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
cnt_total: int # 总消息数
default_personality: dict
dashboard_data = None
logger: None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.cnt_total = 0
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
class AstrMessageEvent():
'''
消息事件。
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: AstrBotMessage # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+82 -79
View File
@@ -1,106 +1,109 @@
import os import os
import sys import sys
from pip._internal import main as pipmain
import warnings import warnings
import traceback import traceback
import threading import threading
from logging import Formatter, Logger
from util.cmd_config import CmdConfig, try_migrate_config
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
logger: Logger = None
logo_tmpl = """
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
def make_necessary_dirs():
'''
创建必要的目录。
'''
os.makedirs("data/config", exist_ok=True)
os.makedirs("temp", exist_ok=True)
def update_dept():
'''
更新依赖库。
'''
# 获取 Python 可执行文件路径
py = sys.executable
# 更新依赖库
mirror = "https://mirrors.aliyun.com/pypi/simple/"
os.system(f"{py} -m pip install -r requirements.txt -i {mirror}")
def main(): def main():
# config.yaml 配置文件加载和环境确认
try: try:
import cores.astrbot.core as qqBot import botpy, logging
import yaml import astrbot.core as bot_core
ymlfile = open(abs_path+"configs/config.yaml", 'r', encoding='utf-8') # delete qqbotpy's logger
cfg = yaml.safe_load(ymlfile) for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
except ImportError as import_error: except ImportError as import_error:
traceback.print_exc() logger.error(import_error)
print(import_error) logger.error("检测到一些依赖库没有安装。由于兼容性问题,AstrBot 此版本将不会自动为您安装依赖库。请您先自行安装,然后重试。")
input("第三方库未完全安装完毕,请退出程序重试。") logger.info("如何安装?如果:")
logger.info("- Windows 启动器部署且使用启动器下载了 Python的:在 launcher.exe 所在目录下的地址框输入 powershell,然后执行 .\python\python.exe -m pip install .\AstrBot\requirements.txt")
logger.info("- Windows 启动器部署且使用自己之前下载的 Python的:在 launcher.exe 所在目录下的地址框输入 powershell,然后执行 python -m pip install .\AstrBot\requirements.txt")
logger.info("- 自行 clone 源码部署的:python -m pip install -r requirements.txt")
logger.info("- 如果还不会,加群 322154837 ")
input("按任意键退出。")
exit()
except FileNotFoundError as file_not_found: except FileNotFoundError as file_not_found:
print(file_not_found) logger.error(file_not_found)
input("配置文件不存在,请检查是否已经下载配置文件。") input("配置文件不存在,请检查是否已经下载配置文件。")
exit()
except BaseException as e: except BaseException as e:
raise e logger.error(traceback.format_exc())
input("未知错误。")
# 设置代理
if 'http_proxy' in cfg and cfg['http_proxy'] != '':
os.environ['HTTP_PROXY'] = cfg['http_proxy']
if 'https_proxy' in cfg and cfg['https_proxy'] != '':
os.environ['HTTPS_PROXY'] = cfg['https_proxy']
os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com'
# 检查并创建 temp 文件夹
if not os.path.exists(abs_path + "temp"):
os.mkdir(abs_path+"temp")
if not os.path.exists(abs_path + "data"):
os.mkdir(abs_path+"data")
if not os.path.exists(abs_path + "data/config"):
os.mkdir(abs_path+"data/config")
# 启动主程序(cores/qqbot/core.py
qqBot.init(cfg)
def check_env(ch_mirror=False):
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
print("请使用Python3.9+运行本项目")
input("按任意键退出...")
exit() exit()
if os.path.exists('requirements.txt'): make_necessary_dirs()
pth = 'requirements.txt'
else:
pth = 'QQChannelChatGPT' + os.sep + 'requirements.txt'
print("正在检查或下载第三方库,请耐心等待...")
try:
if ch_mirror:
print("使用阿里云镜像")
pipmain(['install', '-r', pth, '-i',
'https://mirrors.aliyun.com/pypi/simple/'])
else:
pipmain(['install', '-r', pth])
except BaseException as e:
print(e)
while True:
res = input(
"安装失败。\n如报错ValueError: check_hostname requires server_hostname,请尝试先关闭代理后重试。\n1.输入y回车重试\n2. 输入c回车使用国内镜像源下载\n3. 输入其他按键回车继续往下执行。")
if res == "y":
try:
pipmain(['install', '-r', pth])
break
except BaseException as e:
print(e)
continue
elif res == "c":
try:
pipmain(['install', '-r', pth, '-i',
'https://mirrors.aliyun.com/pypi/simple/'])
break
except BaseException as e:
print(e)
continue
else:
break
print("第三方库检查完毕。")
# 启动主程序(cores/qqbot/core.py
bot_core.init()
def check_env():
if not (sys.version_info.major == 3 and sys.version_info.minor >= 9):
logger.error("请使用 Python3.9+ 运行本项目。按任意键退出。")
input("")
exit()
if __name__ == "__main__": if __name__ == "__main__":
args = sys.argv update_dept()
if '-cn' in args: try_migrate_config()
check_env(True) cc = CmdConfig()
else: http_proxy = cc.get("http_proxy")
check_env() https_proxy = cc.get("https_proxy")
if http_proxy:
os.environ['HTTP_PROXY'] = http_proxy
if https_proxy:
os.environ['HTTPS_PROXY'] = https_proxy
os.environ['NO_PROXY'] = 'https://api.sgroup.qq.com'
from SparkleLogging.utils.core import LogManager
logger = LogManager.GetLogger(
log_name='astrbot-core',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
logger.info(logo_tmpl)
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
check_env()
t = threading.Thread(target=main, daemon=True) t = threading.Thread(target=main, daemon=True)
t.start() t.start()
try: try:
t.join() t.join()
except KeyboardInterrupt as e: except KeyboardInterrupt as e:
print("退出 AstrBot。") logger.info("退出 AstrBot。")
exit() exit()
+24 -23
View File
@@ -13,17 +13,17 @@ from nakuru.entities.components import (
from util import general_utils as gu from util import general_utils as gu
from model.provider.provider import Provider from model.provider.provider import Provider
from util.cmd_config import CmdConfig as cc from util.cmd_config import CmdConfig as cc
from util.general_utils import Logger from type.message import *
from cores.astrbot.types import ( from type.types import GlobalObject
GlobalObject, from type.command import *
AstrMessageEvent, from type.plugin import *
PluginType, from type.register import *
CommandResult,
RegisteredPlugin,
RegisteredPlatform
)
from typing import List, Tuple from typing import List
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
PLATFORM_QQCHAN = 'qqchan' PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq' PLATFORM_GOCQ = 'gocq'
@@ -35,7 +35,6 @@ class Command:
def __init__(self, provider: Provider, global_object: GlobalObject = None): def __init__(self, provider: Provider, global_object: GlobalObject = None):
self.provider = provider self.provider = provider
self.global_object = global_object self.global_object = global_object
self.logger: Logger = global_object.logger
async def check_command(self, async def check_command(self,
message, message,
@@ -65,6 +64,8 @@ class Command:
result = await plugin.plugin_instance.run(ame) result = await plugin.plugin_instance.run(ame)
else: else:
result = await asyncio.to_thread(plugin.plugin_instance.run, ame) result = await asyncio.to_thread(plugin.plugin_instance.run, ame)
if not result:
continue
if isinstance(result, CommandResult): if isinstance(result, CommandResult):
hit = result.hit hit = result.hit
res = result._result_tuple() res = result._result_tuple()
@@ -74,6 +75,8 @@ class Command:
else: else:
raise TypeError("插件返回值格式错误。") raise TypeError("插件返回值格式错误。")
if hit: if hit:
plugin.trig()
logger.debug("hit plugin: " + plugin.metadata.plugin_name)
return True, res return True, res
except TypeError as e: except TypeError as e:
# 参数不匹配,尝试使用旧的参数方案 # 参数不匹配,尝试使用旧的参数方案
@@ -85,11 +88,11 @@ class Command:
if hit: if hit:
return True, res return True, res
except BaseException as e: except BaseException as e:
self.logger.log( logger.error(
f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。")
except BaseException as e: except BaseException as e:
self.logger.log( logger.error(
f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。", level=gu.LEVEL_WARNING) f"{plugin.metadata.plugin_name} 插件异常,原因: {str(e)}\n如果你没有相关装插件的想法, 请直接忽略此报错, 不影响其他功能的运行。")
if self.command_start_with(message, "nick"): if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role) return True, self.set_nick(message, platform, role)
@@ -186,7 +189,7 @@ class Command:
break break
if info: if info:
p = gu.create_text_image( p = gu.create_text_image(
f"【插件信息】", f"名称: {info['name']}\n{info['desc']}\n版本: {info['version']}\n作者: {info['author']}\n\n帮助:\n{info['help']}") f"【插件信息】", f"名称: {info.plugin_name}\n类型: {info.plugin_type}\n{info.desc}\n版本: {info.version}\n作者: {info.author}")
return True, [Image.fromFileSystem(p)], "plugin" return True, [Image.fromFileSystem(p)], "plugin"
else: else:
return False, "未找到该插件", "plugin" return False, "未找到该插件", "plugin"
@@ -197,10 +200,10 @@ class Command:
nick: 存储机器人的昵称 nick: 存储机器人的昵称
''' '''
def set_nick(self, message: str, platform: str, role: str = "member"): def set_nick(self, message: str, platform: RegisteredPlatform, role: str = "member"):
if role != "admin": if role != "admin":
return True, "你无权使用该指令 :P", "nick" return True, "你无权使用该指令 :P", "nick"
if platform == PLATFORM_GOCQ: if str(platform) == PLATFORM_GOCQ:
l = message.split(" ") l = message.split(" ")
if len(l) == 1: if len(l) == 1:
return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick" return True, "【设置机器人昵称】示例:\n支持多昵称\nnick 昵称1 昵称2 昵称3", "nick"
@@ -208,7 +211,7 @@ class Command:
cc.put("nick_qq", nick) cc.put("nick_qq", nick)
self.global_object.nick = tuple(nick) self.global_object.nick = tuple(nick)
return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick" return True, f"设置成功!现在你可以叫我这些昵称来提问我啦~", "nick"
elif platform == PLATFORM_QQCHAN: elif str(platform) == PLATFORM_QQCHAN:
nick = message.split(" ")[2] nick = message.split(" ")[2]
return False, "QQ频道平台不支持为机器人设置昵称。", "nick" return False, "QQ频道平台不支持为机器人设置昵称。", "nick"
@@ -217,11 +220,9 @@ class Command:
"help": "帮助", "help": "帮助",
"keyword": "设置关键词/关键指令回复", "keyword": "设置关键词/关键指令回复",
"update": "更新项目", "update": "更新项目",
"nick": "设置机器人昵称", "nick": "设置机器人唤醒词",
"plugin": "插件安装、卸载和重载", "plugin": "插件安装、卸载和重载",
"web on/off": "LLM 网页搜索能力", "web on/off": "LLM 网页搜索能力",
"reset": "重置 LLM 对话",
"/gpt": "切换到 OpenAI 官方接口"
} }
async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None): async def help_messager(self, commands: dict, platform: str, cached_plugins: List[RegisteredPlugin] = None):
@@ -248,7 +249,7 @@ class Command:
p = gu.create_markdown_image(msg) p = gu.create_markdown_image(msg)
return [Image.fromFileSystem(p),] return [Image.fromFileSystem(p),]
except BaseException as e: except BaseException as e:
self.logger.log(str(e)) logger.error(str(e))
return msg return msg
def command_start_with(self, message: str, *args): def command_start_with(self, message: str, *args):
+97 -124
View File
@@ -1,14 +1,24 @@
from model.command.command import Command from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
from util.personality import personalities from util.personality import personalities
from cores.astrbot.types import GlobalObject from type.types import GlobalObject
from type.command import CommandItem
from SparkleLogging.utils.core import LogManager
from logging import Logger
from openai._exceptions import NotFoundError
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class CommandOpenAIOfficial(Command): class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject): def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
self.provider = provider self.provider = provider
self.global_object = global_object self.global_object = global_object
self.personality_str = "" self.personality_str = ""
self.commands = [
CommandItem("reset", self.reset, "重置 LLM 会话。", "内置"),
CommandItem("his", self.his, "查看与 LLM 的历史记录。", "内置"),
CommandItem("status", self.status, "查看 GPT 配置信息和用量状态。", "内置"),
]
super().__init__(provider, global_object) super().__init__(provider, global_object)
async def check_command(self, async def check_command(self,
@@ -28,6 +38,8 @@ class CommandOpenAIOfficial(Command):
message_obj message_obj
) )
logger.debug(f"基础指令hit: {hit}, res: {res}")
# 这里是这个 LLM 的专属指令 # 这里是这个 LLM 的专属指令
if hit: if hit:
return True, res return True, res
@@ -35,12 +47,8 @@ class CommandOpenAIOfficial(Command):
return True, await self.reset(session_id, message) return True, await self.reset(session_id, message)
elif self.command_start_with(message, "his", "历史"): elif self.command_start_with(message, "his", "历史"):
return True, self.his(message, session_id) return True, self.his(message, session_id)
elif self.command_start_with(message, "token"):
return True, self.token(session_id)
elif self.command_start_with(message, "gpt"):
return True, self.gpt()
elif self.command_start_with(message, "status"): elif self.command_start_with(message, "status"):
return True, self.status() return True, self.status(session_id)
elif self.command_start_with(message, "help", "帮助"): elif self.command_start_with(message, "help", "帮助"):
return True, await self.help() return True, await self.help()
elif self.command_start_with(message, "unset"): elif self.command_start_with(message, "unset"):
@@ -51,21 +59,61 @@ class CommandOpenAIOfficial(Command):
return True, self.update(message, role) return True, self.update(message, role)
elif self.command_start_with(message, "", "draw"): elif self.command_start_with(message, "", "draw"):
return True, await self.draw(message) return True, await self.draw(message)
elif self.command_start_with(message, "key"):
return True, self.key(message)
elif self.command_start_with(message, "switch"): elif self.command_start_with(message, "switch"):
return True, await self.switch(message) return True, await self.switch(message)
elif self.command_start_with(message, "models"):
return True, await self.print_models()
elif self.command_start_with(message, "model"):
return True, await self.set_model(message)
return False, None return False, None
async def get_models(self):
try:
models = await self.provider.client.models.list()
except NotFoundError as e:
bu = str(self.provider.client.base_url)
self.provider.client.base_url = bu + "/v1"
models = await self.provider.client.models.list()
finally:
return filter(lambda x: x.id.startswith("gpt"), models.data)
async def print_models(self):
models = await self.get_models()
i = 1
ret = "OpenAI GPT 类可用模型"
for model in models:
ret += f"\n{i}. {model.id}"
i += 1
ret += "\nTips: 使用 /model 模型名/编号,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
logger.debug(ret)
return True, ret, "models"
async def set_model(self, message: str):
l = message.split(" ")
if len(l) == 1:
return True, "请输入 /model 模型名/编号", "model"
model = str(l[1])
models = await self.get_models()
models = list(models)
if model.isdigit() and int(model) <= len(models) and int(model) >= 1:
model = models[int(model)-1]
self.provider.set_model(model.id)
return True, f"模型已设置为 {model.id}", "model"
async def help(self): async def help(self):
commands = super().general_commands() commands = super().general_commands()
commands[''] = '画画' commands[''] = '调用 OpenAI DallE 模型生成图片'
commands['key'] = '添加OpenAI key' commands['/set'] = '人格设置面板'
commands['set'] = '人格设置面板' commands['/status'] = '查看 Api Key 状态和配置信息'
commands['gpt'] = '查看gpt配置信息' commands['/token'] = '查看本轮会话 token'
commands['status'] = '查看key使用状态' commands['/reset'] = '重置当前与 LLM 的会话,但保留人格(system prompt'
commands['token'] = '查看本轮会话token' commands['/reset p'] = '重置当前与 LLM 的会话,并清除人格。'
commands['/models'] = '获取当前可用的模型'
commands['/model'] = '更换模型'
return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help" return True, await super().help_messager(commands, self.platform, self.global_object.cached_plugins), "help"
async def reset(self, session_id: str, message: str = "reset"): async def reset(self, session_id: str, message: str = "reset"):
@@ -73,79 +121,44 @@ class CommandOpenAIOfficial(Command):
return False, "未启用 OpenAI 官方 API", "reset" return False, "未启用 OpenAI 官方 API", "reset"
l = message.split(" ") l = message.split(" ")
if len(l) == 1: if len(l) == 1:
await self.provider.forget(session_id) await self.provider.forget(session_id, keep_system_prompt=True)
return True, "重置成功", "reset" return True, "重置成功", "reset"
if len(l) == 2 and l[1] == "p": if len(l) == 2 and l[1] == "p":
self.provider.forget(session_id) await self.provider.forget(session_id)
if self.personality_str != "":
self.set(self.personality_str, session_id) # 重新设置人格
return True, "重置成功", "reset"
def his(self, message: str, session_id: str): def his(self, message: str, session_id: str):
if self.provider is None: if self.provider is None:
return False, "未启用 OpenAI 官方 API", "his" return False, "未启用 OpenAI 官方 API", "his"
# 分页,每页5条
msg = ''
size_per_page = 3 size_per_page = 3
page = 1 page = 1
if message[4:]: l = message.split(" ")
page = int(message[4:]) if len(l) == 2:
# 检查是否有过历史记录 try:
if session_id not in self.provider.session_dict: page = int(l[1])
msg = f"历史记录为空" except BaseException as e:
return True, msg, "his" return True, "页码不合法", "his"
l = self.provider.session_dict[session_id] contexts, total_num = self.provider.dump_contexts_page(session_id, size_per_page, page=page)
max_page = len(l)//size_per_page + \ t_pages = total_num // size_per_page + 1
1 if len(l) % size_per_page != 0 else len(l)//size_per_page return True, f"历史记录如下:\n{contexts}\n{page} 页 | 共 {t_pages}\n*输入 /his 2 跳转到第 2 页", "his"
p = self.provider.get_prompts_by_cache_list(
self.provider.session_dict[session_id], divide=True, paging=True, size=size_per_page, page=page)
return True, f"历史记录如下:\n{p}\n{page}页 | 共{max_page}\n*输入/his 2跳转到第2页", "his"
def token(self, session_id: str): def status(self, session_id: str):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "token"
return True, f"会话的token数: {self.provider.get_user_usage_tokens(self.provider.session_dict[session_id])}\n系统最大缓存token数: {self.provider.max_tokens}", "token"
def gpt(self):
if self.provider is None:
return False, "未启用 OpenAI 官方 API", "gpt"
return True, f"OpenAI GPT配置:\n {self.provider.chatGPT_configs}", "gpt"
def status(self):
if self.provider is None: if self.provider is None:
return False, "未启用 OpenAI 官方 API", "status" return False, "未启用 OpenAI 官方 API", "status"
chatgpt_cfg_str = "" keys_data = self.provider.get_keys_data()
key_stat = self.provider.get_key_stat() ret = "OpenAI Key"
index = 1 for k in keys_data:
max = 9000000 status = "🟢" if keys_data[k] else "🔴"
gg_count = 0 ret += "\n|- " + k[:8] + " " + status
total = 0
tag = ''
for key in key_stat.keys():
sponsor = ''
total += key_stat[key]['used']
if key_stat[key]['exceed']:
gg_count += 1
continue
if 'sponsor' in key_stat[key]:
sponsor = key_stat[key]['sponsor']
chatgpt_cfg_str += f" |-{index}: {key[-8:]} {key_stat[key]['used']}/{max} {sponsor}{tag}\n"
index += 1
return True, f"⭐使用情况({str(gg_count)}个已用):\n{chatgpt_cfg_str}", "status"
def key(self, message: str): conf = self.provider.get_configs()
if self.provider is None: ret += "\n当前模型:" + conf['model']
return False, "未启用 OpenAI 官方 API", "reset" if conf['model'] in MODELS:
l = message.split(" ") ret += "\n最大上下文窗口:" + str(MODELS[conf['model']]) + " tokens"
if len(l) == 1:
msg = "感谢您赞助key,key为官方API使用,请以以下格式赞助:\n/key xxxxx" if session_id in self.provider.session_memory and len(self.provider.session_memory[session_id]):
return True, msg, "key" ret += "\n你的会话上下文:" + str(self.provider.session_memory[session_id][-1]['usage_tokens']) + " tokens"
key = l[1]
if self.provider.check_key(key): return True, ret, "status"
self.provider.append_key(key)
return True, f"*★,°*:.☆( ̄▽ ̄)/$:*.°★* 。\n该Key被验证为有效。感谢你的赞助~"
else:
return True, "该Key被验证为无效。也许是输入错误了,或者重试。", "key"
async def switch(self, message: str): async def switch(self, message: str):
''' '''
@@ -162,14 +175,13 @@ class CommandOpenAIOfficial(Command):
return True, ret, "switch" return True, ret, "switch"
elif len(l) == 2: elif len(l) == 2:
try: try:
key_stat = self.provider.get_key_stat() key_stat = self.provider.get_keys_data()
index = int(l[1]) index = int(l[1])
if index > len(key_stat) or index < 1: if index > len(key_stat) or index < 1:
return True, "账号序号不合法。", "switch" return True, "账号序号不合法。", "switch"
else: else:
try: try:
new_key = list(key_stat.keys())[index-1] new_key = list(key_stat.keys())[index-1]
ret = await self.provider.check_key(new_key)
self.provider.set_key(new_key) self.provider.set_key(new_key)
except BaseException as e: except BaseException as e:
return True, "账号切换失败,原因: " + str(e), "switch" return True, "账号切换失败,原因: " + str(e), "switch"
@@ -218,58 +230,19 @@ class CommandOpenAIOfficial(Command):
'name': ps, 'name': ps,
'prompt': personalities[ps] 'prompt': personalities[ps]
} }
self.provider.session_dict[session_id] = [] self.provider.personality_set(ps, session_id)
new_record = {
"user": {
"role": "user",
"content": personalities[ps],
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
return True, f"人格{ps}已设置。", "set" return True, f"人格{ps}已设置。", "set"
else: else:
self.provider.curr_personality = { self.provider.curr_personality = {
'name': '自定义人格', 'name': '自定义人格',
'prompt': ps 'prompt': ps
} }
new_record = { self.provider.personality_set(ps, session_id)
"user": {
"role": "user",
"content": ps,
},
"AI": {
"role": "assistant",
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
}
self.provider.session_dict[session_id] = []
self.provider.session_dict[session_id].append(new_record)
self.personality_str = message
return True, f"自定义人格已设置。 \n人格信息: {ps}", "set" return True, f"自定义人格已设置。 \n人格信息: {ps}", "set"
async def draw(self, message): async def draw(self, message: str):
if self.provider is None: if self.provider is None:
return False, "未启用 OpenAI 官方 API", "draw" return False, "未启用 OpenAI 官方 API", "draw"
if message.startswith("/"): message = message.removeprefix("/").removeprefix("")
message = message[2:] img_url = await self.provider.image_generate(message)
elif message.startswith(""): return True, img_url, "draw"
message = message[1:]
try:
# 画图模式传回3个参数
img_url = await self.provider.image_chat(message)
return True, img_url, "draw"
except Exception as e:
if 'exceeded' in str(e):
return f"OpenAI API错误。原因:\n{str(e)} \n超额了。可自己搭建一个机器人(Github仓库:QQChannelChatGPT)"
return False, f"图片生成失败: {e}", "draw"
+1 -1
View File
@@ -5,7 +5,7 @@ from nakuru import (
FriendMessage FriendMessage
) )
import botpy.message import botpy.message
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember from type.message import *
from typing import List, Union from typing import List, Union
import time import time
+10 -4
View File
@@ -14,34 +14,40 @@ class Platform():
初始化平台的各种接口 初始化平台的各种接口
''' '''
self.message_handler = message_handler self.message_handler = message_handler
self.cnt_receive = 0
self.cnt_reply = 0
pass pass
@abc.abstractmethod @abc.abstractmethod
async def handle_msg(): async def handle_msg(self):
''' '''
处理到来的消息 处理到来的消息
''' '''
self.cnt_receive += 1
pass pass
@abc.abstractmethod @abc.abstractmethod
async def reply_msg(): async def reply_msg(self):
''' '''
回复消息(被动发送) 回复消息(被动发送)
''' '''
self.cnt_reply += 1
pass pass
@abc.abstractmethod @abc.abstractmethod
async def send_msg(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): async def send_msg(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
''' '''
发送消息(主动发送) 发送消息(主动发送)
''' '''
self.cnt_reply += 1
pass pass
@abc.abstractmethod @abc.abstractmethod
async def send(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): async def send(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]):
''' '''
发送消息(主动发送)同 send_msg() 发送消息(主动发送)同 send_msg()
''' '''
self.cnt_reply += 1
pass pass
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str: def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str:
+19 -17
View File
@@ -11,11 +11,16 @@ from nakuru import (
Notify Notify
) )
from typing import Union from typing import Union
from type.types import GlobalObject
import time import time
from ._platfrom import Platform from ._platfrom import Platform
from ._message_parse import nakuru_message_parse_rev from ._message_parse import nakuru_message_parse_rev
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember from type.message import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
class FakeSource: class FakeSource:
@@ -25,7 +30,7 @@ class FakeSource:
class QQGOCQ(Platform): class QQGOCQ(Platform):
def __init__(self, cfg: dict, message_handler: callable, global_object) -> None: def __init__(self, cfg: dict, message_handler: callable, global_object: GlobalObject) -> None:
super().__init__(message_handler) super().__init__(message_handler)
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
@@ -34,15 +39,8 @@ class QQGOCQ(Platform):
self.waiting = {} self.waiting = {}
self.cc = CmdConfig() self.cc = CmdConfig()
self.cfg = cfg self.cfg = cfg
self.logger: gu.Logger = global_object.logger
self.context = global_object
try:
self.nick_qq = cfg['nick_qq']
except:
self.nick_qq = ["ai", "!", ""]
nick_qq = self.nick_qq
if isinstance(nick_qq, str):
nick_qq = [nick_qq]
self.unique_session = cfg['uniqueSessionMode'] self.unique_session = cfg['uniqueSessionMode']
self.pic_mode = cfg['qq_pic_mode'] self.pic_mode = cfg['qq_pic_mode']
@@ -106,8 +104,9 @@ class QQGOCQ(Platform):
self.client.run() self.client.run()
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
self.logger.log( await super().handle_msg()
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}", tag="QQ_GOCQ") logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
assert isinstance(message.raw_message, assert isinstance(message.raw_message,
(GroupMessage, FriendMessage, GuildMessage)) (GroupMessage, FriendMessage, GuildMessage))
@@ -129,8 +128,8 @@ class QQGOCQ(Platform):
if message.type.value == "GroupMessage": if message.type.value == "GroupMessage":
if str(i.qq) == str(message.self_id): if str(i.qq) == str(message.self_id):
resp = True resp = True
elif isinstance(i, Plain): elif isinstance(i, Plain) and self.context.nick:
for nick in self.nick_qq: for nick in self.context.nick:
if nick != '' and i.text.strip().startswith(nick): if nick != '' and i.text.strip().startswith(nick):
resp = True resp = True
break break
@@ -178,6 +177,7 @@ class QQGOCQ(Platform):
async def reply_msg(self, async def reply_msg(self,
message: Union[AstrBotMessage, GuildMessage, GroupMessage, FriendMessage], message: Union[AstrBotMessage, GuildMessage, GroupMessage, FriendMessage],
result_message: list): result_message: list):
await super().reply_msg()
""" """
插件开发者请使用send方法, 可以不用直接调用这个方法。 插件开发者请使用send方法, 可以不用直接调用这个方法。
""" """
@@ -188,8 +188,8 @@ class QQGOCQ(Platform):
res = result_message res = result_message
self.logger.log( logger.info(
f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ") f"{source.user_id} <- {self.parse_message_outline(res)}")
if isinstance(source, int): if isinstance(source, int):
source = FakeSource("GroupMessage", source) source = FakeSource("GroupMessage", source)
@@ -256,6 +256,7 @@ class QQGOCQ(Platform):
提供给插件的发送QQ消息接口。 提供给插件的发送QQ消息接口。
参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。
''' '''
await super().reply_msg()
try: try:
await self.reply_msg(message, result_message) await self.reply_msg(message, result_message)
except BaseException as e: except BaseException as e:
@@ -267,6 +268,7 @@ class QQGOCQ(Platform):
''' '''
同 send_msg() 同 send_msg()
''' '''
await super().reply_msg()
await self.reply_msg(to, res) await self.reply_msg(to, res)
def create_text_image(title: str, text: str, max_width=30, font_size=20): def create_text_image(title: str, text: str, max_width=30, font_size=20):
+13 -8
View File
@@ -5,6 +5,8 @@ import botpy.message
import re import re
import asyncio import asyncio
import aiohttp import aiohttp
import botpy.types
import botpy.types.message
from util import general_utils as gu from util import general_utils as gu
from botpy.types.message import Reference from botpy.types.message import Reference
@@ -15,13 +17,15 @@ from ._message_parse import (
qq_official_message_parse_rev, qq_official_message_parse_rev,
qq_official_message_parse qq_official_message_parse
) )
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember from type.message import *
from typing import Union, List from typing import Union, List
from nakuru.entities.components import BaseMessageComponent from nakuru.entities.components import BaseMessageComponent
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
# QQ 机器人官方框架 # QQ 机器人官方框架
class botClient(Client): class botClient(Client):
def set_platform(self, platform: 'QQOfficial'): def set_platform(self, platform: 'QQOfficial'):
self.platform = platform self.platform = platform
@@ -59,7 +63,6 @@ class QQOfficial(Platform):
self.token = cfg['qqbot']['token'] self.token = cfg['qqbot']['token']
self.secret = cfg['qqbot_secret'] self.secret = cfg['qqbot_secret']
self.unique_session = cfg['uniqueSessionMode'] self.unique_session = cfg['uniqueSessionMode']
self.logger: gu.Logger = global_object.logger
qq_group = cfg['qqofficial_enable_group_message'] qq_group = cfg['qqofficial_enable_group_message']
if qq_group: if qq_group:
@@ -99,13 +102,14 @@ class QQOfficial(Platform):
) )
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
await super().handle_msg()
assert isinstance(message.raw_message, (botpy.message.Message, assert isinstance(message.raw_message, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage)) botpy.message.GroupMessage, botpy.message.DirectMessage))
is_group = message.type != MessageType.FRIEND_MESSAGE is_group = message.type != MessageType.FRIEND_MESSAGE
_t = "/私聊" if not is_group else "" _t = "/私聊" if not is_group else ""
self.logger.log( logger.info(
f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}", tag="QQ_OFFICIAL") f"{message.sender.nickname}({message.sender.user_id}{_t}) -> {self.parse_message_outline(message)}")
# 解析出 session_id # 解析出 session_id
if self.unique_session or not is_group: if self.unique_session or not is_group:
@@ -151,14 +155,15 @@ class QQOfficial(Platform):
''' '''
回复频道消息 回复频道消息
''' '''
await super().reply_msg()
if isinstance(message, AstrBotMessage): if isinstance(message, AstrBotMessage):
source = message.raw_message source = message.raw_message
else: else:
source = message source = message
assert isinstance(source, (botpy.message.Message, assert isinstance(source, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage)) botpy.message.GroupMessage, botpy.message.DirectMessage))
self.logger.log( logger.info(
f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL") f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(res)}")
plain_text = '' plain_text = ''
image_path = '' image_path = ''
+425 -346
View File
@@ -5,90 +5,109 @@ import time
import tiktoken import tiktoken
import threading import threading
import traceback import traceback
import base64
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from cores.database.conn import dbConn from persist.session import dbConn
from model.provider.provider import Provider from model.provider.provider import Provider
from util import general_utils as gu from util import general_utils as gu
from util.cmd_config import CmdConfig from util.cmd_config import CmdConfig
from util.general_utils import Logger from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' MODELS = {
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-0125-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-0613": 16385,
"gpt-3.5-turbo-16k-0613": 16385,
}
class ProviderOpenAIOfficial(Provider): class ProviderOpenAIOfficial(Provider):
def __init__(self, cfg): def __init__(self, cfg) -> None:
self.cc = CmdConfig() super().__init__()
self.logger = Logger()
self.key_list = [] os.makedirs("data/openai", exist_ok=True)
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
for key in cfg['key']:
if len(key) == 1:
raise BaseException(
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
if cfg['key'] != '' and cfg['key'] != None:
self.key_list = cfg['key']
if len(self.key_list) == 0:
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
self.key_stat = {} self.cc = CmdConfig
for k in self.key_list: self.key_data_path = "data/openai/keys.json"
self.key_stat[k] = {'exceed': False, 'used': 0} self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.keys_data = {} # 记录超额
self.api_base = None if cfg['key']: self.api_keys = cfg['key']
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': if cfg['api_base']: self.base_url = cfg['api_base']
self.api_base = cfg['api_base'] if not self.api_keys:
self.logger.log(f"设置 api_base 为: {self.api_base}", tag="OpenAI") logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
else:
self.chosen_api_key = self.api_keys[0]
for key in self.api_keys:
self.keys_data[key] = True
# 创建 OpenAI Client
self.client = AsyncOpenAI( self.client = AsyncOpenAI(
api_key=self.key_list[0], api_key=self.chosen_api_key,
base_url=self.api_base base_url=self.base_url
) )
self.model_configs: Dict = cfg['chatGPTConfigs']
self.openai_model_configs: dict = cfg['chatGPTConfigs'] super().set_curr_model(self.model_configs['model'])
self.logger.log( self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
f'加载 OpenAI Chat Configs: {self.openai_model_configs}', tag="OpenAI") self.session_memory: Dict[str, List] = {} # 会话记忆
self.openai_configs = cfg self.session_memory_lock = threading.Lock()
# 会话缓存 self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.session_dict = {} self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
# 最大缓存token self.DEFAULT_PERSONALITY = {
self.max_tokens = cfg['total_tokens_limit'] "name": "default",
# 历史记录持久化间隔时间 "prompt": "你是一个很有帮助的 AI 助手。"
self.history_dump_interval = 20 }
self.curr_personality = self.DEFAULT_PERSONALITY
self.enc = tiktoken.get_encoding("cl100k_base") self.session_personality = {} # 记录了某个session是否已设置人格。
# 从 SQLite DB 读取历史记录 # 从 SQLite DB 读取历史记录
try: try:
db1 = dbConn() db1 = dbConn()
for session in db1.get_all_session(): for session in db1.get_all_session():
self.session_dict[session[0]] = json.loads(session[1])['data'] self.session_memory_lock.acquire()
self.logger.log("读取历史记录成功。", tag="OpenAI") self.session_memory[session[0]] = json.loads(session[1])['data']
self.session_memory_lock.release()
except BaseException as e: except BaseException as e:
self.logger.log("读取历史记录失败,但不影响使用。", logger.warn(f"读取 OpenAI LLM 对话历史记录 失败{e}。仍可正常使用。")
level=gu.LEVEL_ERROR, tag="OpenAI")
# 定时保存历史记录
# 创建转储定时器线程
threading.Thread(target=self.dump_history, daemon=True).start() threading.Thread(target=self.dump_history, daemon=True).start()
# 人格
self.curr_personality = {}
# 转储历史记录
def dump_history(self): def dump_history(self):
'''
转储历史记录
'''
time.sleep(10) time.sleep(10)
db = dbConn() db = dbConn()
while True: while True:
try: try:
# print("转储历史记录...") for key in self.session_memory:
for key in self.session_dict: data = self.session_memory[key]
data = self.session_dict[key]
data_json = { data_json = {
'data': data 'data': data
} }
@@ -96,326 +115,386 @@ class ProviderOpenAIOfficial(Provider):
db.update_session(key, json.dumps(data_json)) db.update_session(key, json.dumps(data_json))
else: else:
db.insert_session(key, json.dumps(data_json)) db.insert_session(key, json.dumps(data_json))
# print("转储历史记录完毕") logger.debug("已保存 OpenAI 会话历史记录")
except BaseException as e: except BaseException as e:
print(e) print(e)
# 每隔10分钟转储一次 finally:
time.sleep(10*self.history_dump_interval) time.sleep(10*60)
def personality_set(self, default_personality: dict, session_id: str): def personality_set(self, default_personality: dict, session_id: str):
if not default_personality: return
if session_id not in self.session_memory:
self.session_memory[session_id] = []
self.curr_personality = default_personality self.curr_personality = default_personality
self.session_personality = {} # 重置
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
if model in MODELS and tokens_num > MODELS[model] - 500:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
new_record = { new_record = {
"user": { "user": {
"role": "user", "role": "system",
"content": default_personality['prompt'], "content": default_personality['prompt'],
}, },
"AI": { 'usage_tokens': 0, # 到该条目的总 token 数
"role": "assistant", 'single-tokens': 0 # 该条目的 token 数
"content": "好的,接下来我会扮演这个角色。"
},
'type': "personality",
'usage_tokens': 0,
'single-tokens': 0
} }
self.session_dict[session_id].append(new_record)
async def text_chat(self, prompt, self.session_memory[session_id].append(new_record)
session_id=None,
image_url=None,
function_call=None,
extra_conf: dict = None,
default_personality: dict = None):
if session_id is None:
session_id = "unknown"
if "unknown" in self.session_dict:
del self.session_dict["unknown"]
# 会话机制
if session_id not in self.session_dict:
self.session_dict[session_id] = []
if len(self.session_dict[session_id]) == 0: async def encode_image_bs64(self, image_url: str) -> str:
# 设置默认人格 '''
if default_personality is not None: 将图片转换为 base64
self.personality_set(default_personality, session_id) '''
if image_url.startswith("http"):
image_url = await gu.download_image_by_url(image_url)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode()
return "data:image/jpeg;base64," + image_bs64
# 使用 tictoken 截断消息 async def retrieve_context(self, session_id: str):
_encoded_prompt = self.enc.encode(prompt) '''
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): 根据 session_id 获取保存的 OpenAI 格式的上下文
prompt = self.enc.decode(_encoded_prompt[:int( '''
self.openai_model_configs['max_tokens']*0.80)]) if session_id not in self.session_memory:
self.logger.log(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。", raise Exception("会话 ID 不存在")
level=gu.LEVEL_WARNING, tag="OpenAI")
# 转换为 openai 要求的格式
cache_data_list, new_record, req = self.wrap( context = []
prompt, session_id, image_url) is_lvm = await self.is_lvm()
self.logger.log(f"cache: {str(cache_data_list)}", for record in self.session_memory[session_id]:
level=gu.LEVEL_DEBUG, tag="OpenAI") if "user" in record and record['user']:
self.logger.log(f"request: {str(req)}", if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
level=gu.LEVEL_DEBUG, tag="OpenAI") logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
retry = 0
response = None
err = ''
# 截断倍率
truncate_rate = 0.75
use_gpt4v = False
for i in req:
if isinstance(i['content'], list):
use_gpt4v = True
break
if image_url is not None:
use_gpt4v = True
if use_gpt4v:
conf = self.openai_model_configs.copy()
conf['model'] = 'gpt-4-vision-preview'
else:
conf = self.openai_model_configs
if extra_conf is not None:
conf.update(extra_conf)
while retry < 10:
try:
if function_call is None:
response = await self.client.chat.completions.create(
messages=req,
**conf
)
else:
response = await self.client.chat.completions.create(
messages=req,
tools=function_call,
**conf
)
break
except Exception as e:
traceback.print_exc()
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
raise e
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
self.logger.log("当前 Key 已超额或异常, 正在切换",
level=gu.LEVEL_WARNING, tag="OpenAI")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
retry -= 1
elif 'maximum context length' in str(e):
self.logger.log("token 超限, 清空对应缓存,并进行消息截断", tag="OpenAI")
self.session_dict[session_id] = []
prompt = prompt[:int(len(prompt)*truncate_rate)]
truncate_rate -= 0.05
cache_data_list, new_record, req = self.wrap(
prompt, session_id)
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
time.sleep(30)
continue continue
else: context.append(record['user'])
self.logger.log(str(e), level=gu.LEVEL_ERROR, tag="OpenAI") if "AI" in record and record['AI']:
time.sleep(2) context.append(record['AI'])
err = str(e)
retry += 1
if retry >= 10:
self.logger.log(
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki", tag="OpenAI")
raise BaseException("连接出错: "+str(err))
assert isinstance(response, ChatCompletion)
self.logger.log(
f"OPENAI RESPONSE: {response.usage}", level=gu.LEVEL_DEBUG, tag="OpenAI")
# 结果分类 return context
choice = response.choices[0]
if choice.message.content != None: async def is_lvm(self):
# 文本形式 '''
chatgpt_res = str(choice.message.content).strip() 是否是 LVM
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0: '''
return self.model_configs['model'].startswith("gpt-4")
async def get_models(self):
'''
获取所有模型
'''
models = await self.client.models.list()
logger.info(f"OpenAI 模型列表:{models}")
return models
async def assemble_context(self, session_id: str, prompt: str, image_url: str = None):
'''
组装上下文,并且根据当前上下文窗口大小截断
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
tokens_num = len(self.tokenizer.encode(prompt))
previous_total_tokens_num = 0 if not self.session_memory[session_id] else self.session_memory[session_id][-1]['usage_tokens']
message = {
"usage_tokens": previous_total_tokens_num + tokens_num,
"single_tokens": tokens_num,
"AI": None
}
if image_url:
user_content = {
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": await self.encode_image_bs64(image_url)
}
}
]
}
else:
user_content = {
"role": "user",
"content": prompt
}
message["user"] = user_content
self.session_memory[session_id].append(message)
# 根据 模型的上下文窗口 淘汰掉多余的记录
curr_model = self.model_configs['model']
if curr_model in MODELS:
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
# if message['usage_tokens'] > maxium_tokens_num:
# 淘汰多余的记录,使得最终的 usage_tokens 不超过 maxium_tokens_num - 300
# contexts = self.session_memory[session_id]
# need_to_remove_idx = 0
# freed_tokens_num = contexts[0]['single-tokens']
# while freed_tokens_num < message['usage_tokens'] - maxium_tokens_num:
# need_to_remove_idx += 1
# freed_tokens_num += contexts[need_to_remove_idx]['single-tokens']
# # 更新之后的所有记录的 usage_tokens
# for i in range(len(contexts)):
# if i > need_to_remove_idx:
# contexts[i]['usage_tokens'] -= freed_tokens_num
# logger.debug(f"淘汰上下文记录 {need_to_remove_idx+1} 条,释放 {freed_tokens_num} 个 token。当前上下文总 token 为 {contexts[-1]['usage_tokens']}。")
# self.session_memory[session_id] = contexts[need_to_remove_idx+1:]
while len(self.session_memory[session_id]) and self.session_memory[session_id][-1]['usage_tokens'] > maxium_tokens_num:
self.pop_record(session_id)
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
# 更新之后所有记录的 usage_tokens
for i in range(len(self.session_memory[session_id])):
self.session_memory[session_id][i]['usage_tokens'] -= record['single-tokens']
logger.debug(f"淘汰上下文记录 1 条,释放 {record['single-tokens']} 个 token。当前上下文总 token 为 {self.session_memory[session_id][-1]['usage_tokens']}")
return record
async def text_chat(self,
prompt: str,
session_id: str,
image_url: None=None,
tools: None=None,
extra_conf: Dict = None,
**kwargs
) -> str:
super().accu_model_stat()
if not session_id:
session_id = "unknown"
if "unknown" in self.session_memory:
del self.session_memory["unknown"]
if session_id not in self.session_memory:
self.session_memory[session_id] = []
if session_id not in self.session_personality or not self.session_personality[session_id]:
self.personality_set(self.curr_personality, session_id)
self.session_personality[session_id] = True
# 如果 prompt 超过了最大窗口,截断。
# 1. 可以保证之后 pop 的时候不会出现问题
# 2. 可以保证不会超过最大 token 数
_encoded_prompt = self.tokenizer.encode(prompt)
curr_model = self.model_configs['model']
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
prompt = self.tokenizer.decode(_encoded_prompt)
# 组装上下文,并且根据当前上下文窗口大小截断
await self.assemble_context(session_id, prompt, image_url)
# 获取上下文,openai 格式
contexts = await self.retrieve_context(session_id)
conf = self.model_configs
if extra_conf: conf.update(extra_conf)
# start request
retry = 0
rate_limit_retry = 0
while retry < 3 or rate_limit_retry < 5:
logger.debug(conf)
logger.debug(contexts)
if tools:
completion_coro = self.client.chat.completions.create(
messages=contexts,
tools=tools,
**conf
)
else:
completion_coro = self.client.chat.completions.create(
messages=contexts,
**conf
)
try:
completion = await completion_coro
break
except AuthenticationError as e:
api_key = self.chosen_api_key[10:] + "..."
logger.error(f"OpenAI API Key {api_key} 验证错误。详细原因:{e}。正在切换到下一个可用的 Key(如果有的话)")
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except BadRequestError as e:
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
retry += 1
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
ok = await self.switch_to_next_key()
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
time.sleep(1)
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 请求失败:{e}。重试次数已达到上限。")
if "maximum context length" in str(e):
logger.warn(f"OpenAI 请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens
completion_tokens = completion.usage.completion_tokens
self.session_memory[session_id][-1]['usage_tokens'] = usage_tokens
self.session_memory[session_id][-1]['single_tokens'] += completion_tokens
if choice.message.content:
# 返回文本
completion_text = str(choice.message.content).strip()
elif choice.message.tool_calls and choice.message.tool_calls:
# tools call (function calling) # tools call (function calling)
return choice.message.tool_calls[0].function return choice.message.tool_calls[0].function
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens self.session_memory[session_id][-1]['AI'] = {
current_usage_tokens = response.usage.total_tokens "role": "assistant",
"content": completion_text
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
if current_usage_tokens > self.max_tokens:
t = current_usage_tokens
index = 0
while t > self.max_tokens:
if index >= len(cache_data_list):
break
# 保留人格信息
if cache_data_list[index]['type'] != 'personality':
t -= int(cache_data_list[index]['single_tokens'])
del cache_data_list[index]
else:
index += 1
# 删除完后更新相关字段
self.session_dict[session_id] = cache_data_list
# 添加新条目进入缓存的prompt
new_record['AI'] = {
'role': 'assistant',
'content': chatgpt_res,
} }
new_record['usage_tokens'] = current_usage_tokens
if len(cache_data_list) > 0:
new_record['single_tokens'] = current_usage_tokens - \
int(cache_data_list[-1]['usage_tokens'])
else:
new_record['single_tokens'] = current_usage_tokens
cache_data_list.append(new_record) return completion_text
self.session_dict[session_id] = cache_data_list async def switch_to_next_key(self):
'''
return chatgpt_res 切换到下一个 API Key
'''
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"): if not self.api_keys:
retry = 0 logger.error("OpenAI API Key 不存在。")
image_url = ''
image_generate_configs = self.cc.get("openai_image_generate", None)
while retry < 5:
try:
response: ImagesResponse = await self.client.images.generate(
prompt=prompt,
**image_generate_configs
)
image_url = []
for i in range(img_num):
image_url.append(response.data[i].url)
break
except Exception as e:
self.logger.log(str(e), level=gu.LEVEL_ERROR)
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
self.logger.log("当前 Key 已超额或者不正常, 正在切换",
level=gu.LEVEL_WARNING, tag="OpenAI")
self.key_stat[self.client.api_key]['exceed'] = True
is_switched = self.handle_switch_key()
if not is_switched:
raise e
elif 'Your request was rejected as a result of our safety system.' in str(e):
self.logger.log("您的请求被 OpenAI 安全系统拒绝, 请稍后再试",
level=gu.LEVEL_WARNING, tag="OpenAI")
raise e
else:
retry += 1
if retry >= 5:
raise BaseException("连接超时")
return image_url
async def forget(self, session_id=None) -> bool:
if session_id is None:
return False return False
self.session_dict[session_id] = []
return True
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): for key in self.keys_data:
if self.keys_data[key]:
# 没超额
self.chosen_api_key = key
self.client.api_key = key
logger.info(f"OpenAI 切换到 API Key {key[:10]}... 成功。")
return True
return False
async def image_generate(self, prompt: str, session_id: str = None, **kwargs) -> str:
'''
生成图片
'''
retry = 0
conf = self.image_generator_model_configs
super().accu_model_stat(model=conf['model'])
if not conf:
logger.error("OpenAI 图片生成模型配置不存在。")
raise Exception("OpenAI 图片生成模型配置不存在。")
while retry < 3:
try:
images_response = await self.client.images.generate(
prompt=prompt,
**conf
)
image_url = images_response.data[0].url
return image_url
except Exception as e:
retry += 1
if retry >= 3:
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
time.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
self.session_memory[session_id] = []
if keep_system_prompt:
self.personality_set(self.curr_personality, session_id)
else:
self.curr_personality = self.DEFAULT_PERSONALITY
return True
def dump_contexts_page(self, session_id: str, size=5, page=1,):
''' '''
获取缓存的会话 获取缓存的会话
''' '''
prompts = "" # contexts_str = ""
if paging: # for i, key in enumerate(self.session_memory):
page_begin = (page-1)*size # if i < (page-1)*size or i >= page*size:
page_end = page*size # continue
if page_begin < 0: # contexts_str += f"Session ID: {key}\n"
page_begin = 0 # for record in self.session_memory[key]:
if page_end > len(cache_data_list): # if "user" in record:
page_end = len(cache_data_list) # contexts_str += f"User: {record['user']['content']}\n"
cache_data_list = cache_data_list[page_begin:page_end] # if "AI" in record:
for item in cache_data_list: # contexts_str += f"AI: {record['AI']['content']}\n"
prompts += str(item['user']['role']) + ":\n" + \ # contexts_str += "---\n"
str(item['user']['content']) + "\n" contexts_str = ""
prompts += str(item['AI']['role']) + ":\n" + \ if session_id in self.session_memory:
str(item['AI']['content']) + "\n" for record in self.session_memory[session_id]:
if "user" in record and record['user']:
text = record['user']['content'][:100] + "..." if len(record['user']['content']) > 100 else record['user']['content']
contexts_str += f"User: {text}\n"
if "AI" in record and record['AI']:
text = record['AI']['content'][:100] + "..." if len(record['AI']['content']) > 100 else record['AI']['content']
contexts_str += f"Assistant: {text}\n"
else:
contexts_str = "会话 ID 不存在。"
if divide: return contexts_str, len(self.session_memory[session_id])
prompts += "----------\n"
return prompts
def wrap(self, prompt, session_id, image_url=None):
if image_url is not None:
prompt = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
# 获得缓存信息
context = self.session_dict[session_id]
new_record = {
"user": {
"role": "user",
"content": prompt,
},
"AI": {},
'type': "common",
'usage_tokens': 0,
}
req_list = []
for i in context:
if 'user' in i:
req_list.append(i['user'])
if 'AI' in i:
req_list.append(i['AI'])
req_list.append(new_record['user'])
return context, new_record, req_list
def handle_switch_key(self):
is_all_exceed = True
for key in self.key_stat:
if key == None or self.key_stat[key]['exceed']:
continue
is_all_exceed = False
self.client.api_key = key
self.logger.log(
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})", level=gu.LEVEL_INFO, tag="OpenAI")
break
if is_all_exceed:
self.logger.log(
"所有 Key 已超额", level=gu.LEVEL_CRITICAL, tag="OpenAI")
return False
return True
def set_model(self, model: str):
self.model_configs['model'] = model
self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model)
super().set_curr_model(model)
def get_configs(self): def get_configs(self):
return self.openai_configs return self.model_configs
def get_key_stat(self): def get_keys_data(self):
return self.key_stat return self.keys_data
def get_key_list(self):
return self.key_list
def get_curr_key(self): def get_curr_key(self):
return self.client.api_key return self.chosen_api_key
def set_key(self, key): def set_key(self, key):
self.client.api_key = key self.client.api_key = key
# 添加key
def append_key(self, key, sponsor):
self.key_list.append(key)
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
# 检查key是否可用
async def check_key(self, key):
client_ = AsyncOpenAI(
api_key=key,
base_url=self.api_base
)
messages = [{"role": "user", "content": "please just echo `test`"}]
await client_.chat.completions.create(
messages=messages,
**self.openai_model_configs
)
return True
+29 -6
View File
@@ -1,9 +1,32 @@
from collections import defaultdict
class Provider: class Provider:
def __init__(self) -> None:
self.model_stat = defaultdict(int) # 用于记录 LLM Model 使用数据
self.curr_model_name = "unknown"
def reset_model_stat(self):
self.model_stat.clear()
def set_curr_model(self, model_name: str):
self.curr_model_name = model_name
def get_curr_model(self):
'''
返回当前正在使用的 LLM
'''
return self.curr_model_name
def accu_model_stat(self, model: str = None):
if not model:
model = self.get_curr_model()
self.model_stat[model] += 1
async def text_chat(self, async def text_chat(self,
prompt: str, prompt: str,
session_id: str, session_id: str,
image_url: None, image_url: None = None,
function_call: None, tools: None = None,
extra_conf: dict = None, extra_conf: dict = None,
default_personality: dict = None, default_personality: dict = None,
**kwargs) -> str: **kwargs) -> str:
@@ -14,11 +37,11 @@ class Provider:
[optional] [optional]
image_url: 图片url识图 image_url: 图片url识图
function_call: 函数调用 tools: 函数调用工具
extra_conf: 额外配置 extra_conf: 额外配置
default_personality: 默认人格 default_personality: 默认人格
''' '''
raise NotImplementedError raise NotImplementedError()
async def image_generate(self, prompt, session_id, **kwargs) -> str: async def image_generate(self, prompt, session_id, **kwargs) -> str:
''' '''
@@ -26,10 +49,10 @@ class Provider:
prompt: 提示词 prompt: 提示词
session_id: 会话id session_id: 会话id
''' '''
raise NotImplementedError raise NotImplementedError()
async def forget(self, session_id=None) -> bool: async def forget(self, session_id=None) -> bool:
''' '''
重置会话 重置会话
''' '''
raise NotImplementedError raise NotImplementedError()
@@ -1,13 +1,16 @@
import sqlite3 import sqlite3
import yaml import os
import shutil
import time import time
from typing import Tuple from typing import Tuple
class dbConn(): class dbConn():
def __init__(self): def __init__(self):
# 读取参数,并支持中文 db_path = "data/data.db"
conn = sqlite3.connect("data.db") if os.path.exists("data.db"):
shutil.copy("data.db", db_path)
conn = sqlite3.connect(db_path)
conn.text_factory = str conn.text_factory = str
self.conn = conn self.conn = conn
c = conn.cursor() c = conn.cursor()
+5 -4
View File
@@ -4,15 +4,16 @@ requests
openai~=1.2.3 openai~=1.2.3
qq-botpy qq-botpy
chardet~=5.1.0 chardet~=5.1.0
Pillow~=9.4.0 Pillow
GitPython~=3.1.31 GitPython
nakuru-project nakuru-project
beautifulsoup4 beautifulsoup4
googlesearch-python googlesearch-python
tiktoken tiktoken
readability-lxml readability-lxml
baidu-aip~=4.16.9 baidu-aip
websockets websockets
flask flask
psutil psutil
lxml_html_clean lxml_html_clean
SparkleLogging
+28
View File
@@ -0,0 +1,28 @@
from typing import Union, List, Callable
from dataclasses import dataclass
@dataclass
class CommandItem():
'''
用来描述单个指令
'''
command_name: Union[str, tuple] # 指令名
callback: Callable # 回调函数
description: str # 描述
origin: str # 注册来源
class CommandResult():
'''
用于在Command中返回多个值
'''
def __init__(self, hit: bool, success: bool = False, message_chain: list = [], command_name: str = "unknown_command") -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
+1
View File
@@ -0,0 +1 @@
VERSION = '3.1.13'
+62
View File
@@ -0,0 +1,62 @@
from enum import Enum
from typing import List
from dataclasses import dataclass
from nakuru.entities.components import BaseMessageComponent
from type.register import RegisteredPlatform
from type.types import GlobalObject
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
GUILD_MESSAGE = 'GuildMessage' # 频道消息
@dataclass
class MessageMember():
user_id: str # 发送者id
nickname: str = None
class AstrBotMessage():
'''
AstrBot 的消息对象
'''
tag: str # 消息来源标签
type: MessageType # 消息类型
self_id: str # 机器人的识别id
session_id: str # 会话id
message_id: str # 消息id
sender: MessageMember # 发送者
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
raw_message: object
timestamp: int # 消息时间戳
def __str__(self) -> str:
return str(self.__dict__)
class AstrMessageEvent():
'''
消息事件
'''
context: GlobalObject # 一些公用数据
message_str: str # 纯消息字符串
message_obj: AstrBotMessage # 消息对象
platform: RegisteredPlatform # 来源平台
role: str # 基本身份。`admin` 或 `member`
session_id: int # 会话 id
def __init__(self,
message_str: str,
message_obj: AstrBotMessage,
platform: RegisteredPlatform,
role: str,
context: GlobalObject,
session_id: str = None):
self.context = context
self.message_str = message_str
self.message_obj = message_obj
self.platform = platform
self.role = role
self.session_id = session_id
+27
View File
@@ -0,0 +1,27 @@
from enum import Enum
from dataclasses import dataclass
class PluginType(Enum):
PLATFORM = 'platfrom' # 平台类插件。
LLM = 'llm' # 大语言模型类插件
COMMON = 'common' # 其他插件
@dataclass
class PluginMetadata:
'''
插件的元数据
'''
# required
plugin_name: str
plugin_type: PluginType
author: str # 插件作者
desc: str # 插件简介
version: str # 插件版本
# optional
repo: str = None # 插件仓库地址
def __str__(self) -> str:
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
+53
View File
@@ -0,0 +1,53 @@
from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform
from type.plugin import *
from typing import List
from types import ModuleType
from dataclasses import dataclass
@dataclass
class RegisteredPlugin:
'''
注册在 AstrBot 中的插件
'''
metadata: PluginMetadata
plugin_instance: object
module_path: str
module: ModuleType
root_dir_name: str
trig_cnt: int = 0
def reset_trig_cnt(self):
self.trig_cnt = 0
def trig(self):
self.trig_cnt += 1
def __str__(self) -> str:
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
RegisteredPlugins = List[RegisteredPlugin]
@dataclass
class RegisteredPlatform:
'''
注册在 AstrBot 中的平台平台应当实现 Platform 接口
'''
platform_name: str
platform_instance: Platform
origin: str = None # 注册来源
def __str__(self) -> str:
return self.platform_name
@dataclass
class RegisteredLLM:
'''
注册在 AstrBot 中的大语言模型调用大语言模型应当实现 LLMProvider 接口
'''
llm_name: str
llm_instance: LLMProvider
origin: str = None # 注册来源
+32
View File
@@ -0,0 +1,32 @@
from type.register import *
from typing import List
class GlobalObject:
'''
存放一些公用的数据用于在不同模块(如core与command)之间传递
'''
version: str # 机器人版本
nick: tuple # 用户定义的机器人的别名
base_config: dict # config.json 中导出的配置
cached_plugins: List[RegisteredPlugin] # 加载的插件
platforms: List[RegisteredPlatform]
llms: List[RegisteredLLM]
web_search: bool # 是否开启了网页搜索
reply_prefix: str # 回复前缀
unique_session: bool # 是否开启了独立会话
default_personality: dict
dashboard_data = None
def __init__(self):
self.nick = None # gocq 的昵称
self.base_config = None # config.yaml
self.cached_plugins = [] # 缓存的插件
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = None
self.unique_session = False
self.platforms = []
self.llms = []
self.default_personality = None
self.dashboard_data = None
self.stat = {}
+183
View File
@@ -0,0 +1,183 @@
import traceback
import random
import json
import asyncio
import aiohttp
import os
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.agent.func_call import FuncCall
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
from util.search_engine_scraper.bing import Bing
from util.search_engine_scraper.sogo import Sogo
from util.search_engine_scraper.google import Google
from model.provider.provider import Provider
from SparkleLogging.utils.core import LogManager
from logging import Logger
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
bing_search = Bing()
sogo_search = Sogo()
google = Google()
proxy = os.environ.get("HTTPS_PROXY", None)
def tidy_text(text: str) -> str:
'''
清理文本去除空格换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
# def special_fetch_zhihu(link: str) -> str:
# '''
# function-calling 函数, 用于获取知乎文章的内容
# '''
# response = requests.get(link, headers=HEADERS)
# response.encoding = "utf-8"
# soup = BeautifulSoup(response.text, "html.parser")
# if "zhuanlan.zhihu.com" in link:
# r = soup.find(class_="Post-RichTextContainer")
# else:
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
# if r is None:
# print("debug: zhihu none")
# raise Exception("zhihu none")
# return tidy_text(r.text)
async def search_from_bing(keyword: str) -> str:
'''
tools, bing 搜索引擎搜索
'''
logger.info("web_searcher - search_from_bing: " + keyword)
results = []
try:
results = await google.search(keyword, 5)
except BaseException as e:
logger.error(f"google search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search google failed")
try:
results = await bing_search.search(keyword, 5)
except BaseException as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await sogo_search.search(keyword, 5)
except BaseException as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
logger.debug("search sogo failed")
return "没有搜索到结果"
ret = ""
idx = 1
for i in results:
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
try:
site_result = await fetch_website_content(i.url)
except:
site_result = ""
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
return ret
async def fetch_website_content(url):
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=HEADERS, timeout=6, proxy=proxy) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(prompt, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("web_search", [{
"type": "string",
"name": "keyword",
"description": "搜索关键词"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
search_from_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "要获取内容的网页链接"
}],
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func())
if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}")
func_obj = None
for i in new_func_call.func_list:
if i["name"] == result.name:
func_obj = i["func_obj"]
break
if not func_obj:
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
return result
else:
# we use our own function-calling
try:
args = {
'question': prompt,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
await provider.forget(session_id)
summary_prompt = f"""
你是一个专业且高效的助手你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
2. 简单地发表你对这个问题的简略看法
# 例子
1. 从网上的信息来看可以知道...我个人认为...你觉得呢
2. 根据网上的最新信息可以得知...我觉得...你怎么看
# 限制
1. 限制在 200 字以内
2. **直接输出总结**不要输出多余的内容和提示语
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(summary_prompt, session_id)
return ret
return function_invoked_ret
+28 -4
View File
@@ -1,9 +1,9 @@
import os import os
import json import json
import yaml
from typing import Union from typing import Union
cpath = "cmd_config.json" cpath = "data/cmd_config.json"
def check_exist(): def check_exist():
if not os.path.exists(cpath): if not os.path.exists(cpath):
@@ -89,8 +89,7 @@ def init_astrbot_config_items():
# 加载默认配置 # 加载默认配置
cc = CmdConfig() cc = CmdConfig()
cc.init_attributes("qq_forward_threshold", 200) cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes( cc.init_attributes("qq_welcome", "")
"qq_welcome", "欢迎加入本群!\n欢迎给https://github.com/Soulter/QQChannelChatGPT项目一个Star😊~\n输入help查看帮助~\n")
cc.init_attributes("qq_pic_mode", False) cc.init_attributes("qq_pic_mode", False)
cc.init_attributes("gocq_host", "127.0.0.1") cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700) cc.init_attributes("gocq_http_port", 5700)
@@ -119,3 +118,28 @@ def init_astrbot_config_items():
cc.init_attributes("https_proxy", "") cc.init_attributes("https_proxy", "")
cc.init_attributes("dashboard_username", "") cc.init_attributes("dashboard_username", "")
cc.init_attributes("dashboard_password", "") cc.init_attributes("dashboard_password", "")
def try_migrate_config():
'''
cmd_config.json 迁移至 data/cmd_config.json
'''
print("try migrate configs")
if os.path.exists("cmd_config.json"):
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
data = json.load(f)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
try:
os.remove("cmd_config.json")
except Exception as e:
pass
if not os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"):
# 从 configs/config.yaml 上拿数据
configs_pth = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../configs/config.yaml"))
with open(configs_pth, encoding='utf-8') as f:
data = yaml.load(f, Loader=yaml.Loader)
print(data)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
-300
View File
@@ -1,300 +0,0 @@
import requests
import util.general_utils as gu
import traceback
import time
import json
import asyncio
from googlesearch import search, SearchResult
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
from util.function_calling.func_call import (
FuncCall,
FuncCallJsonFormatError,
FuncNotFoundError
)
from model.provider.provider import Provider
def tidy_text(text: str) -> str:
'''
清理文本去除空格换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
def special_fetch_zhihu(link: str) -> str:
'''
function-calling 函数, 用于获取知乎文章的内容
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(link, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
if "zhuanlan.zhihu.com" in link:
r = soup.find(class_="Post-RichTextContainer")
else:
r = soup.find(class_="List-item").find(class_="RichContent-inner")
if r is None:
print("debug: zhihu none")
raise Exception("zhihu none")
return tidy_text(r.text)
def google_web_search(keyword) -> str:
'''
获取 google 搜索结果, 得到 titledesclink
'''
ret = ""
index = 1
try:
ls = search(keyword, advanced=True, num_results=4)
for i in ls:
desc = i.description
try:
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(i.url)
except BaseException as e:
print(f"(google) fetch_website_content err: {str(e)}")
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
index += 1
except Exception as e:
print(f"google search err: {str(e)}")
return web_keyword_search_via_bing(keyword)
return ret
def web_keyword_search_via_bing(keyword) -> str:
'''
获取bing搜索结果, 得到 titledesclink
'''
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
url = "https://www.bing.com/search?q="+keyword
_cnt = 0
# _detail_store = []
while _cnt < 5:
try:
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
soup = BeautifulSoup(response.text, "html.parser")
res = ""
result_cnt = 0
ols = soup.find(id="b_results")
for i in ols.find_all("li", class_="b_algo"):
try:
title = i.find("h2").text
desc = i.find("p").text
link = i.find("h2").find("a").get("href")
# res.append({
# "title": title,
# "desc": desc,
# "link": link,
# })
try:
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
desc = fetch_website_content(link)
except BaseException as e:
print(f"(bing) fetch_website_content err: {str(e)}")
res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n"
result_cnt += 1
if result_cnt > 5:
break
# if len(_detail_store) >= 3:
# continue
# # 爬取前两条的网页内容
# if "zhihu.com" in link:
# try:
# _detail_store.append(special_fetch_zhihu(link))
# except BaseException as e:
# print(f"zhihu parse err: {str(e)}")
# else:
# try:
# _detail_store.append(fetch_website_content(link))
# except BaseException as e:
# print(f"fetch_website_content err: {str(e)}")
except Exception as e:
print(f"bing parse err: {str(e)}")
if result_cnt == 0:
break
return res
except Exception as e:
# gu.log(f"bing fetch err: {str(e)}")
_cnt += 1
time.sleep(1)
# gu.log("fail to fetch bing info, using sougou.")
return web_keyword_search_via_sougou(keyword)
def web_keyword_search_via_sougou(keyword) -> str:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
}
url = f"https://sogou.com/web?query={keyword}"
response = requests.get(url, headers=headers)
response.encoding = "utf-8"
soup = BeautifulSoup(response.text, "html.parser")
res = []
results = soup.find("div", class_="results")
for i in results.find_all("div", class_="vrwrap"):
try:
title = tidy_text(i.find("h3").text)
link = tidy_text(i.find("h3").find("a").get("href"))
if link.startswith("/link?url="):
link = "https://www.sogou.com" + link
res.append({
"title": title,
"link": link,
})
if len(res) >= 5: # 限制5条
break
except Exception as e:
pass
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
for i in res:
if _detail_store >= 3:
break
try:
_detail_store.append(fetch_website_content(i["link"]))
except BaseException as e:
print(f"fetch_website_content err: {str(e)}")
ret = f"{str(res)}"
if len(_detail_store) > 0:
ret += f"\n网页内容: {str(_detail_store)}"
return ret
def fetch_website_content(url):
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(url, headers=headers, timeout=3)
response.encoding = "utf-8"
doc = Document(response.content)
# print('title:', doc.title())
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(question, provider: Provider, session_id, official_fc=False):
'''
official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("google_web_search", [{
"type": "string",
"name": "keyword",
"description": "google search query (分词,尽量保留所有信息)"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
web_keyword_search_via_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "网址"
}],
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
if isinstance(func, Function):
# 执行对应的结果:
func_obj = None
for i in new_func_call.func_list:
if i["name"] == func.name:
func_obj = i["func_obj"]
break
if not func_obj:
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(func.arguments)
# we use to_thread to avoid blocking the event loop
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
else:
# now func is a string
return func
else:
# we use our own function-calling
try:
args = {
'question': question1,
'func_definition': new_func_call.func_dump(),
'is_task': False,
'is_summary': False,
}
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
except BaseException as e:
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
return res
has_func = True
if has_func:
await provider.forget(session_id)
question3 = f"""
你的任务是
1. 根据末尾的材料对问题`{question}`做切题的总结详细;
2. 简单地发表你对这个问题的看法简略
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接引用格式严格按照 `\n[1] title url \n`
不要提到任何函数调用的信息
一些回复的消息模板
模板1:
```
从网上的信息来看可以知道...我个人认为...你觉得呢
```
模板2:
```
根据网上的最新信息可以得知...我觉得...你怎么看
```
你可以根据这些模板来组织回答但可以不照搬要根据问题的内容来回答
以下是相关材料
"""
_c = 0
while _c < 3:
try:
print('text chat')
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
return final_ret
except Exception as e:
print(e)
_c += 1
if _c == 3:
raise e
if "The message you submitted was too long" in str(e):
await provider.forget(session_id)
function_invoked_ret = function_invoked_ret[:int(
len(function_invoked_ret) / 2)]
time.sleep(3)
return function_invoked_ret
+84 -149
View File
@@ -1,143 +1,22 @@
import datetime
import time import time
import socket import socket
from PIL import Image, ImageDraw, ImageFont
import os import os
import re import re
import requests import requests
from util.cmd_config import CmdConfig import aiohttp
import socket import socket
from cores.astrbot.types import GlobalObject
import platform import platform
import logging
import json import json
import sys import sys
import psutil import psutil
import ssl
PLATFORM_GOCQ = 'gocq' from PIL import Image, ImageDraw, ImageFont
PLATFORM_QQCHAN = 'qqchan' from type.types import GlobalObject
from SparkleLogging.utils.core import LogManager
from logging import Logger
FG_COLORS = { logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
"black": "30",
"red": "31",
"green": "32",
"yellow": "33",
"blue": "34",
"purple": "35",
"cyan": "36",
"white": "37",
"default": "39",
}
BG_COLORS = {
"black": "40",
"red": "41",
"green": "42",
"yellow": "43",
"blue": "44",
"purple": "45",
"cyan": "46",
"white": "47",
"default": "49",
}
LEVEL_DEBUG = "DEBUG"
LEVEL_INFO = "INFO"
LEVEL_WARNING = "WARN"
LEVEL_ERROR = "ERROR"
LEVEL_CRITICAL = "CRITICAL"
# 为了兼容旧版
level_codes = {
LEVEL_DEBUG: logging.DEBUG,
LEVEL_INFO: logging.INFO,
LEVEL_WARNING: logging.WARNING,
LEVEL_ERROR: logging.ERROR,
LEVEL_CRITICAL: logging.CRITICAL,
}
level_colors = {
"INFO": "green",
"WARN": "yellow",
"ERROR": "red",
"CRITICAL": "purple",
}
class Logger:
def __init__(self) -> None:
self.history = []
def log(
self,
msg: str,
level: str = "INFO",
tag: str = "System",
fg: str = None,
bg: str = None,
max_len: int = 50000,
err: Exception = None,):
"""
日志打印函数
"""
_set_level_code = level_codes[LEVEL_INFO]
if 'LOG_LEVEL' in os.environ and os.environ['LOG_LEVEL'] in level_codes:
_set_level_code = level_codes[os.environ['LOG_LEVEL']]
if level in level_codes and level_codes[level] < _set_level_code:
return
if err is not None:
msg += "\n异常原因: " + str(err)
level = LEVEL_ERROR
if len(msg) > max_len:
msg = msg[:max_len] + "..."
now = datetime.datetime.now().strftime("%H:%M:%S")
pres = []
for line in msg.split("\n"):
if line == "\n":
pres.append("")
else:
pres.append(f"[{now}] [{tag}/{level}] {line}")
if level == "INFO":
if fg is None:
fg = FG_COLORS["green"]
if bg is None:
bg = BG_COLORS["default"]
elif level == "WARN":
if fg is None:
fg = FG_COLORS["yellow"]
if bg is None:
bg = BG_COLORS["default"]
elif level == "ERROR":
if fg is None:
fg = FG_COLORS["red"]
if bg is None:
bg = BG_COLORS["default"]
elif level == "CRITICAL":
if fg is None:
fg = FG_COLORS["purple"]
if bg is None:
bg = BG_COLORS["default"]
ret = ""
for line in pres:
ret += f"\033[{fg};{bg}m{line}\033[0m\n"
try:
requests.post("http://localhost:6185/api/log",
data=ret[:-1].encode(), timeout=1)
except BaseException as e:
pass
self.history.append(ret)
if len(self.history) > 100:
self.history = self.history[-100:]
print(ret[:-1])
log = Logger().log
def port_checker(port: int, host: str = "localhost"): def port_checker(port: int, host: str = "localhost"):
@@ -477,14 +356,40 @@ def save_temp_img(img: Image) -> str:
if time.time() - ctime > 3600: if time.time() - ctime > 3600:
os.remove(path) os.remove(path)
except Exception as e: except Exception as e:
print(f"清除临时文件失败: {e}", level=LEVEL_WARNING, tag="GeneralUtils") print(f"清除临时文件失败: {e}")
# 获得时间戳 # 获得时间戳
timestamp = int(time.time()) timestamp = int(time.time())
p = f"temp/{timestamp}.png" p = f"temp/{timestamp}.jpg"
img.save(p)
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
logger.info(f"保存临时图片: {p}")
return p return p
async def download_image_by_url(url: str) -> str:
'''
下载图片
'''
try:
logger.info(f"下载图片: {url}")
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
except aiohttp.client_exceptions.ClientConnectorSSLError as e:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
except Exception as e:
raise e
def create_text_image(title: str, text: str, max_width=30, font_size=20): def create_text_image(title: str, text: str, max_width=30, font_size=20):
''' '''
@@ -517,17 +422,6 @@ def create_markdown_image(text: str):
raise e raise e
def try_migrate_config(old_config: dict):
'''
迁移配置文件到 cmd_config.json
'''
cc = CmdConfig()
if cc.get("qqbot", None) is None:
# 未迁移过
for k in old_config:
cc.put(k, old_config[k])
def get_local_ip_addresses(): def get_local_ip_addresses():
ip = '' ip = ''
try: try:
@@ -557,15 +451,41 @@ def get_sys_info(global_object: GlobalObject):
def upload(_global_object: GlobalObject): def upload(_global_object: GlobalObject):
'''
上传相关非敏感统计数据
'''
time.sleep(10)
while True: while True:
addr_ip = '' platform_stats = {}
llm_stats = {}
plugin_stats = {}
for platform in _global_object.platforms:
platform_stats[platform.platform_name] = {
"cnt_receive": platform.platform_instance.cnt_receive,
"cnt_reply": platform.platform_instance.cnt_reply
}
for llm in _global_object.llms:
stat = llm.llm_instance.model_stat
for k in stat:
llm_stats[llm.llm_name + "#" + k] = stat[k]
llm.llm_instance.reset_model_stat()
for plugin in _global_object.cached_plugins:
plugin_stats[plugin.metadata.plugin_name] = {
"metadata": plugin.metadata,
"trig_cnt": plugin.trig_cnt
}
plugin.reset_trig_cnt()
try: try:
res = { res = {
"version": _global_object.version, "stat_version": "moon",
"count": _global_object.cnt_total, "version": _global_object.version, # 版本号
"ip": addr_ip, "platform_stats": platform_stats, # 过去 30 分钟各消息平台交互消息数
"sys": sys.platform, "llm_stats": llm_stats,
"admin": "null", "plugin_stats": plugin_stats,
"sys": sys.platform, # 系统版本
} }
resp = requests.post( resp = requests.post(
'https://api.soulter.top/upload', data=json.dumps(res), timeout=5) 'https://api.soulter.top/upload', data=json.dumps(res), timeout=5)
@@ -575,7 +495,22 @@ def upload(_global_object: GlobalObject):
_global_object.cnt_total = 0 _global_object.cnt_total = 0
except BaseException as e: except BaseException as e:
pass pass
time.sleep(10*60) time.sleep(30*60)
def retry(n: int = 3):
'''
重试装饰器
'''
def decorator(func):
def wrapper(*args, **kwargs):
for i in range(n):
try:
return func(*args, **kwargs)
except Exception as e:
if i == n-1: raise e
logger.warning(f"函数 {func.__name__}{i+1} 次重试... {e}")
return wrapper
return decorator
def run_monitor(global_object: GlobalObject): def run_monitor(global_object: GlobalObject):
+5 -11
View File
@@ -1,11 +1,5 @@
from cores.astrbot.types import ( from type.plugin import PluginMetadata, PluginType
PluginMetadata, from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins
RegisteredLLM, from type.types import GlobalObject
RegisteredPlugin, from type.message import AstrMessageEvent
RegisteredPlatform, from type.command import CommandResult
RegisteredPlugins,
PluginType,
GlobalObject,
AstrMessageEvent,
CommandResult
)
+3 -2
View File
@@ -1,5 +1,6 @@
from cores.astrbot.core import oper_msg from astrbot.core import oper_msg
from cores.astrbot.types import AstrMessageEvent, CommandResult from type.message import AstrMessageEvent, AstrBotMessage
from type.command import CommandResult
from model.platform._message_result import MessageResult from model.platform._message_result import MessageResult
''' '''
+2 -1
View File
@@ -5,7 +5,8 @@
''' '''
from model.provider.provider import Provider as LLMProvider from model.provider.provider import Provider as LLMProvider
from model.platform._platfrom import Platform from model.platform._platfrom import Platform
from cores.astrbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM from type.types import GlobalObject
from type.register import RegisteredPlatform, RegisteredLLM
def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None: def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None:
''' '''
+1 -1
View File
@@ -2,4 +2,4 @@
插件类型 插件类型
''' '''
from cores.astrbot.types import PluginType from type.plugin import PluginType
+52 -50
View File
@@ -1,26 +1,24 @@
''' '''
插件工具函数 插件工具函数
''' '''
import os import os, sys
import inspect import inspect
import shutil
import stat
import traceback
try: try:
import git.exc
from git.repo import Repo from git.repo import Repo
except ImportError: except ImportError:
pass pass
import shutil
import importlib
import stat
import traceback
from types import ModuleType from types import ModuleType
from typing import List from type.plugin import *
from pip._internal import main as pipmain from type.register import *
from cores.astrbot.types import ( from SparkleLogging.utils.core import LogManager
PluginMetadata, from logging import Logger
PluginType,
RegisteredPlugin, logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
RegisteredPlugins
)
# 找出模块里所有的类名 # 找出模块里所有的类名
@@ -62,29 +60,35 @@ def get_modules(path):
def get_plugin_store_path(): def get_plugin_store_path():
if os.path.exists("addons/plugins"): plugin_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../addons/plugins"))
return "addons/plugins" return plugin_dir
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
return "QQChannelChatGPT/addons/plugins"
elif os.path.exists("AstrBot/addons/plugins"):
return "AstrBot/addons/plugins"
else:
raise FileNotFoundError("插件文件夹不存在。")
def get_plugin_modules(): def get_plugin_modules():
plugins = [] plugins = []
try: try:
if os.path.exists("addons/plugins"): plugin_dir = get_plugin_store_path()
plugins = get_modules("addons/plugins") if os.path.exists(plugin_dir):
plugins = get_modules(plugin_dir)
return plugins return plugins
elif os.path.exists("QQChannelChatGPT/addons/plugins"):
plugins = get_modules("QQChannelChatGPT/addons/plugins")
return plugins
else:
return None
except BaseException as e: except BaseException as e:
raise e raise e
def check_plugin_dept_update(cached_plugins: RegisteredPlugins, target_plugin: str = None):
plugin_dir = get_plugin_store_path()
if not os.path.exists(plugin_dir):
return False
to_update = []
if target_plugin:
to_update.append(target_plugin)
else:
for p in cached_plugins:
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查更新插件 {p} 的依赖: {pth}")
update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
def plugin_reload(cached_plugins: RegisteredPlugins): def plugin_reload(cached_plugins: RegisteredPlugins):
@@ -103,11 +107,9 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
module_path = plugin['module_path'] module_path = plugin['module_path']
root_dir_name = plugin['pname'] root_dir_name = plugin['pname']
if module_path in registered_map: check_plugin_dept_update(cached_plugins, root_dir_name)
# 之前注册过
module = importlib.reload(module) module = __import__("addons.plugins." +
else:
module = __import__("addons.plugins." +
root_dir_name + "." + p, fromlist=[p]) root_dir_name + "." + p, fromlist=[p])
cls = get_classes(p, module) cls = get_classes(p, module)
@@ -138,13 +140,15 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
except BaseException as e: except BaseException as e:
fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n" fail_rec += f"注册插件 {module_path} 失败, 原因: {str(e)}\n"
continue continue
cached_plugins.append(RegisteredPlugin(
metadata=metadata, if module_path not in registered_map:
plugin_instance=obj, cached_plugins.append(RegisteredPlugin(
module=module, metadata=metadata,
module_path=module_path, plugin_instance=obj,
root_dir_name=root_dir_name module=module,
)) module_path=module_path,
root_dir_name=root_dir_name
))
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n" fail_rec += f"加载{p}插件出现问题,原因 {str(e)}\n"
@@ -152,6 +156,11 @@ def plugin_reload(cached_plugins: RegisteredPlugins):
return True, None return True, None
else: else:
return False, fail_rec return False, fail_rec
def update_plugin_dept(path):
mirror = "https://mirrors.aliyun.com/pypi/simple/"
py = sys.executable
os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet")
def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins): def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
@@ -163,15 +172,12 @@ def install_plugin(repo_url: str, cached_plugins: RegisteredPlugins):
d = repo_url.split("/")[-1] d = repo_url.split("/")[-1]
# 转换非法字符:- # 转换非法字符:-
d = d.replace("-", "_") d = d.replace("-", "_")
d = d.lower() # 转换为小写
# 创建文件夹 # 创建文件夹
plugin_path = os.path.join(ppath, d) plugin_path = os.path.join(ppath, d)
if os.path.exists(plugin_path): if os.path.exists(plugin_path):
remove_dir(plugin_path) remove_dir(plugin_path)
Repo.clone_from(repo_url, to_path=plugin_path, branch='master') Repo.clone_from(repo_url, to_path=plugin_path, branch='master')
# 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件的依赖安装失败, 需要您手动 pip 安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins) ok, err = plugin_reload(cached_plugins)
if not ok: if not ok:
raise Exception(err) raise Exception(err)
@@ -206,10 +212,6 @@ def update_plugin(plugin_name: str, cached_plugins: RegisteredPlugins):
plugin_path = os.path.join(ppath, root_dir_name) plugin_path = os.path.join(ppath, root_dir_name)
repo = Repo(path=plugin_path) repo = Repo(path=plugin_path)
repo.remotes.origin.pull() repo.remotes.origin.pull()
# 读取插件的requirements.txt
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
if pipmain(['install', '-r', os.path.join(plugin_path, "requirements.txt"), '--quiet']) != 0:
raise Exception("插件依赖安装失败, 需要您手动pip安装对应插件的依赖。")
ok, err = plugin_reload(cached_plugins) ok, err = plugin_reload(cached_plugins)
if not ok: if not ok:
raise Exception(err) raise Exception(err)
+38
View File
@@ -0,0 +1,38 @@
from typing import List
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENT_BING
class Bing(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.bing.com"
self.headers.update({'User-Agent': USER_AGENT_BING})
def _set_selector(self, selector: str):
selectors = {
'url': 'div.b_attribution cite',
'title': 'h2',
'text': 'p',
'links': 'ol#b_results > li.b_algo',
'next': 'div#b_content nav[role="navigation"] a.sb_pagN'
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
if self.page == 1:
await self._get_html(self.base_url)
url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl='
return await self._get_html(url, None)
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = await super().search(query, num_results)
for result in results:
if not isinstance(result.url, str):
result.url = result.url.text
return results
+20
View File
@@ -0,0 +1,20 @@
HEADERS = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0',
'Accept': '*/*',
'Connection': 'keep-alive',
'Accept-Language': 'en-GB,en;q=0.5'
}
USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0'
USER_AGENTS = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36',
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0',
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'
]
+73
View File
@@ -0,0 +1,73 @@
import random
try:
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from config import HEADERS, USER_AGENTS
from bs4 import BeautifulSoup
from aiohttp import ClientSession
from dataclasses import dataclass
from typing import List
@dataclass
class SearchResult():
title: str
url: str
snippet: str
def __str__(self) -> str:
return f"{self.title} - {self.url}\n{self.snippet}"
class SearchEngine():
'''
搜索引擎爬虫基类
'''
def __init__(self) -> None:
self.TIMEOUT = 10
self.page = 1
self.headers = HEADERS
def _set_selector(self, selector: str) -> None:
raise NotImplementedError()
def _get_next_page(self):
raise NotImplementedError()
async def _get_html(self, url: str, data: dict = None) -> str:
headers = self.headers
headers["Referer"] = url
headers["User-Agent"] = random.choice(USER_AGENTS)
if data:
async with ClientSession() as session:
async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp:
return await resp.text(encoding="utf-8")
else:
async with ClientSession() as session:
async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp:
return await resp.text(encoding="utf-8")
def tidy_text(self, text: str) -> str:
'''
清理文本去除空格换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
try:
resp = await self._get_next_page(query)
soup = BeautifulSoup(resp, 'html.parser')
links = soup.select(self._set_selector('links'))
results = []
for link in links:
title = self.tidy_text(link.select_one(self._set_selector('title')).text)
url = link.select_one(self._set_selector('url'))
snippet = ''
if title and url:
results.append(SearchResult(title=title, url=url, snippet=snippet))
return results[:num_results] if len(results) > num_results else results
except Exception as e:
raise e
+27
View File
@@ -0,0 +1,27 @@
import os
from googlesearch import search
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENTS
from typing import List
class Google(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.proxy = os.environ.get("HTTPS_PROXY")
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = []
try:
print("use proxy:", self.proxy)
ls = search(query, advanced=True, num_results=num_results, timeout=3, proxy=self.proxy)
for i in ls:
results.append(SearchResult(title=i.title, url=i.url, snippet=i.description))
except Exception as e:
raise e
return results
+49
View File
@@ -0,0 +1,49 @@
import random, re
from bs4 import BeautifulSoup
try:
from util.search_engine_scraper.engine import SearchEngine, SearchResult
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
except ImportError:
from engine import SearchEngine, SearchResult
from config import HEADERS, USER_AGENTS
from typing import List
class Sogo(SearchEngine):
def __init__(self) -> None:
super().__init__()
self.base_url = "https://www.sogou.com"
self.headers['User-Agent'] = random.choice(USER_AGENTS)
def _set_selector(self, selector: str):
selectors = {
'url': 'h3 > a',
'title': 'h3',
'text': '',
'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)',
'next': ''
}
return selectors[selector]
async def _get_next_page(self, query) -> str:
url = f'{self.base_url}/web?query={query}'
return await self._get_html(url, None)
async def search(self, query: str, num_results: int) -> List[SearchResult]:
results = await super().search(query, num_results)
for result in results:
result.url = result.url.get("href")
if result.url.startswith("/link?"):
result.url = self.base_url + result.url
result.url = await self._parse_url(result.url)
return results
async def _parse_url(self, url) -> str:
html = await self._get_html(url)
soup = BeautifulSoup(html, 'html.parser')
script = soup.find("script")
if script:
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1)
return url
+22
View File
@@ -0,0 +1,22 @@
from sogo import Sogo
from bing import Bing
sogo_search = Sogo()
bing_search = Bing()
async def search(keyword: str) -> str:
results = await sogo_search.search(keyword, 5)
# results = await bing_search.search(keyword, 5)
ret = ""
if len(results) == 0:
return "没有搜索到结果"
idx = 1
for i in results:
ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n"
idx += 1
return ret
import asyncio
ret = asyncio.run(search("gpt4orelease"))
print(ret)
+9 -9
View File
@@ -6,6 +6,7 @@ except BaseException as e:
has_git = False has_git = False
import sys, os import sys, os
import requests import requests
from type.config import VERSION
def _reboot(): def _reboot():
py = sys.executable py = sys.executable
@@ -78,11 +79,11 @@ def check_update() -> str:
print(f"当前版本: {curr_commit}") print(f"当前版本: {curr_commit}")
print(f"最新版本: {new_commit}") print(f"最新版本: {new_commit}")
if curr_commit.startswith(new_commit): if curr_commit.startswith(new_commit):
return "当前已经是最新版本" return f"当前已经是最新版本: v{VERSION}"
else: else:
update_info = f"""有新版本可用。 update_info = f"""有新版本可用。
=== 当前版本 === === 当前版本 ===
{curr_commit} v{VERSION}
=== 新版本 === === 新版本 ===
{update_data[0]['version']} {update_data[0]['version']}
@@ -111,8 +112,8 @@ def update_project(update_data: list,
else: else:
# 更新到最新版本对应的commit # 更新到最新版本对应的commit
try: try:
repo.remotes.origin.fetch() repo.git.fetch()
repo.git.checkout(update_data[0]['tag_name']) repo.git.checkout(update_data[0]['tag_name'], "-f")
if reboot: _reboot() if reboot: _reboot()
except BaseException as e: except BaseException as e:
raise e raise e
@@ -123,8 +124,8 @@ def update_project(update_data: list,
for data in update_data: for data in update_data:
if data['tag_name'] == version: if data['tag_name'] == version:
try: try:
repo.remotes.origin.fetch() repo.git.fetch()
repo.git.checkout(data['tag_name']) repo.git.checkout(data['tag_name'], "-f")
flag = True flag = True
if reboot: _reboot() if reboot: _reboot()
except BaseException as e: except BaseException as e:
@@ -135,9 +136,8 @@ def update_project(update_data: list,
def checkout_branch(branch_name: str): def checkout_branch(branch_name: str):
repo = find_repo() repo = find_repo()
try: try:
origin = repo.remotes.origin repo.git.fetch()
origin.fetch() repo.git.checkout(branch_name, "-f")
repo.git.checkout(branch_name)
repo.git.pull("origin", branch_name, "-f") repo.git.pull("origin", branch_name, "-f")
return True return True
except BaseException as e: except BaseException as e: