Переглянути джерело

feat: add hosted moderation (#1158)

takatost 1 рік тому
батько
коміт
f9082104ed

+ 5 - 0
api/config.py

@@ -61,6 +61,8 @@ DEFAULTS = {
     'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
     'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
     'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
     'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
     'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
     'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
+    'HOSTED_MODERATION_ENABLED': 'False',
+    'HOSTED_MODERATION_PROVIDERS': '',
     'TENANT_DOCUMENT_COUNT': 100,
     'TENANT_DOCUMENT_COUNT': 100,
     'CLEAN_DAY_SETTING': 30,
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
@@ -230,6 +232,9 @@ class Config:
         self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
         self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
         self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
         self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
 
 
+        self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
+        self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
+
         self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
         self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
         self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
         self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
 
 

+ 17 - 1
api/core/agent/agent_executor.py

@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
 from langchain.agents import AgentExecutor as LCAgentExecutor
 from langchain.agents import AgentExecutor as LCAgentExecutor
 
 
+from core.helper import moderation
+from core.model_providers.error import LLMError
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
@@ -116,6 +118,18 @@ class AgentExecutor:
         return self.agent.should_use_agent(query)
         return self.agent.should_use_agent(query)
 
 
     def run(self, query: str) -> AgentExecuteResult:
     def run(self, query: str) -> AgentExecuteResult:
+        moderation_result = moderation.check_moderation(
+            self.configuration.model_instance.model_provider,
+            query
+        )
+
+        if not moderation_result:
+            return AgentExecuteResult(
+                output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
+                strategy=self.configuration.strategy,
+                configuration=self.configuration
+            )
+
         agent_executor = LCAgentExecutor.from_agent_and_tools(
         agent_executor = LCAgentExecutor.from_agent_and_tools(
             agent=self.agent,
             agent=self.agent,
             tools=self.configuration.tools,
             tools=self.configuration.tools,
@@ -128,7 +142,9 @@ class AgentExecutor:
 
 
         try:
         try:
             output = agent_executor.run(query)
             output = agent_executor.run(query)
-        except Exception:
+        except LLMError as ex:
+            raise ex
+        except Exception as ex:
             logging.exception("agent_executor run failed")
             logging.exception("agent_executor run failed")
             output = None
             output = None
 
 

+ 22 - 7
api/core/callback_handler/agent_loop_gather_callback_handler.py

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
 
 
 from langchain.agents import openai_functions_agent, openai_functions_multi_agent
 from langchain.agents import openai_functions_agent, openai_functions_multi_agent
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
-from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
+from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
 
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.conversation_message_task import ConversationMessageTask
 from core.conversation_message_task import ConversationMessageTask
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
     """Callback Handler that prints to std out."""
     """Callback Handler that prints to std out."""
     raise_error: bool = True
     raise_error: bool = True
 
 
-    def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
+    def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
         """Initialize callback handler."""
         """Initialize callback handler."""
-        self.model_instant = model_instant
+        self.model_instance = model_instance
         self.conversation_message_task = conversation_message_task
         self.conversation_message_task = conversation_message_task
         self._agent_loops = []
         self._agent_loops = []
         self._current_loop = None
         self._current_loop = None
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
         """Whether to ignore chain callbacks."""
         """Whether to ignore chain callbacks."""
         return True
         return True
 
 
+    def on_chat_model_start(
+            self,
+            serialized: Dict[str, Any],
+            messages: List[List[BaseMessage]],
+            **kwargs: Any
+    ) -> Any:
+        if not self._current_loop:
+            # Agent start with a LLM query
+            self._current_loop = AgentLoop(
+                position=len(self._agent_loops) + 1,
+                prompt="\n".join([message.content for message in messages[0]]),
+                status='llm_started',
+                started_at=time.perf_counter()
+            )
+
     def on_llm_start(
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
     ) -> None:
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             if response.llm_output:
             if response.llm_output:
                 self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
                 self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
             else:
             else:
-                self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
+                self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self._current_loop.prompt)]
                     [PromptMessage(content=self._current_loop.prompt)]
                 )
                 )
             completion_generation = response.generations[0][0]
             completion_generation = response.generations[0][0]
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             if response.llm_output:
             if response.llm_output:
                 self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
                 self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
             else:
             else:
