|
@@ -0,0 +1,418 @@
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import re
|
|
|
+import time
|
|
|
+import uuid
|
|
|
+from typing import Any, Union
|
|
|
+from urllib.parse import urlparse
|
|
|
+
|
|
|
+import boto3
|
|
|
+import requests
|
|
|
+from botocore.exceptions import ClientError
|
|
|
+from requests.exceptions import RequestException
|
|
|
+
|
|
|
+from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
|
+from core.tools.tool.builtin_tool import BuiltinTool
|
|
|
+
|
|
|
+logging.basicConfig(level=logging.INFO)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+LanguageCodeOptions = [
|
|
|
+ "af-ZA",
|
|
|
+ "ar-AE",
|
|
|
+ "ar-SA",
|
|
|
+ "da-DK",
|
|
|
+ "de-CH",
|
|
|
+ "de-DE",
|
|
|
+ "en-AB",
|
|
|
+ "en-AU",
|
|
|
+ "en-GB",
|
|
|
+ "en-IE",
|
|
|
+ "en-IN",
|
|
|
+ "en-US",
|
|
|
+ "en-WL",
|
|
|
+ "es-ES",
|
|
|
+ "es-US",
|
|
|
+ "fa-IR",
|
|
|
+ "fr-CA",
|
|
|
+ "fr-FR",
|
|
|
+ "he-IL",
|
|
|
+ "hi-IN",
|
|
|
+ "id-ID",
|
|
|
+ "it-IT",
|
|
|
+ "ja-JP",
|
|
|
+ "ko-KR",
|
|
|
+ "ms-MY",
|
|
|
+ "nl-NL",
|
|
|
+ "pt-BR",
|
|
|
+ "pt-PT",
|
|
|
+ "ru-RU",
|
|
|
+ "ta-IN",
|
|
|
+ "te-IN",
|
|
|
+ "tr-TR",
|
|
|
+ "zh-CN",
|
|
|
+ "zh-TW",
|
|
|
+ "th-TH",
|
|
|
+ "en-ZA",
|
|
|
+ "en-NZ",
|
|
|
+ "vi-VN",
|
|
|
+ "sv-SE",
|
|
|
+ "ab-GE",
|
|
|
+ "ast-ES",
|
|
|
+ "az-AZ",
|
|
|
+ "ba-RU",
|
|
|
+ "be-BY",
|
|
|
+ "bg-BG",
|
|
|
+ "bn-IN",
|
|
|
+ "bs-BA",
|
|
|
+ "ca-ES",
|
|
|
+ "ckb-IQ",
|
|
|
+ "ckb-IR",
|
|
|
+ "cs-CZ",
|
|
|
+ "cy-WL",
|
|
|
+ "el-GR",
|
|
|
+ "et-ET",
|
|
|
+ "eu-ES",
|
|
|
+ "fi-FI",
|
|
|
+ "gl-ES",
|
|
|
+ "gu-IN",
|
|
|
+ "ha-NG",
|
|
|
+ "hr-HR",
|
|
|
+ "hu-HU",
|
|
|
+ "hy-AM",
|
|
|
+ "is-IS",
|
|
|
+ "ka-GE",
|
|
|
+ "kab-DZ",
|
|
|
+ "kk-KZ",
|
|
|
+ "kn-IN",
|
|
|
+ "ky-KG",
|
|
|
+ "lg-IN",
|
|
|
+ "lt-LT",
|
|
|
+ "lv-LV",
|
|
|
+ "mhr-RU",
|
|
|
+ "mi-NZ",
|
|
|
+ "mk-MK",
|
|
|
+ "ml-IN",
|
|
|
+ "mn-MN",
|
|
|
+ "mr-IN",
|
|
|
+ "mt-MT",
|
|
|
+ "no-NO",
|
|
|
+ "or-IN",
|
|
|
+ "pa-IN",
|
|
|
+ "pl-PL",
|
|
|
+ "ps-AF",
|
|
|
+ "ro-RO",
|
|
|
+ "rw-RW",
|
|
|
+ "si-LK",
|
|
|
+ "sk-SK",
|
|
|
+ "sl-SI",
|
|
|
+ "so-SO",
|
|
|
+ "sr-RS",
|
|
|
+ "su-ID",
|
|
|
+ "sw-BI",
|
|
|
+ "sw-KE",
|
|
|
+ "sw-RW",
|
|
|
+ "sw-TZ",
|
|
|
+ "sw-UG",
|
|
|
+ "tl-PH",
|
|
|
+ "tt-RU",
|
|
|
+ "ug-CN",
|
|
|
+ "uk-UA",
|
|
|
+ "uz-UZ",
|
|
|
+ "wo-SN",
|
|
|
+ "zu-ZA",
|
|
|
+]
|
|
|
+
|
|
|
+MediaFormat = ["mp3", "mp4", "wav", "flac", "ogg", "amr", "webm", "m4a"]
|
|
|
+
|
|
|
+
|
|
|
+def is_url(text):
|
|
|
+ if not text:
|
|
|
+ return False
|
|
|
+ text = text.strip()
|
|
|
+ # Regular expression pattern for URL validation
|
|
|
+ pattern = re.compile(
|
|
|
+ r"^" # Start of the string
|
|
|
+ r"(?:http|https)://" # Protocol (http or https)
|
|
|
+ r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # Domain
|
|
|
+ r"localhost|" # localhost
|
|
|
+ r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address
|
|
|
+ r"(?::\d+)?" # Optional port
|
|
|
+ r"(?:/?|[/?]\S+)" # Path
|
|
|
+ r"$", # End of the string
|
|
|
+ re.IGNORECASE,
|
|
|
+ )
|
|
|
+ return bool(pattern.match(text))
|
|
|
+
|
|
|
+
|
|
|
+def upload_file_from_url_to_s3(s3_client, url, bucket_name, s3_key=None, max_retries=3):
|
|
|
+ """
|
|
|
+ Upload a file from a URL to an S3 bucket with retries and better error handling.
|
|
|
+
|
|
|
+ Parameters:
|
|
|
+ - s3_client
|
|
|
+ - url (str): The URL of the file to upload
|
|
|
+ - bucket_name (str): The name of the S3 bucket
|
|
|
+ - s3_key (str): The desired key (path) in S3. If None, will use the filename from URL
|
|
|
+ - max_retries (int): Maximum number of retry attempts
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ - tuple: (bool, str) - (Success status, Message)
|
|
|
+ """
|
|
|
+
|
|
|
+ # Validate inputs
|
|
|
+ if not url or not bucket_name:
|
|
|
+ return False, "URL and bucket name are required"
|
|
|
+
|
|
|
+ retry_count = 0
|
|
|
+ while retry_count < max_retries:
|
|
|
+ try:
|
|
|
+ # Download the file from URL
|
|
|
+ response = requests.get(url, stream=True, timeout=30)
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ # If s3_key is not provided, try to get filename from URL
|
|
|
+ if not s3_key:
|
|
|
+ parsed_url = urlparse(url)
|
|
|
+ filename = os.path.basename(parsed_url.path.split("/file-preview")[0])
|
|
|
+ s3_key = "transcribe-files/" + filename
|
|
|
+
|
|
|
+ # Upload the file to S3
|
|
|
+ s3_client.upload_fileobj(
|
|
|
+ response.raw,
|
|
|
+ bucket_name,
|
|
|
+ s3_key,
|
|
|
+ ExtraArgs={
|
|
|
+ "ContentType": response.headers.get("content-type"),
|
|
|
+ "ACL": "private", # Ensure the uploaded file is private
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ return f"s3://{bucket_name}/{s3_key}", f"Successfully uploaded file to s3://{bucket_name}/{s3_key}"
|
|
|
+
|
|
|
+ except RequestException as e:
|
|
|
+ retry_count += 1
|
|
|
+ if retry_count == max_retries:
|
|
|
+ return None, f"Failed to download file from URL after {max_retries} attempts: {str(e)}"
|
|
|
+ continue
|
|
|
+
|
|
|
+ except ClientError as e:
|
|
|
+ return None, f"AWS S3 error: {str(e)}"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return None, f"Unexpected error: {str(e)}"
|
|
|
+
|
|
|
+ return None, "Maximum retries exceeded"
|
|
|
+
|
|
|
+
|
|
|
+class TranscribeTool(BuiltinTool):
|
|
|
+ s3_client: Any = None
|
|
|
+ transcribe_client: Any = None
|
|
|
+
|
|
|
+ """
|
|
|
+ Note that you must include one of LanguageCode, IdentifyLanguage,
|
|
|
+ or IdentifyMultipleLanguages in your request.
|
|
|
+ If you include more than one of these parameters, your transcription job fails.
|
|
|
+ """
|
|
|
+
|
|
|
+ def _transcribe_audio(self, audio_file_uri, file_type, **extra_args):
|
|
|
+ uuid_str = str(uuid.uuid4())
|
|
|
+ job_name = f"{int(time.time())}-{uuid_str}"
|
|
|
+ try:
|
|
|
+ # Start transcription job
|
|
|
+ response = self.transcribe_client.start_transcription_job(
|
|
|
+ TranscriptionJobName=job_name, Media={"MediaFileUri": audio_file_uri}, **extra_args
|
|
|
+ )
|
|
|
+
|
|
|
+ # Wait for the job to complete
|
|
|
+ while True:
|
|
|
+ status = self.transcribe_client.get_transcription_job(TranscriptionJobName=job_name)
|
|
|
+ if status["TranscriptionJob"]["TranscriptionJobStatus"] in ["COMPLETED", "FAILED"]:
|
|
|
+ break
|
|
|
+ time.sleep(5)
|
|
|
+
|
|
|
+ if status["TranscriptionJob"]["TranscriptionJobStatus"] == "COMPLETED":
|
|
|
+ return status["TranscriptionJob"]["Transcript"]["TranscriptFileUri"], None
|
|
|
+ else:
|
|
|
+ return None, f"Error: TranscriptionJobStatus:{status['TranscriptionJob']['TranscriptionJobStatus']} "
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return None, f"Error: {str(e)}"
|
|
|
+
|
|
|
+ def _download_and_read_transcript(self, transcript_file_uri: str, max_retries: int = 3) -> tuple[str, str]:
|
|
|
+ """
|
|
|
+ Download and read the transcript file from the given URI.
|
|
|
+
|
|
|
+ Parameters:
|
|
|
+ - transcript_file_uri (str): The URI of the transcript file
|
|
|
+ - max_retries (int): Maximum number of retry attempts
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ - tuple: (text, error) - (Transcribed text if successful, error message if failed)
|
|
|
+ """
|
|
|
+ retry_count = 0
|
|
|
+ while retry_count < max_retries:
|
|
|
+ try:
|
|
|
+ # Download the transcript file
|
|
|
+ response = requests.get(transcript_file_uri, timeout=30)
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ # Parse the JSON content
|
|
|
+ transcript_data = response.json()
|
|
|
+
|
|
|
+ # Check if speaker labels are present and enabled
|
|
|
+ has_speaker_labels = (
|
|
|
+ "results" in transcript_data
|
|
|
+ and "speaker_labels" in transcript_data["results"]
|
|
|
+ and "segments" in transcript_data["results"]["speaker_labels"]
|
|
|
+ )
|
|
|
+
|
|
|
+ if has_speaker_labels:
|
|
|
+ # Get speaker segments
|
|
|
+ segments = transcript_data["results"]["speaker_labels"]["segments"]
|
|
|
+ items = transcript_data["results"]["items"]
|
|
|
+
|
|
|
+ # Create a mapping of start_time -> speaker_label
|
|
|
+ time_to_speaker = {}
|
|
|
+ for segment in segments:
|
|
|
+ speaker_label = segment["speaker_label"]
|
|
|
+ for item in segment["items"]:
|
|
|
+ time_to_speaker[item["start_time"]] = speaker_label
|
|
|
+
|
|
|
+ # Build transcript with speaker labels
|
|
|
+ current_speaker = None
|
|
|
+ transcript_parts = []
|
|
|
+
|
|
|
+ for item in items:
|
|
|
+ # Skip non-pronunciation items (like punctuation)
|
|
|
+ if item["type"] == "punctuation":
|
|
|
+ transcript_parts.append(item["alternatives"][0]["content"])
|
|
|
+ continue
|
|
|
+
|
|
|
+ start_time = item["start_time"]
|
|
|
+ speaker = time_to_speaker.get(start_time)
|
|
|
+
|
|
|
+ if speaker != current_speaker:
|
|
|
+ current_speaker = speaker
|
|
|
+ transcript_parts.append(f"\n[{speaker}]: ")
|
|
|
+
|
|
|
+ transcript_parts.append(item["alternatives"][0]["content"])
|
|
|
+
|
|
|
+ return " ".join(transcript_parts).strip(), None
|
|
|
+ else:
|
|
|
+ # Extract the transcription text
|
|
|
+ # The transcript text is typically in the 'results' -> 'transcripts' array
|
|
|
+ if "results" in transcript_data and "transcripts" in transcript_data["results"]:
|
|
|
+ transcripts = transcript_data["results"]["transcripts"]
|
|
|
+ if transcripts:
|
|
|
+ # Combine all transcript segments
|
|
|
+ full_text = " ".join(t.get("transcript", "") for t in transcripts)
|
|
|
+ return full_text, None
|
|
|
+
|
|
|
+ return None, "No transcripts found in the response"
|
|
|
+
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
+ retry_count += 1
|
|
|
+ if retry_count == max_retries:
|
|
|
+ return None, f"Failed to download transcript file after {max_retries} attempts: {str(e)}"
|
|
|
+ continue
|
|
|
+
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ return None, f"Failed to parse transcript JSON: {str(e)}"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return None, f"Unexpected error while processing transcript: {str(e)}"
|
|
|
+
|
|
|
+ return None, "Maximum retries exceeded"
|
|
|
+
|
|
|
+ def _invoke(
|
|
|
+ self,
|
|
|
+ user_id: str,
|
|
|
+ tool_parameters: dict[str, Any],
|
|
|
+ ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
|
+ """
|
|
|
+ invoke tools
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ if not self.transcribe_client:
|
|
|
+ aws_region = tool_parameters.get("aws_region")
|
|
|
+ if aws_region:
|
|
|
+ self.transcribe_client = boto3.client("transcribe", region_name=aws_region)
|
|
|
+ self.s3_client = boto3.client("s3", region_name=aws_region)
|
|
|
+ else:
|
|
|
+ self.transcribe_client = boto3.client("transcribe")
|
|
|
+ self.s3_client = boto3.client("s3")
|
|
|
+
|
|
|
+ file_url = tool_parameters.get("file_url")
|
|
|
+ file_type = tool_parameters.get("file_type")
|
|
|
+ language_code = tool_parameters.get("language_code")
|
|
|
+ identify_language = tool_parameters.get("identify_language", True)
|
|
|
+ identify_multiple_languages = tool_parameters.get("identify_multiple_languages", False)
|
|
|
+ language_options_str = tool_parameters.get("language_options")
|
|
|
+ s3_bucket_name = tool_parameters.get("s3_bucket_name")
|
|
|
+ ShowSpeakerLabels = tool_parameters.get("ShowSpeakerLabels", True)
|
|
|
+ MaxSpeakerLabels = tool_parameters.get("MaxSpeakerLabels", 2)
|
|
|
+
|
|
|
+ # Check the input params
|
|
|
+ if not s3_bucket_name:
|
|
|
+ return self.create_text_message(text="s3_bucket_name is required")
|
|
|
+ language_options = None
|
|
|
+ if language_options_str:
|
|
|
+ language_options = language_options_str.split("|")
|
|
|
+ for lang in language_options:
|
|
|
+ if lang not in LanguageCodeOptions:
|
|
|
+ return self.create_text_message(
|
|
|
+ text=f"{lang} is not supported, should be one of {LanguageCodeOptions}"
|
|
|
+ )
|
|
|
+ if language_code and language_code not in LanguageCodeOptions:
|
|
|
+ err_msg = f"language_code:{language_code} is not supported, should be one of {LanguageCodeOptions}"
|
|
|
+ return self.create_text_message(text=err_msg)
|
|
|
+
|
|
|
+ err_msg = f"identify_language:{identify_language}, \
|
|
|
+ identify_multiple_languages:{identify_multiple_languages}, \
|
|
|
+ Note that you must include one of LanguageCode, IdentifyLanguage, \
|
|
|
+ or IdentifyMultipleLanguages in your request. \
|
|
|
+ If you include more than one of these parameters, \
|
|
|
+ your transcription job fails."
|
|
|
+ if not language_code:
|
|
|
+ if identify_language and identify_multiple_languages:
|
|
|
+ return self.create_text_message(text=err_msg)
|
|
|
+ else:
|
|
|
+ if identify_language or identify_multiple_languages:
|
|
|
+ return self.create_text_message(text=err_msg)
|
|
|
+
|
|
|
+ extra_args = {
|
|
|
+ "IdentifyLanguage": identify_language,
|
|
|
+ "IdentifyMultipleLanguages": identify_multiple_languages,
|
|
|
+ }
|
|
|
+ if language_code:
|
|
|
+ extra_args["LanguageCode"] = language_code
|
|
|
+ if language_options:
|
|
|
+ extra_args["LanguageOptions"] = language_options
|
|
|
+ if ShowSpeakerLabels:
|
|
|
+ extra_args["Settings"] = {"ShowSpeakerLabels": ShowSpeakerLabels, "MaxSpeakerLabels": MaxSpeakerLabels}
|
|
|
+
|
|
|
+ # upload to s3 bucket
|
|
|
+ s3_path_result, error = upload_file_from_url_to_s3(self.s3_client, url=file_url, bucket_name=s3_bucket_name)
|
|
|
+ if not s3_path_result:
|
|
|
+ return self.create_text_message(text=error)
|
|
|
+
|
|
|
+ transcript_file_uri, error = self._transcribe_audio(
|
|
|
+ audio_file_uri=s3_path_result,
|
|
|
+ file_type=file_type,
|
|
|
+ **extra_args,
|
|
|
+ )
|
|
|
+ if not transcript_file_uri:
|
|
|
+ return self.create_text_message(text=error)
|
|
|
+
|
|
|
+ # Download and read the transcript
|
|
|
+ transcript_text, error = self._download_and_read_transcript(transcript_file_uri)
|
|
|
+ if not transcript_text:
|
|
|
+ return self.create_text_message(text=error)
|
|
|
+
|
|
|
+ return self.create_text_message(text=transcript_text)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return self.create_text_message(f"Exception {str(e)}")
|