|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
|
|
|
|
from pydantic import ConfigDict
|
|
|
|
|
|
+from core.embedding.embedding_constant import EmbeddingInputType
|
|
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
|
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
@@ -20,35 +21,47 @@ class TextEmbeddingModel(AIModel):
|
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
|
|
def invoke(
|
|
|
- self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ texts: list[str],
|
|
|
+ user: Optional[str] = None,
|
|
|
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
|
|
) -> TextEmbeddingResult:
|
|
|
"""
|
|
|
- Invoke large language model
|
|
|
+ Invoke text embedding model
|
|
|
|
|
|
:param model: model name
|
|
|
:param credentials: model credentials
|
|
|
:param texts: texts to embed
|
|
|
:param user: unique user id
|
|
|
+ :param input_type: input type
|
|
|
:return: embeddings result
|
|
|
"""
|
|
|
self.started_at = time.perf_counter()
|
|
|
|
|
|
try:
|
|
|
- return self._invoke(model, credentials, texts, user)
|
|
|
+ return self._invoke(model, credentials, texts, user, input_type)
|
|
|
except Exception as e:
|
|
|
raise self._transform_invoke_error(e)
|
|
|
|
|
|
@abstractmethod
|
|
|
def _invoke(
|
|
|
- self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ texts: list[str],
|
|
|
+ user: Optional[str] = None,
|
|
|
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
|
|
) -> TextEmbeddingResult:
|
|
|
"""
|
|
|
- Invoke large language model
|
|
|
+ Invoke text embedding model
|
|
|
|
|
|
:param model: model name
|
|
|
:param credentials: model credentials
|
|
|
:param texts: texts to embed
|
|
|
:param user: unique user id
|
|
|
+ :param input_type: input type
|
|
|
:return: embeddings result
|
|
|
"""
|
|
|
raise NotImplementedError
|