feat: 增加模型响应后的插件钩子

remove: 移除了默认的r1过滤
This commit is contained in:
Soulter
2025-02-02 16:42:21 +08:00
parent 5581eae957
commit ef44d4471a
7 changed files with 51 additions and 16 deletions
+3 -1
View File
@@ -6,6 +6,7 @@ from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type,
register_on_llm_request as on_llm_request,
register_on_llm_response as on_llm_response,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent
@@ -31,5 +32,6 @@ __all__ = [
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent'
'after_message_sent',
'on_llm_response'
]
@@ -68,6 +68,15 @@ class LLMRequestSubStage(Stage):
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
+4 -1
View File
@@ -2,6 +2,7 @@ import enum
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
class ProviderType(enum.Enum):
@@ -51,4 +52,6 @@ class LLMResponse:
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
'''工具调用名称'''
raw_completion: ChatCompletion = None
+9 -14
View File
@@ -101,17 +101,16 @@ class ProviderOpenAIOfficial(Provider):
stream=False
)
except BaseException as e:
if 'does not support Function Calling' in e \
or 'does not support tools' in e: # ollama
if 'does not support Function Calling' in str(e) \
or 'does not support tools' in str(e): # ollama
del payloads['tools']
logger.debug(f"模型 {self.model_name} 不支持 tools,已自动移除")
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
if not completion:
raise Exception("API 返回的 completion 为空。")
else:
raise e
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion}")
@@ -123,14 +122,8 @@ class ProviderOpenAIOfficial(Provider):
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
# 适配 deepseek-r1 模型
if r'<think>' in completion_text or r'</think>' in completion_text:
completion_text = re.sub(r'<think>.*?</think>', '', completion_text, flags=re.DOTALL).strip()
# 可能有单标签情况
completion_text = completion_text.replace(r'<think>', '').replace(r'</think>', '').strip()
return LLMResponse("assistant", completion_text)
return LLMResponse("assistant", completion_text, raw_completion=completion)
elif choice.message.tool_calls:
# tools call (function calling)
args_ls = []
@@ -141,8 +134,9 @@ class ProviderOpenAIOfficial(Provider):
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls)
return LLMResponse(role="tool", tools_call_args=args_ls, tools_call_name=func_name_ls, raw_completion=completion)
else:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception("Internal Error")
async def text_chat(
@@ -195,6 +189,7 @@ class ProviderOpenAIOfficial(Provider):
else:
raise e
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
raise e
+2
View File
@@ -7,6 +7,7 @@ from .star_handler import (
register_regex,
register_permission_type,
register_on_llm_request,
register_on_llm_response,
register_llm_tool,
register_on_decorating_result,
register_after_message_sent
@@ -21,6 +22,7 @@ __all__ = [
'register_regex',
'register_permission_type',
'register_on_llm_request',
'register_on_llm_response',
'register_llm_tool',
'register_on_decorating_result',
'register_after_message_sent'
@@ -139,6 +139,8 @@ def register_on_llm_request():
Examples:
```py
from astrbot.api.provider import ProviderRequest
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
@@ -152,6 +154,27 @@ def register_on_llm_request():
return decorator
def register_on_llm_response():
'''当有 LLM 请求后的事件
Examples:
```py
from astrbot.api.provider import LLMResponse
@on_llm_response()
async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None:
...
```
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent)
return awaitable
return decorator
def register_llm_tool(name: str = None):
'''为函数调用(function-calling / tools-use)添加工具。
+1
View File
@@ -47,6 +47,7 @@ class EventType(enum.Enum):
'''
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnLLMResponseEvent = enum.auto() # LLM 响应后
OnDecoratingResultEvent = enum.auto() # 发送消息前
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后