refactor: 支持llm tool

This commit is contained in:
Soulter
2024-12-11 13:20:21 +08:00
parent e9e789da20
commit 92aa3123ec
12 changed files with 176 additions and 86 deletions
+3 -1
View File
@@ -2,10 +2,12 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
__all__ = [
"AstrBotConfig",
"logger",
"personalities",
"html_renderer"
"html_renderer",
"llm_tool",
]
+1
View File
@@ -3,6 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
# event
from astrbot.core.message.message_event_result import (
@@ -2,22 +2,27 @@ from . import ContentSafetyStrategy
from typing import List, Tuple
class StrategySelector():
class StrategySelector:
def __init__(self, config: dict) -> None:
self.enabled_strategies: List[ContentSafetyStrategy] = []
if config['internal_keywords']['enable']:
if config["internal_keywords"]["enable"]:
from .keywords import KeywordsStrategy
self.enabled_strategies.append(KeywordsStrategy(
config['internal_keywords']['extra_keywords']))
if config['baidu_aip']['enable']:
self.enabled_strategies.append(
KeywordsStrategy(config["internal_keywords"]["extra_keywords"])
)
if config["baidu_aip"]["enable"]:
try:
from .baidu_aip import BaiduAipStrategy
except ImportError:
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
self.enabled_strategies.append(BaiduAipStrategy(config['baidu_aip']['app_id'],
config['baidu_aip']['api_key'],
config['baidu_aip']['secret_key']
))
self.enabled_strategies.append(
BaiduAipStrategy(
config["baidu_aip"]["app_id"],
config["baidu_aip"]["api_key"],
config["baidu_aip"]["secret_key"],
)
)
def check(self, content: str) -> Tuple[bool, str]:
for strategy in self.enabled_strategies:
@@ -7,6 +7,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Comman
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.star.star import star_map
class LLMRequestSubStage(Stage):
@@ -39,7 +40,7 @@ class LLMRequestSubStage(Stage):
prompt=event.message_str,
session_id=event.session_id,
image_urls=image_urls,
tools=tools
func_tool=tools
)
await Metric.upload(llm_tick=1, model_name=self.curr_provider.get_model(), provider_type=self.curr_provider.meta().type)
@@ -50,16 +51,29 @@ class LLMRequestSubStage(Stage):
# function calling
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = tools.get_func(func_tool_name)
logger.debug(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
ret = await func_tool(event=event, *func_tool_args)
# 尝试调用工具函数
star_cls_obj = star_map.get(func_tool.module_name).star_cls
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
if hasattr(func_tool.func_obj, '__self__'):
# 猜测没有通过装饰器去注册
try:
ret = await func_tool.func_obj(event, **func_tool_args)
except TypeError:
# 向下兼容
ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
else:
ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args)
if ret:
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.stop_event()
event.set_result(ret)
# 执行后续步骤来发送消息
yield
event.clear_result() # 清除上一个 func tool 的结果
except BaseException:
logger.error(traceback.format_exc())
@@ -40,7 +40,7 @@ class StarRequestSubStage(Stage):
ret = await handler.handler(star_cls_obj, event, **params)
logger.debug("star handler %s called" % handler.handler_full_name)
if ret:
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.stop_event()
event.set_result(ret)
# 执行后续步骤来发送消息
@@ -58,7 +58,6 @@ class WakingCheckStage(Stage):
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry:
# filter 需要满足 AND 的逻辑关系
print(handler.handler_full_name)
passed = True
child_command_handler_md = None
+2 -3
View File
@@ -3,8 +3,7 @@ from .provider import Provider
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from astrbot.core.provider.tool import FuncCall
from .register import provider_cls_map
from .register import provider_cls_map, llm_tools
from astrbot.core import logger
class ProviderManager():
@@ -13,7 +12,7 @@ class ProviderManager():
self.provider_settings: dict = config['provider_settings']
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.llm_tools: FuncCall = FuncCall()
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
+44 -1
View File
@@ -1,12 +1,16 @@
from typing import List, Dict, Type
import docstring_parser
from typing import List, Dict, Type, Awaitable
from .provider_metadata import ProviderMetaData
from astrbot.core import logger
from .tool import FuncCall, SUPPORTED_TYPES
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
provider_cls_map: Dict[str, Type] = {}
'''维护了 Provider 类型名称和 Provider 类的映射'''
llm_tools = FuncCall()
def register_provider_adapter(provider_type_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
@@ -23,3 +27,42 @@ def register_provider_adapter(provider_type_name: str, desc: str):
return cls
return decorator
def register_llm_tool(name: str = None):
'''为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
'''
name_ = name
def decorator(func_obj: Awaitable):
llm_tool_name = name_ if name_ else func_obj.__name__
module_name = func_obj.__module__
docstring = docstring_parser.parse(func_obj.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return func_obj
return decorator
@@ -93,6 +93,7 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
logger.debug("request with llm tools")
payloads["tools"] = tools.get_func_desc_openai_style()
completion = await self.client.chat.completions.create(
@@ -117,7 +118,7 @@ class ProviderOpenAIOfficial(Provider):
func_name_ls = []
for tool_call in choice.message.tool_calls:
for tool in tools.func_list:
if tool['name'] == tool_call.function.name:
if tool.name == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
+74 -45
View File
@@ -1,7 +1,7 @@
import json
import textwrap
from typing import Awaitable, Dict, List
from typing_extensions import TypedDict
from dataclasses import dataclass
class FuncCallJsonFormatError(Exception):
@@ -11,6 +11,7 @@ class FuncCallJsonFormatError(Exception):
def __str__(self):
return self.msg
class FuncNotFoundError(Exception):
def __init__(self, msg):
self.msg = msg
@@ -18,86 +19,115 @@ class FuncNotFoundError(Exception):
def __str__(self):
return self.msg
class FuncTool(TypedDict):
'''
@dataclass
class FuncTool:
"""
用于描述一个函数调用工具。
'''
"""
name: str
parameters: Dict
description: str
func_obj: Awaitable
module_name: str = None
class FuncCall():
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
def empty(self) -> bool:
return len(self.func_list) == 0
def add_func(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
def add_func(
self,
name: str,
func_args: list,
desc: str,
func_obj: Awaitable,
module_name: str = None,
) -> None:
"""
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
'''
"""
params = {
"type": "object", # hard-coded here
"properties": {}
"properties": {},
}
for param in func_args:
params['properties'][param['name']] = {
"type": param['type'],
"description": param['description']
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FuncTool(name=name, parameters=params, description=desc, func_obj=func_obj)
_func = FuncTool(
name=name,
parameters=params,
description=desc,
func_obj=func_obj,
module_name=module_name,
)
self.func_list.append(_func)
def remove_func(self, name: str) -> None:
'''
"""
删除一个函数调用工具。
'''
"""
for i, f in enumerate(self.func_list):
if f["name"] == name:
self.func_list.pop(i)
break
def get_func(self, name) -> FuncTool:
for f in self.func_list:
if f["name"] == name:
if f.name == name:
return f
return None
def get_func_desc_openai_style(self) -> list:
'''
"""
获得 OpenAI API 风格的工具描述
'''
"""
_l = []
for f in self.func_list:
_l.append({
"type": "function",
"function": {
_l.append(
{
"type": "function",
"function": {
"name": f.name,
"parameters": f.parameters,
"description": f.description,
},
}
)
return _l
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
_l.append(
{
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
}
})
return _l
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
_l.append({
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
})
)
func_definition = json.dumps(_l, ensure_ascii=False)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
@@ -123,8 +153,8 @@ class FuncCall():
while _c < 3:
try:
res = await provider.text_chat(prompt, session_id)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
if res.find("```") != -1:
res = res[res.find("```json") + 7 : res.rfind("```")]
res = json.loads(res)
break
except Exception as e:
@@ -133,8 +163,8 @@ class FuncCall():
raise e
if "The message you submitted was too long" in str(e):
raise e
if 'res' in res and not res['res']:
if "res" in res and not res["res"]:
return "", False
tool_call_result = []
@@ -149,8 +179,7 @@ class FuncCall():
tool_callable = func["func_obj"]
break
if not tool_callable:
raise FuncNotFoundError(
f"Request function {func_name} not found.")
raise FuncNotFoundError(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
if ret:
tool_call_result.append(str(ret))
+15 -19
View File
@@ -52,29 +52,25 @@ class Context:
获取 LLM Tools。
'''
return self.provider_manager.llm_tools
# def get_star_commands(self, star_name: str) -> List[]:
# '''获得一个'''
# def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
# '''
# 为函数调用(function-calling / tools-use)添加工具。
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
# @param name: 函数名
# @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
# @param desc: 函数描述
# @param func_obj: 异步处理函数。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
# 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
# '''
# self.llm_tools.add_func(name, func_args, desc, func_obj)
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
# def unregister_llm_tool(self, name: str) -> None:
# '''
# 删除一个函数调用工具。
# '''
# self.llm_tools.remove_func(name)
def unregister_llm_tool(self, name: str) -> None:
'''
删除一个函数调用工具。
'''
self.provider_manager.llm_tools.remove_func(name)
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
+2 -1
View File
@@ -14,4 +14,5 @@ lxml_html_clean
colorlog
aiocqhttp
pyjwt
apscheduler
apscheduler
docstring_parser