-                self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
+                self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self._current_loop.completion)]
                     [PromptMessage(content=self._current_loop.completion)]
                 )
                 )
 
 
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
             self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
 
 
             self.conversation_message_task.on_agent_end(
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instant, self._current_loop
+                self._message_agent_thought, self.model_instance, self._current_loop
             )
             )
 
 
             self._agent_loops.append(self._current_loop)
             self._agent_loops.append(self._current_loop)
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
             )
             )
 
 
             self.conversation_message_task.on_agent_end(
             self.conversation_message_task.on_agent_end(
-                self._message_agent_thought, self.model_instant, self._current_loop
+                self._message_agent_thought, self.model_instance, self._current_loop
             )
             )
 
 
             self._agent_loops.append(self._current_loop)
             self._agent_loops.append(self._current_loop)

+ 0 - 1
api/core/callback_handler/entity/llm_message.py

@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
     prompt_tokens: int = 0
     prompt_tokens: int = 0
     completion: str = ''
     completion: str = ''
     completion_tokens: int = 0
     completion_tokens: int = 0
-    latency: float = 0.0

+ 0 - 9
api/core/callback_handler/llm_callback_handler.py

@@ -1,5 +1,4 @@
 import logging
 import logging
-import time
 from typing import Any, Dict, List, Union
 from typing import Any, Dict, List, Union
 
 
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.callbacks.base import BaseCallbackHandler
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
             messages: List[List[BaseMessage]],
             messages: List[List[BaseMessage]],
             **kwargs: Any
             **kwargs: Any
     ) -> Any:
     ) -> Any:
-        self.start_at = time.perf_counter()
         real_prompts = []
         real_prompts = []
         for message in messages[0]:
         for message in messages[0]:
             if message.type == 'human':
             if message.type == 'human':
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
     def on_llm_start(
     def on_llm_start(
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
         self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
     ) -> None:
     ) -> None:
-        self.start_at = time.perf_counter()
-
         self.llm_message.prompt = [{
         self.llm_message.prompt = [{
             "role": 'user',
             "role": 'user',
             "text": prompts[0]
             "text": prompts[0]
@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
         self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
         self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
 
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        end_at = time.perf_counter()
-        self.llm_message.latency = end_at - self.start_at
-
         if not self.conversation_message_task.streaming:
         if not self.conversation_message_task.streaming:
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.conversation_message_task.append_message_text(response.generations[0][0].text)
             self.llm_message.completion = response.generations[0][0].text
             self.llm_message.completion = response.generations[0][0].text
@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
         """Do nothing."""
         """Do nothing."""
         if isinstance(error, ConversationTaskStoppedException):
         if isinstance(error, ConversationTaskStoppedException):
             if self.conversation_message_task.streaming:
             if self.conversation_message_task.streaming:
-                end_at = time.perf_counter()
-                self.llm_message.latency = end_at - self.start_at
                 self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
                 self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self.llm_message.completion)]
                     [PromptMessage(content=self.llm_message.completion)]
                 )
                 )

+ 52 - 8
api/core/chain/sensitive_word_avoidance_chain.py

@@ -1,15 +1,38 @@
+import enum
+import logging
 from typing import List, Dict, Optional, Any
 from typing import List, Dict, Optional, Any
 
 
+import openai
+from flask import current_app
 from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.callbacks.manager import CallbackManagerForChainRun
 from langchain.chains.base import Chain
 from langchain.chains.base import Chain
