Преглед на файлове

feat: support more model types and builtin tools on aws/sagemaker (#8061)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
ybalbert001 преди 7 месеца
родител
ревизия
954580a4af

+ 274 - 43
api/core/model_runtime/model_providers/sagemaker/llm/llm.py

@@ -1,17 +1,36 @@
 import json
 import logging
-from collections.abc import Generator
-from typing import Any, Optional, Union
+import re
+from collections.abc import Generator, Iterator
+from typing import Any, Optional, Union, cast
 
+# from openai.types.chat import ChatCompletion, ChatCompletionChunk
 import boto3
+from sagemaker import Predictor, serializers
+from sagemaker.session import Session
 
-from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
+    ImagePromptMessageContent,
     PromptMessage,
+    PromptMessageContent,
+    PromptMessageContentType,
     PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    I18nObject,
+    ModelFeature,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
 )
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
     InvokeBadRequestError,
@@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
 
 logger = logging.getLogger(__name__)
 
+def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
+    """    
+    params:
+    predictor : Sagemaker Predictor 
+    messages (List[Dict[str,Any]]): message list。
+                messages = [
+                {"role": "system", "content":"please answer in Chinese"},
+                {"role": "user", "content": "who are you? what are you doing?"},
+            ]
+    params (Dict[str,Any]): model parameters for LLM。
+    stream (bool): False by default。
+    
+    response:
+    result of inference if stream is False
+    Iterator of Chunks if stream is True
+    """
+    payload = {
+        "model" : params.get('model_name'),
+        "stop" : stop,
+        "messages": messages,
+        "stream" : stream,
+        "max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
+        "temperature" : params.get('temperature', 0.1),
+        "top_p" : params.get('top_p', 0.9),
+    }
+
+    if not stream:
+        response = predictor.predict(payload)
+        return response
+    else:
+        response_stream = predictor.predict_stream(payload)
+        return response_stream
 
 class SageMakerLargeLanguageModel(LargeLanguageModel):
     """
     Model class for Cohere large language model.
     """
     sagemaker_client: Any = None
+    sagemaker_sess : Any = None
+    predictor : Any = None
+
+    def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                                       tools: list[PromptMessageTool],
+                                       resp: bytes) -> LLMResult:
+        """
+            handle normal chat generate response
+        """
+        resp_obj = json.loads(resp.decode('utf-8'))
+        resp_str = resp_obj.get('choices')[0].get('message').get('content')
+
+        if len(resp_str) == 0:
+            raise InvokeServerUnavailableError("Empty response")
+
+        assistant_prompt_message = AssistantPromptMessage(
+            content=resp_str,
+            tool_calls=[]
+        )
+
+        prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
+        completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
+
+        usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
+                                          completion_tokens=completion_tokens)
+
+        response = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            system_fingerprint=None,
+            usage=usage,
+            message=assistant_prompt_message,
+        )
+
+        return response
+
+    def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                                     tools: list[PromptMessageTool],
+                                     resp: Iterator[bytes]) -> Generator:
+        """
+            handle stream chat generate response
+        """
+        full_response = ''
+        buffer = ""
+        for chunk_bytes in resp:
+            buffer += chunk_bytes.decode('utf-8')
+            last_idx = 0
+            for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
+                try:
+                    data = json.loads(match.group(1).strip())
+                    last_idx = match.span()[1]
+
+                    if "content" in data["choices"][0]["delta"]:
+                        chunk_content = data["choices"][0]["delta"]["content"]
+                        assistant_prompt_message = AssistantPromptMessage(
+                            content=chunk_content,
+                            tool_calls=[] 
+                        )
+
+                        if data["choices"][0]['finish_reason'] is not None:
+                            temp_assistant_prompt_message = AssistantPromptMessage(
+                                content=full_response,
+                                tool_calls=[]
+                            )
+                            prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
+                            completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
+                            usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
+
+                            yield LLMResultChunk(
+                                model=model,
+                                prompt_messages=prompt_messages,
+                                system_fingerprint=None,
+                                delta=LLMResultChunkDelta(
+                                    index=0,
+                                    message=assistant_prompt_message,
+                                    finish_reason=data["choices"][0]['finish_reason'],
+                                    usage=usage
+                                ),
+                            )
+                        else:
+                            yield LLMResultChunk(
+                                model=model,
+                                prompt_messages=prompt_messages,
+                                system_fingerprint=None,
+                                delta=LLMResultChunkDelta(
+                                    index=0,
+                                    message=assistant_prompt_message
+                                ),
+                            )
+
+                            full_response += chunk_content
+                except (json.JSONDecodeError, KeyError, IndexError) as e:
+                    logger.info("json parse exception, content: {}".format(match.group(1).strip()))
+                    pass
+
+            buffer = buffer[last_idx:]
 
     def _invoke(self, model: str, credentials: dict,
                 prompt_messages: list[PromptMessage], model_parameters: dict,
@@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         :param user: unique user id
         :return: full response or stream response chunk generator result
         """
-        # get model mode
-        model_mode = self.get_model_mode(model, credentials)
-
         if not self.sagemaker_client:
             access_key = credentials.get('access_key')
             secret_key = credentials.get('secret_key')
@@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
             else:
                 self.sagemaker_client = boto3.client("sagemaker-runtime")
 
+            sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
+            self.predictor = Predictor(
+                endpoint_name=credentials.get('sagemaker_endpoint'),
+                sagemaker_session=sagemaker_session,
+                serializer=serializers.JSONSerializer(),
+            )
 
-        sagemaker_endpoint = credentials.get('sagemaker_endpoint')
-        response_model = self.sagemaker_client.invoke_endpoint(
-                    EndpointName=sagemaker_endpoint,
-                    Body=json.dumps(
-                    {
-                        "inputs": prompt_messages[0].content,
-                        "parameters": { "stop" : stop},
-                        "history" : []
-                    }
-                    ),
-                    ContentType="application/json",
-                )
 
-        assistant_text = response_model['Body'].read().decode('utf8')
+        messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
+        response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
 
-        # transform assistant message to prompt message
-        assistant_prompt_message = AssistantPromptMessage(
-            content=assistant_text
-        )
+        if stream:
+            if tools and len(tools) > 0:
+                raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
 
-        usage = self._calc_response_usage(model, credentials, 0, 0)
+            return self._handle_chat_stream_response(model=model, credentials=credentials,
+                                                     prompt_messages=prompt_messages,
+                                                     tools=tools, resp=response)
+        return self._handle_chat_generate_response(model=model, credentials=credentials,
+                                                   prompt_messages=prompt_messages,
+                                                   tools=tools, resp=response)
 
-        response = LLMResult(
-            model=model,
-            prompt_messages=prompt_messages,
-            message=assistant_prompt_message,
-            usage=usage
-        )
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+        """
+        Convert PromptMessage to dict for OpenAI Compatibility API
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "user", "content": message.content}
+            else:
+                sub_messages = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(PromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "text",
+                            "text": message_content.data
+                        }
+                        sub_messages.append(sub_message_dict)
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": message_content.data,
+                                "detail": message_content.detail.value
+                            }
+                        }
+                        sub_messages.append(sub_message_dict)
+                message_dict = {"role": "user", "content": sub_messages}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "assistant", "content": message.content}
+            if message.tool_calls and len(message.tool_calls) > 0:
+                message_dict["function_call"] = {
+                    "name": message.tool_calls[0].function.name,
+                    "arguments": message.tool_calls[0].function.arguments
+                }
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
+        else:
+            raise ValueError(f"Unknown message type {type(message)}")
+
+        return message_dict
+
+    def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
+                                  is_completion_model: bool = False) -> int:
+        def tokens(text: str):
+            return self._get_num_tokens_by_gpt2(text)
+
+        if is_completion_model:
+            return sum(tokens(str(message.content)) for message in messages)
+
+        tokens_per_message = 3
+        tokens_per_name = 1
+
+        num_tokens = 0
+        messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
+        for message in messages_dict:
+            num_tokens += tokens_per_message
+            for key, value in message.items():
+                if isinstance(value, list):
+                    text = ''
+                    for item in value:
+                        if isinstance(item, dict) and item['type'] == 'text':
+                            text += item['text']
+
+                    value = text
+
+                if key == "tool_calls":
+                    for tool_call in value:
+                        for t_key, t_value in tool_call.items():
+                            num_tokens += tokens(t_key)
+                            if t_key == "function":
+                                for f_key, f_value in t_value.items():
+                                    num_tokens += tokens(f_key)
+                                    num_tokens += tokens(f_value)
+                            else:
+                                num_tokens += tokens(t_key)
+                                num_tokens += tokens(t_value)
+                if key == "function_call":
+                    for t_key, t_value in value.items():
+                        num_tokens += tokens(t_key)
+                        if t_key == "function":
+                            for f_key, f_value in t_value.items():
+                                num_tokens += tokens(f_key)
+                                num_tokens += tokens(f_value)
+                        else:
+                            num_tokens += tokens(t_key)
+                            num_tokens += tokens(t_value)
+                else:
+                    num_tokens += tokens(str(value))
 
