From 92aa3123ec0cb184f580b116fc986731bbdc4f44 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 11 Dec 2024 13:20:21 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=94=AF=E6=8C=81llm=20tool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/api/__init__.py | 4 +- astrbot/api/all.py | 1 + .../strategies/strategy.py | 23 ++-- .../process_stage/method/llm_request.py | 22 +++- .../process_stage/method/star_request.py | 2 +- astrbot/core/pipeline/waking_check/stage.py | 1 - astrbot/core/provider/manager.py | 5 +- astrbot/core/provider/register.py | 45 ++++++- .../core/provider/sources/openai_source.py | 3 +- astrbot/core/provider/tool.py | 119 +++++++++++------- astrbot/core/star/context.py | 34 +++-- requirements.txt | 3 +- 12 files changed, 176 insertions(+), 86 deletions(-) diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index 960495093..d484e7c9e 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -2,10 +2,12 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger from astrbot.core.utils.personality import personalities from astrbot.core import html_renderer +from astrbot.core.provider.register import register_llm_tool as llm_tool __all__ = [ "AstrBotConfig", "logger", "personalities", - "html_renderer" + "html_renderer", + "llm_tool", ] \ No newline at end of file diff --git a/astrbot/api/all.py b/astrbot/api/all.py index 814c54cce..226e2b30f 100644 --- a/astrbot/api/all.py +++ b/astrbot/api/all.py @@ -3,6 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger from astrbot.core.utils.personality import personalities from astrbot.core import html_renderer +from astrbot.core.provider.register import register_llm_tool as llm_tool # event from astrbot.core.message.message_event_result import ( diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py index 21bb58535..dc1bf7e09 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -2,22 +2,27 @@ from . import ContentSafetyStrategy from typing import List, Tuple -class StrategySelector(): +class StrategySelector: def __init__(self, config: dict) -> None: self.enabled_strategies: List[ContentSafetyStrategy] = [] - if config['internal_keywords']['enable']: + if config["internal_keywords"]["enable"]: from .keywords import KeywordsStrategy - self.enabled_strategies.append(KeywordsStrategy( - config['internal_keywords']['extra_keywords'])) - if config['baidu_aip']['enable']: + + self.enabled_strategies.append( + KeywordsStrategy(config["internal_keywords"]["extra_keywords"]) + ) + if config["baidu_aip"]["enable"]: try: from .baidu_aip import BaiduAipStrategy except ImportError: raise ImportError("使用百度内容审核应该先 pip install baidu-aip") - self.enabled_strategies.append(BaiduAipStrategy(config['baidu_aip']['app_id'], - config['baidu_aip']['api_key'], - config['baidu_aip']['secret_key'] - )) + self.enabled_strategies.append( + BaiduAipStrategy( + config["baidu_aip"]["app_id"], + config["baidu_aip"]["api_key"], + config["baidu_aip"]["secret_key"], + ) + ) def check(self, content: str) -> Tuple[bool, str]: for strategy in self.enabled_strategies: diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index f764609d2..7e67cab3b 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -7,6 +7,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Comman from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric +from astrbot.core.star.star import star_map class LLMRequestSubStage(Stage): @@ -39,7 +40,7 @@ class LLMRequestSubStage(Stage): prompt=event.message_str, session_id=event.session_id, image_urls=image_urls, - tools=tools + func_tool=tools ) await Metric.upload(llm_tick=1, model_name=self.curr_provider.get_model(), provider_type=self.curr_provider.meta().type) @@ -50,16 +51,29 @@ class LLMRequestSubStage(Stage): # function calling for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args): func_tool = tools.get_func(func_tool_name) - logger.debug(f"调用工具函数:{func_tool_name},参数:{func_tool_args}") + logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}") try: - ret = await func_tool(event=event, *func_tool_args) + # 尝试调用工具函数 + + star_cls_obj = star_map.get(func_tool.module_name).star_cls + # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) + if hasattr(func_tool.func_obj, '__self__'): + # 猜测没有通过装饰器去注册 + try: + ret = await func_tool.func_obj(event, **func_tool_args) + except TypeError: + # 向下兼容 + ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args) + else: + ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args) if ret: - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。" + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" event.stop_event() event.set_result(ret) # 执行后续步骤来发送消息 yield + event.clear_result() # 清除上一个 func tool 的结果 except BaseException: logger.error(traceback.format_exc()) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 516b9076d..2394ff7ac 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -40,7 +40,7 @@ class StarRequestSubStage(Stage): ret = await handler.handler(star_cls_obj, event, **params) logger.debug("star handler %s called" % handler.handler_full_name) if ret: - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。" + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" event.stop_event() event.set_result(ret) # 执行后续步骤来发送消息 diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index f04e8a9bc..739d9de18 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -58,7 +58,6 @@ class WakingCheckStage(Stage): handlers_parsed_params = {} # 注册了指令的 handler for handler in star_handlers_registry: # filter 需要满足 AND 的逻辑关系 - print(handler.handler_full_name) passed = True child_command_handler_md = None diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 577341818..c4e543e7d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -3,8 +3,7 @@ from .provider import Provider from typing import List from astrbot.core.db import BaseDatabase from collections import defaultdict -from astrbot.core.provider.tool import FuncCall -from .register import provider_cls_map +from .register import provider_cls_map, llm_tools from astrbot.core import logger class ProviderManager(): @@ -13,7 +12,7 @@ class ProviderManager(): self.provider_settings: dict = config['provider_settings'] self.provider_insts: List[Provider] = [] '''加载的 Provider 的实例''' - self.llm_tools: FuncCall = FuncCall() + self.llm_tools = llm_tools self.curr_provider_inst: Provider = None self.loaded_ids = defaultdict(bool) self.db_helper = db_helper diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 1ed812b68..150338f1d 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,12 +1,16 @@ -from typing import List, Dict, Type +import docstring_parser +from typing import List, Dict, Type, Awaitable from .provider_metadata import ProviderMetaData from astrbot.core import logger +from .tool import FuncCall, SUPPORTED_TYPES provider_registry: List[ProviderMetaData] = [] '''维护了通过装饰器注册的 Provider''' provider_cls_map: Dict[str, Type] = {} '''维护了 Provider 类型名称和 Provider 类的映射''' +llm_tools = FuncCall() + def register_provider_adapter(provider_type_name: str, desc: str): '''用于注册平台适配器的带参装饰器''' def decorator(cls): @@ -23,3 +27,42 @@ def register_provider_adapter(provider_type_name: str, desc: str): return cls return decorator + +def register_llm_tool(name: str = None): + '''为函数调用(function-calling / tools-use)添加工具。 + + 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) + + ``` + @llm_tool(name="get_weather") # 如果 name 不填,将使用函数名 + async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult: + \'\'\'获取天气信息。 + + Args: + location(string): 地点 + \'\'\' + # 处理逻辑 + ``` + + ''' + name_ = name + + def decorator(func_obj: Awaitable): + llm_tool_name = name_ if name_ else func_obj.__name__ + module_name = func_obj.__module__ + docstring = docstring_parser.parse(func_obj.__doc__) + args = [] + for arg in docstring.params: + if arg.type_name not in SUPPORTED_TYPES: + raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}") + args.append({ + "type": arg.type_name, + "name": arg.arg_name, + "description": arg.description + }) + llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name) + + logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") + return func_obj + + return decorator \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index ec0de2364..79aba2ae6 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -93,6 +93,7 @@ class ProviderOpenAIOfficial(Provider): async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse: if tools: + logger.debug("request with llm tools") payloads["tools"] = tools.get_func_desc_openai_style() completion = await self.client.chat.completions.create( @@ -117,7 +118,7 @@ class ProviderOpenAIOfficial(Provider): func_name_ls = [] for tool_call in choice.message.tool_calls: for tool in tools.func_list: - if tool['name'] == tool_call.function.name: + if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name) diff --git a/astrbot/core/provider/tool.py b/astrbot/core/provider/tool.py index 429038ee3..38e599c2d 100644 --- a/astrbot/core/provider/tool.py +++ b/astrbot/core/provider/tool.py @@ -1,7 +1,7 @@ import json import textwrap from typing import Awaitable, Dict, List -from typing_extensions import TypedDict +from dataclasses import dataclass class FuncCallJsonFormatError(Exception): @@ -11,6 +11,7 @@ class FuncCallJsonFormatError(Exception): def __str__(self): return self.msg + class FuncNotFoundError(Exception): def __init__(self, msg): self.msg = msg @@ -18,86 +19,115 @@ class FuncNotFoundError(Exception): def __str__(self): return self.msg -class FuncTool(TypedDict): - ''' + +@dataclass +class FuncTool: + """ 用于描述一个函数调用工具。 - ''' + """ + name: str parameters: Dict description: str func_obj: Awaitable + module_name: str = None -class FuncCall(): +SUPPORTED_TYPES = [ + "string", + "number", + "object", + "array", + "boolean", +] # json schema 支持的数据类型 + + +class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - + def empty(self) -> bool: return len(self.func_list) == 0 - def add_func(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: - ''' + def add_func( + self, + name: str, + func_args: list, + desc: str, + func_obj: Awaitable, + module_name: str = None, + ) -> None: + """ 为函数调用(function-calling / tools-use)添加工具。 - + @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @param desc: 函数描述 @param func_obj: 处理函数 - ''' + """ params = { "type": "object", # hard-coded here - "properties": {} + "properties": {}, } for param in func_args: - params['properties'][param['name']] = { - "type": param['type'], - "description": param['description'] + params["properties"][param["name"]] = { + "type": param["type"], + "description": param["description"], } - _func = FuncTool(name=name, parameters=params, description=desc, func_obj=func_obj) + _func = FuncTool( + name=name, + parameters=params, + description=desc, + func_obj=func_obj, + module_name=module_name, + ) self.func_list.append(_func) - + def remove_func(self, name: str) -> None: - ''' + """ 删除一个函数调用工具。 - ''' + """ for i, f in enumerate(self.func_list): if f["name"] == name: self.func_list.pop(i) break - + def get_func(self, name) -> FuncTool: for f in self.func_list: - if f["name"] == name: + if f.name == name: return f return None - + def get_func_desc_openai_style(self) -> list: - ''' + """ 获得 OpenAI API 风格的工具描述 - ''' + """ _l = [] for f in self.func_list: - _l.append({ - "type": "function", - "function": { + _l.append( + { + "type": "function", + "function": { + "name": f.name, + "parameters": f.parameters, + "description": f.description, + }, + } + ) + return _l + + async def func_call(self, question: str, session_id: str, provider) -> tuple: + _l = [] + for f in self.func_list: + _l.append( + { "name": f["name"], "parameters": f["parameters"], "description": f["description"], } - }) - return _l - - async def func_call(self, question: str, session_id: str, provider) -> tuple: - - _l = [] - for f in self.func_list: - _l.append({ - "name": f["name"], - "parameters": f["parameters"], - "description": f["description"], - }) + ) func_definition = json.dumps(_l, ensure_ascii=False) - + prompt = textwrap.dedent(f""" ROLE: 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。 @@ -123,8 +153,8 @@ class FuncCall(): while _c < 3: try: res = await provider.text_chat(prompt, session_id) - if res.find('```') != -1: - res = res[res.find('```json') + 7: res.rfind('```')] + if res.find("```") != -1: + res = res[res.find("```json") + 7 : res.rfind("```")] res = json.loads(res) break except Exception as e: @@ -133,8 +163,8 @@ class FuncCall(): raise e if "The message you submitted was too long" in str(e): raise e - - if 'res' in res and not res['res']: + + if "res" in res and not res["res"]: return "", False tool_call_result = [] @@ -149,8 +179,7 @@ class FuncCall(): tool_callable = func["func_obj"] break if not tool_callable: - raise FuncNotFoundError( - f"Request function {func_name} not found.") + raise FuncNotFoundError(f"Request function {func_name} not found.") ret = await tool_callable(**args) if ret: tool_call_result.append(str(ret)) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index e2dc8d6b8..a203e1215 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -52,29 +52,25 @@ class Context: 获取 LLM Tools。 ''' return self.provider_manager.llm_tools - - - # def get_star_commands(self, star_name: str) -> List[]: - # '''获得一个''' - # def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: - # ''' - # 为函数调用(function-calling / tools-use)添加工具。 + def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None: + ''' + 为函数调用(function-calling / tools-use)添加工具。 - # @param name: 函数名 - # @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] - # @param desc: 函数描述 - # @param func_obj: 异步处理函数。 + @param name: 函数名 + @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + @param desc: 函数描述 + @param func_obj: 异步处理函数。 - # 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 - # ''' - # self.llm_tools.add_func(name, func_args, desc, func_obj) + 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 + ''' + self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__) - # def unregister_llm_tool(self, name: str) -> None: - # ''' - # 删除一个函数调用工具。 - # ''' - # self.llm_tools.remove_func(name) + def unregister_llm_tool(self, name: str) -> None: + ''' + 删除一个函数调用工具。 + ''' + self.provider_manager.llm_tools.remove_func(name) def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False): ''' diff --git a/requirements.txt b/requirements.txt index 89f0f0edc..14b842df6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ lxml_html_clean colorlog aiocqhttp pyjwt -apscheduler \ No newline at end of file +apscheduler +docstring_parser \ No newline at end of file