refactor: 支持llm tool
This commit is contained in:
@@ -2,10 +2,12 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.provider.register import register_llm_tool as llm_tool
|
||||
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"personalities",
|
||||
"html_renderer"
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
]
|
||||
@@ -3,6 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.provider.register import register_llm_tool as llm_tool
|
||||
|
||||
# event
|
||||
from astrbot.core.message.message_event_result import (
|
||||
|
||||
@@ -2,22 +2,27 @@ from . import ContentSafetyStrategy
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class StrategySelector():
|
||||
class StrategySelector:
|
||||
def __init__(self, config: dict) -> None:
|
||||
self.enabled_strategies: List[ContentSafetyStrategy] = []
|
||||
if config['internal_keywords']['enable']:
|
||||
if config["internal_keywords"]["enable"]:
|
||||
from .keywords import KeywordsStrategy
|
||||
self.enabled_strategies.append(KeywordsStrategy(
|
||||
config['internal_keywords']['extra_keywords']))
|
||||
if config['baidu_aip']['enable']:
|
||||
|
||||
self.enabled_strategies.append(
|
||||
KeywordsStrategy(config["internal_keywords"]["extra_keywords"])
|
||||
)
|
||||
if config["baidu_aip"]["enable"]:
|
||||
try:
|
||||
from .baidu_aip import BaiduAipStrategy
|
||||
except ImportError:
|
||||
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
|
||||
self.enabled_strategies.append(BaiduAipStrategy(config['baidu_aip']['app_id'],
|
||||
config['baidu_aip']['api_key'],
|
||||
config['baidu_aip']['secret_key']
|
||||
))
|
||||
self.enabled_strategies.append(
|
||||
BaiduAipStrategy(
|
||||
config["baidu_aip"]["app_id"],
|
||||
config["baidu_aip"]["api_key"],
|
||||
config["baidu_aip"]["secret_key"],
|
||||
)
|
||||
)
|
||||
|
||||
def check(self, content: str) -> Tuple[bool, str]:
|
||||
for strategy in self.enabled_strategies:
|
||||
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, Comman
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -39,7 +40,7 @@ class LLMRequestSubStage(Stage):
|
||||
prompt=event.message_str,
|
||||
session_id=event.session_id,
|
||||
image_urls=image_urls,
|
||||
tools=tools
|
||||
func_tool=tools
|
||||
)
|
||||
await Metric.upload(llm_tick=1, model_name=self.curr_provider.get_model(), provider_type=self.curr_provider.meta().type)
|
||||
|
||||
@@ -50,16 +51,29 @@ class LLMRequestSubStage(Stage):
|
||||
# function calling
|
||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
||||
func_tool = tools.get_func(func_tool_name)
|
||||
logger.debug(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
try:
|
||||
ret = await func_tool(event=event, *func_tool_args)
|
||||
# 尝试调用工具函数
|
||||
|
||||
star_cls_obj = star_map.get(func_tool.module_name).star_cls
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
if hasattr(func_tool.func_obj, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ret = await func_tool.func_obj(event, **func_tool_args)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
|
||||
else:
|
||||
ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args)
|
||||
|
||||
if ret:
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.stop_event()
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
yield
|
||||
event.clear_result() # 清除上一个 func tool 的结果
|
||||
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -40,7 +40,7 @@ class StarRequestSubStage(Stage):
|
||||
ret = await handler.handler(star_cls_obj, event, **params)
|
||||
logger.debug("star handler %s called" % handler.handler_full_name)
|
||||
if ret:
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,事件监听器的返回值必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.stop_event()
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
|
||||
@@ -58,7 +58,6 @@ class WakingCheckStage(Stage):
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
for handler in star_handlers_registry:
|
||||
# filter 需要满足 AND 的逻辑关系
|
||||
print(handler.handler_full_name)
|
||||
passed = True
|
||||
child_command_handler_md = None
|
||||
|
||||
|
||||
@@ -3,8 +3,7 @@ from .provider import Provider
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from collections import defaultdict
|
||||
from astrbot.core.provider.tool import FuncCall
|
||||
from .register import provider_cls_map
|
||||
from .register import provider_cls_map, llm_tools
|
||||
from astrbot.core import logger
|
||||
|
||||
class ProviderManager():
|
||||
@@ -13,7 +12,7 @@ class ProviderManager():
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
self.llm_tools: FuncCall = FuncCall()
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from typing import List, Dict, Type
|
||||
import docstring_parser
|
||||
from typing import List, Dict, Type, Awaitable
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from astrbot.core import logger
|
||||
from .tool import FuncCall, SUPPORTED_TYPES
|
||||
|
||||
provider_registry: List[ProviderMetaData] = []
|
||||
'''维护了通过装饰器注册的 Provider'''
|
||||
provider_cls_map: Dict[str, Type] = {}
|
||||
'''维护了 Provider 类型名称和 Provider 类的映射'''
|
||||
|
||||
llm_tools = FuncCall()
|
||||
|
||||
def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def decorator(cls):
|
||||
@@ -23,3 +27,42 @@ def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
|
||||
```
|
||||
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
|
||||
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
|
||||
\'\'\'获取天气信息。
|
||||
|
||||
Args:
|
||||
location(string): 地点
|
||||
\'\'\'
|
||||
# 处理逻辑
|
||||
```
|
||||
|
||||
'''
|
||||
name_ = name
|
||||
|
||||
def decorator(func_obj: Awaitable):
|
||||
llm_tool_name = name_ if name_ else func_obj.__name__
|
||||
module_name = func_obj.__module__
|
||||
docstring = docstring_parser.parse(func_obj.__doc__)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
|
||||
args.append({
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description
|
||||
})
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return func_obj
|
||||
|
||||
return decorator
|
||||
@@ -93,6 +93,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
logger.debug("request with llm tools")
|
||||
payloads["tools"] = tools.get_func_desc_openai_style()
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
@@ -117,7 +118,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
func_name_ls = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
for tool in tools.func_list:
|
||||
if tool['name'] == tool_call.function.name:
|
||||
if tool.name == tool_call.function.name:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
args_ls.append(args)
|
||||
func_name_ls.append(tool_call.function.name)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import textwrap
|
||||
from typing import Awaitable, Dict, List
|
||||
from typing_extensions import TypedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
@@ -11,6 +11,7 @@ class FuncCallJsonFormatError(Exception):
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
|
||||
class FuncNotFoundError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
@@ -18,86 +19,115 @@ class FuncNotFoundError(Exception):
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
|
||||
class FuncTool(TypedDict):
|
||||
'''
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
用于描述一个函数调用工具。
|
||||
'''
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
func_obj: Awaitable
|
||||
module_name: str = None
|
||||
|
||||
|
||||
class FuncCall():
|
||||
SUPPORTED_TYPES = [
|
||||
"string",
|
||||
"number",
|
||||
"object",
|
||||
"array",
|
||||
"boolean",
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
class FuncCall:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: List[FuncTool] = []
|
||||
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
|
||||
def add_func(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
def add_func(
|
||||
self,
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
func_obj: Awaitable,
|
||||
module_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
"""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {}
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params['properties'][param['name']] = {
|
||||
"type": param['type'],
|
||||
"description": param['description']
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FuncTool(name=name, parameters=params, description=desc, func_obj=func_obj)
|
||||
_func = FuncTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
func_obj=func_obj,
|
||||
module_name=module_name,
|
||||
)
|
||||
self.func_list.append(_func)
|
||||
|
||||
|
||||
def remove_func(self, name: str) -> None:
|
||||
'''
|
||||
"""
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
"""
|
||||
for i, f in enumerate(self.func_list):
|
||||
if f["name"] == name:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
|
||||
def get_func(self, name) -> FuncTool:
|
||||
for f in self.func_list:
|
||||
if f["name"] == name:
|
||||
if f.name == name:
|
||||
return f
|
||||
return None
|
||||
|
||||
|
||||
def get_func_desc_openai_style(self) -> list:
|
||||
'''
|
||||
"""
|
||||
获得 OpenAI API 风格的工具描述
|
||||
'''
|
||||
"""
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
_l.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
)
|
||||
return _l
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append(
|
||||
{
|
||||
"name": f["name"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
}
|
||||
})
|
||||
return _l
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
_l.append({
|
||||
"name": f["name"],
|
||||
"parameters": f["parameters"],
|
||||
"description": f["description"],
|
||||
})
|
||||
)
|
||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
|
||||
@@ -123,8 +153,8 @@ class FuncCall():
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
if res.find("```") != -1:
|
||||
res = res[res.find("```json") + 7 : res.rfind("```")]
|
||||
res = json.loads(res)
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -133,8 +163,8 @@ class FuncCall():
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise e
|
||||
|
||||
if 'res' in res and not res['res']:
|
||||
|
||||
if "res" in res and not res["res"]:
|
||||
return "", False
|
||||
|
||||
tool_call_result = []
|
||||
@@ -149,8 +179,7 @@ class FuncCall():
|
||||
tool_callable = func["func_obj"]
|
||||
break
|
||||
if not tool_callable:
|
||||
raise FuncNotFoundError(
|
||||
f"Request function {func_name} not found.")
|
||||
raise FuncNotFoundError(f"Request function {func_name} not found.")
|
||||
ret = await tool_callable(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
|
||||
@@ -52,29 +52,25 @@ class Context:
|
||||
获取 LLM Tools。
|
||||
'''
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
|
||||
# def get_star_commands(self, star_name: str) -> List[]:
|
||||
# '''获得一个'''
|
||||
|
||||
# def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
# '''
|
||||
# 为函数调用(function-calling / tools-use)添加工具。
|
||||
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
# @param name: 函数名
|
||||
# @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
# @param desc: 函数描述
|
||||
# @param func_obj: 异步处理函数。
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 异步处理函数。
|
||||
|
||||
# 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
# '''
|
||||
# self.llm_tools.add_func(name, func_args, desc, func_obj)
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
|
||||
|
||||
# def unregister_llm_tool(self, name: str) -> None:
|
||||
# '''
|
||||
# 删除一个函数调用工具。
|
||||
# '''
|
||||
# self.llm_tools.remove_func(name)
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
'''
|
||||
删除一个函数调用工具。
|
||||
'''
|
||||
self.provider_manager.llm_tools.remove_func(name)
|
||||
|
||||
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
|
||||
'''
|
||||
|
||||
+2
-1
@@ -14,4 +14,5 @@ lxml_html_clean
|
||||
colorlog
|
||||
aiocqhttp
|
||||
pyjwt
|
||||
apscheduler
|
||||
apscheduler
|
||||
docstring_parser
|
||||
Reference in New Issue
Block a user