perf: websearch 优化;项目结构调整
This commit is contained in:
@@ -10,3 +10,4 @@ cmd_config.json
|
||||
addons/plugins/
|
||||
data/*
|
||||
cookies.json
|
||||
logs/
|
||||
|
||||
@@ -12,10 +12,11 @@ from flask.logging import default_handler
|
||||
from werkzeug.serving import make_server
|
||||
from util import general_utils as gu
|
||||
from dataclasses import dataclass
|
||||
from cores.database.conn import dbConn
|
||||
from persist.session import dbConn
|
||||
from type.register import RegisteredPlugin
|
||||
from typing import List
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.updator import check_update, update_project, request_release_info
|
||||
from cores.astrbot.types import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||
|
||||
@@ -2,17 +2,14 @@ import re
|
||||
import threading
|
||||
import asyncio
|
||||
import time
|
||||
import aiohttp
|
||||
import util.unfit_words as uw
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
import traceback
|
||||
|
||||
import util.function_calling.gplugin as gplugin
|
||||
import util.agent.web_searcher as web_searcher
|
||||
import util.plugin_util as putil
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from nakuru.entities.components import Plain, At, Image
|
||||
|
||||
from addons.baidu_aip_judge import BaiduJudge
|
||||
@@ -22,10 +19,12 @@ from util import general_utils as gu
|
||||
from util.general_utils import upload, run_monitor
|
||||
from util.cmd_config import CmdConfig as cc
|
||||
from util.cmd_config import init_astrbot_config_items
|
||||
from .types import *
|
||||
from type.types import GlobalObject
|
||||
from type.register import *
|
||||
from type.message import AstrBotMessage
|
||||
from addons.dashboard.helper import DashBoardHelper
|
||||
from addons.dashboard.server import DashBoardData
|
||||
from cores.database.conn import dbConn
|
||||
from persist.session import dbConn
|
||||
from model.platform._message_result import MessageResult
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
@@ -134,7 +133,7 @@ def init(cfg):
|
||||
instance = llm_instance[OPENAI_OFFICIAL]
|
||||
assert isinstance(instance, ProviderOpenAIOfficial)
|
||||
instance.DEFAULT_PERSONALITY = _global_object.default_personality
|
||||
instance.personality_set(_global_object.default_personality, session_id=None)
|
||||
instance.curr_personality = instance.DEFAULT_PERSONALITY
|
||||
|
||||
# 检查provider设置偏好
|
||||
p = cc.get("chosen_provider", None)
|
||||
@@ -434,7 +433,7 @@ async def oper_msg(message: AstrBotMessage,
|
||||
if chosen_provider == OPENAI_OFFICIAL:
|
||||
if _global_object.web_search or web_sch_flag:
|
||||
official_fc = chosen_provider == OPENAI_OFFICIAL
|
||||
llm_result_str = await gplugin.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
|
||||
llm_result_str = await web_searcher.web_search(message_str, llm_instance[chosen_provider], session_id, official_fc)
|
||||
else:
|
||||
llm_result_str = await llm_instance[chosen_provider].text_chat(message_str, session_id, image_url)
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
from model.provider.provider import Provider as LLMProvider
|
||||
from model.platform._platfrom import Platform
|
||||
from nakuru import (
|
||||
GroupMessage,
|
||||
FriendMessage,
|
||||
GuildMessage,
|
||||
)
|
||||
from nakuru.entities.components import BaseMessageComponent
|
||||
from typing import Union, List, ClassVar, Callable
|
||||
from types import ModuleType
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
|
||||
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
|
||||
GUILD_MESSAGE = 'GuildMessage' # 频道消息
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageMember():
|
||||
user_id: str # 发送者id
|
||||
nickname: str = None
|
||||
|
||||
|
||||
class AstrBotMessage():
|
||||
'''
|
||||
AstrBot 的消息对象
|
||||
'''
|
||||
tag: str # 消息来源标签
|
||||
type: MessageType # 消息类型
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id
|
||||
message_id: str # 消息id
|
||||
sender: MessageMember # 发送者
|
||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
raw_message: object
|
||||
timestamp: int # 消息时间戳
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
|
||||
class PluginType(Enum):
|
||||
PLATFORM = 'platfrom' # 平台类插件。
|
||||
LLM = 'llm' # 大语言模型类插件
|
||||
COMMON = 'common' # 其他插件
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMetadata:
|
||||
'''
|
||||
插件的元数据。
|
||||
'''
|
||||
# required
|
||||
plugin_name: str
|
||||
plugin_type: PluginType
|
||||
author: str # 插件作者
|
||||
desc: str # 插件简介
|
||||
version: str # 插件版本
|
||||
|
||||
# optional
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredPlugin:
|
||||
'''
|
||||
注册在 AstrBot 中的插件。
|
||||
'''
|
||||
metadata: PluginMetadata
|
||||
plugin_instance: object
|
||||
module_path: str
|
||||
module: ModuleType
|
||||
root_dir_name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
|
||||
|
||||
|
||||
RegisteredPlugins = List[RegisteredPlugin]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredPlatform:
|
||||
'''
|
||||
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
|
||||
'''
|
||||
platform_name: str
|
||||
platform_instance: Platform
|
||||
origin: str = None # 注册来源
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.platform_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredLLM:
|
||||
'''
|
||||
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
|
||||
'''
|
||||
llm_name: str
|
||||
llm_instance: LLMProvider
|
||||
origin: str = None # 注册来源
|
||||
|
||||
|
||||
class GlobalObject:
|
||||
'''
|
||||
存放一些公用的数据,用于在不同模块(如core与command)之间传递
|
||||
'''
|
||||
version: str # 机器人版本
|
||||
nick: tuple # 用户定义的机器人的别名
|
||||
base_config: dict # config.json 中导出的配置
|
||||
cached_plugins: List[RegisteredPlugin] # 加载的插件
|
||||
platforms: List[RegisteredPlatform]
|
||||
llms: List[RegisteredLLM]
|
||||
|
||||
web_search: bool # 是否开启了网页搜索
|
||||
reply_prefix: str # 回复前缀
|
||||
unique_session: bool # 是否开启了独立会话
|
||||
cnt_total: int # 总消息数
|
||||
default_personality: dict
|
||||
dashboard_data = None
|
||||
|
||||
def __init__(self):
|
||||
self.nick = None # gocq 的昵称
|
||||
self.base_config = None # config.yaml
|
||||
self.cached_plugins = [] # 缓存的插件
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
self.reply_prefix = None
|
||||
self.unique_session = False
|
||||
self.cnt_total = 0
|
||||
self.platforms = []
|
||||
self.llms = []
|
||||
self.default_personality = None
|
||||
self.dashboard_data = None
|
||||
self.stat = {}
|
||||
|
||||
|
||||
class AstrMessageEvent():
|
||||
'''
|
||||
消息事件。
|
||||
'''
|
||||
context: GlobalObject # 一些公用数据
|
||||
message_str: str # 纯消息字符串
|
||||
message_obj: AstrBotMessage # 消息对象
|
||||
platform: RegisteredPlatform # 来源平台
|
||||
role: str # 基本身份。`admin` 或 `member`
|
||||
session_id: int # 会话 id
|
||||
|
||||
def __init__(self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform: RegisteredPlatform,
|
||||
role: str,
|
||||
context: GlobalObject,
|
||||
session_id: str = None):
|
||||
self.context = context
|
||||
self.message_str = message_str
|
||||
self.message_obj = message_obj
|
||||
self.platform = platform
|
||||
self.role = role
|
||||
self.session_id = session_id
|
||||
|
||||
@dataclass
|
||||
class CommandItem():
|
||||
'''
|
||||
用来描述单个指令
|
||||
'''
|
||||
|
||||
command_name: Union[str, tuple] # 指令名
|
||||
callback: Callable # 回调函数
|
||||
description: str # 描述
|
||||
origin: str # 注册来源
|
||||
|
||||
class CommandResult():
|
||||
'''
|
||||
用于在Command中返回多个值
|
||||
'''
|
||||
|
||||
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
|
||||
self.hit = hit
|
||||
self.success = success
|
||||
self.message_chain = message_chain
|
||||
self.command_name = command_name
|
||||
|
||||
def _result_tuple(self):
|
||||
return (self.success, self.message_chain, self.command_name)
|
||||
@@ -38,7 +38,7 @@ def main():
|
||||
# config.yaml 配置文件加载和环境确认
|
||||
try:
|
||||
import botpy, logging, yaml
|
||||
import cores.astrbot.core as qqBot
|
||||
import astrbot.core as bot_core
|
||||
# delete qqbotpy's logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
@@ -59,7 +59,9 @@ def main():
|
||||
input("配置文件不存在,请检查是否已经下载配置文件。")
|
||||
exit()
|
||||
except BaseException as e:
|
||||
raise e
|
||||
logger.error(traceback.format_exc())
|
||||
input("未知错误。")
|
||||
exit()
|
||||
|
||||
# 设置代理
|
||||
if 'http_proxy' in cfg and cfg['http_proxy'] != '':
|
||||
@@ -71,7 +73,7 @@ def main():
|
||||
make_necessary_dirs()
|
||||
|
||||
# 启动主程序(cores/qqbot/core.py)
|
||||
qqBot.init(cfg)
|
||||
bot_core.init(cfg)
|
||||
|
||||
|
||||
def check_env():
|
||||
|
||||
@@ -13,16 +13,13 @@ from nakuru.entities.components import (
|
||||
from util import general_utils as gu
|
||||
from model.provider.provider import Provider
|
||||
from util.cmd_config import CmdConfig as cc
|
||||
from cores.astrbot.types import (
|
||||
GlobalObject,
|
||||
AstrMessageEvent,
|
||||
PluginType,
|
||||
CommandResult,
|
||||
RegisteredPlugin,
|
||||
RegisteredPlatform
|
||||
)
|
||||
from type.message import *
|
||||
from type.types import GlobalObject
|
||||
from type.command import *
|
||||
from type.plugin import *
|
||||
from type.register import *
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from model.command.command import Command
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
|
||||
from util.personality import personalities
|
||||
from cores.astrbot.types import GlobalObject, CommandItem
|
||||
from type.types import GlobalObject
|
||||
from type.command import CommandItem
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from openai._exceptions import NotFoundError, RateLimitError, APIError
|
||||
from openai._exceptions import NotFoundError
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from nakuru import (
|
||||
FriendMessage
|
||||
)
|
||||
import botpy.message
|
||||
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
|
||||
from type.message import *
|
||||
from typing import List, Union
|
||||
import time
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import time
|
||||
|
||||
from ._platfrom import Platform
|
||||
from ._message_parse import nakuru_message_parse_rev
|
||||
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
|
||||
from type.message import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from ._message_parse import (
|
||||
qq_official_message_parse_rev,
|
||||
qq_official_message_parse
|
||||
)
|
||||
from cores.astrbot.types import MessageType, AstrBotMessage, MessageMember
|
||||
from type.message import *
|
||||
from typing import Union, List
|
||||
from nakuru.entities.components import BaseMessageComponent
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
|
||||
@@ -12,7 +12,7 @@ from openai.types.images_response import ImagesResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import *
|
||||
|
||||
from cores.database.conn import dbConn
|
||||
from persist.session import dbConn
|
||||
from model.provider.provider import Provider
|
||||
from util import general_utils as gu
|
||||
from util.cmd_config import CmdConfig
|
||||
@@ -122,6 +122,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
def personality_set(self, default_personality: dict, session_id: str):
|
||||
if not default_personality: return
|
||||
if session_id not in self.session_memory:
|
||||
self.session_memory[session_id] = []
|
||||
self.curr_personality = default_personality
|
||||
self.session_personality = {} # 重置
|
||||
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
|
||||
@@ -282,7 +284,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_url: None,
|
||||
image_url: None=None,
|
||||
tools: None=None,
|
||||
extra_conf: Dict = None,
|
||||
**kwargs
|
||||
|
||||
@@ -1,409 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import tiktoken
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from cores.database.conn import dbConn
|
||||
from model.provider.provider import Provider
|
||||
from util import general_utils as gu
|
||||
from util.cmd_config import CmdConfig
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||
|
||||
|
||||
abs_path = os.path.dirname(os.path.realpath(sys.argv[0])) + '/'
|
||||
|
||||
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, cfg):
|
||||
self.cc = CmdConfig()
|
||||
|
||||
self.key_list = []
|
||||
# 如果 cfg['key'] 中有长度为 1 的字符串,那么是格式错误,直接报错
|
||||
for key in cfg['key']:
|
||||
if len(key) == 1:
|
||||
raise BaseException(
|
||||
"检查到了长度为 1 的Key。配置文件中的 openai.key 处的格式错误 (符号 - 的后面要加空格)。")
|
||||
if cfg['key'] != '' and cfg['key'] != None:
|
||||
self.key_list = cfg['key']
|
||||
if len(self.key_list) == 0:
|
||||
raise Exception("您打开了 OpenAI 模型服务,但是未填写 key。请前往填写。")
|
||||
|
||||
self.key_stat = {}
|
||||
for k in self.key_list:
|
||||
self.key_stat[k] = {'exceed': False, 'used': 0}
|
||||
|
||||
self.api_base = None
|
||||
if 'api_base' in cfg and cfg['api_base'] != 'none' and cfg['api_base'] != '':
|
||||
self.api_base = cfg['api_base']
|
||||
logger.info(f"设置 api_base 为: {self.api_base}")
|
||||
|
||||
# 创建 OpenAI Client
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.key_list[0],
|
||||
base_url=self.api_base
|
||||
)
|
||||
|
||||
self.openai_model_configs: dict = cfg['chatGPTConfigs']
|
||||
self.openai_configs = cfg
|
||||
# 会话缓存
|
||||
self.session_dict = {}
|
||||
# 最大缓存token
|
||||
self.max_tokens = cfg['total_tokens_limit']
|
||||
# 历史记录持久化间隔时间
|
||||
self.history_dump_interval = 20
|
||||
|
||||
self.enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# 从 SQLite DB 读取历史记录
|
||||
try:
|
||||
db1 = dbConn()
|
||||
for session in db1.get_all_session():
|
||||
self.session_dict[session[0]] = json.loads(session[1])['data']
|
||||
logger.info("读取历史记录成功。")
|
||||
except BaseException as e:
|
||||
logger.info("读取历史记录失败,但不影响使用。")
|
||||
|
||||
# 创建转储定时器线程
|
||||
threading.Thread(target=self.dump_history, daemon=True).start()
|
||||
|
||||
# 人格
|
||||
self.curr_personality = {}
|
||||
|
||||
def make_tmp_client(self, api_key: str, base_url: str):
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
# 转储历史记录
|
||||
def dump_history(self):
|
||||
time.sleep(10)
|
||||
db = dbConn()
|
||||
while True:
|
||||
try:
|
||||
# print("转储历史记录...")
|
||||
for key in self.session_dict:
|
||||
data = self.session_dict[key]
|
||||
data_json = {
|
||||
'data': data
|
||||
}
|
||||
if db.check_session(key):
|
||||
db.update_session(key, json.dumps(data_json))
|
||||
else:
|
||||
db.insert_session(key, json.dumps(data_json))
|
||||
# print("转储历史记录完毕")
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
# 每隔10分钟转储一次
|
||||
time.sleep(10*self.history_dump_interval)
|
||||
|
||||
def personality_set(self, default_personality: dict, session_id: str):
|
||||
self.curr_personality = default_personality
|
||||
new_record = {
|
||||
"user": {
|
||||
"role": "user",
|
||||
"content": default_personality['prompt'],
|
||||
},
|
||||
"AI": {
|
||||
"role": "assistant",
|
||||
"content": "好的,接下来我会扮演这个角色。"
|
||||
},
|
||||
'type': "personality",
|
||||
'usage_tokens': 0,
|
||||
'single-tokens': 0
|
||||
}
|
||||
self.session_dict[session_id].append(new_record)
|
||||
|
||||
async def text_chat(self, prompt,
|
||||
session_id=None,
|
||||
image_url=None,
|
||||
function_call=None,
|
||||
extra_conf: dict = None,
|
||||
default_personality: dict = None):
|
||||
if session_id is None:
|
||||
session_id = "unknown"
|
||||
if "unknown" in self.session_dict:
|
||||
del self.session_dict["unknown"]
|
||||
# 会话机制
|
||||
if session_id not in self.session_dict:
|
||||
self.session_dict[session_id] = []
|
||||
|
||||
if len(self.session_dict[session_id]) == 0:
|
||||
# 设置默认人格
|
||||
if default_personality is not None:
|
||||
self.personality_set(default_personality, session_id)
|
||||
|
||||
# 使用 tictoken 截断消息
|
||||
_encoded_prompt = self.enc.encode(prompt)
|
||||
if self.openai_model_configs['max_tokens'] < len(_encoded_prompt):
|
||||
prompt = self.enc.decode(_encoded_prompt[:int(
|
||||
self.openai_model_configs['max_tokens']*0.80)])
|
||||
logger.info(f"注意,有一部分 prompt 文本由于超出 token 限制而被截断。")
|
||||
|
||||
cache_data_list, new_record, req = self.wrap(
|
||||
prompt, session_id, image_url)
|
||||
logger.debug(f"cache: {str(cache_data_list)}")
|
||||
logger.debug(f"request: {str(req)}")
|
||||
retry = 0
|
||||
response = None
|
||||
err = ''
|
||||
|
||||
# 截断倍率
|
||||
truncate_rate = 0.75
|
||||
|
||||
conf = self.openai_model_configs
|
||||
if extra_conf is not None:
|
||||
conf.update(extra_conf)
|
||||
|
||||
while retry < 10:
|
||||
try:
|
||||
if function_call is None:
|
||||
response = await self.client.chat.completions.create(
|
||||
messages=req,
|
||||
**conf
|
||||
)
|
||||
else:
|
||||
response = await self.client.chat.completions.create(
|
||||
messages=req,
|
||||
tools=function_call,
|
||||
**conf
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if 'Invalid content type. image_url is only supported by certain models.' in str(e):
|
||||
raise e
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
logger.info("当前 Key 已超额或异常, 正在切换",
|
||||
)
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
raise e
|
||||
retry -= 1
|
||||
elif 'maximum context length' in str(e):
|
||||
logger.info("token 超限, 清空对应缓存,并进行消息截断")
|
||||
self.session_dict[session_id] = []
|
||||
prompt = prompt[:int(len(prompt)*truncate_rate)]
|
||||
truncate_rate -= 0.05
|
||||
cache_data_list, new_record, req = self.wrap(
|
||||
prompt, session_id)
|
||||
|
||||
elif 'Limit: 3 / min. Please try again in 20s.' in str(e) or "OpenAI response error" in str(e):
|
||||
time.sleep(30)
|
||||
continue
|
||||
else:
|
||||
logger.error(str(e))
|
||||
time.sleep(2)
|
||||
err = str(e)
|
||||
retry += 1
|
||||
if retry >= 10:
|
||||
logger.warning(
|
||||
r"如果报错, 且您的机器在中国大陆内, 请确保您的电脑已经设置好代理软件(梯子), 并在配置文件设置了系统代理地址。详见 https://github.com/Soulter/QQChannelChatGPT/wiki")
|
||||
raise BaseException("连接出错: "+str(err))
|
||||
assert isinstance(response, ChatCompletion)
|
||||
logger.debug(
|
||||
f"OPENAI RESPONSE: {response.usage}")
|
||||
|
||||
# 结果分类
|
||||
choice = response.choices[0]
|
||||
if choice.message.content != None:
|
||||
# 文本形式
|
||||
chatgpt_res = str(choice.message.content).strip()
|
||||
elif choice.message.tool_calls != None and len(choice.message.tool_calls) > 0:
|
||||
# tools call (function calling)
|
||||
return choice.message.tool_calls[0].function
|
||||
|
||||
self.key_stat[self.client.api_key]['used'] += response.usage.total_tokens
|
||||
current_usage_tokens = response.usage.total_tokens
|
||||
|
||||
# 超过指定tokens, 尽可能的保留最多的条目,直到小于max_tokens
|
||||
if current_usage_tokens > self.max_tokens:
|
||||
t = current_usage_tokens
|
||||
index = 0
|
||||
while t > self.max_tokens:
|
||||
if index >= len(cache_data_list):
|
||||
break
|
||||
# 保留人格信息
|
||||
if cache_data_list[index]['type'] != 'personality':
|
||||
t -= int(cache_data_list[index]['single_tokens'])
|
||||
del cache_data_list[index]
|
||||
else:
|
||||
index += 1
|
||||
# 删除完后更新相关字段
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
|
||||
# 添加新条目进入缓存的prompt
|
||||
new_record['AI'] = {
|
||||
'role': 'assistant',
|
||||
'content': chatgpt_res,
|
||||
}
|
||||
new_record['usage_tokens'] = current_usage_tokens
|
||||
if len(cache_data_list) > 0:
|
||||
new_record['single_tokens'] = current_usage_tokens - \
|
||||
int(cache_data_list[-1]['usage_tokens'])
|
||||
else:
|
||||
new_record['single_tokens'] = current_usage_tokens
|
||||
|
||||
cache_data_list.append(new_record)
|
||||
|
||||
self.session_dict[session_id] = cache_data_list
|
||||
|
||||
return chatgpt_res
|
||||
|
||||
async def image_chat(self, prompt, img_num=1, img_size="1024x1024"):
|
||||
retry = 0
|
||||
image_url = ''
|
||||
|
||||
image_generate_configs = self.cc.get("openai_image_generate", None)
|
||||
|
||||
while retry < 5:
|
||||
try:
|
||||
response: ImagesResponse = await self.client.images.generate(
|
||||
prompt=prompt,
|
||||
**image_generate_configs
|
||||
)
|
||||
image_url = []
|
||||
for i in range(img_num):
|
||||
image_url.append(response.data[i].url)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
if 'You exceeded' in str(e) or 'Billing hard limit has been reached' in str(
|
||||
e) or 'No API key provided' in str(e) or 'Incorrect API key provided' in str(e):
|
||||
logger.warning("当前 Key 已超额或者不正常, 正在切换")
|
||||
self.key_stat[self.client.api_key]['exceed'] = True
|
||||
is_switched = self.handle_switch_key()
|
||||
if not is_switched:
|
||||
raise e
|
||||
elif 'Your request was rejected as a result of our safety system.' in str(e):
|
||||
logger.warning("您的请求被 OpenAI 安全系统拒绝, 请稍后再试")
|
||||
raise e
|
||||
else:
|
||||
retry += 1
|
||||
if retry >= 5:
|
||||
raise BaseException("连接超时")
|
||||
|
||||
return image_url
|
||||
|
||||
async def forget(self, session_id=None) -> bool:
|
||||
if session_id is None:
|
||||
return False
|
||||
self.session_dict[session_id] = []
|
||||
return True
|
||||
|
||||
def get_prompts_by_cache_list(self, cache_data_list, divide=False, paging=False, size=5, page=1):
|
||||
'''
|
||||
获取缓存的会话
|
||||
'''
|
||||
prompts = ""
|
||||
if paging:
|
||||
page_begin = (page-1)*size
|
||||
page_end = page*size
|
||||
if page_begin < 0:
|
||||
page_begin = 0
|
||||
if page_end > len(cache_data_list):
|
||||
page_end = len(cache_data_list)
|
||||
cache_data_list = cache_data_list[page_begin:page_end]
|
||||
for item in cache_data_list:
|
||||
prompts += str(item['user']['role']) + ":\n" + \
|
||||
str(item['user']['content']) + "\n"
|
||||
prompts += str(item['AI']['role']) + ":\n" + \
|
||||
str(item['AI']['content']) + "\n"
|
||||
|
||||
if divide:
|
||||
prompts += "----------\n"
|
||||
return prompts
|
||||
|
||||
def wrap(self, prompt, session_id, image_url=None):
|
||||
if image_url is not None:
|
||||
prompt = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}
|
||||
]
|
||||
# 获得缓存信息
|
||||
context = self.session_dict[session_id]
|
||||
new_record = {
|
||||
"user": {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
"AI": {},
|
||||
'type': "common",
|
||||
'usage_tokens': 0,
|
||||
}
|
||||
req_list = []
|
||||
for i in context:
|
||||
if 'user' in i:
|
||||
req_list.append(i['user'])
|
||||
if 'AI' in i:
|
||||
req_list.append(i['AI'])
|
||||
req_list.append(new_record['user'])
|
||||
return context, new_record, req_list
|
||||
|
||||
def handle_switch_key(self):
|
||||
is_all_exceed = True
|
||||
for key in self.key_stat:
|
||||
if key == None or self.key_stat[key]['exceed']:
|
||||
continue
|
||||
is_all_exceed = False
|
||||
self.client.api_key = key
|
||||
logger.warning(
|
||||
f"切换到 Key: {key}(已使用 token: {self.key_stat[key]['used']})")
|
||||
break
|
||||
if is_all_exceed:
|
||||
logger.warning(
|
||||
"所有 Key 已超额")
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_configs(self):
|
||||
return self.openai_configs
|
||||
|
||||
def get_key_stat(self):
|
||||
return self.key_stat
|
||||
|
||||
def get_key_list(self):
|
||||
return self.key_list
|
||||
|
||||
def get_curr_key(self):
|
||||
return self.client.api_key
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
|
||||
# 添加key
|
||||
def append_key(self, key, sponsor):
|
||||
self.key_list.append(key)
|
||||
self.key_stat[key] = {'exceed': False, 'used': 0, 'sponsor': sponsor}
|
||||
|
||||
# 检查key是否可用
|
||||
async def check_key(self, key):
|
||||
client_ = AsyncOpenAI(
|
||||
api_key=key,
|
||||
base_url=self.api_base
|
||||
)
|
||||
messages = [{"role": "user", "content": "please just echo `test`"}]
|
||||
await client_.chat.completions.create(
|
||||
messages=messages,
|
||||
**self.openai_model_configs
|
||||
)
|
||||
return True
|
||||
@@ -2,8 +2,8 @@ class Provider:
|
||||
async def text_chat(self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_url: None,
|
||||
tools: None,
|
||||
image_url: None = None,
|
||||
tools: None = None,
|
||||
extra_conf: dict = None,
|
||||
default_personality: dict = None,
|
||||
**kwargs) -> str:
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import Union, List, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandItem():
|
||||
'''
|
||||
用来描述单个指令
|
||||
'''
|
||||
|
||||
command_name: Union[str, tuple] # 指令名
|
||||
callback: Callable # 回调函数
|
||||
description: str # 描述
|
||||
origin: str # 注册来源
|
||||
|
||||
class CommandResult():
|
||||
'''
|
||||
用于在Command中返回多个值
|
||||
'''
|
||||
|
||||
def __init__(self, hit: bool, success: bool, message_chain: list, command_name: str = "unknown_command") -> None:
|
||||
self.hit = hit
|
||||
self.success = success
|
||||
self.message_chain = message_chain
|
||||
self.command_name = command_name
|
||||
|
||||
def _result_tuple(self):
|
||||
return (self.success, self.message_chain, self.command_name)
|
||||
@@ -0,0 +1,62 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from nakuru.entities.components import BaseMessageComponent
|
||||
|
||||
from type.register import RegisteredPlatform
|
||||
from type.types import GlobalObject
|
||||
|
||||
class MessageType(Enum):
|
||||
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
|
||||
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
|
||||
GUILD_MESSAGE = 'GuildMessage' # 频道消息
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageMember():
|
||||
user_id: str # 发送者id
|
||||
nickname: str = None
|
||||
|
||||
|
||||
class AstrBotMessage():
|
||||
'''
|
||||
AstrBot 的消息对象
|
||||
'''
|
||||
tag: str # 消息来源标签
|
||||
type: MessageType # 消息类型
|
||||
self_id: str # 机器人的识别id
|
||||
session_id: str # 会话id
|
||||
message_id: str # 消息id
|
||||
sender: MessageMember # 发送者
|
||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||
message_str: str # 最直观的纯文本消息字符串
|
||||
raw_message: object
|
||||
timestamp: int # 消息时间戳
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.__dict__)
|
||||
|
||||
class AstrMessageEvent():
|
||||
'''
|
||||
消息事件。
|
||||
'''
|
||||
context: GlobalObject # 一些公用数据
|
||||
message_str: str # 纯消息字符串
|
||||
message_obj: AstrBotMessage # 消息对象
|
||||
platform: RegisteredPlatform # 来源平台
|
||||
role: str # 基本身份。`admin` 或 `member`
|
||||
session_id: int # 会话 id
|
||||
|
||||
def __init__(self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform: RegisteredPlatform,
|
||||
role: str,
|
||||
context: GlobalObject,
|
||||
session_id: str = None):
|
||||
self.context = context
|
||||
self.message_str = message_str
|
||||
self.message_obj = message_obj
|
||||
self.platform = platform
|
||||
self.role = role
|
||||
self.session_id = session_id
|
||||
@@ -0,0 +1,27 @@
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
class PluginType(Enum):
|
||||
PLATFORM = 'platfrom' # 平台类插件。
|
||||
LLM = 'llm' # 大语言模型类插件
|
||||
COMMON = 'common' # 其他插件
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMetadata:
|
||||
'''
|
||||
插件的元数据。
|
||||
'''
|
||||
# required
|
||||
plugin_name: str
|
||||
plugin_type: PluginType
|
||||
author: str # 插件作者
|
||||
desc: str # 插件简介
|
||||
version: str # 插件版本
|
||||
|
||||
# optional
|
||||
repo: str = None # 插件仓库地址
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"PluginMetadata({self.plugin_name}, {self.plugin_type}, {self.desc}, {self.version}, {self.repo})"
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from model.provider.provider import Provider as LLMProvider
|
||||
from model.platform._platfrom import Platform
|
||||
from type.plugin import *
|
||||
from typing import List
|
||||
from types import ModuleType
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class RegisteredPlugin:
|
||||
'''
|
||||
注册在 AstrBot 中的插件。
|
||||
'''
|
||||
metadata: PluginMetadata
|
||||
plugin_instance: object
|
||||
module_path: str
|
||||
module: ModuleType
|
||||
root_dir_name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})"
|
||||
|
||||
|
||||
RegisteredPlugins = List[RegisteredPlugin]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredPlatform:
|
||||
'''
|
||||
注册在 AstrBot 中的平台。平台应当实现 Platform 接口。
|
||||
'''
|
||||
platform_name: str
|
||||
platform_instance: Platform
|
||||
origin: str = None # 注册来源
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.platform_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredLLM:
|
||||
'''
|
||||
注册在 AstrBot 中的大语言模型调用。大语言模型应当实现 LLMProvider 接口。
|
||||
'''
|
||||
llm_name: str
|
||||
llm_instance: LLMProvider
|
||||
origin: str = None # 注册来源
|
||||
@@ -0,0 +1,34 @@
|
||||
from type.register import *
|
||||
from typing import List
|
||||
|
||||
class GlobalObject:
|
||||
'''
|
||||
存放一些公用的数据,用于在不同模块(如core与command)之间传递
|
||||
'''
|
||||
version: str # 机器人版本
|
||||
nick: tuple # 用户定义的机器人的别名
|
||||
base_config: dict # config.json 中导出的配置
|
||||
cached_plugins: List[RegisteredPlugin] # 加载的插件
|
||||
platforms: List[RegisteredPlatform]
|
||||
llms: List[RegisteredLLM]
|
||||
|
||||
web_search: bool # 是否开启了网页搜索
|
||||
reply_prefix: str # 回复前缀
|
||||
unique_session: bool # 是否开启了独立会话
|
||||
cnt_total: int # 总消息数
|
||||
default_personality: dict
|
||||
dashboard_data = None
|
||||
|
||||
def __init__(self):
|
||||
self.nick = None # gocq 的昵称
|
||||
self.base_config = None # config.yaml
|
||||
self.cached_plugins = [] # 缓存的插件
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
self.reply_prefix = None
|
||||
self.unique_session = False
|
||||
self.cnt_total = 0
|
||||
self.platforms = []
|
||||
self.llms = []
|
||||
self.default_personality = None
|
||||
self.dashboard_data = None
|
||||
self.stat = {}
|
||||
@@ -0,0 +1,165 @@
|
||||
import traceback
|
||||
import random
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from util.agent.func_call import FuncCall
|
||||
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
|
||||
from util.search_engine_scraper.bing import Bing
|
||||
from util.search_engine_scraper.sogo import Sogo
|
||||
from util.search_engine_scraper.google import Google
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot-core')
|
||||
|
||||
|
||||
bing_search = Bing()
|
||||
sogo_search = Sogo()
|
||||
google = Google()
|
||||
|
||||
def tidy_text(text: str) -> str:
|
||||
'''
|
||||
清理文本,去除空格、换行符等
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
# def special_fetch_zhihu(link: str) -> str:
|
||||
# '''
|
||||
# function-calling 函数, 用于获取知乎文章的内容
|
||||
# '''
|
||||
# response = requests.get(link, headers=HEADERS)
|
||||
# response.encoding = "utf-8"
|
||||
# soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# if "zhuanlan.zhihu.com" in link:
|
||||
# r = soup.find(class_="Post-RichTextContainer")
|
||||
# else:
|
||||
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
|
||||
# if r is None:
|
||||
# print("debug: zhihu none")
|
||||
# raise Exception("zhihu none")
|
||||
# return tidy_text(r.text)
|
||||
|
||||
async def search_from_bing(keyword: str) -> str:
|
||||
'''
|
||||
tools, 从 bing 搜索引擎搜索
|
||||
'''
|
||||
logger.info("web_searcher - search_from_bing: " + keyword)
|
||||
results = await google.search(keyword, 5)
|
||||
if len(results) == 0:
|
||||
results = await bing_search.search(keyword, 5)
|
||||
if len(results) == 0:
|
||||
results = await sogo_search.search(keyword, 5)
|
||||
if len(results) == 0:
|
||||
return "没有搜索到结果"
|
||||
ret = ""
|
||||
idx = 1
|
||||
for i in results:
|
||||
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
|
||||
site_result = await fetch_website_content(i.url)
|
||||
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
|
||||
ret += f"{idx}. {i.title}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
return ret
|
||||
|
||||
|
||||
async def fetch_website_content(url):
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=HEADERS, timeout=3) as response:
|
||||
html = await response.text()
|
||||
doc = Document(html)
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
|
||||
async def web_search(prompt, provider: Provider, session_id, official_fc=False):
|
||||
'''
|
||||
official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
|
||||
new_func_call.add_func("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
# we use official function-calling
|
||||
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func())
|
||||
if isinstance(result, Function):
|
||||
logger.debug(f"web_searcher - function-calling: {result}")
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(result.arguments)
|
||||
function_invoked_ret = await func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
# we use our own function-calling
|
||||
try:
|
||||
args = {
|
||||
'question': prompt,
|
||||
'func_definition': new_func_call.func_dump(),
|
||||
'is_task': False,
|
||||
'is_summary': False,
|
||||
}
|
||||
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
|
||||
except BaseException as e:
|
||||
res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
has_func = True
|
||||
|
||||
if has_func:
|
||||
await provider.forget(session_id)
|
||||
summary_prompt = f"""
|
||||
你是一个专业且高效的助手,你的任务是
|
||||
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
|
||||
2. 简单地发表你对这个问题的简略看法。
|
||||
|
||||
# 例子
|
||||
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
|
||||
2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
|
||||
|
||||
# 限制
|
||||
1. 限制在 200 字以内;
|
||||
2. 请**直接输出总结**,不要输出多余的内容和提示语。
|
||||
|
||||
# 相关材料
|
||||
{function_invoked_ret}"""
|
||||
ret = await provider.text_chat(summary_prompt, session_id)
|
||||
return ret
|
||||
return function_invoked_ret
|
||||
@@ -1,300 +0,0 @@
|
||||
import requests
|
||||
import util.general_utils as gu
|
||||
import traceback
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from googlesearch import search, SearchResult
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from util.function_calling.func_call import (
|
||||
FuncCall,
|
||||
FuncCallJsonFormatError,
|
||||
FuncNotFoundError
|
||||
)
|
||||
from model.provider.provider import Provider
|
||||
|
||||
|
||||
def tidy_text(text: str) -> str:
|
||||
'''
|
||||
清理文本,去除空格、换行符等
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
|
||||
def special_fetch_zhihu(link: str) -> str:
|
||||
'''
|
||||
function-calling 函数, 用于获取知乎文章的内容
|
||||
'''
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
response = requests.get(link, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
if "zhuanlan.zhihu.com" in link:
|
||||
r = soup.find(class_="Post-RichTextContainer")
|
||||
else:
|
||||
r = soup.find(class_="List-item").find(class_="RichContent-inner")
|
||||
if r is None:
|
||||
print("debug: zhihu none")
|
||||
raise Exception("zhihu none")
|
||||
return tidy_text(r.text)
|
||||
|
||||
|
||||
def google_web_search(keyword) -> str:
|
||||
'''
|
||||
获取 google 搜索结果, 得到 title、desc、link
|
||||
'''
|
||||
ret = ""
|
||||
index = 1
|
||||
try:
|
||||
ls = search(keyword, advanced=True, num_results=4)
|
||||
for i in ls:
|
||||
desc = i.description
|
||||
try:
|
||||
# gu.log(f"搜索网页: {i.url}", tag="网页搜索", level=gu.LEVEL_INFO)
|
||||
desc = fetch_website_content(i.url)
|
||||
except BaseException as e:
|
||||
print(f"(google) fetch_website_content err: {str(e)}")
|
||||
# gu.log(f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
ret += f"# No.{str(index)}\ntitle: {i.title}\nurl: {i.url}\ncontent: {desc}\n\n"
|
||||
index += 1
|
||||
except Exception as e:
|
||||
print(f"google search err: {str(e)}")
|
||||
return web_keyword_search_via_bing(keyword)
|
||||
return ret
|
||||
|
||||
|
||||
def web_keyword_search_via_bing(keyword) -> str:
|
||||
'''
|
||||
获取bing搜索结果, 得到 title、desc、link
|
||||
'''
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
url = "https://www.bing.com/search?q="+keyword
|
||||
_cnt = 0
|
||||
# _detail_store = []
|
||||
while _cnt < 5:
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
# gu.log(f"bing response: {response.text}", tag="bing", level=gu.LEVEL_DEBUG, max_len=9999)
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
res = ""
|
||||
result_cnt = 0
|
||||
ols = soup.find(id="b_results")
|
||||
for i in ols.find_all("li", class_="b_algo"):
|
||||
try:
|
||||
title = i.find("h2").text
|
||||
desc = i.find("p").text
|
||||
link = i.find("h2").find("a").get("href")
|
||||
# res.append({
|
||||
# "title": title,
|
||||
# "desc": desc,
|
||||
# "link": link,
|
||||
# })
|
||||
try:
|
||||
# gu.log(f"搜索网页: {link}", tag="网页搜索", level=gu.LEVEL_INFO)
|
||||
desc = fetch_website_content(link)
|
||||
except BaseException as e:
|
||||
print(f"(bing) fetch_website_content err: {str(e)}")
|
||||
|
||||
res += f"# No.{str(result_cnt + 1)}\ntitle: {title}\nurl: {link}\ncontent: {desc}\n\n"
|
||||
result_cnt += 1
|
||||
if result_cnt > 5:
|
||||
break
|
||||
|
||||
# if len(_detail_store) >= 3:
|
||||
# continue
|
||||
# # 爬取前两条的网页内容
|
||||
# if "zhihu.com" in link:
|
||||
# try:
|
||||
# _detail_store.append(special_fetch_zhihu(link))
|
||||
# except BaseException as e:
|
||||
# print(f"zhihu parse err: {str(e)}")
|
||||
# else:
|
||||
# try:
|
||||
# _detail_store.append(fetch_website_content(link))
|
||||
# except BaseException as e:
|
||||
# print(f"fetch_website_content err: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"bing parse err: {str(e)}")
|
||||
if result_cnt == 0:
|
||||
break
|
||||
return res
|
||||
except Exception as e:
|
||||
# gu.log(f"bing fetch err: {str(e)}")
|
||||
_cnt += 1
|
||||
time.sleep(1)
|
||||
|
||||
# gu.log("fail to fetch bing info, using sougou.")
|
||||
return web_keyword_search_via_sougou(keyword)
|
||||
|
||||
|
||||
def web_keyword_search_via_sougou(keyword) -> str:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
}
|
||||
url = f"https://sogou.com/web?query={keyword}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
res = []
|
||||
results = soup.find("div", class_="results")
|
||||
for i in results.find_all("div", class_="vrwrap"):
|
||||
try:
|
||||
title = tidy_text(i.find("h3").text)
|
||||
link = tidy_text(i.find("h3").find("a").get("href"))
|
||||
if link.startswith("/link?url="):
|
||||
link = "https://www.sogou.com" + link
|
||||
res.append({
|
||||
"title": title,
|
||||
"link": link,
|
||||
})
|
||||
if len(res) >= 5: # 限制5条
|
||||
break
|
||||
except Exception as e:
|
||||
pass
|
||||
# gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
|
||||
# 爬取网页内容
|
||||
_detail_store = []
|
||||
for i in res:
|
||||
if _detail_store >= 3:
|
||||
break
|
||||
try:
|
||||
_detail_store.append(fetch_website_content(i["link"]))
|
||||
except BaseException as e:
|
||||
print(f"fetch_website_content err: {str(e)}")
|
||||
ret = f"{str(res)}"
|
||||
if len(_detail_store) > 0:
|
||||
ret += f"\n网页内容: {str(_detail_store)}"
|
||||
return ret
|
||||
|
||||
|
||||
def fetch_website_content(url):
|
||||
# gu.log(f"fetch_website_content: {url}", tag="fetch_website_content", level=gu.LEVEL_DEBUG)
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
response = requests.get(url, headers=headers, timeout=3)
|
||||
response.encoding = "utf-8"
|
||||
doc = Document(response.content)
|
||||
# print('title:', doc.title())
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
|
||||
async def web_search(question, provider: Provider, session_id, official_fc=False):
|
||||
'''
|
||||
official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
new_func_call.add_func("google_web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "google search query (分词,尽量保留所有信息)"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
web_keyword_search_via_bing
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "网址"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
question1 = f"{question} \n> hint: 最多只能调用1个function, 并且存在不会调用任何function的可能性。"
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
# we use official function-calling
|
||||
func = await provider.text_chat(question1, session_id, function_call=new_func_call.get_func())
|
||||
if isinstance(func, Function):
|
||||
# 执行对应的结果:
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == func.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
# gu.log("找不到返回的 func name " + func.name, level=gu.LEVEL_ERROR)
|
||||
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(func.arguments)
|
||||
# we use to_thread to avoid blocking the event loop
|
||||
function_invoked_ret = await asyncio.to_thread(func_obj, **args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return await provider.text_chat(question1, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
# now func is a string
|
||||
return func
|
||||
else:
|
||||
# we use our own function-calling
|
||||
try:
|
||||
args = {
|
||||
'question': question1,
|
||||
'func_definition': new_func_call.func_dump(),
|
||||
'is_task': False,
|
||||
'is_summary': False,
|
||||
}
|
||||
function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
|
||||
except BaseException as e:
|
||||
res = await provider.text_chat(question) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return res
|
||||
has_func = True
|
||||
|
||||
if has_func:
|
||||
await provider.forget(session_id)
|
||||
question3 = f"""
|
||||
你的任务是:
|
||||
1. 根据末尾的材料对问题`{question}`做切题的总结(详细);
|
||||
2. 简单地发表你对这个问题的看法(简略)。
|
||||
你的总结末尾应当有对材料的引用, 如果有链接, 请在末尾附上引用网页链接。引用格式严格按照 `\n[1] title url \n`。
|
||||
不要提到任何函数调用的信息。
|
||||
|
||||
一些回复的消息模板:
|
||||
模板1:
|
||||
```
|
||||
从网上的信息来看,可以知道...我个人认为...你觉得呢?
|
||||
```
|
||||
模板2:
|
||||
```
|
||||
根据网上的最新信息,可以得知...我觉得...你怎么看?
|
||||
```
|
||||
你可以根据这些模板来组织回答,但可以不照搬,要根据问题的内容来回答。
|
||||
|
||||
以下是相关材料:
|
||||
"""
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
print('text chat')
|
||||
final_ret = await provider.text_chat(question3 + "```" + function_invoked_ret + "```", session_id)
|
||||
return final_ret
|
||||
except Exception as e:
|
||||
print(e)
|
||||
_c += 1
|
||||
if _c == 3:
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
await provider.forget(session_id)
|
||||
function_invoked_ret = function_invoked_ret[:int(
|
||||
len(function_invoked_ret) / 2)]
|
||||
time.sleep(3)
|
||||
return function_invoked_ret
|
||||
+10
-1
@@ -9,9 +9,10 @@ import platform
|
||||
import json
|
||||
import sys
|
||||
import psutil
|
||||
import ssl
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from cores.astrbot.types import GlobalObject
|
||||
from type.types import GlobalObject
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
|
||||
@@ -378,6 +379,14 @@ async def download_image_by_url(url: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
except aiohttp.client_exceptions.ClientConnectorSSLError as e:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
from cores.astrbot.types import (
|
||||
PluginMetadata,
|
||||
RegisteredLLM,
|
||||
RegisteredPlugin,
|
||||
RegisteredPlatform,
|
||||
RegisteredPlugins,
|
||||
PluginType,
|
||||
GlobalObject,
|
||||
AstrMessageEvent,
|
||||
CommandResult
|
||||
)
|
||||
from type.plugin import PluginMetadata, PluginType
|
||||
from type.register import RegisteredLLM, RegisteredPlatform, RegisteredPlugin, RegisteredPlugins
|
||||
from type.types import GlobalObject
|
||||
from type.message import AstrMessageEvent
|
||||
from type.command import CommandResult
|
||||
@@ -1,5 +1,6 @@
|
||||
from cores.astrbot.core import oper_msg
|
||||
from cores.astrbot.types import AstrMessageEvent, CommandResult
|
||||
from astrbot.core import oper_msg
|
||||
from type.message import AstrMessageEvent, AstrBotMessage
|
||||
from type.command import CommandResult
|
||||
from model.platform._message_result import MessageResult
|
||||
|
||||
'''
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
'''
|
||||
from model.provider.provider import Provider as LLMProvider
|
||||
from model.platform._platfrom import Platform
|
||||
from cores.astrbot.types import GlobalObject, RegisteredPlatform, RegisteredLLM
|
||||
from type.types import GlobalObject
|
||||
from type.register import RegisteredPlatform, RegisteredLLM
|
||||
|
||||
def register_platform(platform_name: str, platform_instance: Platform, context: GlobalObject) -> None:
|
||||
'''
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
插件类型
|
||||
'''
|
||||
|
||||
from cores.astrbot.types import PluginType
|
||||
from type.plugin import PluginType
|
||||
+6
-12
@@ -3,24 +3,18 @@
|
||||
'''
|
||||
import os
|
||||
import inspect
|
||||
import shutil
|
||||
import stat
|
||||
import traceback
|
||||
|
||||
try:
|
||||
import git.exc
|
||||
from git.repo import Repo
|
||||
except ImportError:
|
||||
pass
|
||||
import shutil
|
||||
import importlib
|
||||
import stat
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from pip._internal import main as pipmain
|
||||
from cores.astrbot.types import (
|
||||
PluginMetadata,
|
||||
PluginType,
|
||||
RegisteredPlugin,
|
||||
RegisteredPlugins
|
||||
)
|
||||
from type.plugin import *
|
||||
from type.register import *
|
||||
|
||||
|
||||
# 找出模块里所有的类名
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import List
|
||||
|
||||
try:
|
||||
from util.search_engine_scraper.engine import SearchEngine, SearchResult
|
||||
from util.search_engine_scraper.config import HEADERS, USER_AGENT_BING
|
||||
except ImportError:
|
||||
from engine import SearchEngine, SearchResult
|
||||
from config import HEADERS, USER_AGENT_BING
|
||||
|
||||
class Bing(SearchEngine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.base_url = "https://www.bing.com"
|
||||
self.headers.update({'User-Agent': USER_AGENT_BING})
|
||||
|
||||
def _set_selector(self, selector: str):
|
||||
selectors = {
|
||||
'url': 'div.b_attribution cite',
|
||||
'title': 'h2',
|
||||
'text': 'p',
|
||||
'links': 'ol#b_results > li.b_algo',
|
||||
'next': 'div#b_content nav[role="navigation"] a.sb_pagN'
|
||||
}
|
||||
return selectors[selector]
|
||||
|
||||
async def _get_next_page(self, query) -> str:
|
||||
if self.page == 1:
|
||||
await self._get_html(self.base_url)
|
||||
url = f'{self.base_url}/search?q={query}&form=QBLH&sp=-1&lq=0&pq=hi&sc=10-2&qs=n&sk=&cvid=DE75965E2D6346D681288933984DE48F&ghsh=0&ghacc=0&ghpl='
|
||||
return await self._get_html(url, None)
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
if not isinstance(result.url, str):
|
||||
result.url = result.url.text
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,20 @@
|
||||
HEADERS = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0',
|
||||
'Accept': '*/*',
|
||||
'Connection': 'keep-alive',
|
||||
'Accept-Language': 'en-GB,en;q=0.5'
|
||||
}
|
||||
|
||||
USER_AGENT_BING = 'Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0'
|
||||
USER_AGENTS = [
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:88.0) Gecko/20100101 Firefox/88.0',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1.2 Safari/537.36',
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Version/14.1 Safari/537.36',
|
||||
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:89.0) Gecko/20100101 Firefox/89.0',
|
||||
'Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'
|
||||
]
|
||||
@@ -0,0 +1,74 @@
|
||||
import random
|
||||
try:
|
||||
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
|
||||
except ImportError:
|
||||
from config import HEADERS, USER_AGENTS
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from aiohttp import ClientSession
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult():
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.title} - {self.url}\n{self.snippet}"
|
||||
|
||||
class SearchEngine():
|
||||
'''
|
||||
搜索引擎爬虫基类
|
||||
'''
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.TIMEOUT = 10
|
||||
self.page = 1
|
||||
self.headers = HEADERS
|
||||
|
||||
def _set_selector(self, selector: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_next_page(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _get_html(self, url: str, data: dict = None) -> str:
|
||||
headers = self.headers
|
||||
headers["Referer"] = url
|
||||
headers["User-Agent"] = random.choice(USER_AGENTS)
|
||||
print(headers)
|
||||
if data:
|
||||
async with ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=data, timeout=self.TIMEOUT) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
async with ClientSession() as session:
|
||||
async with session.get(url, headers=headers, timeout=self.TIMEOUT) as resp:
|
||||
return await resp.text()
|
||||
|
||||
|
||||
def tidy_text(self, text: str) -> str:
|
||||
'''
|
||||
清理文本,去除空格、换行符等
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
try:
|
||||
resp = await self._get_next_page(query)
|
||||
soup = BeautifulSoup(resp, 'html.parser')
|
||||
links = soup.select(self._set_selector('links'))
|
||||
results = []
|
||||
for link in links:
|
||||
title = self.tidy_text(link.select_one(self._set_selector('title')).text)
|
||||
url = link.select_one(self._set_selector('url'))
|
||||
snippet = ''
|
||||
if title and url:
|
||||
results.append(SearchResult(title=title, url=url, snippet=snippet))
|
||||
return results[:num_results] if len(results) > num_results else results
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,23 @@
|
||||
from googlesearch import search
|
||||
|
||||
try:
|
||||
from util.search_engine_scraper.engine import SearchEngine, SearchResult
|
||||
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
|
||||
except ImportError:
|
||||
from engine import SearchEngine, SearchResult
|
||||
from config import HEADERS, USER_AGENTS
|
||||
|
||||
from typing import List
|
||||
|
||||
class Google(SearchEngine):
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
index = 1
|
||||
results = []
|
||||
try:
|
||||
ls = search(query, advanced=True, num_results=num_results, timeout=3)
|
||||
for i in ls:
|
||||
results.append(SearchResult(title=i.title, url=i.url, snippet=i.description))
|
||||
except:
|
||||
pass
|
||||
return results
|
||||
@@ -0,0 +1,49 @@
|
||||
import random, re
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
try:
|
||||
from util.search_engine_scraper.engine import SearchEngine, SearchResult
|
||||
from util.search_engine_scraper.config import HEADERS, USER_AGENTS
|
||||
except ImportError:
|
||||
from engine import SearchEngine, SearchResult
|
||||
from config import HEADERS, USER_AGENTS
|
||||
|
||||
from typing import List
|
||||
|
||||
class Sogo(SearchEngine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.base_url = "https://www.sogou.com"
|
||||
self.headers['User-Agent'] = random.choice(USER_AGENTS)
|
||||
|
||||
|
||||
def _set_selector(self, selector: str):
|
||||
selectors = {
|
||||
'url': 'h3 > a',
|
||||
'title': 'h3',
|
||||
'text': '',
|
||||
'links': 'div.results > div.vrwrap:not(.middle-better-hintBox)',
|
||||
'next': ''
|
||||
}
|
||||
return selectors[selector]
|
||||
|
||||
async def _get_next_page(self, query) -> str:
|
||||
url = f'{self.base_url}/web?query={query}'
|
||||
return await self._get_html(url, None)
|
||||
|
||||
async def search(self, query: str, num_results: int) -> List[SearchResult]:
|
||||
results = await super().search(query, num_results)
|
||||
for result in results:
|
||||
result.url = result.url.get("href")
|
||||
if result.url.startswith("/link?"):
|
||||
result.url = self.base_url + result.url
|
||||
result.url = await self._parse_url(result.url)
|
||||
return results
|
||||
|
||||
async def _parse_url(self, url) -> str:
|
||||
html = await self._get_html(url)
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
script = soup.find("script")
|
||||
if script:
|
||||
url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group(1)
|
||||
return url
|
||||
@@ -0,0 +1,22 @@
|
||||
from sogo import Sogo
|
||||
from bing import Bing
|
||||
|
||||
sogo_search = Sogo()
|
||||
bing_search = Bing()
|
||||
async def search(keyword: str) -> str:
|
||||
results = await sogo_search.search(keyword, 5)
|
||||
# results = await bing_search.search(keyword, 5)
|
||||
ret = ""
|
||||
if len(results) == 0:
|
||||
return "没有搜索到结果"
|
||||
|
||||
idx = 1
|
||||
for i in results:
|
||||
ret += f"{idx}. {i.title}({i.url})\n{i.snippet}\n\n"
|
||||
idx += 1
|
||||
|
||||
return ret
|
||||
|
||||
import asyncio
|
||||
ret = asyncio.run(search("gpt4orelease"))
|
||||
print(ret)
|
||||
Reference in New Issue
Block a user