Forráskód Böngészése

feat: add hosted moderation (#1158)

takatost 1 éve
szülő
commit
f9082104ed

+ 5 - 0
api/config.py

@@ -61,6 +61,8 @@ DEFAULTS = {
     'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
     'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
     'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
+    'HOSTED_MODERATION_ENABLED': 'False',
+    'HOSTED_MODERATION_PROVIDERS': '',
     'TENANT_DOCUMENT_COUNT': 100,
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
@@ -230,6 +232,9 @@ class Config:
         self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
         self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
 
+        self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
+        self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
+
         self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
         self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
 

+ 17 - 1
api/core/agent/agent_executor.py

@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from langchain.agents import AgentExecutor as LCAgentExecutor
 
+from core.helper import moderation
+from core.model_providers.error import LLMError
 from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
@@ -116,6 +118,18 @@ class AgentExecutor:
         return self.agent.should_use_agent(query)
 
     def run(self, query: str) -> AgentExecuteResult:
+        moderation_result = moderation.check_moderation(
+            self.configuration.model_instance.model_provider,
+            query
+        )
+
+        if not moderation_result:
+            return AgentExecuteResult(
+                output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
+                strategy=self.configuration.strategy,
+                configuration=self.configuration
+            )
+
         agent_executor = LCAgentExecutor.from_agent_and_tools(
             agent=self.agent,
             tools=self.configuration.tools,
@@ -128,7 +142,9 @@ class AgentExecutor:
 
         try:
             output = agent_executor.run(query)
-        except Exception:
+        except LLMError as ex:
+            raise ex
+        except Exception as ex:
             logging.exception("agent_executor run failed")
             output = None
 

+ 22 - 7
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
 
 from langchain.agents import openai_functions_agent, openai_functions_multi_agent
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
+from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.conversation_message_task import ConversationMessageTask
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     raise_error: bool = True
 
-    def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
+    def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
-        self.model_instant = model_instant
+        self.model_instance = model_instance
         self.conversation_message_task = conversation_message_task
         self._agent_loops = []
         self._current_loop = None
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         """Whether to ignore chain callbacks."""
         return True
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        if not self._current_loop:
+            # Agent start with a LLM query
+            self._current_loop = AgentLoop(
+                position=len(self._agent_loops) + 1,
+                prompt="\n".join([message.content for message in messages[0]]),
+                status='llm_started',
+                started_at=time.perf_counter()
+            )
+
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             if response.llm_output:
                 self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
             else:
-                self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
+                self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self._current_loop.prompt)]
                 )
             completion_generation = response.generations[0][0]
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             if response.llm_output:
                 self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
             else:
-                self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
+                self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self._current_loop.completion)]
                 )
 
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
 
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instant, self._current_loop
+                self._message_agent_thought, self.model_instance, self._current_loop
             )
 
             self._agent_loops.append(self._current_loop)
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             )
 
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instant, self._current_loop
+                self._message_agent_thought, self.model_instance, self._current_loop
             )
 
             self._agent_loops.append(self._current_loop)

+ 0 - 1
api/core/callback_handler/entity/llm_message.py

@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
     prompt_tokens: int = 0
     completion: str = ''
     completion_tokens: int = 0
-    latency: float = 0.0

+ 0 - 9
api/core/callback_handler/llm_callback_handler.py

@@ -1,5 +1,4 @@
 import logging
-import time
 from typing import Any, Dict, List, Union
 
 from langchain.callbacks.base import BaseCallbackHandler
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
             messages: List[List[BaseMessage]],
             **kwargs: Any
     ) -> Any:
-        self.start_at = time.perf_counter()
         real_prompts = []
         for message in messages[0]:
             if message.type == 'human':
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
-        self.start_at = time.perf_counter()
-
         self.llm_message.prompt = [{
             "role": 'user',
             "text": prompts[0]
@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
         self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        end_at = time.perf_counter()
-        self.llm_message.latency = end_at - self.start_at
-
         if not self.conversation_message_task.streaming:
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.llm_message.completion = response.generations[0][0].text
@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
         """Do nothing."""
         if isinstance(error, ConversationTaskStoppedException):
             if self.conversation_message_task.streaming:
-                end_at = time.perf_counter()
-                self.llm_message.latency = end_at - self.start_at
                 self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self.llm_message.completion)]
                 )

