From 43cd34d94c2209680879dae8999bb07ca13981bf Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 11 Feb 2025 22:03:44 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8feat:=20supports=20to=20check=20the=20?= =?UTF-8?q?content=20safety=20of=20LLM=20output=20#474?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 6 +++++ .../pipeline/content_safety_check/stage.py | 6 +++-- .../core/pipeline/result_decorate/stage.py | 24 ++++++++++++++++--- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 55d8fbd16..94bd7fdfe 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -66,6 +66,7 @@ DEFAULT_CONFIG = { } }, "content_safety": { + "also_use_in_response": False, "internal_keywords": {"enable": True, "extra_keywords": []}, "baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""}, }, @@ -310,6 +311,11 @@ CONFIG_METADATA_2 = { "description": "内容安全", "type": "object", "items": { + "also_use_in_response": { + "description": "对大模型响应安全审核", + "type": "bool", + "hint": "启用后,大模型的响应也会通过内容安全审核。", + }, "baidu_aip": { "description": "百度内容审核配置", "type": "object", diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index 8e9a0ad25..4c2b4e82e 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -17,11 +17,13 @@ class ContentSafetyCheckStage(Stage): config = ctx.astrbot_config['content_safety'] self.strategy_selector = StrategySelector(config) - async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + async def process(self, event: AstrMessageEvent, check_text: str = None) -> Union[None, AsyncGenerator[None, None]]: '''检查内容安全''' - ok, info = self.strategy_selector.check(event.get_message_str()) + text = check_text if check_text else event.get_message_str() + ok, info = self.strategy_selector.check(text) if not ok: event.set_result(MessageEventResult().message("你的消息中包含不适当的内容,已被屏蔽。")) + yield event.stop_event() logger.info(f"内容安全检查不通过,原因:{info}") return diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 31e284765..1d9cca8ad 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -2,7 +2,7 @@ import time import re import traceback from typing import Union, AsyncGenerator -from ..stage import register_stage +from ..stage import Stage, register_stage, registered_stages from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType @@ -12,7 +12,7 @@ from astrbot.core import html_renderer from astrbot.core.star.star_handler import star_handlers_registry, EventType @register_stage -class ResultDecorateStage: +class ResultDecorateStage(Stage): async def initialize(self, ctx: PipelineContext): self.ctx = ctx self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix'] @@ -30,12 +30,30 @@ class ResultDecorateStage: self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable'] self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result'] self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex'] - + + # exception + self.content_safe_check_reply = ctx.astrbot_config['content_safety']['also_use_in_response'] + self.content_safe_check_stage = None + if self.content_safe_check_reply: + for stage in registered_stages: + if stage.__class__.__name__ == "ContentSafetyCheckStage": + self.content_safe_check_stage = stage + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: result = event.get_result() if result is None: return + # 回复时检查内容安全 + if self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result(): + text = "" + for comp in result.chain: + if isinstance(comp, Plain): + text += comp.text + async for _ in self.content_safe_check_stage.process(event, check_text=text): + yield + handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent) for handler in handlers: await handler.handler(event)