ソースを参照

add embedding input type parameter (#8724)

Jyong 7 ヶ月 前
コミット
4669eb24be
33 ファイル変更239 行追加35 行削除
  1. 7 2
      api/core/embedding/cached_embedding.py
  2. 10 0
      api/core/embedding/embedding_constant.py
  3. 6 1
      api/core/model_manager.py
  4. 18 5
      api/core/model_runtime/model_providers/__base/text_embedding_model.py
  5. 7 1
      api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py
  6. 7 1
      api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py
  7. 7 1
      api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py
  8. 7 1
      api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py
  9. 7 1
      api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py
  10. 7 1
      api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py
  11. 7 1
      api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py
  12. 7 1
      api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py
  13. 7 1
      api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py
  14. 7 1
      api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py
  15. 7 1
      api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py
  16. 2 0
      api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py
  17. 7 1
      api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py
  18. 8 1
      api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
  19. 7 1
      api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
  20. 7 1
      api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py
  21. 7 1
      api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py
  22. 7 1
      api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py
  23. 8 1
      api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py
  24. 7 1
      api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py
  25. 7 1
      api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py
  26. 7 1
      api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py
  27. 2 0
      api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py
  28. 9 1
      api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py
  29. 7 1
      api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py
  30. 7 1
      api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
  31. 7 1
      api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py
  32. 8 1
      api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
  33. 7 1
      api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py

+ 7 - 2
api/core/embedding/cached_embedding.py

@@ -5,6 +5,7 @@ from typing import Optional, cast
 import numpy as np
 from sqlalchemy.exc import IntegrityError
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.model_entities import ModelPropertyKey
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
@@ -56,7 +57,9 @@ class CacheEmbedding(Embeddings):
                 for i in range(0, len(embedding_queue_texts), max_chunks):
                     batch_texts = embedding_queue_texts[i : i + max_chunks]
 
-                    embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
+                    embedding_result = self._model_instance.invoke_text_embedding(
+                        texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
+                    )
 
                     for vector in embedding_result.embeddings:
                         try:
@@ -100,7 +103,9 @@ class CacheEmbedding(Embeddings):
             redis_client.expire(embedding_cache_key, 600)
             return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
         try:
-            embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
+            embedding_result = self._model_instance.invoke_text_embedding(
+                texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
+            )
 
             embedding_results = embedding_result.embeddings[0]
             embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()

+ 10 - 0
api/core/embedding/embedding_constant.py

@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class EmbeddingInputType(Enum):
+    """
+    Enum for embedding input type.
+    """
+
+    DOCUMENT = "document"
+    QUERY = "query"

+ 6 - 1
api/core/model_manager.py

@@ -3,6 +3,7 @@ import os
 from collections.abc import Callable, Generator, Sequence
 from typing import IO, Optional, Union, cast
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
 from core.entities.provider_entities import ModelLoadBalancingConfiguration
 from core.errors.error import ProviderTokenNotInitError
@@ -158,12 +159,15 @@ class ModelInstance:
             tools=tools,
         )
 
-    def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
+    def invoke_text_embedding(
+        self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
+    ) -> TextEmbeddingResult:
         """
         Invoke large language model
 
         :param texts: texts to embed
         :param user: unique user id
+        :param input_type: input type
         :return: embeddings result
         """
         if not isinstance(self.model_type_instance, TextEmbeddingModel):
@@ -176,6 +180,7 @@ class ModelInstance:
             credentials=self.credentials,
             texts=texts,
             user=user,
+            input_type=input_type,
         )
 
     def get_text_embedding_num_tokens(self, texts: list[str]) -> int:

+ 18 - 5
api/core/model_runtime/model_providers/__base/text_embedding_model.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from pydantic import ConfigDict
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
 from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
 from core.model_runtime.model_providers.__base.ai_model import AIModel
@@ -20,35 +21,47 @@ class TextEmbeddingModel(AIModel):
     model_config = ConfigDict(protected_namespaces=())
 
     def invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
-        Invoke large language model
+        Invoke text embedding model
 
         :param model: model name
         :param credentials: model credentials
         :param texts: texts to embed
         :param user: unique user id
+        :param input_type: input type
         :return: embeddings result
         """
         self.started_at = time.perf_counter()
 
         try:
-            return self._invoke(model, credentials, texts, user)
+            return self._invoke(model, credentials, texts, user, input_type)
         except Exception as e:
             raise self._transform_invoke_error(e)
 
     @abstractmethod
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
-        Invoke large language model
+        Invoke text embedding model
 
         :param model: model name
         :param credentials: model credentials
         :param texts: texts to embed
         :param user: unique user id
+        :param input_type: input type
         :return: embeddings result
         """
         raise NotImplementedError

+ 7 - 1
api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py

@@ -7,6 +7,7 @@ import numpy as np
 import tiktoken
 from openai import AzureOpenAI
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -17,7 +18,12 @@ from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_
 
 class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         base_model_name = credentials["base_model_name"]
         credentials_kwargs = self._to_credential_kwargs(credentials)

+ 7 - 1
api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from requests import post
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -35,7 +36,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
     api_base: str = "http://api.baichuan-ai.com/v1/embeddings"
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -13,6 +13,7 @@ from botocore.exceptions import (
     UnknownServiceError,
 )
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -30,7 +31,12 @@ logger = logging.getLogger(__name__)
 
 class BedrockTextEmbeddingModel(TextEmbeddingModel):
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py

@@ -5,6 +5,7 @@ import cohere
 import numpy as np
 from cohere.core import RequestOptions
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -25,7 +26,12 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py