-        return response
+                if key == "name":
+                    num_tokens += tokens_per_name
+        num_tokens += 3
+
+        if tools:
+            num_tokens += self._num_tokens_for_tools(tools)
+
+        return num_tokens
 
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                        tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         :return:
         """
         # get model mode
-        model_mode = self.get_model_mode(model)
-
         try:
-            return 0
+            return self._num_tokens_from_messages(prompt_messages, tools)
         except Exception as e:
             raise self._transform_invoke_error(e)
 
@@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
         """
         try:
             # get model mode
-            model_mode = self.get_model_mode(model)
+            pass
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
@@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
             )
         ]
 
-        completion_type = LLMMode.value_of(credentials["mode"])
-
-        if completion_type == LLMMode.CHAT:
-            print(f"completion_type : {LLMMode.CHAT.value}") 
-
-        if completion_type == LLMMode.COMPLETION:
-            print(f"completion_type : {LLMMode.COMPLETION.value}") 
+        completion_type = LLMMode.value_of(credentials["mode"]).value
 
         features = []
 

+ 1 - 1
api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py

@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
 
 class SageMakerRerankModel(RerankModel):
     """
-    Model class for Cohere rerank model.
+    Model class for SageMaker rerank model.
     """
     sagemaker_client: Any = None
 

+ 27 - 1
api/core/model_runtime/model_providers/sagemaker/sagemaker.py

@@ -1,10 +1,11 @@
 import logging
+import uuid
+from typing import IO, Any
 
 from core.model_runtime.model_providers.__base.model_provider import ModelProvider
 
 logger = logging.getLogger(__name__)
 
-
 class SageMakerProvider(ModelProvider):
     def validate_provider_credentials(self, credentials: dict) -> None:
         """
@@ -15,3 +16,28 @@ class SageMakerProvider(ModelProvider):
         :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
         """
         pass
+
+def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
+    '''
+        return s3_uri of this file
+    '''
+    s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
+    s3_client.put_object(
+        Body=file.read(),
+        Bucket=bucket,
+        Key=s3_key,
+        ContentType='audio/mp3'
+    )
+    return s3_key
+
+def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
+    object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
+    try:
+        response = s3_client.generate_presigned_url('get_object',
+                                                    Params={'Bucket': bucket_name, 'Key': object_key},
+                                                    ExpiresIn=expiration)
+    except Exception as e:
+        print(f"Error generating presigned URL: {e}")
+        return None
+
+    return response

+ 73 - 5
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml

@@ -21,6 +21,8 @@ supported_model_types:
   - llm
   - text-embedding
   - rerank
+  - speech2text
+  - tts
 configurate_methods:
   - customizable-model
 model_credential_schema:
@@ -45,14 +47,10 @@ model_credential_schema:
         zh_Hans: 选择对话类型
         en_US: Select completion mode
       options:
-        - value: completion
-          label:
-            en_US: Completion
-            zh_Hans: 补全
         - value: chat
           label:
             en_US: Chat
-            zh_Hans: 对话
+            zh_Hans: Chat
     - variable: sagemaker_endpoint
       label:
         en_US: sagemaker endpoint
@@ -61,6 +59,76 @@ model_credential_schema:
       placeholder:
         zh_Hans: 请输出你的Sagemaker推理端点
         en_US: Enter your Sagemaker Inference endpoint
