diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index c3d26305e..89383a463 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -1,4 +1,4 @@ -import time +import time, json import re, os import asyncio import traceback @@ -16,6 +16,8 @@ from logging import Logger from nakuru.entities.components import Image from util.agent.func_call import FuncCall import util.agent.web_searcher as web_searcher +from openai._exceptions import * +from openai.types.chat.chat_completion_message_tool_call import Function logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -186,31 +188,82 @@ class MessageHandler(): image_url = comp.url if comp.url else comp.file break - web_search = self.context.web_search - if not web_search and msg_plain.startswith("ws"): - # leverage web search feature - web_search = True - msg_plain = msg_plain.removeprefix("ws").strip() - + # web_search = self.context.web_search + # if not web_search and msg_plain.startswith("ws"): + # # leverage web search feature + # web_search = True + # msg_plain = msg_plain.removeprefix("ws").strip() try: - if web_search: - llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True) + if not self.llm_tools.empty(): + # tools-use + tool_use_flag = True + llm_result = await provider.text_chat( + prompt=msg_plain, + session_id=message.session_id, + tools=self.llm_tools.get_func() + ) + + if isinstance(llm_result, Function): + logger.debug(f"function-calling: {llm_result}") + func_obj = None + for i in self.llm_tools.func_list: + if i["name"] == llm_result.name: + func_obj = i["func_obj"] + break + if not func_obj: + return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。") + try: + args = json.loads(llm_result.arguments) + function_invoked_ret = await func_obj(**args) + has_func = True + except BaseException as e: + traceback.print_exc() + return MessageResult("AstrBot Function-calling 异常:" + str(e)) + else: + return MessageResult(llm_result) + else: + # normal chat + tool_use_flag = False llm_result = await provider.text_chat( prompt=msg_plain, session_id=message.session_id, image_url=image_url ) + except BadRequestError as e: + if tool_use_flag: + # seems like the model don't support function-calling + logger.error(f"error: {e}. Using local function-calling implementation") + + try: + # use local function-calling implementation + args = { + 'question': llm_result, + 'func_definition': self.llm_tools.func_dump(), + } + _, has_func = await self.llm_tools.func_call(**args) + + if not has_func: + # normal chat + llm_result = await provider.text_chat( + prompt=msg_plain, + session_id=message.session_id, + image_url=image_url + ) + except BaseException as e: + logger.error(traceback.format_exc()) + return CommandResult("AstrBot Function-calling 异常:" + str(e)) + except BaseException as e: logger.error(traceback.format_exc()) logger.error(f"LLM 调用失败。") return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e)) - - # concatenate the reply prefix + + # concatenate reply prefix if self.reply_prefix: llm_result = self.reply_prefix + llm_result - # mask the unsafe content + # mask unsafe content llm_result = self.content_safety_helper.filter_content(llm_result) check = self.content_safety_helper.baidu_check(llm_result) if not check: diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 422c226f8..9d16f8dca 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -9,6 +9,7 @@ from type.config import VERSION from SparkleLogging.utils.core import LogManager from logging import Logger from nakuru.entities.components import Image +from util.agent.web_searcher import search_from_bing, fetch_website_content logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -212,6 +213,23 @@ class InternalCommandHandler: ) elif l[1] == 'on': context.web_search = True + context.register_llm_tool("web_search", [{ + "type": "string", + "name": "keyword", + "description": "搜索关键词" + }], + "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", + search_from_bing + ) + context.register_llm_tool("fetch_website_content", [{ + "type": "string", + "name": "url", + "description": "要获取内容的网页链接" + }], + "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", + fetch_website_content + ) + return CommandResult( hit=True, success=True, @@ -219,6 +237,9 @@ class InternalCommandHandler: ) elif l[1] == 'off': context.web_search = False + context.unregister_llm_tool("web_search") + context.unregister_llm_tool("fetch_website_content") + return CommandResult( hit=True, success=True, diff --git a/type/types.py b/type/types.py index fd4d07d9c..542ab4fbf 100644 --- a/type/types.py +++ b/type/types.py @@ -110,6 +110,12 @@ class Context: ''' self.message_handler.llm_tools.add_func(tool_name, params, desc, func) + def unregister_llm_tool(self, tool_name: str): + ''' + 删除一个函数调用工具。 + ''' + self.message_handler.llm_tools.remove_func(tool_name) + def find_platform(self, platform_name: str) -> RegisteredPlatform: for platform in self.platforms: if platform_name == platform.platform_name: @@ -131,4 +137,9 @@ class Context: platform_name, message_type, id = l platform = self.find_platform(platform_name) await platform.platform_instance.send_msg_new(MessageType(message_type), id, message) - \ No newline at end of file + + def get_current_llm_provider(self) -> Provider: + ''' + 获取当前的 LLM Provider。 + ''' + return self.message_handler.provider \ No newline at end of file diff --git a/util/agent/func_call.py b/util/agent/func_call.py index e805f5bfc..5283ee4d6 100644 --- a/util/agent/func_call.py +++ b/util/agent/func_call.py @@ -23,6 +23,9 @@ class FuncCall(): def __init__(self, provider: Provider) -> None: self.func_list = [] self.provider = provider + + def empty(self) -> bool: + return len(self.func_list) == 0 def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None: ''' @@ -34,7 +37,7 @@ class FuncCall(): @param func_obj: 处理函数 ''' params = { - "type": "object", # hardcore here + "type": "object", # hard-coded here "properties": {} } for param in func_args: @@ -42,14 +45,23 @@ class FuncCall(): "type": param['type'], "description": param['description'] } - self._func = { + _func = { "name": name, "parameters": params, "description": desc, "func_obj": func_obj, } - self.func_list.append(self._func) - + self.func_list.append(_func) + + def remove_func(self, name: str) -> None: + ''' + 删除一个函数调用工具。 + ''' + for i, f in enumerate(self.func_list): + if f["name"] == name: + self.func_list.pop(i) + break + def func_dump(self) -> str: _l = [] for f in self.func_list: diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py index c8634c2b7..e519ca697 100644 --- a/util/agent/web_searcher.py +++ b/util/agent/web_searcher.py @@ -16,6 +16,8 @@ from util.websearch.google import Google from model.provider.provider import Provider from SparkleLogging.utils.core import LogManager from logging import Logger +from type.types import Context +from type.message_event import AstrMessageEvent logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -31,24 +33,7 @@ def tidy_text(text: str) -> str: ''' return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") -# def special_fetch_zhihu(link: str) -> str: -# ''' -# function-calling 函数, 用于获取知乎文章的内容 -# ''' -# response = requests.get(link, headers=HEADERS) -# response.encoding = "utf-8" -# soup = BeautifulSoup(response.text, "html.parser") - -# if "zhuanlan.zhihu.com" in link: -# r = soup.find(class_="Post-RichTextContainer") -# else: -# r = soup.find(class_="List-item").find(class_="RichContent-inner") -# if r is None: -# print("debug: zhihu none") -# raise Exception("zhihu none") -# return tidy_text(r.text) - -async def search_from_bing(keyword: str) -> str: +async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str: ''' tools, 从 bing 搜索引擎搜索 ''' @@ -84,10 +69,11 @@ async def search_from_bing(keyword: str) -> str: site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n" idx += 1 - return ret + + return await summarize(context, ame, ret) -async def fetch_website_content(url): +async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str): header = HEADERS header.update({'User-Agent': random.choice(USER_AGENTS)}) async with aiohttp.ClientSession() as session: @@ -97,81 +83,13 @@ async def fetch_website_content(url): ret = doc.summary(html_partial=True) soup = BeautifulSoup(ret, 'html.parser') ret = tidy_text(soup.get_text()) - return ret - - -async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False): - ''' - @param official_fc: 使用官方 function-calling - ''' - new_func_call = FuncCall(provider) - - new_func_call.add_func("web_search", [{ - "type": "string", - "name": "keyword", - "description": "搜索关键词" - }], - "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。", - search_from_bing - ) - new_func_call.add_func("fetch_website_content", [{ - "type": "string", - "name": "url", - "description": "要获取内容的网页链接" - }], - "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。", - fetch_website_content - ) + return await summarize(context, ame, ret) - has_func = False - function_invoked_ret = "" - if official_fc: - # we use official function-calling - try: - result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func()) - except BadRequestError as e: - # seems dont support function-calling - logger.error(f"error: {e}. Try to use local function-calling implementation") - return await web_search(prompt, provider, session_id, official_fc=False) - if isinstance(result, Function): - logger.debug(f"function-calling: {result}") - func_obj = None - for i in new_func_call.func_list: - if i["name"] == result.name: - func_obj = i["func_obj"] - break - if not func_obj: - return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)" - try: - args = json.loads(result.arguments) - function_invoked_ret = await func_obj(**args) - has_func = True - except BaseException as e: - traceback.print_exc() - return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)" - else: - return result - else: - # we use our own function-calling - try: - args = { - 'question': prompt, - 'func_definition': new_func_call.func_dump(), - } - function_invoked_ret, has_func = await new_func_call.func_call(**args) - - if not has_func: - return await provider.text_chat(prompt, session_id) - - except BaseException as e: - logger.error(traceback.format_exc()) - return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)" - - if has_func: - await provider.forget(session_id=session_id) - summary_prompt = f""" +async def summarize(context: Context, ame: AstrMessageEvent, text: str): + + summary_prompt = f""" 你是一个专业且高效的助手,你的任务是 -1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; +1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结; 2. 简单地发表你对这个问题的看法。 # 例子 @@ -183,7 +101,7 @@ async def web_search(prompt: str, provider: Provider, session_id: str, official_ 2. 请**直接输出总结**,不要输出多余的内容和提示语。 # 相关材料 -{function_invoked_ret}""" - ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id) - return ret - return function_invoked_ret +{text}""" + + provider = context.get_current_llm_provider() + return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id) \ No newline at end of file