+ 52 - 8
api/core/chain/sensitive_word_avoidance_chain.py

@@ -1,15 +1,38 @@
+import enum
+import logging
 from typing import List, Dict, Optional, Any
 
+import openai
+from flask import current_app
 from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
+from openai import InvalidRequestError
+from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
+    AuthenticationError, OpenAIError
+from pydantic import BaseModel
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.moderation import openai_moderation
+
+
+class SensitiveWordAvoidanceRule(BaseModel):
+    class Type(enum.Enum):
+        MODERATION = "moderation"
+        KEYWORDS = "keywords"
+
+    type: Type
+    canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
+    extra_params: dict = {}
 
 
 class SensitiveWordAvoidanceChain(Chain):
     input_key: str = "input"  #: :meta private:
     output_key: str = "output"  #: :meta private:
 
-    sensitive_words: List[str] = []
-    canned_response: str = None
+    model_instance: BaseLLM
+    sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
 
     @property
     def _chain_type(self) -> str:
@@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
         """
         return [self.output_key]
 
-    def _check_sensitive_word(self, text: str) -> str:
-        for word in self.sensitive_words:
+    def _check_sensitive_word(self, text: str) -> bool:
+        for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
             if word in text:
-                return self.canned_response
-        return text
+                return False
+        return True
+
+    def _check_moderation(self, text: str) -> bool:
+        moderation_model_instance = ModelFactory.get_moderation_model(
+            tenant_id=self.model_instance.model_provider.provider.tenant_id,
+            model_provider_name='openai',
+            model_name=openai_moderation.DEFAULT_MODEL
+        )
+
+        try:
+            return moderation_model_instance.run(text=text)
+        except Exception as ex:
+            logging.exception(ex)
+            raise LLMBadRequestError('Rate limit exceeded, please try again later.')
 
     def _call(
             self,
@@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
             run_manager: Optional[CallbackManagerForChainRun] = None,
     ) -> Dict[str, Any]:
         text = inputs[self.input_key]
-        output = self._check_sensitive_word(text)
-        return {self.output_key: output}
+
+        if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
+            result = self._check_sensitive_word(text)
+        else:
+            result = self._check_moderation(text)
+
+        if not result:
+            raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
+
+        return {self.output_key: text}

+ 3 - 6
api/core/completion.py

@@ -1,9 +1,7 @@
 import json
 import logging
-import re
-from typing import Optional, List, Union, Tuple
+from typing import Optional, List, Union
 
-from langchain.schema import BaseMessage
 from requests.exceptions import ChunkedEncodingError
 
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
@@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
 from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
+from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from models.dataset import DocumentSegment, Dataset, Document
 from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
@@ -81,7 +78,7 @@ class Completion:
 
         # parse sensitive_word_avoidance_chain
         chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
-        sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
+        sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
         if sensitive_word_avoidance_chain:
             query = sensitive_word_avoidance_chain.run(query)
 

+ 13 - 10
api/core/conversation_message_task.py

@@ -1,5 +1,5 @@
-import decimal
 import json
+import time
 from typing import Optional, Union, List
 
 from core.callback_handler.entity.agent_loop import AgentLoop
@@ -23,6 +23,8 @@ class ConversationMessageTask:
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
                  inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
                  conversation: Optional[Conversation] = None, is_override: bool = False):
+        self.start_at = time.perf_counter()
+
         self.task_id = task_id
 
         self.app = app
@@ -61,6 +63,7 @@ class ConversationMessageTask:
         )
 
     def init(self):
+
         override_model_configs = None
         if self.is_override:
             override_model_configs = self.app_model_config.to_dict()
@@ -165,7 +168,7 @@ class ConversationMessageTask:
         self.message.answer_tokens = answer_tokens
         self.message.answer_unit_price = answer_unit_price
         self.message.answer_price_unit = answer_price_unit
-        self.message.provider_response_latency = llm_message.latency
+        self.message.provider_response_latency = time.perf_counter() - self.start_at
         self.message.total_price = total_price
 
         db.session.commit()
@@ -220,18 +223,18 @@ class ConversationMessageTask:
 
         return message_agent_thought
 
-    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
+    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
                      agent_loop: AgentLoop):
-        agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
-        agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
-        agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
-        agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
+        agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
+        agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
+        agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
+        agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
 
-        loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
-        loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
+        loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
+        loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
         loop_total_price = loop_message_total_price + loop_answer_total_price
 
         message_agent_thought.observation = agent_loop.tool_output
@@ -245,7 +248,7 @@ class ConversationMessageTask:
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.total_price = loop_total_price
-        message_agent_thought.currency = agent_model_instant.get_currency()
+        message_agent_thought.currency = agent_model_instance.get_currency()
         db.session.flush()
 
     def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

+ 32 - 0
api/core/helper/moderation.py

@@ -0,0 +1,32 @@
+import logging
+
+import openai
+from flask import current_app
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from models.provider import ProviderType
+
+
+def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
+    if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
+        moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
+
+        if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
+                and model_provider.provider_name in moderation_providers:
+            # 2000 text per chunk
+            length = 2000
+            chunks = [text[i:i + length] for i in range(0, len(text), length)]
+
+            try:
+                moderation_result = openai.Moderation.create(input=chunks,
+                                                             api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
+            except Exception as ex:
+                logging.exception(ex)
+                raise LLMBadRequestError('Rate limit exceeded, please try again later.')
+
+            for result in moderation_result.results:
+                if result['flagged'] is True:
+                    return False
+
+    return True

+ 2 - 1
api/core/model_providers/model_factory.py

@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.embedding.base import BaseEmbedding
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
 from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.moderation.base import BaseModeration
 from core.model_providers.models.speech2text.base import BaseSpeech2Text
 from extensions.ext_database import db
 from models.provider import TenantDefaultModel
@@ -180,7 +181,7 @@ class ModelFactory:
     def get_moderation_model(cls,
                              tenant_id: str,
                              model_provider_name: str,
-                             model_name: str) -> Optional[BaseProviderModel]:
+                             model_name: str) -> Optional[BaseModeration]:
         """
         get moderation model.
 