+    - variable: audio_s3_cache_bucket
+      show_on:
+        - variable: __model_type
+          value: speech2text
+      label:
+        zh_Hans: 音频缓存桶(s3 bucket)
+        en_US: audio cache bucket(s3 bucket)
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: sagemaker-us-east-1-******207838
+        en_US: sagemaker-us-east-1-*******7838
+    - variable: audio_model_type
+      show_on:
+        - variable: __model_type
+          value: tts
+      label:
+        en_US: Audio model type
+      type: select
+      required: true
+      placeholder:
+        zh_Hans: 语音模型类型
+        en_US: Audio model type
+      options:
+        - value: PresetVoice
+          label:
+            en_US: preset voice
+            zh_Hans: 内置音色
+        - value: CloneVoice
+          label:
+            en_US: clone voice
+            zh_Hans: 克隆音色
+        - value: CloneVoice_CrossLingual
+          label:
+            en_US: crosslingual clone voice
+            zh_Hans: 跨语种克隆音色
+        - value: InstructVoice
+          label:
+            en_US: Instruct voice
+            zh_Hans: 文字指令音色
+    - variable: prompt_audio
+      show_on:
+        - variable: __model_type
+          value: tts
+      label:
+        en_US: Mock Audio Source
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 被模仿的音色音频
+        en_US: source audio to be mocked
+    - variable: prompt_text
+      show_on:
+        - variable: __model_type
+          value: tts
+      label:
+        en_US: Prompt Audio Text
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 模仿音色的对应文本
+        en_US: text for the mocked source audio
+    - variable: instruct_text
+      show_on:
+        - variable: __model_type
+          value: tts
+      label:
+        en_US: instruct text for speaker
+      type: text-input
+      required: false
     - variable: aws_access_key_id
       required: false
       label:

+ 0 - 0
api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py


+ 142 - 0
api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py

@@ -0,0 +1,142 @@
+import json
+import logging
+from typing import IO, Any, Optional
+
+import boto3
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
+
+logger = logging.getLogger(__name__)
+
+class SageMakerSpeech2TextModel(Speech2TextModel):
+    """
+    Model class for Xinference speech to text model.
+    """
+    sagemaker_client: Any = None
+    s3_client : Any = None
+
+    def _invoke(self, model: str, credentials: dict,
+                file: IO[bytes], user: Optional[str] = None) \
+            -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        asr_text = None
+
+        try:
+            if not self.sagemaker_client:
+                access_key = credentials.get('aws_access_key_id')
+                secret_key = credentials.get('aws_secret_access_key')
+                aws_region = credentials.get('aws_region')
+                if aws_region:
+                    if access_key and secret_key:
+                        self.sagemaker_client = boto3.client("sagemaker-runtime", 
+                            aws_access_key_id=access_key,
+                            aws_secret_access_key=secret_key,
+                            region_name=aws_region)
+                        self.s3_client = boto3.client("s3",
+                            aws_access_key_id=access_key,
+                            aws_secret_access_key=secret_key,
+                            region_name=aws_region)
+                    else:
+                        self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+                        self.s3_client = boto3.client("s3", region_name=aws_region)
+                else:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime")
+                    self.s3_client = boto3.client("s3")
+
+            s3_prefix='dify/speech2text/'
+            sagemaker_endpoint = credentials.get('sagemaker_endpoint')
+            bucket = credentials.get('audio_s3_cache_bucket')
+
+            s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
+            payload = {
+                "audio_s3_presign_uri" : s3_presign_url
+            }
+
+            response_model = self.sagemaker_client.invoke_endpoint(
+                EndpointName=sagemaker_endpoint,
+                Body=json.dumps(payload),
+                ContentType="application/json"
+            )
+            json_str = response_model['Body'].read().decode('utf8')
+            json_obj = json.loads(json_str)
+            asr_text = json_obj['text']
+        except Exception as e:
+            logger.exception(f'Exception {e}, line : {line}')
+
+        return asr_text
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        pass
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                InvokeBadRequestError,
+                KeyError,
+                ValueError
+            ]
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.SPEECH2TEXT,
+            model_properties={ },
+            parameter_rules=[]
+        )
+
+        return entity

+ 0 - 0
api/core/model_runtime/model_providers/sagemaker/tts/__init__.py


+ 287 - 0
api/core/model_runtime/model_providers/sagemaker/tts/tts.py

