Kaynağa Gözat

Add suuport for AWS Bedrock Cohere embedding (#3444)

kerlion 1 yıl önce
ebeveyn
işleme
200010be19

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml

@@ -1 +1,3 @@
 - amazon.titan-embed-text-v1
+- cohere.embed-english-v3
+- cohere.embed-multilingual-v3

+ 8 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml

@@ -0,0 +1,8 @@
+model: cohere.embed-english-v3
+model_type: text-embedding
+model_properties:
+  context_size: 512
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 8 - 0
api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml

@@ -0,0 +1,8 @@
+model: cohere.embed-multilingual-v3
+model_type: text-embedding
+model_properties:
+  context_size: 512
+pricing:
+  input: '0.1'
+  unit: '0.000001'
+  currency: USD

+ 40 - 15
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -1,4 +1,5 @@
 import json
+import logging
 import time
 from typing import Optional
 
@@ -24,6 +25,7 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
 
+logger = logging.getLogger(__name__)
 
 class BedrockTextEmbeddingModel(TextEmbeddingModel):
 
@@ -53,17 +55,19 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
 
         embeddings = []
         token_usage = 0
-
+        
         model_prefix = model.split('.')[0]
-        if model_prefix == "amazon":
-            for text in texts:
-                body = {
-                    "inputText": text,
-                }
-                response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
-                embeddings.extend([response_body.get('embedding')])
-                token_usage += response_body.get('inputTextTokenCount')
-            result = TextEmbeddingResult(
+         
+        if model_prefix == "amazon" :
+           for text in texts:
+              body = {
+                 "inputText": text,
+              }
+              response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+              embeddings.extend([response_body.get('embedding')])
+              token_usage += response_body.get('inputTextTokenCount')
+           logger.warning(f'Total Tokens: {token_usage}')
+           result = TextEmbeddingResult(
                 model=model,
                 embeddings=embeddings,
                 usage=self._calc_response_usage(
@@ -71,11 +75,32 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
                     credentials=credentials,
                     tokens=token_usage
                 )
-            )
-        else:
-            raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
-
-        return result
+           )
+           return result
+           
+        if model_prefix == "cohere" :
+           input_type = 'search_document' if len(texts) > 1 else 'search_query'
+           for text in texts:
+              body = {
+                 "texts": [text],
+                 "input_type": input_type,
+              }
+              response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+              embeddings.extend(response_body.get('embeddings'))
+              token_usage += len(text)
+           result = TextEmbeddingResult(
+                model=model,
+                embeddings=embeddings,
+                usage=self._calc_response_usage(
+                    model=model,
+                    credentials=credentials,
+                    tokens=token_usage
+                )
+           )
+           return result
+        
+        #others
+        raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
 
 
     def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: