|
@@ -1,4 +1,5 @@
|
|
|
import json
|
|
|
+import logging
|
|
|
import time
|
|
|
from typing import Optional
|
|
|
|
|
@@ -24,6 +25,7 @@ from core.model_runtime.errors.invoke import (
|
|
|
)
|
|
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|
|
|
|
@@ -53,17 +55,19 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|
|
|
|
|
embeddings = []
|
|
|
token_usage = 0
|
|
|
-
|
|
|
+
|
|
|
model_prefix = model.split('.')[0]
|
|
|
- if model_prefix == "amazon":
|
|
|
- for text in texts:
|
|
|
- body = {
|
|
|
- "inputText": text,
|
|
|
- }
|
|
|
- response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
|
|
- embeddings.extend([response_body.get('embedding')])
|
|
|
- token_usage += response_body.get('inputTextTokenCount')
|
|
|
- result = TextEmbeddingResult(
|
|
|
+
|
|
|
+ if model_prefix == "amazon" :
|
|
|
+ for text in texts:
|
|
|
+ body = {
|
|
|
+ "inputText": text,
|
|
|
+ }
|
|
|
+ response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
|
|
+ embeddings.extend([response_body.get('embedding')])
|
|
|
+ token_usage += response_body.get('inputTextTokenCount')
|
|
|
+ logger.warning(f'Total Tokens: {token_usage}')
|
|
|
+ result = TextEmbeddingResult(
|
|
|
model=model,
|
|
|
embeddings=embeddings,
|
|
|
usage=self._calc_response_usage(
|
|
@@ -71,11 +75,32 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|
|
credentials=credentials,
|
|
|
tokens=token_usage
|
|
|
)
|
|
|
- )
|
|
|
- else:
|
|
|
- raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
|
|
-
|
|
|
- return result
|
|
|
+ )
|
|
|
+ return result
|
|
|
+
|
|
|
+ if model_prefix == "cohere" :
|
|
|
+ input_type = 'search_document' if len(texts) > 1 else 'search_query'
|
|
|
+ for text in texts:
|
|
|
+ body = {
|
|
|
+ "texts": [text],
|
|
|
+ "input_type": input_type,
|
|
|
+ }
|
|
|
+ response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
|
|
+ embeddings.extend(response_body.get('embeddings'))
|
|
|
+ token_usage += len(text)
|
|
|
+ result = TextEmbeddingResult(
|
|
|
+ model=model,
|
|
|
+ embeddings=embeddings,
|
|
|
+ usage=self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ tokens=token_usage
|
|
|
+ )
|
|
|
+ )
|
|
|
+ return result
|
|
|
+
|
|
|
+ #others
|
|
|
+ raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
|
|
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|