@@ -0,0 +1,287 @@
+import concurrent.futures
+import copy
+import json
+import logging
+from enum import Enum
+from typing import Any, Optional
+
+import boto3
+import requests
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.model_providers.__base.tts_model import TTSModel
+
+logger = logging.getLogger(__name__)
+
+class TTSModelType(Enum):
+    PresetVoice = "PresetVoice"
+    CloneVoice = "CloneVoice"
+    CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
+    InstructVoice = "InstructVoice"
+
+class SageMakerText2SpeechModel(TTSModel):
+
+    sagemaker_client: Any = None
+    s3_client : Any = None
+    comprehend_client : Any = None
+
+    def __init__(self):
+        # preset voices, need support custom voice
+        self.model_voices = {
+            '__default': {
+                'all': [
+                    {'name': 'Default', 'value': 'default'},
+                ]
+            },
+            'CosyVoice': {
+                'zh-Hans': [
+                    {'name': '中文男', 'value': '中文男'},
+                    {'name': '中文女', 'value': '中文女'},
+                    {'name': '粤语女', 'value': '粤语女'},
+                ],
+                'zh-Hant': [
+                    {'name': '中文男', 'value': '中文男'},
+                    {'name': '中文女', 'value': '中文女'},
+                    {'name': '粤语女', 'value': '粤语女'},
+                ],
+                'en-US': [
+                    {'name': '英文男', 'value': '英文男'},
+                    {'name': '英文女', 'value': '英文女'},
+                ],
+                'ja-JP': [
+                    {'name': '日语男', 'value': '日语男'},
+                ],
+                'ko-KR': [
+                    {'name': '韩语女', 'value': '韩语女'},
+                ]
+            }
+        }
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+                Validate model credentials
+
+                :param model: model name
+                :param credentials: model credentials
+                :return:
+                """
+        pass
+
+    def _detect_lang_code(self, content:str, map_dict:dict=None):
+        map_dict = {
+            "zh" : "<|zh|>",
+            "en" : "<|en|>",
+            "ja" : "<|jp|>",
+            "zh-TW" : "<|yue|>",
+            "ko" : "<|ko|>"
+        }
+
+        response = self.comprehend_client.detect_dominant_language(Text=content)
+        language_code = response['Languages'][0]['LanguageCode']
+
+        return map_dict.get(language_code, '<|zh|>')
+
+    def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
+        if model_type == TTSModelType.PresetVoice.value and model_role:
+            return { "tts_text" : content_text, "role" : model_role }
+        if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
+            return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
+        if model_type ==  TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
+            lang_tag = self._detect_lang_code(content_text)
+            return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
+        if model_type ==  TTSModelType.InstructVoice.value and instruct_text and model_role:
+            return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
+
+        raise RuntimeError(f"Invalid params for {model_type}")
+
+    def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
+                user: Optional[str] = None):
+        """
+        _invoke text2speech model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param content_text: text content to be translated
+        :param user: unique user id
+        :return: text translated to audio file
+        """
+        if not self.sagemaker_client:
+            access_key = credentials.get('aws_access_key_id')
+            secret_key = credentials.get('aws_secret_access_key')
+            aws_region = credentials.get('aws_region')
+            if aws_region:
+                if access_key and secret_key:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime", 
+                        aws_access_key_id=access_key,
+                        aws_secret_access_key=secret_key,
+                        region_name=aws_region)
+                    self.s3_client = boto3.client("s3",
+                        aws_access_key_id=access_key,
+                        aws_secret_access_key=secret_key,
+                        region_name=aws_region)
+                    self.comprehend_client = boto3.client('comprehend',
+                        aws_access_key_id=access_key,
+                        aws_secret_access_key=secret_key,
+                        region_name=aws_region)
+                else:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+                    self.s3_client = boto3.client("s3", region_name=aws_region)
+                    self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
+            else:
+                self.sagemaker_client = boto3.client("sagemaker-runtime")
+                self.s3_client = boto3.client("s3")
+                self.comprehend_client = boto3.client('comprehend')
+
+        model_type = credentials.get('audio_model_type', 'PresetVoice')
+        prompt_text = credentials.get('prompt_text')
+        prompt_audio = credentials.get('prompt_audio')
+        instruct_text = credentials.get('instruct_text')
+        sagemaker_endpoint = credentials.get('sagemaker_endpoint')
+        payload = self._build_tts_payload(
+            model_type, 
+            content_text, 
+            voice, 
+            prompt_text, 
+            prompt_audio, 
+            instruct_text
+        )
+
+        return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TTS,
+            model_properties={},
+            parameter_rules=[]
+        )
+
+        return entity
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                InvokeBadRequestError,
+                KeyError,
+                ValueError
+            ]
+        }
+
+    def _get_model_default_voice(self, model: str, credentials: dict) -> any:
+        return ""
+
+    def _get_model_word_limit(self, model: str, credentials: dict) -> int:
+        return 15
+
+    def _get_model_audio_type(self, model: str, credentials: dict) -> str:
+        return "mp3"
+
+    def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
+        return 5
+
+    def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
+        audio_model_name = 'CosyVoice'
+        for key, voices in self.model_voices.items():
+            if key in audio_model_name:
+                if language and language in voices:
+                    return voices[language]
+                elif 'all' in voices:
+                    return voices['all']
+
+        return self.model_voices['__default']['all']
+
+    def _invoke_sagemaker(self, payload:dict, endpoint:str):
+        response_model = self.sagemaker_client.invoke_endpoint(
+            EndpointName=endpoint,
+            Body=json.dumps(payload),
+            ContentType="application/json",
+        )
+        json_str = response_model['Body'].read().decode('utf8')
+        json_obj = json.loads(json_str)
+        return json_obj
+
+    def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
+        """
+        _tts_invoke_streaming text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :return: text translated to audio file
+        """
+        try:
+            lang_tag = ''
+            if model_type == TTSModelType.CloneVoice_CrossLingual.value:
+                lang_tag = payload.pop('lang_tag')
+            
+            word_limit = self._get_model_word_limit(model='', credentials={})
+            content_text = payload.get("tts_text")
+            if len(content_text) > word_limit:
+                split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
+                sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
+                len_sent = len(sentences)
+                executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
+                payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
+                for idx in range(len_sent):
+                    payloads[idx]["tts_text"] = sentences[idx]
+
+                futures = [ executor.submit(
+                    self._invoke_sagemaker,
+                    payload=payload,
+                    endpoint=sagemaker_endpoint,
+                )
+                    for payload in payloads]
+
+                for index, future in enumerate(futures):
+                    resp = future.result()
+                    audio_bytes = requests.get(resp.get('s3_presign_url')).content
+                    for i in range(0, len(audio_bytes), 1024):
+                        yield audio_bytes[i:i + 1024]
+            else:
+                resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
+                audio_bytes = requests.get(resp.get('s3_presign_url')).content
+
+                for i in range(0, len(audio_bytes), 1024):
+                    yield audio_bytes[i:i + 1024]
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))

+ 6 - 3
api/core/tools/provider/builtin/aws/tools/apply_guardrail.py

@@ -3,6 +3,7 @@ import logging
 from typing import Any, Union
 
 import boto3
+from botocore.exceptions import BotoCoreError
 from pydantic import BaseModel, Field
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
@@ -16,7 +17,7 @@ class GuardrailParameters(BaseModel):
     guardrail_version: str = Field(..., description="The version of the guardrail")
     source: str = Field(..., description="The source of the content")
     text: str = Field(..., description="The text to apply the guardrail to")
-    aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client")
+    aws_region: str = Field(..., description="AWS region for the Bedrock client")
 
 class ApplyGuardrailTool(BuiltinTool):
     def _invoke(self,
@@ -40,6 +41,8 @@ class ApplyGuardrailTool(BuiltinTool):
                 source=params.source,
                 content=[{"text": {"text": params.text}}]
             )
+            
+            logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
 
             # Check for empty response
             if not response:
@@ -69,7 +72,7 @@ class ApplyGuardrailTool(BuiltinTool):
 
             return self.create_text_message(text=result)
 
-        except boto3.exceptions.BotoCoreError as e:
+        except BotoCoreError as e:
             error_message = f'AWS service error: {str(e)}'
             logger.error(error_message, exc_info=True)
             return self.create_text_message(text=error_message)
@@ -80,4 +83,4 @@ class ApplyGuardrailTool(BuiltinTool):
         except Exception as e:
             error_message = f'An unexpected error occurred: {str(e)}'
             logger.error(error_message, exc_info=True)
-            return self.create_text_message(text=error_message)
+            return self.create_text_message(text=error_message)

+ 11 - 0
api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml

@@ -54,3 +54,14 @@ parameters:
       zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
     llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
     form: llm
+  - name: aws_region
+    type: string
+    required: true
+    label:
+      en_US: AWS Region
+      zh_Hans: AWS 区域
+    human_description:
+      en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
+      zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
+    llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
+    form: form

+ 71 - 0
api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py

@@ -0,0 +1,71 @@
+import json
+import logging
+from typing import Any, Union
+
+import boto3
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+console_handler = logging.StreamHandler()
+logger.addHandler(console_handler)
+
+
+class LambdaYamlToJsonTool(BuiltinTool):
+    lambda_client: Any = None
+
+    def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
+        msg = { 
+            "body": yaml_content
+        }
+        logger.info(json.dumps(msg))
+
+        invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
+                                               InvocationType='RequestResponse',
+                                               Payload=json.dumps(msg))
+        response_body = invoke_response['Payload']
+
+        response_str = response_body.read().decode("utf-8")
+        resp_json = json.loads(response_str)
+
+        logger.info(resp_json)
+        if resp_json['statusCode'] != 200:
+            raise Exception(f"Invalid status code: {response_str}")
+
+        return resp_json['body']
+
+    def _invoke(self, 
+                user_id: str, 
+               tool_parameters: dict[str, Any], 
+        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+            invoke tools
+        """
+        try:
+            if not self.lambda_client:
+                aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region 
+                if aws_region:
+                    self.lambda_client = boto3.client("lambda", region_name=aws_region)
+                else:
+                    self.lambda_client = boto3.client("lambda")
+
+            yaml_content = tool_parameters.get('yaml_content', '')
+            if not yaml_content:
+                return self.create_text_message('Please input yaml_content')
+
+            lambda_name = tool_parameters.get('lambda_name', '')
+            if not lambda_name:
+                return self.create_text_message('Please input lambda_name')
+            logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
+            
+            result = self._invoke_lambda(lambda_name, yaml_content)
+            logger.debug(result)
+            
+            return self.create_text_message(result)
+        except Exception as e:
+            return self.create_text_message(f'Exception: {str(e)}')
+
+        console_handler.flush()

