|
@@ -0,0 +1,209 @@
|
|
|
+import json
|
|
|
+import time
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+import boto3
|
|
|
+from botocore.config import Config
|
|
|
+from botocore.exceptions import (
|
|
|
+ ClientError,
|
|
|
+ EndpointConnectionError,
|
|
|
+ NoRegionError,
|
|
|
+ ServiceNotInRegionError,
|
|
|
+ UnknownServiceError,
|
|
|
+)
|
|
|
+
|
|
|
+from core.model_runtime.entities.model_entities import PriceType
|
|
|
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
|
+from core.model_runtime.errors.invoke import (
|
|
|
+ InvokeAuthorizationError,
|
|
|
+ InvokeBadRequestError,
|
|
|
+ InvokeConnectionError,
|
|
|
+ InvokeError,
|
|
|
+ InvokeRateLimitError,
|
|
|
+ InvokeServerUnavailableError,
|
|
|
+)
|
|
|
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
|
+
|
|
|
+
|
|
|
+class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|
|
+
|
|
|
+
|
|
|
+ def _invoke(self, model: str, credentials: dict,
|
|
|
+ texts: list[str], user: Optional[str] = None) \
|
|
|
+ -> TextEmbeddingResult:
|
|
|
+ """
|
|
|
+ Invoke text embedding model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param texts: texts to embed
|
|
|
+ :param user: unique user id
|
|
|
+ :return: embeddings result
|
|
|
+ """
|
|
|
+ client_config = Config(
|
|
|
+ region_name=credentials["aws_region"]
|
|
|
+ )
|
|
|
+
|
|
|
+ bedrock_runtime = boto3.client(
|
|
|
+ service_name='bedrock-runtime',
|
|
|
+ config=client_config,
|
|
|
+ aws_access_key_id=credentials["aws_access_key_id"],
|
|
|
+ aws_secret_access_key=credentials["aws_secret_access_key"]
|
|
|
+ )
|
|
|
+
|
|
|
+ embeddings = []
|
|
|
+ token_usage = 0
|
|
|
+
|
|
|
+ model_prefix = model.split('.')[0]
|
|
|
+ if model_prefix == "amazon":
|
|
|
+ for text in texts:
|
|
|
+ body = {
|
|
|
+ "inputText": text,
|
|
|
+ }
|
|
|
+ response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
|
|
+ embeddings.extend([response_body.get('embedding')])
|
|
|
+ token_usage += response_body.get('inputTextTokenCount')
|
|
|
+ result = TextEmbeddingResult(
|
|
|
+ model=model,
|
|
|
+ embeddings=embeddings,
|
|
|
+ usage=self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ tokens=token_usage
|
|
|
+ )
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
|
+ """
|
|
|
+ Get number of tokens for given prompt messages
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param texts: texts to embed
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ num_tokens = 0
|
|
|
+ for text in texts:
|
|
|
+ num_tokens += self._get_num_tokens_by_gpt2(text)
|
|
|
+ return num_tokens
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
|
+ """
|
|
|
+ Map model invoke error to unified error
|
|
|
+ The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
|
|
+ The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
|
|
+ which needs to be converted into a unified error type for the caller.
|
|
|
+
|
|
|
+ :return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
|
|
+ """
|
|
|
+ return {
|
|
|
+ InvokeConnectionError: [],
|
|
|
+ InvokeServerUnavailableError: [],
|
|
|
+ InvokeRateLimitError: [],
|
|
|
+ InvokeAuthorizationError: [],
|
|
|
+ InvokeBadRequestError: []
|
|
|
+ }
|
|
|
+
|
|
|
+ def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
|
|
+ """
|
|
|
+ Create payload for bedrock api call depending on model provider
|
|
|
+ """
|
|
|
+ payload = dict()
|
|
|
+
|
|
|
+ if model_prefix == "amazon":
|
|
|
+ payload['inputText'] = texts
|
|
|
+
|
|
|
+
|
|
|
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
|
+ """
|
|
|
+ Calculate response usage
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param tokens: input tokens
|
|
|
+ :return: usage
|
|
|
+ """
|
|
|
+ # get input price info
|
|
|
+ input_price_info = self.get_price(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ price_type=PriceType.INPUT,
|
|
|
+ tokens=tokens
|
|
|
+ )
|
|
|
+
|
|
|
+ # transform usage
|
|
|
+ usage = EmbeddingUsage(
|
|
|
+ tokens=tokens,
|
|
|
+ total_tokens=tokens,
|
|
|
+ unit_price=input_price_info.unit_price,
|
|
|
+ price_unit=input_price_info.unit,
|
|
|
+ total_price=input_price_info.total_amount,
|
|
|
+ currency=input_price_info.currency,
|
|
|
+ latency=time.perf_counter() - self.started_at
|
|
|
+ )
|
|
|
+
|
|
|
+ return usage
|
|
|
+
|
|
|
+ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
|
|
+ """
|
|
|
+ Map client error to invoke error
|
|
|
+
|
|
|
+ :param error_code: error code
|
|
|
+ :param error_msg: error message
|
|
|
+ :return: invoke error
|
|
|
+ """
|
|
|
+
|
|
|
+ if error_code == "AccessDeniedException":
|
|
|
+ return InvokeAuthorizationError(error_msg)
|
|
|
+ elif error_code in ["ResourceNotFoundException", "ValidationException"]:
|
|
|
+ return InvokeBadRequestError(error_msg)
|
|
|
+ elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
|
|
|
+ return InvokeRateLimitError(error_msg)
|
|
|
+ elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
|
|
|
+ return InvokeServerUnavailableError(error_msg)
|
|
|
+ elif error_code == "ModelStreamErrorException":
|
|
|
+ return InvokeConnectionError(error_msg)
|
|
|
+
|
|
|
+ return InvokeError(error_msg)
|
|
|
+
|
|
|
+
|
|
|
+ def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
|
|
|
+ accept = 'application/json'
|
|
|
+ content_type = 'application/json'
|
|
|
+ try:
|
|
|
+ response = bedrock_runtime.invoke_model(
|
|
|
+ body=json.dumps(body),
|
|
|
+ modelId=model,
|
|
|
+ accept=accept,
|
|
|
+ contentType=content_type
|
|
|
+ )
|
|
|
+ response_body = json.loads(response.get('body').read().decode('utf-8'))
|
|
|
+ return response_body
|
|
|
+ except ClientError as ex:
|
|
|
+ error_code = ex.response['Error']['Code']
|
|
|
+ full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
|
|
+ raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
|
|
+
|
|
|
+ except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
|
|
+ raise InvokeConnectionError(str(ex))
|
|
|
+
|
|
|
+ except UnknownServiceError as ex:
|
|
|
+ raise InvokeServerUnavailableError(str(ex))
|
|
|
+
|
|
|
+ except Exception as ex:
|
|
|
+ raise InvokeError(str(ex))
|