+from openai import InvalidRequestError
+from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
+    AuthenticationError, OpenAIError
+from pydantic import BaseModel
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.moderation import openai_moderation
+
+
+class SensitiveWordAvoidanceRule(BaseModel):
+    class Type(enum.Enum):
+        MODERATION = "moderation"
+        KEYWORDS = "keywords"
+
+    type: Type
+    canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
+    extra_params: dict = {}
 
 
 
 
 class SensitiveWordAvoidanceChain(Chain):
 class SensitiveWordAvoidanceChain(Chain):
     input_key: str = "input"  #: :meta private:
     input_key: str = "input"  #: :meta private:
     output_key: str = "output"  #: :meta private:
     output_key: str = "output"  #: :meta private:
 
 
-    sensitive_words: List[str] = []
-    canned_response: str = None
+    model_instance: BaseLLM
+    sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
 
 
     @property
     @property
     def _chain_type(self) -> str:
     def _chain_type(self) -> str:
@@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
         """
         """
         return [self.output_key]
         return [self.output_key]
 
 
-    def _check_sensitive_word(self, text: str) -> str:
-        for word in self.sensitive_words:
+    def _check_sensitive_word(self, text: str) -> bool:
+        for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
             if word in text:
             if word in text:
-                return self.canned_response
-        return text
+                return False
+        return True
+
+    def _check_moderation(self, text: str) -> bool:
+        moderation_model_instance = ModelFactory.get_moderation_model(
+            tenant_id=self.model_instance.model_provider.provider.tenant_id,
+            model_provider_name='openai',
+            model_name=openai_moderation.DEFAULT_MODEL
+        )
+
+        try:
+            return moderation_model_instance.run(text=text)
+        except Exception as ex:
+            logging.exception(ex)
+            raise LLMBadRequestError('Rate limit exceeded, please try again later.')
 
 
     def _call(
     def _call(
             self,
             self,
@@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
             run_manager: Optional[CallbackManagerForChainRun] = None,
             run_manager: Optional[CallbackManagerForChainRun] = None,
     ) -> Dict[str, Any]:
     ) -> Dict[str, Any]:
         text = inputs[self.input_key]
         text = inputs[self.input_key]
-        output = self._check_sensitive_word(text)
-        return {self.output_key: output}
+
+        if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
+            result = self._check_sensitive_word(text)
+        else:
+            result = self._check_moderation(text)
+
+        if not result:
+            raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
+
+        return {self.output_key: text}

+ 3 - 6
api/core/completion.py

@@ -1,9 +1,7 @@
 import json
 import json
 import logging
 import logging
-import re
-from typing import Optional, List, Union, Tuple
+from typing import Optional, List, Union
 
 
-from langchain.schema import BaseMessage
 from requests.exceptions import ChunkedEncodingError
 from requests.exceptions import ChunkedEncodingError
 
 
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
@@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
     ReadOnlyConversationTokenDBBufferSharedMemory
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
+from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from models.dataset import DocumentSegment, Dataset, Document
 from models.dataset import DocumentSegment, Dataset, Document
 from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
 from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
@@ -81,7 +78,7 @@ class Completion:
 
 
         # parse sensitive_word_avoidance_chain
         # parse sensitive_word_avoidance_chain
         chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
         chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
-        sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
+        sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
         if sensitive_word_avoidance_chain:
         if sensitive_word_avoidance_chain:
             query = sensitive_word_avoidance_chain.run(query)
             query = sensitive_word_avoidance_chain.run(query)
 
 

+ 13 - 10
api/core/conversation_message_task.py

@@ -1,5 +1,5 @@
-import decimal
 import json
 import json
+import time
 from typing import Optional, Union, List
 from typing import Optional, Union, List
 
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.agent_loop import AgentLoop
@@ -23,6 +23,8 @@ class ConversationMessageTask:
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
                  inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
                  inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
                  conversation: Optional[Conversation] = None, is_override: bool = False):
                  conversation: Optional[Conversation] = None, is_override: bool = False):
+        self.start_at = time.perf_counter()
+
         self.task_id = task_id
         self.task_id = task_id
 
 
         self.app = app
         self.app = app
@@ -61,6 +63,7 @@ class ConversationMessageTask:
         )
         )
 
 
     def init(self):
     def init(self):