+ 53 - 0
api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml

@@ -0,0 +1,53 @@
+identity:
+  name: lambda_yaml_to_json
+  author: AWS
+  label:
+    en_US: LambdaYamlToJson
+    zh_Hans: LambdaYamlToJson
+    pt_BR: LambdaYamlToJson
+  icon: icon.svg
+description:
+  human:
+    en_US: A tool to convert yaml to json using AWS Lambda.
+    zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。
+    pt_BR: A tool to convert yaml to json using AWS Lambda.
+  llm: A tool to convert yaml to json.
+parameters:
+  - name: yaml_content
+    type: string
+    required: true
+    label:
+      en_US: YAML content to convert for
+      zh_Hans: YAML 内容
+      pt_BR: YAML content to convert for
+    human_description:
+      en_US: YAML content to convert for
+      zh_Hans: YAML 内容
+      pt_BR: YAML content to convert for
+    llm_description: YAML content to convert for
+    form: llm
+  - name: aws_region
+    type: string
+    required: false
+    label:
+      en_US: region of lambda
+      zh_Hans: Lambda 所在的region
+      pt_BR: region of lambda
+    human_description:
+      en_US: region of lambda
+      zh_Hans: Lambda 所在的region
+      pt_BR: region of lambda
+    llm_description: region of lambda
+    form: form
+  - name: lambda_name
+    type: string
+    required: false
+    label:
+      en_US: name of lambda
+      zh_Hans: Lambda 名称
+      pt_BR: name of lambda
+    human_description:
+      en_US: name of lambda
+      zh_Hans: Lambda 名称
+      pt_BR: name of lambda
+    form: form

+ 2 - 4
api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py

@@ -78,9 +78,7 @@ class SageMakerReRankTool(BuiltinTool):
             sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
 
             line = 9
-            results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False)
-            return self.create_text_message(text=results_str)
+            return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
             
         except Exception as e:
-            return self.create_text_message(f'Exception {str(e)}, line : {line}')
-    
+            return self.create_text_message(f'Exception {str(e)}, line : {line}')

+ 95 - 0
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py

