Forráskód Böngészése

Enhancement: add model provider - Amazon Sagemaker (#6255)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
Co-authored-by: crazywoola <427733928@qq.com>
ybalbert001 9 hónapja
szülő
commit
4a026fa352

+ 0 - 0
api/core/model_runtime/model_providers/sagemaker/__init__.py


BIN
api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png


BIN
api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png


+ 0 - 0
api/core/model_runtime/model_providers/sagemaker/llm/__init__.py


+ 238 - 0
api/core/model_runtime/model_providers/sagemaker/llm/llm.py

@@ -0,0 +1,238 @@
+import json
+import logging
+from collections.abc import Generator
+from typing import Any, Optional, Union
+
+import boto3
+
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+logger = logging.getLogger(__name__)
+
+
+class SageMakerLargeLanguageModel(LargeLanguageModel):
+    """
+    Model class for Cohere large language model.
+    """
+    sagemaker_client: Any = None
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        """
+        Invoke large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :param stream: is stream response
+        :param user: unique user id
+        :return: full response or stream response chunk generator result
+        """
+        # get model mode
+        model_mode = self.get_model_mode(model, credentials)
+
+        if not self.sagemaker_client:
+            access_key = credentials.get('access_key')
+            secret_key = credentials.get('secret_key')
+            aws_region = credentials.get('aws_region')
+            if aws_region:
+                if access_key and secret_key:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime", 
+                        aws_access_key_id=access_key,
+                        aws_secret_access_key=secret_key,
+                        region_name=aws_region)
+                else:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+            else:
+                self.sagemaker_client = boto3.client("sagemaker-runtime")
+
+
+        sagemaker_endpoint = credentials.get('sagemaker_endpoint')
+        response_model = self.sagemaker_client.invoke_endpoint(
+                    EndpointName=sagemaker_endpoint,
+                    Body=json.dumps(
+                    {
+                        "inputs": prompt_messages[0].content,
+                        "parameters": { "stop" : stop},
+                        "history" : []
+                    }
+                    ),
+                    ContentType="application/json",
+                )
+
+        assistant_text = response_model['Body'].read().decode('utf8')
+
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(
+            content=assistant_text
+        )
+
+        usage = self._calc_response_usage(model, credentials, 0, 0)
+
+        response = LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage
+        )
+
+        return response
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return:
+        """
+        # get model mode
+        model_mode = self.get_model_mode(model)
+
+        try:
+            return 0
+        except Exception as e:
+            raise self._transform_invoke_error(e)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            # get model mode
+            model_mode = self.get_model_mode(model)
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                InvokeBadRequestError,
+                KeyError,
+                ValueError
+            ]
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        rules = [
+            ParameterRule(
+                name='temperature',
+                type=ParameterType.FLOAT,
+                use_template='temperature',
+                label=I18nObject(
+                    zh_Hans='温度',
+                    en_US='Temperature'
+                ),
+            ),
+            ParameterRule(
+                name='top_p',
+                type=ParameterType.FLOAT,
+                use_template='top_p',
+                label=I18nObject(
+                    zh_Hans='Top P',
+                    en_US='Top P'
+                )
+            ),
+            ParameterRule(
+                name='max_tokens',
+                type=ParameterType.INT,
+                use_template='max_tokens',
+                min=1,
+                max=credentials.get('context_length', 2048),
+                default=512,
+                label=I18nObject(
+                    zh_Hans='最大生成长度',
+                    en_US='Max Tokens'
+                )
+            )
+        ]
+
+        completion_type = LLMMode.value_of(credentials["mode"])
+
+        if completion_type == LLMMode.CHAT:
+            print(f"completion_type : {LLMMode.CHAT.value}") 
+
+        if completion_type == LLMMode.COMPLETION:
+            print(f"completion_type : {LLMMode.COMPLETION.value}") 
+
+        features = []
+
+        support_function_call = credentials.get('support_function_call', False)
+        if support_function_call:
+            features.append(ModelFeature.TOOL_CALL)
+
+        support_vision = credentials.get('support_vision', False)
+        if support_vision:
+            features.append(ModelFeature.VISION)
+
+        context_length = credentials.get('context_length', 2048)
+
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.LLM,
+            features=features,
+            model_properties={
+                ModelPropertyKey.MODE: completion_type,
+                ModelPropertyKey.CONTEXT_SIZE: context_length
+            },
+            parameter_rules=rules
+        )
+
+        return entity

+ 0 - 0
api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py


+ 190 - 0
api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py

@@ -0,0 +1,190 @@
+import json
+import logging
+from typing import Any, Optional
+
+import boto3
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+logger = logging.getLogger(__name__)
+
+class SageMakerRerankModel(RerankModel):
+    """
+    Model class for Cohere rerank model.
+    """
+    sagemaker_client: Any = None
+
+    def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
+        inputs = [query_input]*len(docs)
+        response_model = self.sagemaker_client.invoke_endpoint(
+            EndpointName=rerank_endpoint,
+            Body=json.dumps(
+                {
+                    "inputs": inputs,
+                    "docs": docs
+                }
+            ),
+            ContentType="application/json",
+        )
+        json_str = response_model['Body'].read().decode('utf8')
+        json_obj = json.loads(json_str)
+        scores = json_obj['scores']
+        return scores if isinstance(scores, list) else [scores]
+
+
+    def _invoke(self, model: str, credentials: dict,
+                query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+                user: Optional[str] = None) \
+            -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id
+        :return: rerank result
+        """
+        line = 0
+        try:
+            if len(docs) == 0:
+                return RerankResult(
+                    model=model,
+                    docs=docs
+                )
+
+            line = 1
+            if not self.sagemaker_client:
+                access_key = credentials.get('aws_access_key_id')
+                secret_key = credentials.get('aws_secret_access_key')
+                aws_region = credentials.get('aws_region')
+                if aws_region:
+                    if access_key and secret_key:
+                        self.sagemaker_client = boto3.client("sagemaker-runtime", 
+                            aws_access_key_id=access_key,
+                            aws_secret_access_key=secret_key,
+                            region_name=aws_region)
+                    else:
+                        self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+                else:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime")
+
+            line = 2
+
+            sagemaker_endpoint = credentials.get('sagemaker_endpoint')
+            candidate_docs = []
+
+            scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
+            for idx in range(len(scores)):
+                candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
+
+            sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
+
+            line = 3
+            rerank_documents = []
+            for idx, result in enumerate(candidate_docs):
+                rerank_document = RerankDocument(
+                    index=idx,
+                    text=result.get('content'),
+                    score=result.get('score', -100.0)
+                )
+
+                if score_threshold is not None:
+                    if rerank_document.score >= score_threshold:
+                        rerank_documents.append(rerank_document)
+                else:
+                    rerank_documents.append(rerank_document)
+
+            return RerankResult(
+                model=model,
+                docs=rerank_documents
+            )
+
+        except Exception as e:
+            logger.exception(f'Exception {e}, line : {line}')
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._invoke(
+                model=model,
+                credentials=credentials,
+                query="What is the capital of the United States?",
+                docs=[
+                    "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                    "Census, Carson City had a population of 55,274.",
+                    "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                    "are a political division controlled by the United States. Its capital is Saipan.",
+                ],
+                score_threshold=0.8
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                InvokeBadRequestError,
+                KeyError,
+                ValueError
+            ]
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.RERANK,
+            model_properties={ },
+            parameter_rules=[]
+        )
+
+        return entity

