|
@@ -0,0 +1,83 @@
|
|
|
+import json
|
|
|
+import logging
|
|
|
+from typing import Any, Union
|
|
|
+
|
|
|
+import boto3
|
|
|
+from pydantic import BaseModel, Field
|
|
|
+
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+logging.basicConfig(level=logging.INFO)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+class GuardrailParameters(BaseModel):
|
|
|
+ guardrail_id: str = Field(..., description="The identifier of the guardrail")
|
|
|
+ guardrail_version: str = Field(..., description="The version of the guardrail")
|
|
|
+ source: str = Field(..., description="The source of the content")
|
|
|
+ text: str = Field(..., description="The text to apply the guardrail to")
|
|
|
+ aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client")
|
|
|
+
|
|
|
+class ApplyGuardrailTool(BuiltinTool):
|
|
|
+ def _invoke(self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any]
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ Invoke the ApplyGuardrail tool
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # Validate and parse input parameters
|
|
|
+ params = GuardrailParameters(**tool_parameters)
|
|
|
+
|
|
|
+ # Initialize AWS client
|
|
|
+ bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region)
|
|
|
+
|
|
|
+ # Apply guardrail
|
|
|
+ response = bedrock_client.apply_guardrail(
|
|
|
+ guardrailIdentifier=params.guardrail_id,
|
|
|
+ guardrailVersion=params.guardrail_version,
|
|
|
+ source=params.source,
|
|
|
+ content=[{"text": {"text": params.text}}]
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check for empty response
|
|
|
+ if not response:
|
|
|
+ return self.create_text_message(text="Received empty response from AWS Bedrock.")
|
|
|
+
|
|
|
+ # Process the result
|
|
|
+ action = response.get("action", "No action specified")
|
|
|
+ outputs = response.get("outputs", [])
|
|
|
+ output = outputs[0].get("text", "No output received") if outputs else "No output received"
|
|
|
+ assessments = response.get("assessments", [])
|
|
|
+
|
|
|
+ # Format assessments
|
|
|
+ formatted_assessments = []
|
|
|
+ for assessment in assessments:
|
|
|
+ for policy_type, policy_data in assessment.items():
|
|
|
+ if isinstance(policy_data, dict) and 'topics' in policy_data:
|
|
|
+ for topic in policy_data['topics']:
|
|
|
+ formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}")
|
|
|
+ else:
|
|
|
+ formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
|
|
|
+
|
|
|
+ result = f"Action: {action}\n "
|
|
|
+ result += f"Output: {output}\n "
|
|
|
+ if formatted_assessments:
|
|
|
+ result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
|
|
|
+# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
|
|
+
|
|
|
+ return self.create_text_message(text=result)
|
|
|
+
|
|
|
+ except boto3.exceptions.BotoCoreError as e:
|
|
|
+ error_message = f'AWS service error: {str(e)}'
|
|
|
+ logger.error(error_message, exc_info=True)
|
|
|
+ return self.create_text_message(text=error_message)
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ error_message = f'JSON parsing error: {str(e)}'
|
|
|
+ logger.error(error_message, exc_info=True)
|
|
|
+ return self.create_text_message(text=error_message)
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f'An unexpected error occurred: {str(e)}'
|
|
|
+ logger.error(error_message, exc_info=True)
|
|
|
+ return self.create_text_message(text=error_message)
|