@@ -0,0 +1,95 @@
+import json
+from enum import Enum
+from typing import Any, Union
+
+import boto3
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class TTSModelType(Enum):
+    PresetVoice = "PresetVoice"
+    CloneVoice = "CloneVoice"
+    CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
+    InstructVoice = "InstructVoice"
+
+class SageMakerTTSTool(BuiltinTool):
+    sagemaker_client: Any = None
+    sagemaker_endpoint:str = None
+    s3_client : Any = None
+    comprehend_client : Any = None
+
+    def _detect_lang_code(self, content:str, map_dict:dict=None):
+        map_dict = {
+            "zh" : "<|zh|>",
+            "en" : "<|en|>",
+            "ja" : "<|jp|>",
+            "zh-TW" : "<|yue|>",
+            "ko" : "<|ko|>"
+        }
+
+        response = self.comprehend_client.detect_dominant_language(Text=content)
+        language_code = response['Languages'][0]['LanguageCode']
+        return map_dict.get(language_code, '<|zh|>')
+
+    def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
+        if model_type == TTSModelType.PresetVoice.value and model_role:
+            return { "tts_text" : content_text, "role" : model_role }
+        if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
+            return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
+        if model_type ==  TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
+            lang_tag = self._detect_lang_code(content_text)
+            return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
+        if model_type ==  TTSModelType.InstructVoice.value and instruct_text and model_role:
+            return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
+
+        raise RuntimeError(f"Invalid params for {model_type}")
+
+    def _invoke_sagemaker(self, payload:dict, endpoint:str):
+        response_model = self.sagemaker_client.invoke_endpoint(
+            EndpointName=endpoint,
+            Body=json.dumps(payload),
+            ContentType="application/json",
+        )
+        json_str = response_model['Body'].read().decode('utf8')
+        json_obj = json.loads(json_str)
+        return json_obj
+
+    def _invoke(self, 
+                user_id: str, 
+               tool_parameters: dict[str, Any], 
+        ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+        """
+            invoke tools
+        """
+        try:
+            if not self.sagemaker_client:
+                aws_region = tool_parameters.get('aws_region')
+                if aws_region:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+                    self.s3_client = boto3.client("s3", region_name=aws_region)
+                    self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
+                else:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime")
+                    self.s3_client = boto3.client("s3")
+                    self.comprehend_client = boto3.client('comprehend')
+
+            if not self.sagemaker_endpoint:
+                self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
+
+            tts_text = tool_parameters.get('tts_text')
+            tts_infer_type = tool_parameters.get('tts_infer_type')
+
+            voice = tool_parameters.get('voice')
+            mock_voice_audio = tool_parameters.get('mock_voice_audio')
+            mock_voice_text = tool_parameters.get('mock_voice_text')
+            voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
+            payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
+
+            result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
+
+            return self.create_text_message(text=result['s3_presign_url'])
+            
+        except Exception as e:
+            return self.create_text_message(f'Exception {str(e)}')

+ 149 - 0
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml

@@ -0,0 +1,149 @@
+identity:
+  name: sagemaker_tts
+  author: AWS
+  label:
+    en_US: SagemakerTTS
+    zh_Hans: Sagemaker语音合成
+    pt_BR: SagemakerTTS
+  icon: icon.svg
+description:
+  human:
+    en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
+    zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
+    pt_BR: A tool for Speech synthesis.
+  llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
+parameters:
+  - name: sagemaker_endpoint
+    type: string
+    required: true
+    label:
+      en_US: sagemaker endpoint for tts
+      zh_Hans: 语音生成的SageMaker端点
+      pt_BR: sagemaker endpoint for tts
+    human_description:
+      en_US: sagemaker endpoint for tts
+      zh_Hans: 语音生成的SageMaker端点
+      pt_BR: sagemaker endpoint for tts
+    llm_description: sagemaker endpoint for tts
+    form: form
+  - name: tts_text
+    type: string
+    required: true
+    label:
+      en_US: tts text
+      zh_Hans: 语音合成原文
+      pt_BR: tts text
+    human_description:
+      en_US: tts text
+      zh_Hans: 语音合成原文
+      pt_BR: tts text
+    llm_description: tts text
+    form: llm
+  - name: tts_infer_type
+    type: select
+    required: false
+    label:
+      en_US: tts infer type
+      zh_Hans: 合成方式
+      pt_BR: tts infer type
+    human_description:
+      en_US: tts infer type
+      zh_Hans: 合成方式
+      pt_BR: tts infer type
+    llm_description: tts infer type
+    options:
+      - value: PresetVoice
+        label:
+          en_US: preset voice
+          zh_Hans: 预置音色
+      - value: CloneVoice
+        label:
+          en_US: clone voice
+          zh_Hans: 克隆音色
+      - value: CloneVoice_CrossLingual
+        label:
+          en_US: clone crossLingual voice
+          zh_Hans: 克隆音色(跨语言)
+      - value: InstructVoice
+        label:
+          en_US: instruct voice
+          zh_Hans: 指令音色
+    form: form
+  - name: voice
+    type: select
+    required: false
+    label:
+      en_US: preset voice
+      zh_Hans: 预置音色
+      pt_BR: preset voice
+    human_description:
+      en_US: preset voice
+      zh_Hans: 预置音色
+      pt_BR: preset voice
+    llm_description: preset voice
+    options:
+      - value: 中文男
+        label:
+          en_US: zh-cn male
+          zh_Hans: 中文男
+      - value: 中文女
+        label:
+          en_US: zh-cn female
+          zh_Hans: 中文女
+      - value: 粤语女
+        label:
+          en_US: zh-TW female
+          zh_Hans: 粤语女
+    form: form
+  - name: mock_voice_audio
+    type: string
+    required: false
+    label:
+      en_US: clone voice link
+      zh_Hans: 克隆音频链接
+      pt_BR: clone voice link
+    human_description:
+      en_US: clone voice link
+      zh_Hans: 克隆音频链接
+      pt_BR: clone voice link
+    llm_description: clone voice link
+    form: llm
+  - name: mock_voice_text
+    type: string
+    required: false
+    label:
+      en_US: text of clone voice
+      zh_Hans: 克隆音频对应文本
+      pt_BR: text of clone voice
+    human_description:
+      en_US: text of clone voice
+      zh_Hans: 克隆音频对应文本
+      pt_BR: text of clone voice
+    llm_description: text of clone voice
+    form: llm
+  - name: voice_instruct_prompt
+    type: string
+    required: false
+    label:
+      en_US: instruct prompt for voice
+      zh_Hans: 音色指令文本
+      pt_BR: instruct prompt for voice
+    human_description:
+      en_US: instruct prompt for voice
+      zh_Hans: 音色指令文本
+      pt_BR: instruct prompt for voice
+    llm_description: instruct prompt for voice
+    form: llm
+  - name: aws_region
+    type: string
+    required: false
+    label:
+      en_US: region of sagemaker endpoint
+      zh_Hans: SageMaker 端点所在的region
+      pt_BR: region of sagemaker endpoint
+    human_description:
+      en_US: region of sagemaker endpoint
+      zh_Hans: SageMaker 端点所在的region
+      pt_BR: region of sagemaker endpoint
+    llm_description: region of sagemaker endpoint
+    form: form

+ 260 - 15
api/poetry.lock

@@ -520,22 +520,22 @@ files = [
 
 [[package]]
 name = "attrs"
-version = "24.2.0"
+version = "23.2.0"
 description = "Classes Without Boilerplate"
 optional = false
 python-versions = ">=3.7"
 files = [
-    {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"},
-    {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"},
+    {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
+    {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
 ]
 
 [package.extras]
-benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
-cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
-dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
-docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
-tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
-tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
+cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
+dev = ["attrs[tests]", "pre-commit"]
+docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
+tests = ["attrs[tests-no-zope]", "zope-interface"]
+tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
+tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
 
 [[package]]
 name = "authlib"
@@ -1719,6 +1719,17 @@ lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"]
 numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"]
 zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"]
 
+[[package]]
+name = "cloudpickle"
+version = "2.2.1"
+description = "Extended pickling support for Python objects"
+optional = false
+python-versions = ">=3.6"
+files = [
+    {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"},
+    {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"},
+]
+
 [[package]]
 name = "cloudscraper"
 version = "1.2.71"
@@ -2151,6 +2162,21 @@ wrapt = ">=1.10,<2"
 [package.extras]
 dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
 
+[[package]]
+name = "dill"
+version = "0.3.8"
+description = "serialize all of Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
+    {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
+]
+
+[package.extras]
+graph = ["objgraph (>=1.7.2)"]
+profile = ["gprof2dot (>=2022.7.29)"]
+
 [[package]]
 name = "distro"
 version = "1.9.0"
@@ -2162,6 +2188,28 @@ files = [
     {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
 ]
 
+[[package]]
+name = "docker"
+version = "7.1.0"
+description = "A Python library for the Docker Engine API."
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"},
+    {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"},
+]
+
+[package.dependencies]
+pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""}
+requests = ">=2.26.0"
+urllib3 = ">=1.26.0"
+
+[package.extras]
+dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"]
+docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"]
+ssh = ["paramiko (>=2.4.3)"]
+websockets = ["websocket-client (>=1.3.0)"]
+
 [[package]]
 name = "docstring-parser"
 version = "0.16"
@@ -3309,6 +3357,21 @@ typing-extensions = "*"
 [package.extras]
 dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
 
+[[package]]
+name = "google-pasta"
+version = "0.2.0"
+description = "pasta is an AST-based Python refactoring library"
+optional = false
+python-versions = "*"
+files = [
+    {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"},
+    {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"},
+    {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"},
+]
+
+[package.dependencies]
+six = "*"
+
 [[package]]
 name = "google-resumable-media"
 version = "2.7.2"
@@ -3930,22 +3993,22 @@ files = [
 
 [[package]]
 name = "importlib-metadata"
-version = "8.4.0"
+version = "6.11.0"
 description = "Read metadata from Python packages"
 optional = false
 python-versions = ">=3.8"
 files = [
-    {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"},
-    {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"},
+    {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"},
+    {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"},
 ]
 
 [package.dependencies]
 zipp = ">=0.5"
 
 [package.extras]
-doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
 perf = ["ipython"]
-test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"]
 
 [[package]]
 name = "importlib-resources"
@@ -4929,6 +4992,22 @@ files = [
 [package.extras]
 test = ["mypy (>=1.0)", "pytest (>=7.0.0)"]
 
+[[package]]
+name = "mock"
+version = "4.0.3"
+description = "Rolling backport of unittest.mock for all Pythons"
+optional = false
+python-versions = ">=3.6"
+files = [
+    {file = "mock-4.0.3-py3-none-any.whl", hash = "sha256:122fcb64ee37cfad5b3f48d7a7d51875d7031aaf3d8be7c42e2bee25044eee62"},
+    {file = "mock-4.0.3.tar.gz", hash = "sha256:7d3fbbde18228f4ff2f1f119a45cdffa458b4c0dee32eb4d2bb2f82554bac7bc"},
+]
+
+[package.extras]
+build = ["blurb", "twine", "wheel"]
+docs = ["sphinx"]
+test = ["pytest (<5.4)", "pytest-cov"]
+
 [[package]]
 name = "monotonic"
 version = "1.6"
@@ -5128,6 +5207,30 @@ files = [
     {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
 ]
 
+[[package]]
+name = "multiprocess"
+version = "0.70.16"
+description = "better multiprocessing and multithreading in Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
+    {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
+    {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
+    {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
+    {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
+    {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
+    {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
+    {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
+    {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
+    {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
+    {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
+    {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
+]
+
+[package.dependencies]
+dill = ">=0.3.8"
+
 [[package]]
 name = "multitasking"
 version = "0.0.11"
@@ -5955,6 +6058,23 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
 test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
 xml = ["lxml (>=4.9.2)"]
 
+[[package]]
+name = "pathos"
+version = "0.3.2"
+description = "parallel graph management and execution in heterogeneous computing"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "pathos-0.3.2-py3-none-any.whl", hash = "sha256:d669275e6eb4b3fbcd2846d7a6d1bba315fe23add0c614445ba1408d8b38bafe"},
+    {file = "pathos-0.3.2.tar.gz", hash = "sha256:4f2a42bc1e10ccf0fe71961e7145fc1437018b6b21bd93b2446abc3983e49a7a"},
+]
+
+[package.dependencies]
+dill = ">=0.3.8"
+multiprocess = ">=0.70.16"
+pox = ">=0.3.4"
+ppft = ">=1.7.6.8"
+
 [[package]]
 name = "peewee"
 version = "3.17.6"
@@ -6196,6 +6316,31 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"]
 sentry = ["django", "sentry-sdk"]
 test = ["coverage", "django", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"]
 
+[[package]]
+name = "pox"
+version = "0.3.4"
+description = "utilities for filesystem exploration and automated builds"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "pox-0.3.4-py3-none-any.whl", hash = "sha256:651b8ae8a7b341b7bfd267f67f63106daeb9805f1ac11f323d5280d2da93fdb6"},
+    {file = "pox-0.3.4.tar.gz", hash = "sha256:16e6eca84f1bec3828210b06b052adf04cf2ab20c22fd6fbef5f78320c9a6fed"},
+]
+
+[[package]]
+name = "ppft"
+version = "1.7.6.8"
+description = "distributed and parallel Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "ppft-1.7.6.8-py3-none-any.whl", hash = "sha256:de2dd4b1b080923dd9627fbdea52649fd741c752fce4f3cf37e26f785df23d9b"},
+    {file = "ppft-1.7.6.8.tar.gz", hash = "sha256:76a429a7d7b74c4d743f6dba8351e58d62b6432ed65df9fe204790160dab996d"},
+]
+
+[package.extras]
+dill = ["dill (>=0.3.8)"]
+
 [[package]]
 name = "primp"
 version = "0.6.1"
@@ -8004,6 +8149,84 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
 testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
 torch = ["safetensors[numpy]", "torch (>=1.10)"]
 
+[[package]]
+name = "sagemaker"
+version = "2.231.0"
+description = "Open source library for training and deploying models on Amazon SageMaker."
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "sagemaker-2.231.0-py3-none-any.whl", hash = "sha256:5b6d84484a58c6ac8b22af42c6c5e0ea3c5f42d719345fe6aafba42f93635000"},
+    {file = "sagemaker-2.231.0.tar.gz", hash = "sha256:d49ee9c35725832dd9810708938af723201b831e82924a3a6ac1c4260a3d8239"},
+]
+
+[package.dependencies]
+attrs = ">=23.1.0,<24"
+boto3 = ">=1.34.142,<2.0"
+cloudpickle = "2.2.1"
+docker = "*"
+google-pasta = "*"
+importlib-metadata = ">=1.4.0,<7.0"
+jsonschema = "*"
+numpy = ">=1.9.0,<2.0"
+packaging = ">=20.0"
+pandas = "*"
+pathos = "*"
+platformdirs = "*"
+protobuf = ">=3.12,<5.0"
+psutil = "*"
+pyyaml = ">=6.0,<7.0"
+requests = "*"
+sagemaker-core = ">=1.0.0,<2.0.0"
+schema = "*"
+smdebug-rulesconfig = "1.0.1"
+tblib = ">=1.7.0,<4"
+tqdm = "*"
+urllib3 = ">=1.26.8,<3.0.0"
+
+[package.extras]
+all = ["accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "fastapi (>=0.111.0)", "nest-asyncio", "pyspark (==3.3.1)", "pyyaml (>=5.4.1,<7)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)"]
+feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3-3"]
+huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "fastapi (>=0.111.0)", "nest-asyncio", "sagemaker-schema-inference-artifacts (>=0.0.5)", "uvicorn (>=0.30.1)"]
+local = ["docker (>=5.0.2,<8.0.0)", "pyyaml (>=5.4.1,<7)", "urllib3 (>=1.26.8,<3.0.0)"]
+scipy = ["scipy (==1.10.1)"]
+test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "build[virtualenv] (==1.2.1)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "fastapi (>=0.111.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "jinja2 (==3.1.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "nest-asyncio", "numpy (>=1.24.0)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "pyyaml (==6.0)", "pyyaml (>=5.4.1,<7)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)", "xgboost (>=1.6.2,<=1.7.6)"]
+
+[[package]]
+name = "sagemaker-core"
+version = "1.0.2"
+description = "An python package for sagemaker core functionalities"
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "sagemaker_core-1.0.2-py3-none-any.whl", hash = "sha256:ce8d38a4a32efa83e4bc037a8befc7e29f87cd3eaf99acc4472b607f75a0f45a"},
+    {file = "sagemaker_core-1.0.2.tar.gz", hash = "sha256:8fb942aac5e7ed928dab512ffe6facf8c6bdd4595df63c59c0bd0795ea434f8d"},
+]
+
+[package.dependencies]
+boto3 = ">=1.34.0,<2.0.0"
+importlib-metadata = ">=1.4.0,<7.0"
+jsonschema = "<5.0.0"
+mock = ">4.0,<5.0"
+platformdirs = ">=4.0.0,<5.0.0"
+pydantic = ">=1.7.0,<3.0.0"
+PyYAML = ">=6.0,<7.0"
+rich = ">=13.0.0,<14.0.0"
+
+[package.extras]
+codegen = ["black (>=24.3.0,<25.0.0)", "pandas (>=2.0.0,<3.0.0)", "pylint (>=3.0.0,<4.0.0)", "pytest (>=8.0.0,<9.0.0)"]
+
+[[package]]
+name = "schema"
+version = "0.7.7"
+description = "Simple data validation library"
+optional = false
+python-versions = "*"
+files = [
+    {file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"},
+    {file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"},
+]
+
 [[package]]
 name = "scikit-learn"
 version = "1.5.1"
@@ -8276,6 +8499,17 @@ files = [
     {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
 ]
 
+[[package]]
+name = "smdebug-rulesconfig"
+version = "1.0.1"
+description = "SMDebug RulesConfig"
+optional = false
+python-versions = ">=2.7"
+files = [
+    {file = "smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl", hash = "sha256:104da3e6931ecf879dfc687ca4bbb3bee5ea2bc27f4478e9dbb3ee3655f1ae61"},
+    {file = "smdebug_rulesconfig-1.0.1.tar.gz", hash = "sha256:7a19e6eb2e6bcfefbc07e4a86ef7a88f32495001a038bf28c7d8e77ab793fcd6"},
+]
+
 [[package]]
 name = "sniffio"
 version = "1.3.1"
@@ -8473,6 +8707,17 @@ files = [
 [package.extras]
 widechars = ["wcwidth"]
 
+[[package]]
+name = "tblib"
+version = "3.0.0"
+description = "Traceback serialization library."
+optional = false
+python-versions = ">=3.8"
+files = [
+    {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"},
+    {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"},
+]
+
 [[package]]
 name = "tcvectordb"
 version = "1.3.2"
@@ -10126,4 +10371,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
-content-hash = "78c7db0bf525a72f4c8309e3363304d1a0a23cf0a6836bfb974a38a1fcde9158"
+content-hash = "c3c637d643f4dcb3e35d0e7f2a3a4fbaf2a730512a4ca31adce5884c94f07f57"

+ 1 - 0
api/pyproject.toml

@@ -113,6 +113,7 @@ azure-identity = "1.16.1"
 azure-storage-blob = "12.13.0"
 beautifulsoup4 = "4.12.2"
 boto3 = "1.34.148"
+sagemaker = "2.231.0"
 bs4 = "~0.0.1"
 cachetools = "~5.3.0"
 celery = "~5.3.6"