+ 17 - 0
api/core/model_runtime/model_providers/sagemaker/sagemaker.py

@@ -0,0 +1,17 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class SageMakerProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        pass

+ 125 - 0
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml

@@ -0,0 +1,125 @@
+provider: sagemaker
+label:
+  zh_Hans: Sagemaker
+  en_US: Sagemaker
+icon_small:
+  en_US: icon_s_en.png
+icon_large:
+  en_US: icon_l_en.png
+description:
+  en_US: Customized model on Sagemaker
+  zh_Hans: Sagemaker上的私有化部署的模型
+background: "#ECE9E3"
+help:
+  title:
+    en_US: How to deploy customized model on Sagemaker
+    zh_Hans: 如何在Sagemaker上的私有化部署的模型
+  url:
+    en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
+    zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
+supported_model_types:
+  - llm
+  - text-embedding
+  - rerank
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Completion mode
+      type: select
+      required: false
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select completion mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: sagemaker_endpoint
+      label:
+        en_US: sagemaker endpoint
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: 请输出你的Sagemaker推理端点
+        en_US: Enter your Sagemaker Inference endpoint
+    - variable: aws_access_key_id
+      required: false
+      label:
+        en_US: Access Key (If not provided, credentials are obtained from the running environment.)
+        zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
+      type: secret-input
+      placeholder:
+        en_US: Enter your Access Key
+        zh_Hans: 在此输入您的 Access Key
+    - variable: aws_secret_access_key
+      required: false
+      label:
+        en_US: Secret Access Key
+        zh_Hans: Secret Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Secret Access Key
+        zh_Hans: 在此输入您的 Secret Access Key
+    - variable: aws_region
+      required: false
+      label:
+        en_US: AWS Region
+        zh_Hans: AWS 地区
+      type: select
+      default: us-east-1
+      options:
+        - value: us-east-1
+          label:
+            en_US: US East (N. Virginia)
+            zh_Hans: 美国东部 (弗吉尼亚北部)
+        - value: us-west-2
+          label:
+            en_US: US West (Oregon)
+            zh_Hans: 美国西部 (俄勒冈州)
+        - value: ap-southeast-1
+          label:
+            en_US: Asia Pacific (Singapore)
+            zh_Hans: 亚太地区 (新加坡)
+        - value: ap-northeast-1
+          label:
+            en_US: Asia Pacific (Tokyo)
+            zh_Hans: 亚太地区 (东京)
+        - value: eu-central-1
+          label:
+            en_US: Europe (Frankfurt)
+            zh_Hans: 欧洲 (法兰克福)
+        - value: us-gov-west-1
+          label:
+            en_US: AWS GovCloud (US-West)
+            zh_Hans: AWS GovCloud (US-West)
+        - value: ap-southeast-2
+          label:
+            en_US: Asia Pacific (Sydney)
+            zh_Hans: 亚太地区 (悉尼)
+        - value: cn-north-1
+          label:
+            en_US: AWS Beijing (cn-north-1)
+            zh_Hans: 中国北京 (cn-north-1)
+        - value: cn-northwest-1
+          label:
+            en_US: AWS Ningxia (cn-northwest-1)
+            zh_Hans: 中国宁夏 (cn-northwest-1)

