From d58c86f6fcda8ae31a2b45aa2cb7c6e8416ba07a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 19 May 2024 12:46:07 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20websearch=20=E4=BC=98=E5=8C=96=EF=BC=9B?= =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E7=BB=93=E6=9E=84=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + addons/dashboard/server.py | 5 +- {cores/astrbot => astrbot}/core.py | 15 +- cores/astrbot/types.py | 193 --------- main.py | 8 +- model/command/command.py | 15 +- model/command/openai_official.py | 5 +- model/platform/_message_parse.py | 2 +- model/platform/qq_gocq.py | 2 +- model/platform/qq_official.py | 2 +- model/provider/openai_official.py | 6 +- model/provider/openai_official_old.py | 409 ------------------ model/provider/provider.py | 4 +- cores/database/conn.py => persist/session.py | 0 type/command.py | 28 ++ type/message.py | 62 +++ type/plugin.py | 27 ++ type/register.py | 46 ++ type/types.py | 34 ++ util/{function_calling => agent}/func_call.py | 0 util/agent/web_searcher.py | 165 +++++++ util/function_calling/gplugin.py | 300 ------------- util/general_utils.py | 11 +- util/plugin_dev/api/v1/bot.py | 16 +- util/plugin_dev/api/v1/message.py | 5 +- util/plugin_dev/api/v1/register.py | 3 +- util/plugin_dev/api/v1/types.py | 2 +- util/plugin_util.py | 18 +- util/search_engine_scraper/bing.py | 38 ++ util/search_engine_scraper/config.py | 20 + util/search_engine_scraper/engine.py | 74 ++++ util/search_engine_scraper/google.py | 23 + util/search_engine_scraper/sogo.py | 49 +++ util/search_engine_scraper/test.py | 22 + 34 files changed, 649 insertions(+), 961 deletions(-) rename {cores/astrbot => astrbot}/core.py (97%) delete mode 100644 cores/astrbot/types.py delete mode 100644 model/provider/openai_official_old.py rename cores/database/conn.py => persist/session.py (100%) create mode 100644 type/command.py create mode 100644 type/message.py create mode 100644 type/plugin.py create mode 100644 type/register.py create mode 100644 type/types.py rename util/{function_calling => agent}/func_call.py (100%) create mode 100644 util/agent/web_searcher.py delete mode 100644 util/function_calling/gplugin.py create mode 100644 util/search_engine_scraper/bing.py create mode 100644 util/search_engine_scraper/config.py create mode 100644 util/search_engine_scraper/engine.py create mode 100644 util/search_engine_scraper/google.py create mode 100644 util/search_engine_scraper/sogo.py create mode 100644 util/search_engine_scraper/test.py diff --git a/.gitignore b/.gitignore index 33d25a74c..71c6b2ef9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ cmd_config.json addons/plugins/ data/* cookies.json +logs/ diff --git a/addons/dashboard/server.py b/addons/dashboard/server.py index 2bb775aec..7ec4ffe21 100644 --- a/addons/dashboard/server.py +++ b/addons/dashboard/server.py @@ -12,10 +12,11 @@ from flask.logging import default_handler from werkzeug.serving import make_server from util import general_utils as gu from dataclasses import dataclass -from cores.database.conn import dbConn +from persist.session import dbConn +from type.register import RegisteredPlugin +from typing import List from util.cmd_config import CmdConfig from util.updator import check_update, update_project, request_release_info -from cores.astrbot.types import * from SparkleLogging.utils.core import LogManager from logging import Logger logger: Logger = LogManager.GetLogger(log_name='astrbot-core') diff --git a/cores/astrbot/core.py b/astrbot/core.py similarity index 97% rename from cores/astrbot/core.py rename to astrbot/core.py index 17aac1df0..0012d3ea2 100644 --- a/cores/astrbot/core.py +++ b/astrbot/core.py @@ -2,17 +2,14 @@ import re import threading import asyncio import time -import aiohttp import util.unfit_words as uw import os import sys -import io import traceback -import util.function_calling.gplugin as gplugin +import util.agent.web_searcher as web_searcher import util.plugin_util as putil -from PIL import Image as PILImage from nakuru.entities.components import Plain, At, Image from addons.baidu_aip_judge import BaiduJudge @@ -22,10 +19,12 @@ from util import general_utils as gu from util.general_utils import upload, run_monitor from util.cmd_config import CmdConfig as cc from util.cmd_config import init_astrbot_config_items -from .types import * +from type.types import GlobalObject +from type.register import * +from type.message import AstrBotMessage from addons.dashboard.helper import DashBoardHelper from addons.dashboard.server import DashBoardData -from cores.database.conn import dbConn +from persist.session import dbConn from model.platform._message_result import MessageResult from SparkleLogging.utils.core import LogManager from logging import Logger @@ -134,7 +133,7 @@ def init(cfg): instance = llm_instance[OPENAI_OFFICIAL] assert isinstance(instance, ProviderOpenAIOfficial) instance.DEFAULT_PERSONALITY = _global_object.default_personality - instance.personality_set(_global_object.default_personality, session_id=None) + instance.curr_personality = instance.DEFAULT_PERSONALITY # 检查provider设置偏好 p = cc.get("chosen_provider", None) @@ -434,7 +433,7 @@ async def oper_msg(message: AstrBotMessage, if chosen_provider == OPENAI_OFFICIAL: if _global_object.web_search or web_sch_flag: official_fc = chosen_provider == OPENAI_OFFICIAL - llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) + llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc) else: llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url) diff --git a/cores/astrbot/types.py b/cores/astrbot/types.py deleted file mode 100644 index c87cb3043..000000000 --- a/cores/astrbot/types.py +++ /dev/null @@ -1,193 +0,0 @@ -from model.provider.provider import Provider as LLMProvider -from model.platform._platfrom import Platform -from nakuru import ( - GroupMessage, - FriendMessage, - GuildMessage, -) -from nakuru.entities.components import BaseMessageComponent -from typing import Union, List, ClassVar, Callable -from types import ModuleType -from enum import Enum -from dataclasses import dataclass - - -class MessageType(Enum): - GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 - FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 - GUILD_MESSAGE = 'GuildMessage' # 频道消息 - - -@dataclass -class MessageMember(): - user_id: str # 发送者id - nickname: str = None - - -class AstrBotMessage(): - ''' - AstrBot 的消息对象 - ''' - tag: str # 消息来源标签 - type: MessageType # 消息类型 - self_id: str # 机器人的识别id - session_id: str # 会话id - message_id: str # 消息id - sender: MessageMember # 发送者 - message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 - message_str: str # 最直观的纯文本消息字符串 - raw_message: object - timestamp: int # 消息时间戳 - - def __str__(self) -> str: - return str(self.__dict__) - - -class PluginType(Enum): - PLATFORM = 'platfrom' # 平台类插件。 - LLM = 'llm' # 大语言模型类插件 - COMMON = 'common' # 其他插件 - - -@dataclass -class PluginMetadata: - ''' - 插件的元数据。 - ''' - # required - plugin_name: str - plugin_type: PluginType - author: str # 插件作者 - desc: str # 插件简介 - version: str # 插件版本 - - # optional - repo: str = None # 插件仓库地址 - - def __str__(self) -> str: - return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" - - -@dataclass -class RegisteredPlugin: - ''' - 注册在 AstrBot 中的插件。 - ''' - metadata: PluginMetadata - plugin_instance: object - module_path: str - module: ModuleType - root_dir_name: str - - def __str__(self) -> str: - return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" - - -RegisteredPlugins = List[RegisteredPlugin] - - -@dataclass -class RegisteredPlatform: - ''' - 注册在 AstrBot 中的平台。平台应当实现 Platform 接口。 - ''' - platform_name: str - platform_instance: Platform - origin: str = None # 注册来源 - - def __str__(self) -> str: - return self.platform_name - - -@dataclass -class RegisteredLLM: - ''' - 注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。 - ''' - llm_name: str - llm_instance: LLMProvider - origin: str = None # 注册来源 - - -class GlobalObject: - ''' - 存放一些公用的数据,用于在不同模块(如core与command)之间传递 - ''' - version: str # 机器人版本 - nick: tuple # 用户定义的机器人的别名 - base_config: dict # config.json 中导出的配置 - cached_plugins: List[RegisteredPlugin] # 加载的插件 - platforms: List[RegisteredPlatform] - llms: List[RegisteredLLM] - - web_search: bool # 是否开启了网页搜索 - reply_prefix: str # 回复前缀 - unique_session: bool # 是否开启了独立会话 - cnt_total: int # 总消息数 - default_personality: dict - dashboard_data = None - - def __init__(self): - self.nick = None # gocq 的昵称 - self.base_config = None # config.yaml - self.cached_plugins = [] # 缓存的插件 - self.web_search = False # 是否开启了网页搜索 - self.reply_prefix = None - self.unique_session = False - self.cnt_total = 0 - self.platforms = [] - self.llms = [] - self.default_personality = None - self.dashboard_data = None - self.stat = {} - - -class AstrMessageEvent(): - ''' - 消息事件。 - ''' - context: GlobalObject # 一些公用数据 - message_str: str # 纯消息字符串 - message_obj: AstrBotMessage # 消息对象 - platform: RegisteredPlatform # 来源平台 - role: str # 基本身份。`admin` 或 `member` - session_id: int # 会话 id - - def __init__(self, - message_str: str, - message_obj: AstrBotMessage, - platform: RegisteredPlatform, - role: str, - context: GlobalObject, - session_id: str = None): - self.context = context - self.message_str = message_str - self.message_obj = message_obj - self.platform = platform - self.role = role - self.session_id = session_id - -@dataclass -class CommandItem(): - ''' - 用来描述单个指令 - ''' - - command_name: Union[str, tuple] # 指令名 - callback: Callable # 回调函数 - description: str # 描述 - origin: str # 注册来源 - -class CommandResult(): - ''' - 用于在Command中返回多个值 - ''' - - def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: - self.hit = hit - self.success = success - self.message_chain = message_chain - self.command_name = command_name - - def _result_tuple(self): - return (self.success, self.message_chain, self.command_name) diff --git a/main.py b/main.py index d8a9622c3..b78e36ffa 100644 --- a/main.py +++ b/main.py @@ -38,7 +38,7 @@ def main(): # config.yaml 配置文件加载和环境确认 try: import botpy, logging, yaml - import cores.astrbot.core as qqBot + import astrbot.core as bot_core # delete qqbotpy's logger for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) @@ -59,7 +59,9 @@ def main(): input("配置文件不存在,请检查是否已经下载配置文件。") exit() except BaseException as e: - raise e + logger.error(traceback.format_exc()) + input("未知错误。") + exit() # 设置代理 if 'http_proxy' in cfg and cfg['http_proxy'] != '': @@ -71,7 +73,7 @@ def main(): make_necessary_dirs() # 启动主程序(cores/qqbot/core.py) - qqBot.init(cfg) + bot_core.init(cfg) def check_env(): diff --git a/model/command/command.py b/model/command/command.py index 99eebb4a6..bc9132e0b 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -13,16 +13,13 @@ from nakuru.entities.components import ( from util import general_utils as gu from model.provider.provider import Provider from util.cmd_config import CmdConfig as cc -from cores.astrbot.types import ( - GlobalObject, - AstrMessageEvent, - PluginType, - CommandResult, - RegisteredPlugin, - RegisteredPlatform -) +from type.message import * +from type.types import GlobalObject +from type.command import * +from type.plugin import * +from type.register import * -from typing import List, Tuple +from typing import List from SparkleLogging.utils.core import LogManager from logging import Logger diff --git a/model/command/openai_official.py b/model/command/openai_official.py index 731c6941c..10d02fd6e 100644 --- a/model/command/openai_official.py +++ b/model/command/openai_official.py @@ -1,10 +1,11 @@ from model.command.command import Command from model.provider.openai_official import ProviderOpenAIOfficial, MODELS from util.personality import personalities -from cores.astrbot.types import GlobalObject, CommandItem +from type.types import GlobalObject +from type.command import CommandItem from SparkleLogging.utils.core import LogManager from logging import Logger -from openai._exceptions import NotFoundError, RateLimitError, APIError +from openai._exceptions import NotFoundError logger: Logger = LogManager.GetLogger(log_name='astrbot-core') diff --git a/model/platform/_message_parse.py b/model/platform/_message_parse.py index a55f0beb0..af25d5955 100644 --- a/model/platform/_message_parse.py +++ b/model/platform/_message_parse.py @@ -5,7 +5,7 @@ from nakuru import ( FriendMessage ) import botpy.message -from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember +from type.message import * from typing import List, Union import time diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py index 20955b101..68bd04557 100644 --- a/model/platform/qq_gocq.py +++ b/model/platform/qq_gocq.py @@ -15,7 +15,7 @@ import time from ._platfrom import Platform from ._message_parse import nakuru_message_parse_rev -from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember +from type.message import * from SparkleLogging.utils.core import LogManager from logging import Logger diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 2995e93b3..daa5a0f5f 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -17,7 +17,7 @@ from ._message_parse import ( qq_official_message_parse_rev, qq_official_message_parse ) -from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember +from type.message import * from typing import Union, List from nakuru.entities.components import BaseMessageComponent from SparkleLogging.utils.core import LogManager diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index a39a49141..fff43ca15 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -12,7 +12,7 @@ from openai.types.images_response import ImagesResponse from openai.types.chat.chat_completion import ChatCompletion from openai._exceptions import * -from cores.database.conn import dbConn +from persist.session import dbConn from model.provider.provider import Provider from util import general_utils as gu from util.cmd_config import CmdConfig @@ -122,6 +122,8 @@ class ProviderOpenAIOfficial(Provider): def personality_set(self, default_personality: dict, session_id: str): if not default_personality: return + if session_id not in self.session_memory: + self.session_memory[session_id] = [] self.curr_personality = default_personality self.session_personality = {} # 重置 encoded_prompt = self.tokenizer.encode(default_personality['prompt']) @@ -282,7 +284,7 @@ class ProviderOpenAIOfficial(Provider): async def text_chat(self, prompt: str, session_id: str, - image_url: None, + image_url: None=None, tools: None=None, extra_conf: Dict = None, **kwargs diff --git a/model/provider/openai_official_old.py b/model/provider/openai_official_old.py deleted file mode 100644 index 72f1c2f89..000000000 --- a/model/provider/openai_official_old.py +++ /dev/null @@ -1,409 +0,0 @@ -import os -import sys -import json -import time -import tiktoken -import threading -import traceback - -from openai import AsyncOpenAI -from openai.types.images_response import ImagesResponse -from openai.types.chat.chat_completion import ChatCompletion - -from cores.database.conn import dbConn -from model.provider.provider import Provider -from util import general_utils as gu -from util.cmd_config import CmdConfig -from SparkleLogging.utils.core import LogManager -from logging import Logger - -logger: Logger = LogManager.GetLogger(log_name='astrbot-core') - - -abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/' - - -class ProviderOpenAIOfficial(Provider): - def __init__(self, cfg): - self.cc = CmdConfig() - - self.key_list = [] - # 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错 - for key in cfg['key']: - if len(key) == 1: - raise BaseException( - "检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。") - if cfg['key'] != '' and cfg['key'] != None: - self.key_list = cfg['key'] - if len(self.key_list) == 0: - raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。") - - self.key_stat = {} - for k in self.key_list: - self.key_stat[k] = {'exceed': False, 'used': 0} - - self.api_base = None - if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '': - self.api_base = cfg['api_base'] - logger.info(f"设置 api_base 为: {self.api_base}") - - # 创建 OpenAI Client - self.client = AsyncOpenAI( - api_key=self.key_list[0], - base_url=self.api_base - ) - - self.openai_model_configs: dict = cfg['chatGPTConfigs'] - self.openai_configs = cfg - # 会话缓存 - self.session_dict = {} - # 最大缓存token - self.max_tokens = cfg['total_tokens_limit'] - # 历史记录持久化间隔时间 - self.history_dump_interval = 20 - - self.enc = tiktoken.get_encoding("cl100k_base") - - # 从 SQLite DB 读取历史记录 - try: - db1 = dbConn() - for session in db1.get_all_session(): - self.session_dict[session[0]] = json.loads(session[1])['data'] - logger.info("读取历史记录成功。") - except BaseException as e: - logger.info("读取历史记录失败,但不影响使用。") - - # 创建转储定时器线程 - threading.Thread(target=self.dump_history, daemon=True).start() - - # 人格 - self.curr_personality = {} - - def make_tmp_client(self, api_key: str, base_url: str): - return AsyncOpenAI( - api_key=api_key, - base_url=base_url - ) - - # 转储历史记录 - def dump_history(self): - time.sleep(10) - db = dbConn() - while True: - try: - # print("转储历史记录...") - for key in self.session_dict: - data = self.session_dict[key] - data_json = { - 'data': data - } - if db.check_session(key): - db.update_session(key, json.dumps(data_json)) - else: - db.insert_session(key, json.dumps(data_json)) - # print("转储历史记录完毕") - except BaseException as e: - print(e) - # 每隔10分钟转储一次 - time.sleep(10*self.history_dump_interval) - - def personality_set(self, default_personality: dict, session_id: str): - self.curr_personality = default_personality - new_record = { - "user": { - "role": "user", - "content": default_personality['prompt'], - }, - "AI": { - "role": "assistant", - "content": "好的,接下来我会扮演这个角色。" - }, - 'type': "personality", - 'usage_tokens': 0, - 'single-tokens': 0 - } - self.session_dict[session_id].append(new_record) - - async def text_chat(self, prompt, - session_id=None, - image_url=None, - function_call=None, - extra_conf: dict = None, - default_personality: dict = None): - if session_id is None: - session_id = "unknown" - if "unknown" in self.session_dict: - del self.session_dict["unknown"] - # 会话机制 - if session_id not in self.session_dict: - self.session_dict[session_id] = [] - - if len(self.session_dict[session_id]) == 0: - # 设置默认人格 - if default_personality is not None: - self.personality_set(default_personality, session_id) - - # 使用 tictoken 截断消息 - _encoded_prompt = self.enc.encode(prompt) - if self.openai_model_configs['max_tokens'] < len(_encoded_prompt): - prompt = self.enc.decode(_encoded_prompt[:int( - self.openai_model_configs['max_tokens']*0.80)]) - logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。") - - cache_data_list, new_record, req = self.wrap( - prompt, session_id, image_url) - logger.debug(f"cache: {str(cache_data_list)}") - logger.debug(f"request: {str(req)}") - retry = 0 - response = None - err = '' - - # 截断倍率 - truncate_rate = 0.75 - - conf = self.openai_model_configs - if extra_conf is not None: - conf.update(extra_conf) - - while retry < 10: - try: - if function_call is None: - response = await self.client.chat.completions.create( - messages=req, - **conf - ) - else: - response = await self.client.chat.completions.create( - messages=req, - tools=function_call, - **conf - ) - break - except Exception as e: - traceback.print_exc() - if 'Invalid content type. image_url is only supported by certain models.' in str(e): - raise e - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - logger.info("当前 Key 已超额或异常, 正在切换", - ) - self.key_stat[self.client.api_key]['exceed'] = True - is_switched = self.handle_switch_key() - if not is_switched: - raise e - retry -= 1 - elif 'maximum context length' in str(e): - logger.info("token 超限, 清空对应缓存,并进行消息截断") - self.session_dict[session_id] = [] - prompt = prompt[:int(len(prompt)*truncate_rate)] - truncate_rate -= 0.05 - cache_data_list, new_record, req = self.wrap( - prompt, session_id) - - elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e): - time.sleep(30) - continue - else: - logger.error(str(e)) - time.sleep(2) - err = str(e) - retry += 1 - if retry >= 10: - logger.warning( - r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki") - raise BaseException("连接出错: "+str(err)) - assert isinstance(response, ChatCompletion) - logger.debug( - f"OPENAI RESPONSE: {response.usage}") - - # 结果分类 - choice = response.choices[0] - if choice.message.content != None: - # 文本形式 - chatgpt_res = str(choice.message.content).strip() - elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0: - # tools call (function calling) - return choice.message.tool_calls[0].function - - self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens - current_usage_tokens = response.usage.total_tokens - - # 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens - if current_usage_tokens > self.max_tokens: - t = current_usage_tokens - index = 0 - while t > self.max_tokens: - if index >= len(cache_data_list): - break - # 保留人格信息 - if cache_data_list[index]['type'] != 'personality': - t -= int(cache_data_list[index]['single_tokens']) - del cache_data_list[index] - else: - index += 1 - # 删除完后更新相关字段 - self.session_dict[session_id] = cache_data_list - - # 添加新条目进入缓存的prompt - new_record['AI'] = { - 'role': 'assistant', - 'content': chatgpt_res, - } - new_record['usage_tokens'] = current_usage_tokens - if len(cache_data_list) > 0: - new_record['single_tokens'] = current_usage_tokens - \ - int(cache_data_list[-1]['usage_tokens']) - else: - new_record['single_tokens'] = current_usage_tokens - - cache_data_list.append(new_record) - - self.session_dict[session_id] = cache_data_list - - return chatgpt_res - - async def image_chat(self, prompt, img_num=1, img_size="1024x1024"): - retry = 0 - image_url = '' - - image_generate_configs = self.cc.get("openai_image_generate", None) - - while retry < 5: - try: - response: ImagesResponse = await self.client.images.generate( - prompt=prompt, - **image_generate_configs - ) - image_url = [] - for i in range(img_num): - image_url.append(response.data[i].url) - break - except Exception as e: - logger.warning(str(e)) - if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str( - e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e): - logger.warning("当前 Key 已超额或者不正常, 正在切换") - self.key_stat[self.client.api_key]['exceed'] = True - is_switched = self.handle_switch_key() - if not is_switched: - raise e - elif 'Your request was rejected as a result of our safety system.' in str(e): - logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试") - raise e - else: - retry += 1 - if retry >= 5: - raise BaseException("连接超时") - - return image_url - - async def forget(self, session_id=None) -> bool: - if session_id is None: - return False - self.session_dict[session_id] = [] - return True - - def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1): - ''' - 获取缓存的会话 - ''' - prompts = "" - if paging: - page_begin = (page-1)*size - page_end = page*size - if page_begin < 0: - page_begin = 0 - if page_end > len(cache_data_list): - page_end = len(cache_data_list) - cache_data_list = cache_data_list[page_begin:page_end] - for item in cache_data_list: - prompts += str(item['user']['role']) + ":\n" + \ - str(item['user']['content']) + "\n" - prompts += str(item['AI']['role']) + ":\n" + \ - str(item['AI']['content']) + "\n" - - if divide: - prompts += "----------\n" - return prompts - - def wrap(self, prompt, session_id, image_url=None): - if image_url is not None: - prompt = [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - } - ] - # 获得缓存信息 - context = self.session_dict[session_id] - new_record = { - "user": { - "role": "user", - "content": prompt, - }, - "AI": {}, - 'type': "common", - 'usage_tokens': 0, - } - req_list = [] - for i in context: - if 'user' in i: - req_list.append(i['user']) - if 'AI' in i: - req_list.append(i['AI']) - req_list.append(new_record['user']) - return context, new_record, req_list - - def handle_switch_key(self): - is_all_exceed = True - for key in self.key_stat: - if key == None or self.key_stat[key]['exceed']: - continue - is_all_exceed = False - self.client.api_key = key - logger.warning( - f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})") - break - if is_all_exceed: - logger.warning( - "所有 Key 已超额") - return False - return True - - def get_configs(self): - return self.openai_configs - - def get_key_stat(self): - return self.key_stat - - def get_key_list(self): - return self.key_list - - def get_curr_key(self): - return self.client.api_key - - def set_key(self, key): - self.client.api_key = key - - # 添加key - def append_key(self, key, sponsor): - self.key_list.append(key) - self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor} - - # 检查key是否可用 - async def check_key(self, key): - client_ = AsyncOpenAI( - api_key=key, - base_url=self.api_base - ) - messages = [{"role": "user", "content": "please just echo `test`"}] - await client_.chat.completions.create( - messages=messages, - **self.openai_model_configs - ) - return True diff --git a/model/provider/provider.py b/model/provider/provider.py index 022ffbedb..8bf437ae8 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -2,8 +2,8 @@ class Provider: async def text_chat(self, prompt: str, session_id: str, - image_url: None, - tools: None, + image_url: None = None, + tools: None = None, extra_conf: dict = None, default_personality: dict = None, **kwargs) -> str: diff --git a/cores/database/conn.py b/persist/session.py similarity index 100% rename from cores/database/conn.py rename to persist/session.py diff --git a/type/command.py b/type/command.py new file mode 100644 index 000000000..0504f947d --- /dev/null +++ b/type/command.py @@ -0,0 +1,28 @@ +from typing import Union, List, Callable +from dataclasses import dataclass + + +@dataclass +class CommandItem(): + ''' + 用来描述单个指令 + ''' + + command_name: Union[str, tuple] # 指令名 + callback: Callable # 回调函数 + description: str # 描述 + origin: str # 注册来源 + +class CommandResult(): + ''' + 用于在Command中返回多个值 + ''' + + def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None: + self.hit = hit + self.success = success + self.message_chain = message_chain + self.command_name = command_name + + def _result_tuple(self): + return (self.success, self.message_chain, self.command_name) diff --git a/type/message.py b/type/message.py new file mode 100644 index 000000000..ca49d51d6 --- /dev/null +++ b/type/message.py @@ -0,0 +1,62 @@ +from enum import Enum +from typing import List +from dataclasses import dataclass +from nakuru.entities.components import BaseMessageComponent + +from type.register import RegisteredPlatform +from type.types import GlobalObject + +class MessageType(Enum): + GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 + FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 + GUILD_MESSAGE = 'GuildMessage' # 频道消息 + + +@dataclass +class MessageMember(): + user_id: str # 发送者id + nickname: str = None + + +class AstrBotMessage(): + ''' + AstrBot 的消息对象 + ''' + tag: str # 消息来源标签 + type: MessageType # 消息类型 + self_id: str # 机器人的识别id + session_id: str # 会话id + message_id: str # 消息id + sender: MessageMember # 发送者 + message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 + message_str: str # 最直观的纯文本消息字符串 + raw_message: object + timestamp: int # 消息时间戳 + + def __str__(self) -> str: + return str(self.__dict__) + +class AstrMessageEvent(): + ''' + 消息事件。 + ''' + context: GlobalObject # 一些公用数据 + message_str: str # 纯消息字符串 + message_obj: AstrBotMessage # 消息对象 + platform: RegisteredPlatform # 来源平台 + role: str # 基本身份。`admin` 或 `member` + session_id: int # 会话 id + + def __init__(self, + message_str: str, + message_obj: AstrBotMessage, + platform: RegisteredPlatform, + role: str, + context: GlobalObject, + session_id: str = None): + self.context = context + self.message_str = message_str + self.message_obj = message_obj + self.platform = platform + self.role = role + self.session_id = session_id diff --git a/type/plugin.py b/type/plugin.py new file mode 100644 index 000000000..cbc878ed3 --- /dev/null +++ b/type/plugin.py @@ -0,0 +1,27 @@ +from enum import Enum +from dataclasses import dataclass + +class PluginType(Enum): + PLATFORM = 'platfrom' # 平台类插件。 + LLM = 'llm' # 大语言模型类插件 + COMMON = 'common' # 其他插件 + + +@dataclass +class PluginMetadata: + ''' + 插件的元数据。 + ''' + # required + plugin_name: str + plugin_type: PluginType + author: str # 插件作者 + desc: str # 插件简介 + version: str # 插件版本 + + # optional + repo: str = None # 插件仓库地址 + + def __str__(self) -> str: + return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})" + diff --git a/type/register.py b/type/register.py new file mode 100644 index 000000000..8d99fc666 --- /dev/null +++ b/type/register.py @@ -0,0 +1,46 @@ +from model.provider.provider import Provider as LLMProvider +from model.platform._platfrom import Platform +from type.plugin import * +from typing import List +from types import ModuleType +from dataclasses import dataclass + +@dataclass +class RegisteredPlugin: + ''' + 注册在 AstrBot 中的插件。 + ''' + metadata: PluginMetadata + plugin_instance: object + module_path: str + module: ModuleType + root_dir_name: str + + def __str__(self) -> str: + return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" + + +RegisteredPlugins = List[RegisteredPlugin] + + +@dataclass +class RegisteredPlatform: + ''' + 注册在 AstrBot 中的平台。平台应当实现 Platform 接口。 + ''' + platform_name: str + platform_instance: Platform + origin: str = None # 注册来源 + + def __str__(self) -> str: + return self.platform_name + + +@dataclass +class RegisteredLLM: + ''' + 注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。 + ''' + llm_name: str + llm_instance: LLMProvider + origin: str = None # 注册来源 diff --git a/type/types.py b/type/types.py new file mode 100644 index 000000000..2f9fd4928 --- /dev/null +++ b/type/types.py @@ -0,0 +1,34 @@ +from type.register import * +from typing import List + +class GlobalObject: + ''' + 存放一些公用的数据,用于在不同模块(如core与command)之间传递 + ''' + version: str # 机器人版本 + nick: tuple # 用户定义的机器人的别名 + base_config: dict # config.json 中导出的配置 + cached_plugins: List[RegisteredPlugin] # 加载的插件 + platforms: List[RegisteredPlatform] + llms: List[RegisteredLLM] + + web_search: bool # 是否开启了网页搜索 + reply_prefix: str # 回复前缀 + unique_session: bool # 是否开启了独立会话 + cnt_total: int # 总消息数 + default_personality: dict + dashboard_data = None + + def __init__(self): + self.nick = None # gocq 的昵称 + self.base_config = None # config.yaml + self.cached_plugins = [] # 缓存的插件 + self.web_search = False # 是否开启了网页搜索 + self.reply_prefix = None + self.unique_session = False + self.cnt_total = 0 + self.platforms = [] + self.llms = [] + self.default_personality = None + self.dashboard_data = None + self.stat = {} diff --git a/util/function_calling/func_call.py b/util/agent/func_call.py similarity index 100% rename from util/function_calling/func_call.py rename to util/agent/func_call.py diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py new file mode 100644 index 000000000..7fa9e35c7 --- /dev/null +++ b/util/agent/web_searcher.py @@ -0,0 +1,165 @@ +import traceback +import random +import json +import asyncio +import aiohttp + +from readability import Document +from bs4 import BeautifulSoup +from openai.types.chat.chat_completion_message_tool_call import Function +from util.agent.func_call import FuncCall +from util.search_engine_scraper.config import HEADERS, USER_AGENTS +from util.search_engine_scraper.bing import Bing +from util.search_engine_scraper.sogo import Sogo +from util.search_engine_scraper.google import Google +from model.provider.provider import Provider +from SparkleLogging.utils.core import LogManager +from logging import Logger + +logger: Logger = LogManager.GetLogger(log_name='astrbot-core') + + +bing_search = Bing() +sogo_search = Sogo() +google = Google() + +def tidy_text(text: str) -> str: + ''' + 清理文本,去除空格、换行符等 + ''' + return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + +# def special_fetch_zhihu(link: str) -> str: +# ''' +# function-calling 函数, 用于获取知乎文章的内容 +# ''' +# response = requests.get(link, headers=HEADERS) +# response.encoding = "utf-8" +# soup = BeautifulSoup(response.text, "html.parser") + +# if "zhuanlan.zhihu.com" in link: +# r = soup.find(class_="Post-RichTextContainer") +# else: +# r = soup.find(class_="List-item").find(class_="RichContent-inner") +# if r is None: +# print("debug: zhihu none") +# raise Exception("zhihu none") +# return tidy_text(r.text) + +async def search_from_bing(keyword: str) -> str: + ''' + tools, 从 bing 搜索引擎搜索 + ''' + logger.info("web_searcher - search_from_bing: " + keyword) + results = await google.search(keyword, 5) + if len(results) == 0: + results = await bing_search.search(keyword, 5) + if len(results) == 0: + results = await sogo_search.search(keyword, 5) + if len(results) == 0: + return "没有搜索到结果" + ret = "" + idx = 1 + for i in results: + logger.info(f"web_searcher - scraping web: {i.title} - {i.url}") + site_result = await fetch_website_content(i.url) + site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result + ret += f"{idx}. {i.title}\n{site_result}\n\n" + idx += 1 + return ret + + +async def fetch_website_content(url): + header = HEADERS + header.update({'User-Agent': random.choice(USER_AGENTS)}) + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=HEADERS, timeout=3) as response: + html = await response.text() + doc = Document(html) + ret = doc.summary(html_partial=True) + soup = BeautifulSoup(ret, 'html.parser') + ret = tidy_text(soup.get_text()) + return ret + + +async def web_search(prompt, provider: Provider, session_id, official_fc=False): + ''' + official_fc: 使用官方 function-calling + ''' + new_func_call = FuncCall(provider) + + new_func_call.add_func("web_search", [{ + "type": "string", + "name": "keyword", + "description": "搜索关键词" + }], + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + search_from_bing + ) + new_func_call.add_func("fetch_website_content", [{ + "type": "string", + "name": "url", + "description": "要获取内容的网页链接" + }], + "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", + fetch_website_content + ) + + has_func = False + function_invoked_ret = "" + if official_fc: + # we use official function-calling + result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func()) + if isinstance(result, Function): + logger.debug(f"web_searcher - function-calling: {result}") + func_obj = None + for i in new_func_call.func_list: + if i["name"] == result.name: + func_obj = i["func_obj"] + break + if not func_obj: + return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" + try: + args = json.loads(result.arguments) + function_invoked_ret = await func_obj(**args) + has_func = True + except BaseException as e: + traceback.print_exc() + return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" + else: + return result + else: + # we use our own function-calling + try: + args = { + 'question': prompt, + 'func_definition': new_func_call.func_dump(), + 'is_task': False, + 'is_summary': False, + } + function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) + except BaseException as e: + res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)" + return res + has_func = True + + if has_func: + await provider.forget(session_id) + summary_prompt = f""" +你是一个专业且高效的助手,你的任务是 +1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; +2. 简单地发表你对这个问题的简略看法。 + +# 例子 +1. 从网上的信息来看,可以知道...我个人认为...你觉得呢? +2. 根据网上的最新信息,可以得知...我觉得...你怎么看? + +# 限制 +1. 限制在 200 字以内; +2. 请**直接输出总结**,不要输出多余的内容和提示语。 + +# 相关材料 +{function_invoked_ret}""" + ret = await provider.text_chat(summary_prompt, session_id) + return ret + return function_invoked_ret diff --git a/util/function_calling/gplugin.py b/util/function_calling/gplugin.py deleted file mode 100644 index 5ddab894e..000000000 --- a/util/function_calling/gplugin.py +++ /dev/null @@ -1,300 +0,0 @@ -import requests -import util.general_utils as gu -import traceback -import time -import json -import asyncio -from googlesearch import search, SearchResult -from readability import Document -from bs4 import BeautifulSoup -from openai.types.chat.chat_completion_message_tool_call import Function -from util.function_calling.func_call import ( - FuncCall, - FuncCallJsonFormatError, - FuncNotFoundError -) -from model.provider.provider import Provider - - -def tidy_text(text: str) -> str: - ''' - 清理文本,去除空格、换行符等 - ''' - return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") - - -def special_fetch_zhihu(link: str) -> str: - ''' - function-calling 函数, 用于获取知乎文章的内容 - ''' - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - response = requests.get(link, headers=headers) - response.encoding = "utf-8" - soup = BeautifulSoup(response.text, "html.parser") - - if "zhuanlan.zhihu.com" in link: - r = soup.find(class_="Post-RichTextContainer") - else: - r = soup.find(class_="List-item").find(class_="RichContent-inner") - if r is None: - print("debug: zhihu none") - raise Exception("zhihu none") - return tidy_text(r.text) - - -def google_web_search(keyword) -> str: - ''' - 获取 google 搜索结果, 得到 title、desc、link - ''' - ret = "" - index = 1 - try: - ls = search(keyword, advanced=True, num_results=4) - for i in ls: - desc = i.description - try: - # gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO) - desc = fetch_website_content(i.url) - except BaseException as e: - print(f"(google) fetch_website_content err: {str(e)}") - # gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999) - ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n" - index += 1 - except Exception as e: - print(f"google search err: {str(e)}") - return web_keyword_search_via_bing(keyword) - return ret - - -def web_keyword_search_via_bing(keyword) -> str: - ''' - 获取bing搜索结果, 得到 title、desc、link - ''' - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - url = "https://www.bing.com/search?q="+keyword - _cnt = 0 - # _detail_store = [] - while _cnt < 5: - try: - response = requests.get(url, headers=headers) - response.encoding = "utf-8" - # gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999) - soup = BeautifulSoup(response.text, "html.parser") - res = "" - result_cnt = 0 - ols = soup.find(id="b_results") - for i in ols.find_all("li", class_="b_algo"): - try: - title = i.find("h2").text - desc = i.find("p").text - link = i.find("h2").find("a").get("href") - # res.append({ - # "title": title, - # "desc": desc, - # "link": link, - # }) - try: - # gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO) - desc = fetch_website_content(link) - except BaseException as e: - print(f"(bing) fetch_website_content err: {str(e)}") - - res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n" - result_cnt += 1 - if result_cnt > 5: - break - - # if len(_detail_store) >= 3: - # continue - # # 爬取前两条的网页内容 - # if "zhihu.com" in link: - # try: - # _detail_store.append(special_fetch_zhihu(link)) - # except BaseException as e: - # print(f"zhihu parse err: {str(e)}") - # else: - # try: - # _detail_store.append(fetch_website_content(link)) - # except BaseException as e: - # print(f"fetch_website_content err: {str(e)}") - - except Exception as e: - print(f"bing parse err: {str(e)}") - if result_cnt == 0: - break - return res - except Exception as e: - # gu.log(f"bing fetch err: {str(e)}") - _cnt += 1 - time.sleep(1) - - # gu.log("fail to fetch bing info, using sougou.") - return web_keyword_search_via_sougou(keyword) - - -def web_keyword_search_via_sougou(keyword) -> str: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", - } - url = f"https://sogou.com/web?query={keyword}" - response = requests.get(url, headers=headers) - response.encoding = "utf-8" - soup = BeautifulSoup(response.text, "html.parser") - - res = [] - results = soup.find("div", class_="results") - for i in results.find_all("div", class_="vrwrap"): - try: - title = tidy_text(i.find("h3").text) - link = tidy_text(i.find("h3").find("a").get("href")) - if link.startswith("/link?url="): - link = "https://www.sogou.com" + link - res.append({ - "title": title, - "link": link, - }) - if len(res) >= 5: # 限制5条 - break - except Exception as e: - pass - # gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR) - # 爬取网页内容 - _detail_store = [] - for i in res: - if _detail_store >= 3: - break - try: - _detail_store.append(fetch_website_content(i["link"])) - except BaseException as e: - print(f"fetch_website_content err: {str(e)}") - ret = f"{str(res)}" - if len(_detail_store) > 0: - ret += f"\n网页内容: {str(_detail_store)}" - return ret - - -def fetch_website_content(url): - # gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG) - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \ - AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - response = requests.get(url, headers=headers, timeout=3) - response.encoding = "utf-8" - doc = Document(response.content) - # print('title:', doc.title()) - ret = doc.summary(html_partial=True) - soup = BeautifulSoup(ret, 'html.parser') - ret = tidy_text(soup.get_text()) - return ret - - -async def web_search(question, provider: Provider, session_id, official_fc=False): - ''' - official_fc: 使用官方 function-calling - ''' - new_func_call = FuncCall(provider) - new_func_call.add_func("google_web_search", [{ - "type": "string", - "name": "keyword", - "description": "google search query (分词,尽量保留所有信息)" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - web_keyword_search_via_bing - ) - new_func_call.add_func("fetch_website_content", [{ - "type": "string", - "name": "url", - "description": "网址" - }], - "获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content - ) - question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。" - has_func = False - function_invoked_ret = "" - if official_fc: - # we use official function-calling - func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func()) - if isinstance(func, Function): - # 执行对应的结果: - func_obj = None - for i in new_func_call.func_list: - if i["name"] == func.name: - func_obj = i["func_obj"] - break - if not func_obj: - # gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR) - return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" - try: - args = json.loads(func.arguments) - # we use to_thread to avoid blocking the event loop - function_invoked_ret = await asyncio.to_thread(func_obj, **args) - has_func = True - except BaseException as e: - traceback.print_exc() - return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)" - else: - # now func is a string - return func - else: - # we use our own function-calling - try: - args = { - 'question': question1, - 'func_definition': new_func_call.func_dump(), - 'is_task': False, - 'is_summary': False, - } - function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args) - except BaseException as e: - res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)" - return res - has_func = True - - if has_func: - await provider.forget(session_id) - question3 = f""" -你的任务是: -1. 根据末尾的材料对问题`{question}`做切题的总结(详细); -2. 简单地发表你对这个问题的看法(简略)。 -你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。 -不要提到任何函数调用的信息。 - -一些回复的消息模板: -模板1: -``` -从网上的信息来看,可以知道...我个人认为...你觉得呢? -``` -模板2: -``` -根据网上的最新信息,可以得知...我觉得...你怎么看? -``` -你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。 - -以下是相关材料: -""" - _c = 0 - while _c < 3: - try: - print('text chat') - final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id) - return final_ret - except Exception as e: - print(e) - _c += 1 - if _c == 3: - raise e - if "The message you submitted was too long" in str(e): - await provider.forget(session_id) - function_invoked_ret = function_invoked_ret[:int( - len(function_invoked_ret) / 2)] - time.sleep(3) - return function_invoked_ret diff --git a/util/general_utils.py b/util/general_utils.py index 1c2a0543f..d9f2d851c 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -9,9 +9,10 @@ import platform import json import sys import psutil +import ssl from PIL import Image, ImageDraw, ImageFont -from cores.astrbot.types import GlobalObject +from type.types import GlobalObject from SparkleLogging.utils.core import LogManager from logging import Logger @@ -378,6 +379,14 @@ async def download_image_by_url(url: str) -> str: async with aiohttp.ClientSession() as session: async with session.get(url) as resp: return save_temp_img(await resp.read()) + except aiohttp.client_exceptions.ClientConnectorSSLError as e: + # 关闭SSL验证 + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get(url, ssl=ssl_context) as resp: + return save_temp_img(await resp.read()) except Exception as e: raise e diff --git a/util/plugin_dev/api/v1/bot.py b/util/plugin_dev/api/v1/bot.py index f855b6132..46a3b7fce 100644 --- a/util/plugin_dev/api/v1/bot.py +++ b/util/plugin_dev/api/v1/bot.py @@ -1,11 +1,5 @@ -from cores.astrbot.types import ( - PluginMetadata, - RegisteredLLM, - RegisteredPlugin, - RegisteredPlatform, - RegisteredPlugins, - PluginType, - GlobalObject, - AstrMessageEvent, - CommandResult -) \ No newline at end of file +from type.plugin import PluginMetadata, PluginType +from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins +from type.types import GlobalObject +from type.message import AstrMessageEvent +from type.command import CommandResult \ No newline at end of file diff --git a/util/plugin_dev/api/v1/message.py b/util/plugin_dev/api/v1/message.py index a0f2853e9..39722f709 100644 --- a/util/plugin_dev/api/v1/message.py +++ b/util/plugin_dev/api/v1/message.py @@ -1,5 +1,6 @@ -from cores.astrbot.core import oper_msg -from cores.astrbot.types import AstrMessageEvent, CommandResult +from astrbot.core import oper_msg +from type.message import AstrMessageEvent, AstrBotMessage +from type.command import CommandResult from model.platform._message_result import MessageResult ''' diff --git a/util/plugin_dev/api/v1/register.py b/util/plugin_dev/api/v1/register.py index bf28cdd68..83ccd5460 100644 --- a/util/plugin_dev/api/v1/register.py +++ b/util/plugin_dev/api/v1/register.py @@ -5,7 +5,8 @@ ''' from model.provider.provider import Provider as LLMProvider from model.platform._platfrom import Platform -from cores.astrbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM +from type.types import GlobalObject +from type.register import RegisteredPlatform, RegisteredLLM def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None: ''' diff --git a/util/plugin_dev/api/v1/types.py b/util/plugin_dev/api/v1/types.py index 3c42be69b..96d07a775 100644 --- a/util/plugin_dev/api/v1/types.py +++ b/util/plugin_dev/api/v1/types.py @@ -2,4 +2,4 @@ 插件类型 ''' -from cores.astrbot.types import PluginType \ No newline at end of file +from type.plugin import PluginType \ No newline at end of file diff --git a/util/plugin_util.py b/util/plugin_util.py index 1045c108a..d8bf57a60 100644 --- a/util/plugin_util.py +++ b/util/plugin_util.py @@ -3,24 +3,18 @@ ''' import os import inspect +import shutil +import stat +import traceback + try: - import git.exc from git.repo import Repo except ImportError: pass -import shutil -import importlib -import stat -import traceback from types import ModuleType -from typing import List from pip._internal import main as pipmain -from cores.astrbot.types import ( - PluginMetadata, - PluginType, - RegisteredPlugin, - RegisteredPlugins -) +from type.plugin import * +from type.register import * # 找出模块里所有的类名 diff --git a/util/search_engine_scraper/bing.py b/util/search_engine_scraper/bing.py new file mode 100644 index 000000000..f51a2a042 --- /dev/null +++ b/util/search_engine_scraper/bing.py @@ -0,0 +1,38 @@ +from typing import List + +try: + from util.search_engine_scraper.engine import SearchEngine, SearchResult + from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING +except ImportError: + from engine import SearchEngine, SearchResult + from config import HEADERS, USER_AGENT_BING + +class Bing(SearchEngine): + def __init__(self) -> None: + super().__init__() + self.base_url = "https://www.bing.com" + self.headers.update({'User-Agent': USER_AGENT_BING}) + + def _set_selector(self, selector: str): + selectors = { + 'url': 'div.b_attribution cite', + 'title': 'h2', + 'text': 'p', + 'links': 'ol#b_results > li.b_algo', + 'next': 'div#b_content nav[role="navigation"] a.sb_pagN' + } + return selectors[selector] + + async def _get_next_page(self, query) -> str: + if self.page == 1: + await self._get_html(self.base_url) + url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl=' + return await self._get_html(url, None) + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + results = await super().search(query, num_results) + for result in results: + if not isinstance(result.url, str): + result.url = result.url.text + + return results \ No newline at end of file diff --git a/util/search_engine_scraper/config.py b/util/search_engine_scraper/config.py new file mode 100644 index 000000000..ab9cec6f8 --- /dev/null +++ b/util/search_engine_scraper/config.py @@ -0,0 +1,20 @@ +HEADERS = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0', + 'Accept': '*/*', + 'Connection': 'keep-alive', + 'Accept-Language': 'en-GB,en;q=0.5' +} + +USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0' +USER_AGENTS = [ + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0', + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36', + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36', + 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0', + 'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0' +] \ No newline at end of file diff --git a/util/search_engine_scraper/engine.py b/util/search_engine_scraper/engine.py new file mode 100644 index 000000000..9adbdea19 --- /dev/null +++ b/util/search_engine_scraper/engine.py @@ -0,0 +1,74 @@ +import random +try: + from util.search_engine_scraper.config import HEADERS, USER_AGENTS +except ImportError: + from config import HEADERS, USER_AGENTS + +from bs4 import BeautifulSoup +from aiohttp import ClientSession +from dataclasses import dataclass +from typing import List + + +@dataclass +class SearchResult(): + title: str + url: str + snippet: str + + def __str__(self) -> str: + return f"{self.title} - {self.url}\n{self.snippet}" + +class SearchEngine(): + ''' + 搜索引擎爬虫基类 + ''' + + def __init__(self) -> None: + self.TIMEOUT = 10 + self.page = 1 + self.headers = HEADERS + + def _set_selector(self, selector: str) -> None: + raise NotImplementedError() + + def _get_next_page(self): + raise NotImplementedError() + + async def _get_html(self, url: str, data: dict = None) -> str: + headers = self.headers + headers["Referer"] = url + headers["User-Agent"] = random.choice(USER_AGENTS) + print(headers) + if data: + async with ClientSession() as session: + async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp: + return await resp.text() + else: + async with ClientSession() as session: + async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp: + return await resp.text() + + + def tidy_text(self, text: str) -> str: + ''' + 清理文本,去除空格、换行符等 + ''' + return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + try: + resp = await self._get_next_page(query) + soup = BeautifulSoup(resp, 'html.parser') + links = soup.select(self._set_selector('links')) + results = [] + for link in links: + title = self.tidy_text(link.select_one(self._set_selector('title')).text) + url = link.select_one(self._set_selector('url')) + snippet = '' + if title and url: + results.append(SearchResult(title=title, url=url, snippet=snippet)) + return results[:num_results] if len(results) > num_results else results + except Exception as e: + raise e \ No newline at end of file diff --git a/util/search_engine_scraper/google.py b/util/search_engine_scraper/google.py new file mode 100644 index 000000000..03109883a --- /dev/null +++ b/util/search_engine_scraper/google.py @@ -0,0 +1,23 @@ +from googlesearch import search + +try: + from util.search_engine_scraper.engine import SearchEngine, SearchResult + from util.search_engine_scraper.config import HEADERS, USER_AGENTS +except ImportError: + from engine import SearchEngine, SearchResult + from config import HEADERS, USER_AGENTS + +from typing import List + +class Google(SearchEngine): + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + index = 1 + results = [] + try: + ls = search(query, advanced=True, num_results=num_results, timeout=3) + for i in ls: + results.append(SearchResult(title=i.title, url=i.url, snippet=i.description)) + except: + pass + return results \ No newline at end of file diff --git a/util/search_engine_scraper/sogo.py b/util/search_engine_scraper/sogo.py new file mode 100644 index 000000000..4ddf88dfc --- /dev/null +++ b/util/search_engine_scraper/sogo.py @@ -0,0 +1,49 @@ +import random, re +from bs4 import BeautifulSoup + +try: + from util.search_engine_scraper.engine import SearchEngine, SearchResult + from util.search_engine_scraper.config import HEADERS, USER_AGENTS +except ImportError: + from engine import SearchEngine, SearchResult + from config import HEADERS, USER_AGENTS + +from typing import List + +class Sogo(SearchEngine): + def __init__(self) -> None: + super().__init__() + self.base_url = "https://www.sogou.com" + self.headers['User-Agent'] = random.choice(USER_AGENTS) + + + def _set_selector(self, selector: str): + selectors = { + 'url': 'h3 > a', + 'title': 'h3', + 'text': '', + 'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)', + 'next': '' + } + return selectors[selector] + + async def _get_next_page(self, query) -> str: + url = f'{self.base_url}/web?query={query}' + return await self._get_html(url, None) + + async def search(self, query: str, num_results: int) -> List[SearchResult]: + results = await super().search(query, num_results) + for result in results: + result.url = result.url.get("href") + if result.url.startswith("/link?"): + result.url = self.base_url + result.url + result.url = await self._parse_url(result.url) + return results + + async def _parse_url(self, url) -> str: + html = await self._get_html(url) + soup = BeautifulSoup(html, 'html.parser') + script = soup.find("script") + if script: + url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1) + return url \ No newline at end of file diff --git a/util/search_engine_scraper/test.py b/util/search_engine_scraper/test.py new file mode 100644 index 000000000..a1be25ee1 --- /dev/null +++ b/util/search_engine_scraper/test.py @@ -0,0 +1,22 @@ +from sogo import Sogo +from bing import Bing + +sogo_search = Sogo() +bing_search = Bing() +async def search(keyword: str) -> str: + results = await sogo_search.search(keyword, 5) + # results = await bing_search.search(keyword, 5) + ret = "" + if len(results) == 0: + return "没有搜索到结果" + + idx = 1 + for i in results: + ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n" + idx += 1 + + return ret + +import asyncio +ret = asyncio.run(search("gpt4orelease")) +print(ret) \ No newline at end of file