Forráskód Böngészése

feat: Add support for embed file with AWS Bedrock Titan Model (#3377)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
longzhihun 1 éve
szülő
commit
f7a417fdb4

+ 1 - 0
api/core/model_runtime/model_providers/bedrock/bedrock.yaml

@@ -15,6 +15,7 @@ help:
     en_US: https://console.aws.amazon.com/
 supported_model_types:
   - llm
+  - text-embedding
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 0 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/__init__.py


+ 1 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml

@@ -0,0 +1 @@
+- amazon.titan-embed-text-v1

+ 8 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml

@@ -0,0 +1,8 @@
+model: amazon.titan-embed-text-v1
+model_type: text-embedding
+model_properties:
+  context_size: 8192
+pricing:
+  input: '0.0001'
+  unit: '0.001'
+  currency: USD

+ 209 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -0,0 +1,209 @@
+import json
+import time
+from typing import Optional
+
+import boto3
+from botocore.config import Config
+from botocore.exceptions import (
+    ClientError,
+    EndpointConnectionError,
+    NoRegionError,
+    ServiceNotInRegionError,
+    UnknownServiceError,
+)
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+
+
+class BedrockTextEmbeddingModel(TextEmbeddingModel):
+
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        client_config = Config(
+            region_name=credentials["aws_region"]
+        )
+
+        bedrock_runtime = boto3.client(
+            service_name='bedrock-runtime',
+            config=client_config,
+            aws_access_key_id=credentials["aws_access_key_id"],
+            aws_secret_access_key=credentials["aws_secret_access_key"]
+        )
+
+        embeddings = []
+        token_usage = 0
+
+        model_prefix = model.split('.')[0]
+        if model_prefix == "amazon":
+            for text in texts:
+                body = {
+                    "inputText": text,
+                }
+                response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+                embeddings.extend([response_body.get('embedding')])
+                token_usage += response_body.get('inputTextTokenCount')
+            result = TextEmbeddingResult(
+                model=model,
+                embeddings=embeddings,
+                usage=self._calc_response_usage(
+                    model=model,
+                    credentials=credentials,
+                    tokens=token_usage
+                )
+            )
+        else:
+            raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
+
+        return result
+
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        num_tokens = 0
+        for text in texts:
+            num_tokens += self._get_num_tokens_by_gpt2(text)
+        return num_tokens
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+    
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
+        The value is the md = genai.GenerativeModel(model)error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke emd = genai.GenerativeModel(model)rror mapping
+        """
+        return {
+            InvokeConnectionError: [],
+            InvokeServerUnavailableError: [],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [],
+            InvokeBadRequestError: []
+        }
+    
+    def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
+        """
+        Create payload for bedrock api call depending on model provider
+        """
+        payload = dict()
+
+        if model_prefix == "amazon":
+            payload['inputText'] = texts
+
+    
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+    
+    def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
+        """
+        Map client error to invoke error
+
+        :param error_code: error code
+        :param error_msg: error message
+        :return: invoke error
+        """
+
+        if error_code == "AccessDeniedException":
+            return InvokeAuthorizationError(error_msg)
+        elif error_code in ["ResourceNotFoundException", "ValidationException"]:
+            return InvokeBadRequestError(error_msg)
+        elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
+            return InvokeRateLimitError(error_msg)
+        elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
+            return InvokeServerUnavailableError(error_msg)
+        elif error_code == "ModelStreamErrorException":
+            return InvokeConnectionError(error_msg)
+
+        return InvokeError(error_msg)
+    
+
+    def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
+        accept = 'application/json' 
+        content_type = 'application/json'
+        try:
+            response = bedrock_runtime.invoke_model(
+                body=json.dumps(body), 
+                modelId=model, 
+                accept=accept, 
+                contentType=content_type
+            )
+            response_body = json.loads(response.get('body').read().decode('utf-8'))
+            return response_body
+        except ClientError as ex:
+            error_code = ex.response['Error']['Code']
+            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
+            raise self._map_client_to_invoke_error(error_code, full_error_msg)
+        
+        except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
+            raise InvokeConnectionError(str(ex))
+
+        except UnknownServiceError as ex:
+            raise InvokeServerUnavailableError(str(ex))
+
+        except Exception as ex:
+            raise InvokeError(str(ex))