feat: 添加 websearch
This commit is contained in:
@@ -33,7 +33,7 @@ class LLMRequestSubStage(Stage):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
image_urls.append(image_url)
|
||||
|
||||
tools = self.ctx.plugin_manager.context.get_llm_tools()
|
||||
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
|
||||
try:
|
||||
llm_response = await self.curr_provider.text_chat(
|
||||
|
||||
@@ -133,16 +133,17 @@ class AstrMessageEvent(abc.ABC):
|
||||
如果没有设置 `MessageEventResult` 中的 result_type,默认为 CONTINUE。即事件将会继续向后面的 listener 或者 command 传播。
|
||||
|
||||
Example:
|
||||
|
||||
async def ban_handler(self, event: AstrMessageEvent):
|
||||
if event.get_sender_id() in self.blacklist:
|
||||
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
|
||||
return
|
||||
|
||||
async def check_count(self, event: AstrMessageEvent):
|
||||
self.count += 1
|
||||
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
|
||||
```
|
||||
async def ban_handler(self, event: AstrMessageEvent):
|
||||
if event.get_sender_id() in self.blacklist:
|
||||
event.set_result(MessageEventResult().set_console_log("由于用户在黑名单,因此消息事件中断处理。")).set_result_type(EventResultType.STOP)
|
||||
return
|
||||
|
||||
async def check_count(self, event: AstrMessageEvent):
|
||||
self.count += 1
|
||||
event.set_result(MessageEventResult().set_console_log("数量已增加", logging.DEBUG).set_result_type(EventResultType.CONTINUE))
|
||||
return
|
||||
```
|
||||
'''
|
||||
self._result = result
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ def register_llm_tool(name: str = None):
|
||||
# 处理逻辑
|
||||
```
|
||||
|
||||
可接受的参数类型有:string, number, object, array, boolean。
|
||||
'''
|
||||
name_ = name
|
||||
|
||||
|
||||
@@ -93,7 +93,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
logger.debug("request with llm tools")
|
||||
payloads["tools"] = tools.get_func_desc_openai_style()
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
|
||||
@@ -32,6 +32,8 @@ class FuncTool:
|
||||
func_obj: Awaitable
|
||||
module_name: str = None
|
||||
|
||||
active: bool = True
|
||||
'''是否激活'''
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
"string",
|
||||
@@ -100,10 +102,12 @@ class FuncCall:
|
||||
|
||||
def get_func_desc_openai_style(self) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的工具描述
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"type": "function",
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
import random
|
||||
import aiohttp
|
||||
import os
|
||||
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from engines.config import HEADERS, USER_AGENTS
|
||||
from engines.bing import Bing
|
||||
from engines.sogo import Sogo
|
||||
from engines.google import Google
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.api import logger
|
||||
|
||||
bing_search = Bing()
|
||||
sogo_search = Sogo()
|
||||
google = Google()
|
||||
proxy = os.environ.get("HTTPS_PROXY", None)
|
||||
|
||||
def tidy_text(text: str) -> str:
|
||||
'''
|
||||
清理文本,去除空格、换行符等
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
async def search_from_bing(keyword: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
|
||||
'''
|
||||
tools, 从 bing 搜索引擎搜索
|
||||
'''
|
||||
logger.info("web_searcher - search_from_bing: " + keyword)
|
||||
results = []
|
||||
try:
|
||||
results = await google.search(keyword, 5)
|
||||
except BaseException as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search google failed")
|
||||
try:
|
||||
results = await bing_search.search(keyword, 5)
|
||||
except BaseException as e:
|
||||
logger.error(f"bing search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search bing failed")
|
||||
try:
|
||||
results = await sogo_search.search(keyword, 5)
|
||||
except BaseException as e:
|
||||
logger.error(f"sogo search error: {e}")
|
||||
if len(results) == 0:
|
||||
logger.debug("search sogo failed")
|
||||
return "没有搜索到结果"
|
||||
ret = ""
|
||||
idx = 1
|
||||
for i in results:
|
||||
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
|
||||
try:
|
||||
site_result = await fetch_website_content(i.url)
|
||||
except BaseException:
|
||||
site_result = ""
|
||||
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
|
||||
return await summarize(ret, event, provider)
|
||||
|
||||
async def fetch_website_content(url: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=HEADERS, timeout=6, proxy=proxy) as response:
|
||||
html = await response.text(encoding="utf-8")
|
||||
doc = Document(html)
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return await summarize(ret, event, provider)
|
||||
|
||||
async def summarize(text: str, event: AstrMessageEvent = None, provider: Provider = None) -> str:
|
||||
|
||||
summary_prompt = f"""
|
||||
你是一个专业且高效的助手,你擅长总结给定文本。你的任务是
|
||||
1. 回答用户的问题 `{event.message_str}`,用户的问题相关的材料在下方;
|
||||
2. 简略发表你的看法。
|
||||
|
||||
# 例子
|
||||
1. 从网上的信息来看,可以知道...我个人认为...
|
||||
2. 根据网上的最新信息,可以得知...我觉得...
|
||||
|
||||
# 限制
|
||||
1. 限制在 200-300 字;
|
||||
2. 请**直接输出总结**,不要输出多余的内容和提示语。
|
||||
|
||||
# 相关材料
|
||||
{text}"""
|
||||
ret = await provider.text_chat(summary_prompt, session_id=event.session_id)
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
@@ -47,9 +47,9 @@ class Context:
|
||||
def get_all_stars(self) -> List[StarMetadata]:
|
||||
return star_registry
|
||||
|
||||
def get_llm_tools(self) -> FuncCall:
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
'''
|
||||
获取 LLM Tools。
|
||||
获取 LLM Tool Manager
|
||||
'''
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
@@ -67,10 +67,31 @@ class Context:
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
def activate_llm_tool(self, name: str) -> bool:
|
||||
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False
|
||||
'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
'''停用一个已经注册的函数调用工具。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False'''
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
return True
|
||||
return False
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
|
||||
@@ -45,11 +45,41 @@ class Main(star.Star):
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 获取会话历史记录
|
||||
/persona: 情境人格设置
|
||||
/tool ls: 查看、激活、停用当前注册的函数工具
|
||||
|
||||
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
{notice}"""
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
|
||||
@filter.command_group("tool")
|
||||
def tool(self):
|
||||
pass
|
||||
|
||||
@tool.command("ls")
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
tm = self.context.get_llm_tool_manager()
|
||||
msg = "函数工具:\n"
|
||||
for tool in tm.func_list:
|
||||
active = " (启用)" if tool.active else "(停用)"
|
||||
msg += f"- {tool.name}: {tool.description} {active}\n"
|
||||
|
||||
msg += "\n使用 /tool on/off <工具名> 激活或者停用工具。"
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
|
||||
@tool.command("on")
|
||||
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
|
||||
if self.context.activate_llm_tool(tool_name):
|
||||
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 成功。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message(f"激活工具 {tool_name} 失败,未找到此工具。"))
|
||||
|
||||
@tool.command("off")
|
||||
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
|
||||
if self.context.deactivate_llm_tool(tool_name):
|
||||
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 成功。"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
|
||||
|
||||
@filter.command("plugin")
|
||||
async def plugin(self, event: AstrMessageEvent):
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
import aiohttp
|
||||
import random
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import llm_tool, logger
|
||||
from .engines.bing import Bing
|
||||
from .engines.sogo import Sogo
|
||||
from .engines.google import Google
|
||||
from readability import Document
|
||||
from bs4 import BeautifulSoup
|
||||
from .engines.config import HEADERS, USER_AGENTS
|
||||
|
||||
|
||||
@star.register(name="astrbot-web-searcher", desc="让 LLM 具有网页检索能力", author="Soulter", version="1.14.514")
|
||||
class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
|
||||
self.bing_search = Bing()
|
||||
self.sogo_search = Sogo()
|
||||
self.google = Google()
|
||||
|
||||
async def _tidy_text(text: str) -> str:
|
||||
'''清理文本,去除空格、换行符等'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
async def _get_from_url(self, url: str) -> str:
|
||||
'''获取网页内容'''
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=header, timeout=6) as response:
|
||||
html = await response.text(encoding="utf-8")
|
||||
doc = Document(html)
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = self._tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
async def _request_from_llm(self, event: AstrMessageEvent, resources: str) -> str:
|
||||
'''使用 LLM 对文本进行生成'''
|
||||
|
||||
if self.context.get_using_provider() is None:
|
||||
raise ValueError("未找到可用的 LLM Provider,无法进行摘要总结")
|
||||
provider = self.context.get_using_provider()
|
||||
summary_prompt = f"""{event.get_message_str()}
|
||||
|
||||
# Provided Sources:
|
||||
{resources}"""
|
||||
ret = await provider.text_chat(summary_prompt, session_id=event.session_id)
|
||||
return ret.completion_text
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str:
|
||||
websearch = self.context.get_config()['provider_settings']['web_search']
|
||||
if oper is None:
|
||||
status = "开启" if websearch else "关闭"
|
||||
event.set_result(MessageEventResult().message("当前网页搜索功能状态:" + status + "。使用 /websearch on 或者 off 启用或者关闭。"))
|
||||
return
|
||||
|
||||
if oper == "on":
|
||||
self.context.get_config()['provider_settings']['web_search'] = True
|
||||
self.context.get_config().save_config()
|
||||
self.context.activate_llm_tool("web_search")
|
||||
self.context.activate_llm_tool("fetch_url")
|
||||
event.set_result(MessageEventResult().message("已开启网页搜索功能"))
|
||||
elif oper == "off":
|
||||
self.context.get_config()['provider_settings']['web_search'] = False
|
||||
self.context.get_config().save_config()
|
||||
self.context.deactivate_llm_tool("web_search")
|
||||
self.context.deactivate_llm_tool("fetch_url")
|
||||
event.set_result(MessageEventResult().message("已关闭网页搜索功能"))
|
||||
else:
|
||||
event.set_result(MessageEventResult().message("操作参数错误,应为 on 或 off"))
|
||||
|
||||
@llm_tool("web_search")
|
||||
async def search_from_search_engine(self, event: AstrMessageEvent, query: str) -> str:
|
||||
'''Search the web for answers to the user's query
|
||||
|
||||
Args:
|
||||
query(string): A search query which will be used to fetch the most relevant snippets regarding the user's query
|
||||
'''
|
||||
logger.info("web_searcher - search_from_search_engine: " + query)
|
||||
results = []
|
||||
try:
|
||||
results = await self.google.search(query, 3)
|
||||
except BaseException as e:
|
||||
logger.error(f"google search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search google failed")
|
||||
try:
|
||||
results = await self.bing_search.search(query, 3)
|
||||
except BaseException as e:
|
||||
logger.error(f"bing search error: {e}, try the next one...")
|
||||
if len(results) == 0:
|
||||
logger.debug("search bing failed")
|
||||
try:
|
||||
results = await self.sogo_search.search(query, 3)
|
||||
except BaseException as e:
|
||||
logger.error(f"sogo search error: {e}")
|
||||
if len(results) == 0:
|
||||
logger.debug("search sogo failed")
|
||||
return "没有搜索到结果"
|
||||
ret = ""
|
||||
idx = 1
|
||||
for i in results:
|
||||
logger.info(f"web_searcher - scraping web: {i.title} - {i.url}")
|
||||
try:
|
||||
site_result = await self._get_from_url(i.url)
|
||||
except BaseException:
|
||||
site_result = ""
|
||||
site_result = site_result[:1000] + "..." if len(site_result) > 1000 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
|
||||
resp = await self._request_from_llm(event, ret)
|
||||
event.set_result(MessageEventResult().message(resp))
|
||||
|
||||
@llm_tool("fetch_url")
|
||||
async def fetch_website_content(self, event: AstrMessageEvent, url: str) -> str:
|
||||
'''fetch the content of a website with the given web url
|
||||
|
||||
Args:
|
||||
url(string): The url of the website to fetch content from
|
||||
'''
|
||||
resp = await self._get_from_url(url)
|
||||
event.set_result(MessageEventResult().message(resp))
|
||||
Reference in New Issue
Block a user