소스 검색

fix: add context_size and max_chunks to Tongyi embedding to resolve issue #7189 (#7227)

Onelevenvy 8 달 전
부모
커밋
0f59d76997

+ 5 - 0
api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml

@@ -2,3 +2,8 @@ model: text-embedding-v1
 model_type: text-embedding
 model_properties:
   context_size: 2048
+  max_chunks: 25
+pricing:
+  input: "0.0007"
+  unit: "0.001"
+  currency: RMB

+ 5 - 0
api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml

@@ -2,3 +2,8 @@ model: text-embedding-v2
 model_type: text-embedding
 model_properties:
   context_size: 2048
+  max_chunks: 25
+pricing:
+  input: "0.0007"
+  unit: "0.001"
+  currency: RMB

+ 52 - 19
api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py

@@ -2,6 +2,7 @@ import time
 from typing import Optional
 
 import dashscope
+import numpy as np
 
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import (
@@ -21,11 +22,11 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
     """
 
     def _invoke(
-            self,
-            model: str,
-            credentials: dict,
-            texts: list[str],
-            user: Optional[str] = None,
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model
@@ -37,16 +38,44 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
         :return: embeddings result
         """
         credentials_kwargs = self._to_credential_kwargs(credentials)
-        embeddings, embedding_used_tokens = self.embed_documents(
-            credentials_kwargs=credentials_kwargs,
-            model=model,
-            texts=texts
-        )
 
+        context_size = self._get_context_size(model, credentials)
+        max_chunks = self._get_max_chunks(model, credentials)
+        inputs = []
+        indices = []
+        used_tokens = 0
+
+        for i, text in enumerate(texts):
+
+            # Here token count is only an approximation based on the GPT2 tokenizer
+            num_tokens = self._get_num_tokens_by_gpt2(text)
+
+            if num_tokens >= context_size:
+                cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
+                # if num tokens is larger than context length, only use the start
+                inputs.append(text[0:cutoff])
+            else:
+                inputs.append(text)
+            indices += [i]
+
+        batched_embeddings = []
+        _iter = range(0, len(inputs), max_chunks)
+
+        for i in _iter:
+            embeddings_batch, embedding_used_tokens = self.embed_documents(
+                credentials_kwargs=credentials_kwargs,
+                model=model,
+                texts=inputs[i : i + max_chunks],
+            )
+            used_tokens += embedding_used_tokens
+            batched_embeddings += embeddings_batch
+
+        # calc usage
+        usage = self._calc_response_usage(
+            model=model, credentials=credentials, tokens=used_tokens
+        )
         return TextEmbeddingResult(
-            embeddings=embeddings,
-            usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens),
-            model=model
+            embeddings=batched_embeddings, usage=usage, model=model
         )
 
     def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
@@ -79,12 +108,16 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
             credentials_kwargs = self._to_credential_kwargs(credentials)
 
             # call embedding model
-            self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"])
+            self.embed_documents(
+                credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]
+            )
         except Exception as ex:
             raise CredentialsValidateFailedError(str(ex))
 
     @staticmethod
-    def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]:
+    def embed_documents(
+        credentials_kwargs: dict, model: str, texts: list[str]
+    ) -> tuple[list[list[float]], int]:
         """Call out to Tongyi's embedding endpoint.
 
         Args:
@@ -102,7 +135,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
                 api_key=credentials_kwargs["dashscope_api_key"],
                 model=model,
                 input=text,
-                text_type="document"
+                text_type="document",
             )
             data = response.output["embeddings"][0]
             embeddings.append(data["embedding"])
@@ -111,7 +144,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
         return [list(map(float, e)) for e in embeddings], embedding_used_tokens
 
     def _calc_response_usage(
-            self, model: str, credentials: dict, tokens: int
+        self, model: str, credentials: dict, tokens: int
     ) -> EmbeddingUsage:
         """
         Calculate response usage
@@ -125,7 +158,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
             model=model,
             credentials=credentials,
             price_type=PriceType.INPUT,
-            tokens=tokens
+            tokens=tokens,
         )
 
         # transform usage
@@ -136,7 +169,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
             price_unit=input_price_info.unit,
             total_price=input_price_info.total_amount,
             currency=input_price_info.currency,
-            latency=time.perf_counter() - self.started_at
+            latency=time.perf_counter() - self.started_at,
         )
 
         return usage