123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- from typing import Optional
- from pydantic import BaseModel
- from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
- from core.helper.encrypter import decrypt_token
- from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
- from extensions.ext_database import db
- from models.api_based_extension import APIBasedExtension
- class ModerationInputParams(BaseModel):
- app_id: str = ""
- inputs: dict = {}
- query: str = ""
- class ModerationOutputParams(BaseModel):
- app_id: str = ""
- text: str
- class ApiModeration(Moderation):
- name: str = "api"
- @classmethod
- def validate_config(cls, tenant_id: str, config: dict) -> None:
- """
- Validate the incoming form config data.
- :param tenant_id: the id of workspace
- :param config: the form config data
- :return:
- """
- cls._validate_inputs_and_outputs_config(config, False)
- api_based_extension_id = config.get("api_based_extension_id")
- if not api_based_extension_id:
- raise ValueError("api_based_extension_id is required")
- extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
- if not extension:
- raise ValueError("API-based Extension not found. Please check it again.")
- def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
- flagged = False
- preset_response = ""
- if self.config is None:
- raise ValueError("The config is not set.")
- if self.config["inputs_config"]["enabled"]:
- params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
- result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
- return ModerationInputsResult(**result)
- return ModerationInputsResult(
- flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
- )
- def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
- flagged = False
- preset_response = ""
- if self.config is None:
- raise ValueError("The config is not set.")
- if self.config["outputs_config"]["enabled"]:
- params = ModerationOutputParams(app_id=self.app_id, text=text)
- result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
- return ModerationOutputsResult(**result)
- return ModerationOutputsResult(
- flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
- )
- def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
- if self.config is None:
- raise ValueError("The config is not set.")
- extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
- if not extension:
- raise ValueError("API-based Extension not found. Please check it again.")
- requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
- result = requestor.request(extension_point, params)
- return result
- @staticmethod
- def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
- extension = (
- db.session.query(APIBasedExtension)
- .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
- .first()
- )
- return extension
|