feat: supports to check the content safety of LLM output #474

This commit is contained in:
Soulter
2025-02-11 22:03:44 +08:00
parent 9fa00aff9a
commit 43cd34d94c
3 changed files with 31 additions and 5 deletions
+6
View File
@@ -66,6 +66,7 @@ DEFAULT_CONFIG = {
}
},
"content_safety": {
"also_use_in_response": False,
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
},
@@ -310,6 +311,11 @@ CONFIG_METADATA_2 = {
"description": "内容安全",
"type": "object",
"items": {
"also_use_in_response": {
"description": "对大模型响应安全审核",
"type": "bool",
"hint": "启用后,大模型的响应也会通过内容安全审核。",
},
"baidu_aip": {
"description": "百度内容审核配置",
"type": "object",
@@ -17,11 +17,13 @@ class ContentSafetyCheckStage(Stage):
config = ctx.astrbot_config['content_safety']
self.strategy_selector = StrategySelector(config)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]:
'''检查内容安全'''
ok, info = self.strategy_selector.check(event.get_message_str())
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
if not ok:
event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。"))
yield
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
+21 -3
View File
@@ -2,7 +2,7 @@ import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..stage import Stage, register_stage, registered_stages
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
@@ -12,7 +12,7 @@ from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@register_stage
class ResultDecorateStage:
class ResultDecorateStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
@@ -30,12 +30,30 @@ class ResultDecorateStage:
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
# exception
self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response']
self.content_safe_check_stage = None
if self.content_safe_check_reply:
for stage in registered_stages:
if stage.__class__.__name__ == "ContentSafetyCheckStage":
self.content_safe_check_stage = stage
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
return
# 回复时检查内容安全
if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result():
text = ""
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(event, check_text=text):
yield
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
await handler.handler(event)