+
         override_model_configs = None
         override_model_configs = None
         if self.is_override:
         if self.is_override:
             override_model_configs = self.app_model_config.to_dict()
             override_model_configs = self.app_model_config.to_dict()
@@ -165,7 +168,7 @@ class ConversationMessageTask:
         self.message.answer_tokens = answer_tokens
         self.message.answer_tokens = answer_tokens
         self.message.answer_unit_price = answer_unit_price
         self.message.answer_unit_price = answer_unit_price
         self.message.answer_price_unit = answer_price_unit
         self.message.answer_price_unit = answer_price_unit
-        self.message.provider_response_latency = llm_message.latency
+        self.message.provider_response_latency = time.perf_counter() - self.start_at
         self.message.total_price = total_price
         self.message.total_price = total_price
 
 
         db.session.commit()
         db.session.commit()
@@ -220,18 +223,18 @@ class ConversationMessageTask:
 
 
         return message_agent_thought
         return message_agent_thought
 
 
-    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
+    def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
                      agent_loop: AgentLoop):
                      agent_loop: AgentLoop):
-        agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
-        agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
-        agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
-        agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
+        agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
+        agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
+        agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
+        agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
 
 
         loop_message_tokens = agent_loop.prompt_tokens
         loop_message_tokens = agent_loop.prompt_tokens
         loop_answer_tokens = agent_loop.completion_tokens
         loop_answer_tokens = agent_loop.completion_tokens
 
 
-        loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
-        loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
+        loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
+        loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
         loop_total_price = loop_message_total_price + loop_answer_total_price
         loop_total_price = loop_message_total_price + loop_answer_total_price
 
 
         message_agent_thought.observation = agent_loop.tool_output
         message_agent_thought.observation = agent_loop.tool_output
@@ -245,7 +248,7 @@ class ConversationMessageTask:
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.latency = agent_loop.latency
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
         message_agent_thought.total_price = loop_total_price
         message_agent_thought.total_price = loop_total_price
-        message_agent_thought.currency = agent_model_instant.get_currency()
+        message_agent_thought.currency = agent_model_instance.get_currency()
         db.session.flush()
         db.session.flush()
 
 
     def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
     def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):

+ 32 - 0
api/core/helper/moderation.py