+ 10 - 0
api/core/model_providers/models/llm/base.py

@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
+from core.helper import moderation
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
         :param callbacks:
         :return:
         """
+        moderation_result = moderation.check_moderation(
+            self.model_provider,
+            "\n".join([message.content for message in messages])
+        )
+
+        if not moderation_result:
+            kwargs['fake_response'] = "I apologize for any confusion, " \
+                                      "but I'm an AI assistant to be helpful, harmless, and honest."
+
         if self.deduct_quota:
             self.model_provider.check_quota_over_limit()
 

+ 29 - 0
api/core/model_providers/models/moderation/base.py

@@ -0,0 +1,29 @@
+from abc import abstractmethod
+from typing import Any
+
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class BaseModeration(BaseProviderModel):
+    name: str
+    type: ModelType = ModelType.MODERATION
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
+        super().__init__(model_provider, client)
+        self.name = name
+
+    def run(self, text: str) -> bool:
+        try:
+            return self._run(text)
+        except Exception as ex:
+            raise self.handle_exceptions(ex)
+
+    @abstractmethod
+    def _run(self, text: str) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        raise NotImplementedError

+ 18 - 12
api/core/model_providers/models/moderation/openai_moderation.py

@@ -4,29 +4,35 @@ import openai
 
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
     LLMRateLimitError, LLMAuthorizationError
-from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.models.moderation.base import BaseModeration
 from core.model_providers.providers.base import BaseModelProvider
 
-DEFAULT_AUDIO_MODEL = 'whisper-1'
+DEFAULT_MODEL = 'whisper-1'
 
 
-class OpenAIModeration(BaseProviderModel):
-    type: ModelType = ModelType.MODERATION
+class OpenAIModeration(BaseModeration):
 
     def __init__(self, model_provider: BaseModelProvider, name: str):
-        super().__init__(model_provider, openai.Moderation)
+        super().__init__(model_provider, openai.Moderation, name)
 
-    def run(self, text):
+    def _run(self, text: str) -> bool:
         credentials = self.model_provider.get_model_credentials(
-            model_name=DEFAULT_AUDIO_MODEL,
+            model_name=self.name,
             model_type=self.type
         )
 
-        try:
-            return self._client.create(input=text, api_key=credentials['openai_api_key'])
-        except Exception as ex:
-            raise self.handle_exceptions(ex)
+        # 2000 text per chunk
+        length = 2000
+        chunks = [text[i:i + length] for i in range(0, len(text), length)]
+
+        moderation_result = self._client.create(input=chunks,
+                                                api_key=credentials['openai_api_key'])
+
+        for result in moderation_result.results:
+            if result['flagged'] is True:
+                return False
+
+        return True
 
     def handle_exceptions(self, ex: Exception) -> Exception:
         if isinstance(ex, openai.error.InvalidRequestError):

+ 35 - 11
api/core/orchestrator_rule_parser.py

@@ -1,6 +1,7 @@
 import math
 from typing import Optional
 
+from flask import current_app
 from langchain import WikipediaAPIWrapper
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
+from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
 from core.conversation_message_task import ConversationMessageTask
 from core.model_providers.error import ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetProcessRule
 from models.model import AppModelConfig
+from models.provider import ProviderType
 
 
 class OrchestratorRuleParser:
@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
 
             # add agent callback to record agent thoughts
             agent_callback = AgentLoopGatherCallbackHandler(
-                model_instant=agent_model_instance,
+                model_instance=agent_model_instance,
                 conversation_message_task=conversation_message_task
             )
 
@@ -123,23 +125,45 @@ class OrchestratorRuleParser:
 
         return chain
 
-    def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
+    def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
             -> Optional[SensitiveWordAvoidanceChain]:
         """
         Convert app sensitive word avoidance config to chain
 
