feat: customized tool-use

This commit is contained in:
Soulter
2024-09-08 19:41:00 +08:00
parent 7f4f96f77b
commit b5cb5eb969
5 changed files with 129 additions and 114 deletions
+65 -12
View File
@@ -1,4 +1,4 @@
import time
import time, json
import re, os
import asyncio
import traceback
@@ -16,6 +16,8 @@ from logging import Logger
from nakuru.entities.components import Image
from util.agent.func_call import FuncCall
import util.agent.web_searcher as web_searcher
from openai._exceptions import *
from openai.types.chat.chat_completion_message_tool_call import Function
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -186,31 +188,82 @@ class MessageHandler():
image_url = comp.url if comp.url else comp.file
break
web_search = self.context.web_search
if not web_search and msg_plain.startswith("ws"):
# leverage web search feature
web_search = True
msg_plain = msg_plain.removeprefix("ws").strip()
# web_search = self.context.web_search
# if not web_search and msg_plain.startswith("ws"):
# # leverage web search feature
# web_search = True
# msg_plain = msg_plain.removeprefix("ws").strip()
try:
if web_search:
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
if not self.llm_tools.empty():
# tools-use
tool_use_flag = True
llm_result = await provider.text_chat(
prompt=msg_plain,
session_id=message.session_id,
tools=self.llm_tools.get_func()
)
if isinstance(llm_result, Function):
logger.debug(f"function-calling: {llm_result}")
func_obj = None
for i in self.llm_tools.func_list:
if i["name"] == llm_result.name:
func_obj = i["func_obj"]
break
if not func_obj:
return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。")
try:
args = json.loads(llm_result.arguments)
function_invoked_ret = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return MessageResult("AstrBot Function-calling 异常:" + str(e))
else:
return MessageResult(llm_result)
else:
# normal chat
tool_use_flag = False
llm_result = await provider.text_chat(
prompt=msg_plain,
session_id=message.session_id,
image_url=image_url
)
except BadRequestError as e:
if tool_use_flag:
# seems like the model don't support function-calling
logger.error(f"error: {e}. Using local function-calling implementation")
try:
# use local function-calling implementation
args = {
'question': llm_result,
'func_definition': self.llm_tools.func_dump(),
}
_, has_func = await self.llm_tools.func_call(**args)
if not has_func:
# normal chat
llm_result = await provider.text_chat(
prompt=msg_plain,
session_id=message.session_id,
image_url=image_url
)
except BaseException as e:
logger.error(traceback.format_exc())
return CommandResult("AstrBot Function-calling 异常:" + str(e))
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"LLM 调用失败。")
return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e))
# concatenate the reply prefix
# concatenate reply prefix
if self.reply_prefix:
llm_result = self.reply_prefix + llm_result
# mask the unsafe content
# mask unsafe content
llm_result = self.content_safety_helper.filter_content(llm_result)
check = self.content_safety_helper.baidu_check(llm_result)
if not check:
+21
View File
@@ -9,6 +9,7 @@ from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from nakuru.entities.components import Image
from util.agent.web_searcher import search_from_bing, fetch_website_content
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -212,6 +213,23 @@ class InternalCommandHandler:
)
elif l[1] == 'on':
context.web_search = True
context.register_llm_tool("web_search", [{
"type": "string",
"name": "keyword",
"description": "搜索关键词"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
search_from_bing
)
context.register_llm_tool("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "要获取内容的网页链接"
}],
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
return CommandResult(
hit=True,
success=True,
@@ -219,6 +237,9 @@ class InternalCommandHandler:
)
elif l[1] == 'off':
context.web_search = False
context.unregister_llm_tool("web_search")
context.unregister_llm_tool("fetch_website_content")
return CommandResult(
hit=True,
success=True,
+12 -1
View File
@@ -110,6 +110,12 @@ class Context:
'''
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
def unregister_llm_tool(self, tool_name: str):
'''
删除一个函数调用工具。
'''
self.message_handler.llm_tools.remove_func(tool_name)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
if platform_name == platform.platform_name:
@@ -131,4 +137,9 @@ class Context:
platform_name, message_type, id = l
platform = self.find_platform(platform_name)
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
def get_current_llm_provider(self) -> Provider:
'''
获取当前的 LLM Provider。
'''
return self.message_handler.provider
+16 -4
View File
@@ -23,6 +23,9 @@ class FuncCall():
def __init__(self, provider: Provider) -> None:
self.func_list = []
self.provider = provider
def empty(self) -> bool:
return len(self.func_list) == 0
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
'''
@@ -34,7 +37,7 @@ class FuncCall():
@param func_obj: 处理函数
'''
params = {
"type": "object", # hardcore here
"type": "object", # hard-coded here
"properties": {}
}
for param in func_args:
@@ -42,14 +45,23 @@ class FuncCall():
"type": param['type'],
"description": param['description']
}
self._func = {
_func = {
"name": name,
"parameters": params,
"description": desc,
"func_obj": func_obj,
}
self.func_list.append(self._func)
self.func_list.append(_func)
def remove_func(self, name: str) -> None:
'''
删除一个函数调用工具。
'''
for i, f in enumerate(self.func_list):
if f["name"] == name:
self.func_list.pop(i)
break
def func_dump(self) -> str:
_l = []
for f in self.func_list:
+15 -97
View File
@@ -16,6 +16,8 @@ from util.websearch.google import Google
from model.provider.provider import Provider
from SparkleLogging.utils.core import LogManager
from logging import Logger
from type.types import Context
from type.message_event import AstrMessageEvent
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -31,24 +33,7 @@ def tidy_text(text: str) -> str:
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
# def special_fetch_zhihu(link: str) -> str:
# '''
# function-calling 函数, 用于获取知乎文章的内容
# '''
# response = requests.get(link, headers=HEADERS)
# response.encoding = "utf-8"
# soup = BeautifulSoup(response.text, "html.parser")
# if "zhuanlan.zhihu.com" in link:
# r = soup.find(class_="Post-RichTextContainer")
# else:
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
# if r is None:
# print("debug: zhihu none")
# raise Exception("zhihu none")
# return tidy_text(r.text)
async def search_from_bing(keyword: str) -> str:
async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str:
'''
tools, 从 bing 搜索引擎搜索
'''
@@ -84,10 +69,11 @@ async def search_from_bing(keyword: str) -> str:
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
return ret
return await summarize(context, ame, ret)
async def fetch_website_content(url):
async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str):
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
@@ -97,81 +83,13 @@ async def fetch_website_content(url):
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
return ret
async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
'''
@param official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
new_func_call.add_func("web_search", [{
"type": "string",
"name": "keyword",
"description": "搜索关键词"
}],
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
search_from_bing
)
new_func_call.add_func("fetch_website_content", [{
"type": "string",
"name": "url",
"description": "要获取内容的网页链接"
}],
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
fetch_website_content
)
return await summarize(context, ame, ret)
has_func = False
function_invoked_ret = ""
if official_fc:
# we use official function-calling
try:
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
except BadRequestError as e:
# seems dont support function-calling
logger.error(f"error: {e}. Try to use local function-calling implementation")
return await web_search(prompt, provider, session_id, official_fc=False)
if isinstance(result, Function):
logger.debug(f"function-calling: {result}")
func_obj = None
for i in new_func_call.func_list:
if i["name"] == result.name:
func_obj = i["func_obj"]
break
if not func_obj:
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
try:
args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
else:
return result
else:
# we use our own function-calling
try:
args = {
'question': prompt,
'func_definition': new_func_call.func_dump(),
}
function_invoked_ret, has_func = await new_func_call.func_call(**args)
if not has_func:
return await provider.text_chat(prompt, session_id)
except BaseException as e:
logger.error(traceback.format_exc())
return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
if has_func:
await provider.forget(session_id=session_id)
summary_prompt = f"""
async def summarize(context: Context, ame: AstrMessageEvent, text: str):
summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结;
2. 简单地发表你对这个问题的看法。
# 例子
@@ -183,7 +101,7 @@ async def web_search(prompt: str, provider: Provider, session_id: str, official_
2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
return ret
return function_invoked_ret
{text}"""
provider = context.get_current_llm_provider()
return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id)