feat: customized tool-use
This commit is contained in:
+65
-12
@@ -1,4 +1,4 @@
|
||||
import time
|
||||
import time, json
|
||||
import re, os
|
||||
import asyncio
|
||||
import traceback
|
||||
@@ -16,6 +16,8 @@ from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.func_call import FuncCall
|
||||
import util.agent.web_searcher as web_searcher
|
||||
from openai._exceptions import *
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -186,31 +188,82 @@ class MessageHandler():
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
|
||||
web_search = self.context.web_search
|
||||
if not web_search and msg_plain.startswith("ws"):
|
||||
# leverage web search feature
|
||||
web_search = True
|
||||
msg_plain = msg_plain.removeprefix("ws").strip()
|
||||
|
||||
# web_search = self.context.web_search
|
||||
# if not web_search and msg_plain.startswith("ws"):
|
||||
# # leverage web search feature
|
||||
# web_search = True
|
||||
# msg_plain = msg_plain.removeprefix("ws").strip()
|
||||
try:
|
||||
if web_search:
|
||||
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
|
||||
if not self.llm_tools.empty():
|
||||
# tools-use
|
||||
tool_use_flag = True
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
tools=self.llm_tools.get_func()
|
||||
)
|
||||
|
||||
if isinstance(llm_result, Function):
|
||||
logger.debug(f"function-calling: {llm_result}")
|
||||
func_obj = None
|
||||
for i in self.llm_tools.func_list:
|
||||
if i["name"] == llm_result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。")
|
||||
try:
|
||||
args = json.loads(llm_result.arguments)
|
||||
function_invoked_ret = await func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return MessageResult("AstrBot Function-calling 异常:" + str(e))
|
||||
else:
|
||||
return MessageResult(llm_result)
|
||||
|
||||
else:
|
||||
# normal chat
|
||||
tool_use_flag = False
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BadRequestError as e:
|
||||
if tool_use_flag:
|
||||
# seems like the model don't support function-calling
|
||||
logger.error(f"error: {e}. Using local function-calling implementation")
|
||||
|
||||
try:
|
||||
# use local function-calling implementation
|
||||
args = {
|
||||
'question': llm_result,
|
||||
'func_definition': self.llm_tools.func_dump(),
|
||||
}
|
||||
_, has_func = await self.llm_tools.func_call(**args)
|
||||
|
||||
if not has_func:
|
||||
# normal chat
|
||||
llm_result = await provider.text_chat(
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return CommandResult("AstrBot Function-calling 异常:" + str(e))
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"LLM 调用失败。")
|
||||
return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e))
|
||||
|
||||
# concatenate the reply prefix
|
||||
|
||||
# concatenate reply prefix
|
||||
if self.reply_prefix:
|
||||
llm_result = self.reply_prefix + llm_result
|
||||
|
||||
# mask the unsafe content
|
||||
# mask unsafe content
|
||||
llm_result = self.content_safety_helper.filter_content(llm_result)
|
||||
check = self.content_safety_helper.baidu_check(llm_result)
|
||||
if not check:
|
||||
|
||||
@@ -9,6 +9,7 @@ from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.web_searcher import search_from_bing, fetch_website_content
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -212,6 +213,23 @@ class InternalCommandHandler:
|
||||
)
|
||||
elif l[1] == 'on':
|
||||
context.web_search = True
|
||||
context.register_llm_tool("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
context.register_llm_tool("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
@@ -219,6 +237,9 @@ class InternalCommandHandler:
|
||||
)
|
||||
elif l[1] == 'off':
|
||||
context.web_search = False
|
||||
context.unregister_llm_tool("web_search")
|
||||
context.unregister_llm_tool("fetch_website_content")
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
|
||||
+12
-1
@@ -110,6 +110,12 @@ class Context:
|
||||
'''
|
||||
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
|
||||
|
||||
def unregister_llm_tool(self, tool_name: str):
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
self.message_handler.llm_tools.remove_func(tool_name)
|
||||
|
||||
def find_platform(self, platform_name: str) -> RegisteredPlatform:
|
||||
for platform in self.platforms:
|
||||
if platform_name == platform.platform_name:
|
||||
@@ -131,4 +137,9 @@ class Context:
|
||||
platform_name, message_type, id = l
|
||||
platform = self.find_platform(platform_name)
|
||||
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
|
||||
|
||||
|
||||
def get_current_llm_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前的 LLM Provider。
|
||||
'''
|
||||
return self.message_handler.provider
|
||||
+16
-4
@@ -23,6 +23,9 @@ class FuncCall():
|
||||
def __init__(self, provider: Provider) -> None:
|
||||
self.func_list = []
|
||||
self.provider = provider
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
|
||||
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
|
||||
'''
|
||||
@@ -34,7 +37,7 @@ class FuncCall():
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
params = {
|
||||
"type": "object", # hardcore here
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {}
|
||||
}
|
||||
for param in func_args:
|
||||
@@ -42,14 +45,23 @@ class FuncCall():
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
}
|
||||
self._func = {
|
||||
_func = {
|
||||
"name": name,
|
||||
"parameters": params,
|
||||
"description": desc,
|
||||
"func_obj": func_obj,
|
||||
}
|
||||
self.func_list.append(self._func)
|
||||
|
||||
self.func_list.append(_func)
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
for i, f in enumerate(self.func_list):
|
||||
if f["name"] == name:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def func_dump(self) -> str:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
|
||||
+15
-97
@@ -16,6 +16,8 @@ from util.websearch.google import Google
|
||||
from model.provider.provider import Provider
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.message_event import AstrMessageEvent
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -31,24 +33,7 @@ def tidy_text(text: str) -> str:
|
||||
'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
# def special_fetch_zhihu(link: str) -> str:
|
||||
# '''
|
||||
# function-calling 函数, 用于获取知乎文章的内容
|
||||
# '''
|
||||
# response = requests.get(link, headers=HEADERS)
|
||||
# response.encoding = "utf-8"
|
||||
# soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
# if "zhuanlan.zhihu.com" in link:
|
||||
# r = soup.find(class_="Post-RichTextContainer")
|
||||
# else:
|
||||
# r = soup.find(class_="List-item").find(class_="RichContent-inner")
|
||||
# if r is None:
|
||||
# print("debug: zhihu none")
|
||||
# raise Exception("zhihu none")
|
||||
# return tidy_text(r.text)
|
||||
|
||||
async def search_from_bing(keyword: str) -> str:
|
||||
async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str:
|
||||
'''
|
||||
tools, 从 bing 搜索引擎搜索
|
||||
'''
|
||||
@@ -84,10 +69,11 @@ async def search_from_bing(keyword: str) -> str:
|
||||
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
|
||||
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
|
||||
idx += 1
|
||||
return ret
|
||||
|
||||
return await summarize(context, ame, ret)
|
||||
|
||||
|
||||
async def fetch_website_content(url):
|
||||
async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str):
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -97,81 +83,13 @@ async def fetch_website_content(url):
|
||||
ret = doc.summary(html_partial=True)
|
||||
soup = BeautifulSoup(ret, 'html.parser')
|
||||
ret = tidy_text(soup.get_text())
|
||||
return ret
|
||||
|
||||
|
||||
async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
|
||||
'''
|
||||
@param official_fc: 使用官方 function-calling
|
||||
'''
|
||||
new_func_call = FuncCall(provider)
|
||||
|
||||
new_func_call.add_func("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
"description": "搜索关键词"
|
||||
}],
|
||||
"通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
|
||||
search_from_bing
|
||||
)
|
||||
new_func_call.add_func("fetch_website_content", [{
|
||||
"type": "string",
|
||||
"name": "url",
|
||||
"description": "要获取内容的网页链接"
|
||||
}],
|
||||
"获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
|
||||
fetch_website_content
|
||||
)
|
||||
return await summarize(context, ame, ret)
|
||||
|
||||
has_func = False
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
# we use official function-calling
|
||||
try:
|
||||
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
|
||||
except BadRequestError as e:
|
||||
# seems dont support function-calling
|
||||
logger.error(f"error: {e}. Try to use local function-calling implementation")
|
||||
return await web_search(prompt, provider, session_id, official_fc=False)
|
||||
if isinstance(result, Function):
|
||||
logger.debug(f"function-calling: {result}")
|
||||
func_obj = None
|
||||
for i in new_func_call.func_list:
|
||||
if i["name"] == result.name:
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(result.arguments)
|
||||
function_invoked_ret = await func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
# we use our own function-calling
|
||||
try:
|
||||
args = {
|
||||
'question': prompt,
|
||||
'func_definition': new_func_call.func_dump(),
|
||||
}
|
||||
function_invoked_ret, has_func = await new_func_call.func_call(**args)
|
||||
|
||||
if not has_func:
|
||||
return await provider.text_chat(prompt, session_id)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
|
||||
|
||||
if has_func:
|
||||
await provider.forget(session_id=session_id)
|
||||
summary_prompt = f"""
|
||||
async def summarize(context: Context, ame: AstrMessageEvent, text: str):
|
||||
|
||||
summary_prompt = f"""
|
||||
你是一个专业且高效的助手,你的任务是
|
||||
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
|
||||
1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结;
|
||||
2. 简单地发表你对这个问题的看法。
|
||||
|
||||
# 例子
|
||||
@@ -183,7 +101,7 @@ async def web_search(prompt: str, provider: Provider, session_id: str, official_
|
||||
2. 请**直接输出总结**,不要输出多余的内容和提示语。
|
||||
|
||||
# 相关材料
|
||||
{function_invoked_ret}"""
|
||||
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
|
||||
return ret
|
||||
return function_invoked_ret
|
||||
{text}"""
|
||||
|
||||
provider = context.get_current_llm_provider()
|
||||
return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id)
|
||||
Reference in New Issue
Block a user