@@ -0,0 +1,32 @@
+import logging
+
+import openai
+from flask import current_app
+
+from core.model_providers.error import LLMBadRequestError
+from core.model_providers.providers.base import BaseModelProvider
+from models.provider import ProviderType
+
+
+def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
+    if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
+        moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
+
+        if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
+                and model_provider.provider_name in moderation_providers:
+            # 2000 text per chunk
+            length = 2000
+            chunks = [text[i:i + length] for i in range(0, len(text), length)]
+
+            try:
+                moderation_result = openai.Moderation.create(input=chunks,
+                                                             api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
+            except Exception as ex:
+                logging.exception(ex)
+                raise LLMBadRequestError('Rate limit exceeded, please try again later.')
+
+            for result in moderation_result.results:
+                if result['flagged'] is True:
+                    return False
+
+    return True

+ 2 - 1
api/core/model_providers/model_factory.py

@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.embedding.base import BaseEmbedding
 from core.model_providers.models.embedding.base import BaseEmbedding
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.moderation.base import BaseModeration
 from core.model_providers.models.speech2text.base import BaseSpeech2Text
 from core.model_providers.models.speech2text.base import BaseSpeech2Text
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.provider import TenantDefaultModel
 from models.provider import TenantDefaultModel
@@ -180,7 +181,7 @@ class ModelFactory:
     def get_moderation_model(cls,
     def get_moderation_model(cls,
                              tenant_id: str,
                              tenant_id: str,
                              model_provider_name: str,
                              model_provider_name: str,
-                             model_name: str) -> Optional[BaseProviderModel]:
+                             model_name: str) -> Optional[BaseModeration]:
         """
         """
         get moderation model.
         get moderation model.
 
 

+ 10 - 0
api/core/model_providers/models/llm/base.py

@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 
 
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
+from core.helper import moderation
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.base import BaseProviderModel
 from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
 from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
         :param callbacks:
         :param callbacks:
         :return:
         :return:
         """
         """
+        moderation_result = moderation.check_moderation(
+            self.model_provider,
+            "\n".join([message.content for message in messages])
+        )
+
+        if not moderation_result:
+            kwargs['fake_response'] = "I apologize for any confusion, " \
+                                      "but I'm an AI assistant to be helpful, harmless, and honest."
+
         if self.deduct_quota:
         if self.deduct_quota:
             self.model_provider.check_quota_over_limit()
             self.model_provider.check_quota_over_limit()
 
 

+ 29 - 0
api/core/model_providers/models/moderation/base.py

@@ -0,0 +1,29 @@
+from abc import abstractmethod
+from typing import Any
+
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class BaseModeration(BaseProviderModel):
+    name: str
+    type: ModelType = ModelType.MODERATION
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
+        super().__init__(model_provider, client)
+        self.name = name
+
+    def run(self, text: str) -> bool:
+        try:
+            return self._run(text)
+        except Exception as ex:
+            raise self.handle_exceptions(ex)
+
+    @abstractmethod
+    def _run(self, text: str) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        raise NotImplementedError

+ 18 - 12
api/core/model_providers/models/moderation/openai_moderation.py

@@ -4,29 +4,35 @@ import openai
 
 
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
     LLMRateLimitError, LLMAuthorizationError
     LLMRateLimitError, LLMAuthorizationError
-from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.models.moderation.base import BaseModeration
 from core.model_providers.providers.base import BaseModelProvider
 from core.model_providers.providers.base import BaseModelProvider
 
 
-DEFAULT_AUDIO_MODEL = 'whisper-1'
+DEFAULT_MODEL = 'whisper-1'
 
 
 
 
-class OpenAIModeration(BaseProviderModel):
-    type: ModelType = ModelType.MODERATION
+class OpenAIModeration(BaseModeration):
 
 
     def __init__(self, model_provider: BaseModelProvider, name: str):
     def __init__(self, model_provider: BaseModelProvider, name: str):
-        super().__init__(model_provider, openai.Moderation)
+        super().__init__(model_provider, openai.Moderation, name)
 
 
-    def run(self, text):
+    def _run(self, text: str) -> bool:
         credentials = self.model_provider.get_model_credentials(
         credentials = self.model_provider.get_model_credentials(
-            model_name=DEFAULT_AUDIO_MODEL,
+            model_name=self.name,
             model_type=self.type
             model_type=self.type
         )
         )
 
 
-        try:
-            return self._client.create(input=text, api_key=credentials['openai_api_key'])
-        except Exception as ex:
-            raise self.handle_exceptions(ex)
+        # 2000 text per chunk
+        length = 2000
+        chunks = [text[i:i + length] for i in range(0, len(text), length)]
+
+        moderation_result = self._client.create(input=chunks,
+                                                api_key=credentials['openai_api_key'])
+
+        for result in moderation_result.results:
+            if result['flagged'] is True:
+                return False
+
+        return True
 
 
     def handle_exceptions(self, ex: Exception) -> Exception:
     def handle_exceptions(self, ex: Exception) -> Exception:
         if isinstance(ex, openai.error.InvalidRequestError):
         if isinstance(ex, openai.error.InvalidRequestError):

+ 35 - 11
api/core/orchestrator_rule_parser.py

@@ -1,6 +1,7 @@
 import math
 import math
 from typing import Optional
 from typing import Optional
 
 
+from flask import current_app
 from langchain import WikipediaAPIWrapper
 from langchain import WikipediaAPIWrapper
 from langchain.callbacks.manager import Callbacks
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.memory.chat_memory import BaseChatMemory
@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
+from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
 from core.conversation_message_task import ConversationMessageTask
 from core.conversation_message_task import ConversationMessageTask
 from core.model_providers.error import ProviderTokenNotInitError
 from core.model_providers.error import ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_factory import ModelFactory
@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetProcessRule
 from models.dataset import Dataset, DatasetProcessRule
 from models.model import AppModelConfig
 from models.model import AppModelConfig
+from models.provider import ProviderType
 
 
 
 
 class OrchestratorRuleParser:
 class OrchestratorRuleParser:
@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
 
 
             # add agent callback to record agent thoughts
             # add agent callback to record agent thoughts
             agent_callback = AgentLoopGatherCallbackHandler(
             agent_callback = AgentLoopGatherCallbackHandler(
-                model_instant=agent_model_instance,
+                model_instance=agent_model_instance,
                 conversation_message_task=conversation_message_task
                 conversation_message_task=conversation_message_task
             )
             )
 
 
@@ -123,23 +125,45 @@ class OrchestratorRuleParser:
 
 
         return chain
         return chain
 
 
-    def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
+    def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
             -> Optional[SensitiveWordAvoidanceChain]:
             -> Optional[SensitiveWordAvoidanceChain]:
         """
         """
         Convert app sensitive word avoidance config to chain
         Convert app sensitive word avoidance config to chain
 
 
+        :param model_instance: model instance
+        :param callbacks: callbacks for the chain
         :param kwargs:
         :param kwargs:
         :return:
         :return:
         """
         """
-        if not self.app_model_config.sensitive_word_avoidance_dict:
-            return None
-
-        sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
-        sensitive_words = sensitive_word_avoidance_config.get("words", "")
-        if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
+        sensitive_word_avoidance_rule = None
+
+        if self.app_model_config.sensitive_word_avoidance_dict:
+            sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
+            if sensitive_word_avoidance_config.get("enabled", False):
+                if sensitive_word_avoidance_config.get('type') == 'moderation':
+                    sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
+                        type=SensitiveWordAvoidanceRule.Type.MODERATION,
+                        canned_response=sensitive_word_avoidance_config.get("canned_response")
+                        if sensitive_word_avoidance_config.get("canned_response")
+                        else 'Your content violates our usage policy. Please revise and try again.',
+                    )
+                else:
+                    sensitive_words = sensitive_word_avoidance_config.get("words", "")
+                    if sensitive_words:
+                        sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
+                            type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
+                            canned_response=sensitive_word_avoidance_config.get("canned_response")
+                            if sensitive_word_avoidance_config.get("canned_response")
+                            else 'Your content violates our usage policy. Please revise and try again.',
+                            extra_params={
+                                'sensitive_words': sensitive_words.split(','),
+                            }
+                        )
+
+        if sensitive_word_avoidance_rule:
             return SensitiveWordAvoidanceChain(
             return SensitiveWordAvoidanceChain(
-                sensitive_words=sensitive_words.split(","),
-                canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
+                model_instance=model_instance,
+                sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
                 output_key="sensitive_word_avoidance_output",
                 output_key="sensitive_word_avoidance_output",
                 callbacks=callbacks,
                 callbacks=callbacks,
                 **kwargs
                 **kwargs

+ 3 - 4
api/tests/integration_tests/models/moderation/test_openai_moderation.py

@@ -2,7 +2,7 @@ import json
 import os
 import os
 from unittest.mock import patch
 from unittest.mock import patch
 
 
-from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL
+from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL
 from core.model_providers.providers.openai_provider import OpenAIProvider
 from core.model_providers.providers.openai_provider import OpenAIProvider
 from models.provider import Provider, ProviderType
 from models.provider import Provider, ProviderType
 
 
@@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
     openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
     openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
     return OpenAIModeration(
     return OpenAIModeration(
         model_provider=openai_provider,
         model_provider=openai_provider,
-        name=DEFAULT_AUDIO_MODEL
+        name=DEFAULT_MODEL
     )
     )
 
 
 
 
@@ -36,5 +36,4 @@ def test_run(mock_decrypt):
     model = get_mock_openai_moderation_model()
     model = get_mock_openai_moderation_model()
     rst = model.run('hello')
     rst = model.run('hello')
 
 
-    assert isinstance(rst, dict)
-    assert 'id' in rst
+    assert rst is True