feat: 添加 websearch

This commit is contained in:
Soulter
2024-12-11 15:02:29 +08:00
parent 92aa3123ec
commit 0b53eae4ad
14 changed files with 201 additions and 112 deletions
@@ -33,7 +33,7 @@ class LLMRequestSubStage(Stage):
image_url = comp.url if comp.url else comp.file
image_urls.append(image_url)
tools = self.ctx.plugin_manager.context.get_llm_tools()
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
try:
llm_response = await self.curr_provider.text_chat(
+10 -9
View File
@@ -133,16 +133,17 @@ class AstrMessageEvent(abc.ABC):
如果没有设置 `MessageEventResult` 中的 result_type,默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。
Example:
async def ban_handler(self, event: AstrMessageEvent):
if event.get_sender_id() in self.blacklist:
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
return
async def check_count(self, event: AstrMessageEvent):
self.count += 1
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
```
async def ban_handler(self, event: AstrMessageEvent):
if event.get_sender_id() in self.blacklist:
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
return
async def check_count(self, event: AstrMessageEvent):
self.count += 1
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
return
```
'''
self._result = result
+1
View File
@@ -44,6 +44,7 @@ def register_llm_tool(name: str = None):
# 处理逻辑
```
可接受的参数类型有:string, number, object, array, boolean。
'''
name_ = name
@@ -93,7 +93,6 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
logger.debug("request with llm tools")
payloads["tools"] = tools.get_func_desc_openai_style()
completion = await self.client.chat.completions.create(
+5 -1
View File
@@ -32,6 +32,8 @@ class FuncTool:
func_obj: Awaitable
module_name: str = None
active: bool = True
'''是否激活'''
SUPPORTED_TYPES = [
"string",
@@ -100,10 +102,12 @@ class FuncCall:
def get_func_desc_openai_style(self) -> list:
"""
获得 OpenAI API 风格的工具描述
获得 OpenAI API 风格的**已经激活**的工具描述
"""
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"type": "function",
@@ -1,95 +0,0 @@
import random
import aiohttp
import os
from readability import Document
from bs4 import BeautifulSoup
from engines.config import HEADERS, USER_AGENTS
from engines.bing import Bing
from engines.sogo import Sogo
from engines.google import Google
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api.provider import Provider
from astrbot.api import logger
bing_search = Bing()
sogo_search = Sogo()
google = Google()
proxy = os.environ.get("HTTPS_PROXY", None)
def tidy_text(text: str) -> str:
'''
清理文本,去除空格、换行符等
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def search_from_bing(keyword: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
'''
tools, 从 bing 搜索引擎搜索
'''
logger.info("web_searcher - search_from_bing: " + keyword)
results = []
try:
results = await google.search(keyword, 5)
except BaseException as e:
logger.error(f"google search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search google failed")
try:
results = await bing_search.search(keyword, 5)
except BaseException as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await sogo_search.search(keyword, 5)
except BaseException as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
logger.debug("search sogo failed")
return "没有搜索到结果"
ret = ""
idx = 1
for i in results:
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
try:
site_result = await fetch_website_content(i.url)
except BaseException:
site_result = ""
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
return await summarize(ret, event, provider)
async def fetch_website_content(url: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=HEADERS, timeout=6, proxy=proxy) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return await summarize(ret, event, provider)
async def summarize(text: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
summary_prompt = f"""
你是一个专业且高效的助手,你擅长总结给定文本。你的任务是
1. 回答用户的问题 `{event.message_str}`,用户的问题相关的材料在下方;
2. 简略发表你的看法。
# 例子
1. 从网上的信息来看,可以知道...我个人认为...
2. 根据网上的最新信息,可以得知...我觉得...
# 限制
1. 限制在 200-300 字;
2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料
{text}"""
ret = await provider.text_chat(summary_prompt, session_id=event.session_id)
event.set_result(MessageEventResult().message(ret))
+26 -5
View File
@@ -47,9 +47,9 @@ class Context:
def get_all_stars(self) -> List[StarMetadata]:
return star_registry
def get_llm_tools(self) -> FuncCall:
def get_llm_tool_manager(self) -> FuncCall:
'''
获取 LLM Tools。
获取 LLM Tool Manager
'''
return self.provider_manager.llm_tools
@@ -67,10 +67,31 @@ class Context:
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
def unregister_llm_tool(self, name: str) -> None:
'''
删除一个函数调用工具。
'''
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
Returns:
如果没找到,会返回 False
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = True
return True
return False
def deactivate_llm_tool(self, name: str) -> bool:
'''停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
return True
return False
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
+30
View File
@@ -45,11 +45,41 @@ class Main(star.Star):
/reset: 重置 LLM 会话
/history: 获取会话历史记录
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@filter.command_group("tool")
def tool(self):
pass
@tool.command("ls")
async def tool_ls(self, event: AstrMessageEvent):
tm = self.context.get_llm_tool_manager()
msg = "函数工具:\n"
for tool in tm.func_list:
active = " (启用)" if tool.active else "(停用)"
msg += f"- {tool.name}: {tool.description} {active}\n"
msg += "\n使用 /tool on/off <工具名> 激活或者停用工具。"
event.set_result(MessageEventResult().message(msg).use_t2i(False))
@tool.command("on")
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
if self.context.activate_llm_tool(tool_name):
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 成功。"))
else:
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 失败,未找到此工具。"))
@tool.command("off")
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
if self.context.deactivate_llm_tool(tool_name):
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
else:
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent):
+128
View File
@@ -0,0 +1,128 @@
import aiohttp
import random
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from .engines.bing import Bing
from .engines.sogo import Sogo
from .engines.google import Google
from readability import Document
from bs4 import BeautifulSoup
from .engines.config import HEADERS, USER_AGENTS
@star.register(name="astrbot-web-searcher", desc="让 LLM 具有网页检索能力", author="Soulter", version="1.14.514")
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
self.bing_search = Bing()
self.sogo_search = Sogo()
self.google = Google()
async def _tidy_text(text: str) -> str:
'''清理文本,去除空格、换行符等'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
async def _get_from_url(self, url: str) -> str:
'''获取网页内容'''
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=header, timeout=6) as response:
html = await response.text(encoding="utf-8")
doc = Document(html)
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = self._tidy_text(soup.get_text())
return ret
async def _request_from_llm(self, event: AstrMessageEvent, resources: str) -> str:
'''使用 LLM 对文本进行生成'''
if self.context.get_using_provider() is None:
raise ValueError("未找到可用的 LLM Provider,无法进行摘要总结")
provider = self.context.get_using_provider()
summary_prompt = f"""{event.get_message_str()}
# Provided Sources:
{resources}"""
ret = await provider.text_chat(summary_prompt, session_id=event.session_id)
return ret.completion_text
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
websearch = self.context.get_config()['provider_settings']['web_search']
if oper is None:
status = "开启" if websearch else "关闭"
event.set_result(MessageEventResult().message("当前网页搜索功能状态:" + status + "。使用 /websearch on 或者 off 启用或者关闭。"))
return
if oper == "on":
self.context.get_config()['provider_settings']['web_search'] = True
self.context.get_config().save_config()
self.context.activate_llm_tool("web_search")
self.context.activate_llm_tool("fetch_url")
event.set_result(MessageEventResult().message("已开启网页搜索功能"))
elif oper == "off":
self.context.get_config()['provider_settings']['web_search'] = False
self.context.get_config().save_config()
self.context.deactivate_llm_tool("web_search")
self.context.deactivate_llm_tool("fetch_url")
event.set_result(MessageEventResult().message("已关闭网页搜索功能"))
else:
event.set_result(MessageEventResult().message("操作参数错误,应为 on 或 off"))
@llm_tool("web_search")
async def search_from_search_engine(self, event: AstrMessageEvent, query: str) -> str:
'''Search the web for answers to the user's query
Args:
query(string): A search query which will be used to fetch the most relevant snippets regarding the user's query
'''
logger.info("web_searcher - search_from_search_engine: " + query)
results = []
try:
results = await self.google.search(query, 3)
except BaseException as e:
logger.error(f"google search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search google failed")
try:
results = await self.bing_search.search(query, 3)
except BaseException as e:
logger.error(f"bing search error: {e}, try the next one...")
if len(results) == 0:
logger.debug("search bing failed")
try:
results = await self.sogo_search.search(query, 3)
except BaseException as e:
logger.error(f"sogo search error: {e}")
if len(results) == 0:
logger.debug("search sogo failed")
return "没有搜索到结果"
ret = ""
idx = 1
for i in results:
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
try:
site_result = await self._get_from_url(i.url)
except BaseException:
site_result = ""
site_result = site_result[:1000] + "..." if len(site_result) > 1000 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
resp = await self._request_from_llm(event, ret)
event.set_result(MessageEventResult().message(resp))
@llm_tool("fetch_url")
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
'''fetch the content of a website with the given web url
Args:
url(string): The url of the website to fetch content from
'''
resp = await self._get_from_url(url)
event.set_result(MessageEventResult().message(resp))