+        :param model_instance: model instance
+        :param callbacks: callbacks for the chain
         :param kwargs:
         :return:
         """
-        if not self.app_model_config.sensitive_word_avoidance_dict:
-            return None
-
-        sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
-        sensitive_words = sensitive_word_avoidance_config.get("words", "")
-        if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
+        sensitive_word_avoidance_rule = None
+
+        if self.app_model_config.sensitive_word_avoidance_dict:
+            sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
+            if sensitive_word_avoidance_config.get("enabled", False):
+                if sensitive_word_avoidance_config.get('type') == 'moderation':
+                    sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
+                        type=SensitiveWordAvoidanceRule.Type.MODERATION,
+                        canned_response=sensitive_word_avoidance_config.get("canned_response")
+                        if sensitive_word_avoidance_config.get("canned_response")
+                        else 'Your content violates our usage policy. Please revise and try again.',
+                    )
+                else:
+                    sensitive_words = sensitive_word_avoidance_config.get("words", "")
+                    if sensitive_words:
+                        sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
+                            type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
+                            canned_response=sensitive_word_avoidance_config.get("canned_response")
+                            if sensitive_word_avoidance_config.get("canned_response")
+                            else 'Your content violates our usage policy. Please revise and try again.',
+                            extra_params={
+                                'sensitive_words': sensitive_words.split(','),
+                            }
+                        )
+
+        if sensitive_word_avoidance_rule:
             return SensitiveWordAvoidanceChain(
-                sensitive_words=sensitive_words.split(","),
-                canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
+                model_instance=model_instance,
+                sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
                 output_key="sensitive_word_avoidance_output",
                 callbacks=callbacks,
                 **kwargs

+ 3 - 4
api/tests/integration_tests/models/moderation/test_openai_moderation.py

@@ -2,7 +2,7 @@ import json
 import os
 from unittest.mock import patch
 
-from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
+from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL
 from core.model_providers.providers.openai_provider import OpenAIProvider
 from models.provider import Provider, ProviderType
 
@@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
     openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
     return OpenAIModeration(
         model_provider=openai_provider,
-        name=DEFAULT_AUDIO_MODEL
+        name=DEFAULT_MODEL
     )
 
 
@@ -36,5 +36,4 @@ def test_run(mock_decrypt):
     model = get_mock_openai_moderation_model()
     rst = model.run('hello')
 
-    assert isinstance(rst, dict)
-    assert 'id' in rst
+    assert rst is True