+ 0 - 0
api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py


+ 214 - 0
api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py

@@ -0,0 +1,214 @@
+import itertools
+import json
+import logging
+import time
+from typing import Any, Optional
+
+import boto3
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+        InvokeAuthorizationError,
+        InvokeBadRequestError,
+        InvokeConnectionError,
+        InvokeError,
+        InvokeRateLimitError,
+        InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+BATCH_SIZE = 20
+CONTEXT_SIZE=8192
+
+logger = logging.getLogger(__name__)
+
+def batch_generator(generator, batch_size):
+    while True:
+        batch = list(itertools.islice(generator, batch_size))
+        if not batch:
+            break
+        yield batch
+
+class SageMakerEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for Cohere text embedding model.
+    """
+    sagemaker_client: Any = None
+
+    def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
+        response_model = sm_client.invoke_endpoint(
+            EndpointName=endpoint_name,
+            Body=json.dumps(
+                {
+                    "inputs": content_list,
+                    "parameters": {},
+                    "is_query" : False,
+                    "instruction" :  ''
+                }
+            ),
+            ContentType="application/json",
+        )
+        json_str = response_model['Body'].read().decode('utf8')
+        json_obj = json.loads(json_str)
+        embeddings = json_obj['embeddings']
+        return embeddings
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        # get model properties
+        try:
+            line = 1
+            if not self.sagemaker_client:
+                access_key = credentials.get('aws_access_key_id')
+                secret_key = credentials.get('aws_secret_access_key')
+                aws_region = credentials.get('aws_region')
+                if aws_region:
+                    if access_key and secret_key:
+                        self.sagemaker_client = boto3.client("sagemaker-runtime", 
+                            aws_access_key_id=access_key,
+                            aws_secret_access_key=secret_key,
+                            region_name=aws_region)
+                    else:
+                        self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
+                else:
+                    self.sagemaker_client = boto3.client("sagemaker-runtime")
+
+            line = 2
+            sagemaker_endpoint = credentials.get('sagemaker_endpoint')
+
+            line = 3
+            truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
+
+            batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
+            all_embeddings = []
+
+            line = 4
+            for batch in batches:
+                embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
+                all_embeddings.extend(embeddings)
+
+            line = 5
+            # calc usage
+            usage = self._calc_response_usage(
+                model=model,
+                credentials=credentials,
+                tokens=0 # It's not SAAS API, usage is meaningless
+            )
+            line = 6
+
+            return TextEmbeddingResult(
+                embeddings=all_embeddings,
+                usage=usage,
+                model=model
+            )
+
+        except Exception as e:
+            logger.exception(f'Exception {e}, line : {line}')
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        return 0
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            print("validate_credentials ok....")
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        return {
+            InvokeConnectionError: [
+                InvokeConnectionError
+            ],
+            InvokeServerUnavailableError: [
+                InvokeServerUnavailableError
+            ],
+            InvokeRateLimitError: [
+                InvokeRateLimitError
+            ],
+            InvokeAuthorizationError: [
+                InvokeAuthorizationError
+            ],
+            InvokeBadRequestError: [
+                KeyError
+            ]
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.TEXT_EMBEDDING,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
+                ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
+            },
+            parameter_rules=[]
+        )
+
+        return entity

+ 0 - 0
api/tests/integration_tests/model_runtime/sagemaker/__init__.py


+ 19 - 0
api/tests/integration_tests/model_runtime/sagemaker/test_provider.py

@@ -0,0 +1,19 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider
+
+
+def test_validate_provider_credentials():
+    provider = SageMakerProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(
+            credentials={}
+        )
+
+    provider.validate_provider_credentials(
+        credentials={}
+    )

+ 55 - 0
api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py

@@ -0,0 +1,55 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel
+
+
+def test_validate_credentials():
+    model = SageMakerRerankModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='bge-m3-rerank-v2',
+            credentials={
+                "aws_region": os.getenv("AWS_REGION"),
+                "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+                "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+            },
+            query="What is the capital of the United States?",
+            docs=[
+                "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                "Census, Carson City had a population of 55,274.",
+                "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                "are a political division controlled by the United States. Its capital is Saipan.",
+            ],
+            score_threshold=0.8
+        )
+
+
+def test_invoke_model():
+    model = SageMakerRerankModel()
+
+    result = model.invoke(
+        model='bge-m3-rerank-v2',
+        credentials={
+            "aws_region": os.getenv("AWS_REGION"),
+            "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
+            "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
+        },
+        query="What is the capital of the United States?",
+        docs=[
+            "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+            "Census, Carson City had a population of 55,274.",
+            "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+            "are a political division controlled by the United States. Its capital is Saipan.",
+        ],
+        score_threshold=0.8
+    )
+
+    assert isinstance(result, RerankResult)
+    assert len(result.docs) == 1
+    assert result.docs[0].index == 1
+    assert result.docs[0].score >= 0.8

+ 55 - 0
api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py

@@ -0,0 +1,55 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel
+
+
+def test_validate_credentials():
+    model = SageMakerEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='bge-m3',
+            credentials={
+            }
+        )
+
+    model.validate_credentials(
+        model='bge-m3-embedding',
+        credentials={
+        }
+    )
+
+
+def test_invoke_model():
+    model = SageMakerEmbeddingModel()
+
+    result = model.invoke(
+        model='bge-m3-embedding',
+        credentials={
+        },
+        texts=[
+            "hello",
+            "world"
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+
+def test_get_num_tokens():
+    model = SageMakerEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model='bge-m3-embedding',
+        credentials={
+        },
+        texts=[
+        ]
+    )
+
+    assert num_tokens == 0