@@ -6,6 +6,7 @@ import numpy as np
 import requests
 from huggingface_hub import HfApi, InferenceClient
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -18,7 +19,12 @@ HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/
 
 class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         client = InferenceClient(token=credentials["huggingfacehub_api_token"])
 

+ 7 - 1
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py

@@ -1,6 +1,7 @@
 import time
 from typing import Optional
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -23,7 +24,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py

@@ -9,6 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile
 from tencentcloud.common.profile.http_profile import HttpProfile
 from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -26,7 +27,12 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from requests import post
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -60,7 +61,12 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
         return data
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py

@@ -5,6 +5,7 @@ from typing import Optional
 from requests import post
 from yarl import URL
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -26,7 +27,12 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from requests import post
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -34,7 +35,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
     api_base: str = "https://api.minimax.chat/v1/embeddings"
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 import requests
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -27,7 +28,12 @@ class MixedBreadTextEmbeddingModel(TextEmbeddingModel):
     api_base: str = "https://api.mixedbread.ai/v1"
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 2 - 0
api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py

@@ -5,6 +5,7 @@ from typing import Optional
 from nomic import embed
 from nomic import login as nomic_login
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import (
     EmbeddingUsage,
@@ -46,6 +47,7 @@ class NomicTextEmbeddingModel(_CommonNomic, TextEmbeddingModel):
         credentials: dict,
         texts: list[str],
         user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from requests import post
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -27,7 +28,12 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel):
     models: list[str] = ["NV-Embed-QA"]
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 8 - 1
api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py

@@ -6,6 +6,7 @@ from typing import Optional
 import numpy as np
 import oci
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -41,7 +42,12 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model
@@ -50,6 +56,7 @@ class OCITextEmbeddingModel(TextEmbeddingModel):
         :param credentials: model credentials
         :param texts: texts to embed
         :param user: unique user id
+        :param input_type: input type
         :return: embeddings result
         """
         # get model properties

+ 7 - 1
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py

@@ -8,6 +8,7 @@ from urllib.parse import urljoin
 import numpy as np
 import requests
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -38,7 +39,12 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py

@@ -6,6 +6,7 @@ import numpy as np
 import tiktoken
 from openai import OpenAI
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -19,7 +20,12 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py

@@ -7,6 +7,7 @@ from urllib.parse import urljoin
 import numpy as np
 import requests
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -28,7 +29,12 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py

@@ -5,6 +5,7 @@ from typing import Optional
 from requests import post
 from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
@@ -25,7 +26,12 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 8 - 1
api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py

@@ -7,6 +7,7 @@ from urllib.parse import urljoin
 import numpy as np
 import requests
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -28,7 +29,12 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model
@@ -37,6 +43,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
         :param credentials: model credentials
         :param texts: texts to embed
         :param user: unique user id
+        :param input_type: input type
         :return: embeddings result
         """
 

+ 7 - 1
api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from replicate import Client as ReplicateClient
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -14,7 +15,12 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat
 
 class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30)
 

+ 7 - 1
api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py

@@ -6,6 +6,7 @@ from typing import Any, Optional
 
 import boto3
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -53,7 +54,12 @@ class SageMakerEmbeddingModel(TextEmbeddingModel):
         return embeddings
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py

@@ -1,5 +1,6 @@
 from typing import Optional
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
 from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
     OAICompatEmbeddingModel,
@@ -16,7 +17,12 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel):
         super().validate_credentials(model, credentials)
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         self._add_custom_parameters(credentials)
         return super()._invoke(model, credentials, texts, user)

+ 2 - 0
api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py

@@ -4,6 +4,7 @@ from typing import Optional
 import dashscope
 import numpy as np
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import (
     EmbeddingUsage,
@@ -27,6 +28,7 @@ class TongyiTextEmbeddingModel(_CommonTongyi, TextEmbeddingModel):
         credentials: dict,
         texts: list[str],
         user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 9 - 1
api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py

@@ -7,6 +7,7 @@ import numpy as np
 from openai import OpenAI
 from tokenizers import Tokenizer
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -22,7 +23,14 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
     def _get_tokenizer(self) -> Tokenizer:
         return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer")
 
-    def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | None = None) -> TextEmbeddingResult:
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: str | None = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
+    ) -> TextEmbeddingResult:
         """
         Invoke text embedding model
 

+ 7 - 1
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

@@ -9,6 +9,7 @@ from google.cloud import aiplatform
 from google.oauth2 import service_account
 from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -30,7 +31,12 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -2,6 +2,7 @@ import time
 from decimal import Decimal
 from typing import Optional
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -41,7 +42,12 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 7 - 1
api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py

@@ -7,6 +7,7 @@ from typing import Any, Optional
 import numpy as np
 from requests import Response, post
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import InvokeError
@@ -70,7 +71,12 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
         return WenxinTextEmbedding(api_key, secret_key)
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model

+ 8 - 1
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py

@@ -3,6 +3,7 @@ from typing import Optional
 
 from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@@ -25,7 +26,12 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model
@@ -40,6 +46,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
         :param credentials: model credentials
         :param texts: texts to embed
         :param user: unique user id
+        :param input_type: input type
         :return: embeddings result
         """
         server_url = credentials["server_url"]

+ 7 - 1
api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py

@@ -1,6 +1,7 @@
 import time
 from typing import Optional
 
+from core.embedding.embedding_constant import EmbeddingInputType
 from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
@@ -15,7 +16,12 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
     """
 
     def _invoke(
-        self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
     ) -> TextEmbeddingResult:
         """
         Invoke text embedding model