|
@@ -0,0 +1,151 @@
|
|
|
+import time
|
|
|
+from collections.abc import Mapping
|
|
|
+from typing import Optional, Union
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+from openai import OpenAI
|
|
|
+
|
|
|
+from core.embedding.embedding_constant import EmbeddingInputType
|
|
|
+from core.model_runtime.entities.model_entities import PriceType
|
|
|
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
|
+from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
|
|
+
|
|
|
+
|
|
|
+class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
|
|
|
+ """
|
|
|
+ Model class for Fireworks text embedding model.
|
|
|
+ """
|
|
|
+
|
|
|
+ def _invoke(
|
|
|
+ self,
|
|
|
+ model: str,
|
|
|
+ credentials: dict,
|
|
|
+ texts: list[str],
|
|
|
+ user: Optional[str] = None,
|
|
|
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
|
|
+ ) -> TextEmbeddingResult:
|
|
|
+ """
|
|
|
+ Invoke text embedding model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param texts: texts to embed
|
|
|
+ :param user: unique user id
|
|
|
+ :param input_type: input type
|
|
|
+ :return: embeddings result
|
|
|
+ """
|
|
|
+
|
|
|
+ credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
|
+ client = OpenAI(**credentials_kwargs)
|
|
|
+
|
|
|
+ extra_model_kwargs = {}
|
|
|
+ if user:
|
|
|
+ extra_model_kwargs["user"] = user
|
|
|
+
|
|
|
+ extra_model_kwargs["encoding_format"] = "float"
|
|
|
+
|
|
|
+ context_size = self._get_context_size(model, credentials)
|
|
|
+ max_chunks = self._get_max_chunks(model, credentials)
|
|
|
+
|
|
|
+ inputs = []
|
|
|
+ indices = []
|
|
|
+ used_tokens = 0
|
|
|
+
|
|
|
+ for i, text in enumerate(texts):
|
|
|
+ # Here token count is only an approximation based on the GPT2 tokenizer
|
|
|
+ # TODO: Optimize for better token estimation and chunking
|
|
|
+ num_tokens = self._get_num_tokens_by_gpt2(text)
|
|
|
+
|
|
|
+ if num_tokens >= context_size:
|
|
|
+ cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
|
|
+ # if num tokens is larger than context length, only use the start
|
|
|
+ inputs.append(text[0:cutoff])
|
|
|
+ else:
|
|
|
+ inputs.append(text)
|
|
|
+ indices += [i]
|
|
|
+
|
|
|
+ batched_embeddings = []
|
|
|
+ _iter = range(0, len(inputs), max_chunks)
|
|
|
+
|
|
|
+ for i in _iter:
|
|
|
+ embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
|
|
+ model=model,
|
|
|
+ client=client,
|
|
|
+ texts=inputs[i : i + max_chunks],
|
|
|
+ extra_model_kwargs=extra_model_kwargs,
|
|
|
+ )
|
|
|
+ used_tokens += embedding_used_tokens
|
|
|
+ batched_embeddings += embeddings_batch
|
|
|
+
|
|
|
+ usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
|
|
+ return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
|
|
|
+
|
|
|
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
|
+ """
|
|
|
+ Get number of tokens for given prompt messages
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param texts: texts to embed
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # transform credentials to kwargs for model instance
|
|
|
+ credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
|
+ client = OpenAI(**credentials_kwargs)
|
|
|
+
|
|
|
+ # call embedding model
|
|
|
+ self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(str(ex))
|
|
|
+
|
|
|
+ def _embedding_invoke(
|
|
|
+ self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
|
|
|
+ ) -> tuple[list[list[float]], int]:
|
|
|
+ """
|
|
|
+ Invoke embedding model
|
|
|
+ :param model: model name
|
|
|
+ :param client: model client
|
|
|
+ :param texts: texts to embed
|
|
|
+ :param extra_model_kwargs: extra model kwargs
|
|
|
+ :return: embeddings and used tokens
|
|
|
+ """
|
|
|
+ response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs)
|
|
|
+ return [data.embedding for data in response.data], response.usage.total_tokens
|
|
|
+
|
|
|
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
|
+ """
|
|
|
+ Calculate response usage
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param tokens: input tokens
|
|
|
+ :return: usage
|
|
|
+ """
|
|
|
+ input_price_info = self.get_price(
|
|
|
+ model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT
|
|
|
+ )
|
|
|
+
|
|
|
+ usage = EmbeddingUsage(
|
|
|
+ tokens=tokens,
|
|
|
+ total_tokens=tokens,
|
|
|
+ unit_price=input_price_info.unit_price,
|
|
|
+ price_unit=input_price_info.unit,
|
|
|
+ total_price=input_price_info.total_amount,
|
|
|
+ currency=input_price_info.currency,
|
|
|
+ latency=time.perf_counter() - self.started_at,
|
|
|
+ )
|
|
|
+
|
|
|
+ return usage
|