Ver código fonte

Add embedding models in fireworks provider (#8728)

ice yao 7 meses atrás
pai
commit
91f70d0bd9

+ 1 - 0
api/core/model_runtime/model_providers/fireworks/fireworks.yaml

@@ -15,6 +15,7 @@ help:
     en_US: https://fireworks.ai/account/api-keys
 supported_model_types:
   - llm
+  - text-embedding
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 12 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/UAE-Large-V1.yaml

@@ -0,0 +1,12 @@
+model: WhereIsAI/UAE-Large-V1
+label:
+  zh_Hans: UAE-Large-V1
+  en_US: UAE-Large-V1
+model_type: text-embedding
+model_properties:
+  context_size: 512
+  max_chunks: 1
+pricing:
+  input: '0.008'
+  unit: '0.000001'
+  currency: 'USD'

+ 0 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/__init__.py


+ 12 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/gte-base.yaml

@@ -0,0 +1,12 @@
+model: thenlper/gte-base
+label:
+  zh_Hans: GTE-base
+  en_US: GTE-base
+model_type: text-embedding
+model_properties:
+  context_size: 512
+  max_chunks: 1
+pricing:
+  input: '0.008'
+  unit: '0.000001'
+  currency: 'USD'

+ 12 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/gte-large.yaml

@@ -0,0 +1,12 @@
+model: thenlper/gte-large
+label:
+  zh_Hans: GTE-large
+  en_US: GTE-large
+model_type: text-embedding
+model_properties:
+  context_size: 512
+  max_chunks: 1
+pricing:
+  input: '0.008'
+  unit: '0.000001'
+  currency: 'USD'

+ 12 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.5.yaml

@@ -0,0 +1,12 @@
+model: nomic-ai/nomic-embed-text-v1.5
+label:
+  zh_Hans: nomic-embed-text-v1.5
+  en_US: nomic-embed-text-v1.5
+model_type: text-embedding
+model_properties:
+  context_size: 8192
+  max_chunks: 16
+pricing:
+  input: '0.008'
+  unit: '0.000001'
+  currency: 'USD'

+ 12 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.yaml

@@ -0,0 +1,12 @@
+model: nomic-ai/nomic-embed-text-v1
+label:
+  zh_Hans: nomic-embed-text-v1
+  en_US: nomic-embed-text-v1
+model_type: text-embedding
+model_properties:
+  context_size: 8192
+  max_chunks: 16
+pricing:
+  input: '0.008'
+  unit: '0.000001'
+  currency: 'USD'

+ 151 - 0
api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py

@@ -0,0 +1,151 @@
+import time
+from collections.abc import Mapping
+from typing import Optional, Union
+
+import numpy as np
+from openai import OpenAI
+
+from core.embedding.embedding_constant import EmbeddingInputType
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
+
+
+class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel):
+    """
+    Model class for Fireworks text embedding model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        texts: list[str],
+        user: Optional[str] = None,
+        input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
+    ) -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :param input_type: input type
+        :return: embeddings result
+        """
+
+        credentials_kwargs = self._to_credential_kwargs(credentials)
+        client = OpenAI(**credentials_kwargs)
+
+        extra_model_kwargs = {}
+        if user:
+            extra_model_kwargs["user"] = user
+
+        extra_model_kwargs["encoding_format"] = "float"
+
+        context_size = self._get_context_size(model, credentials)
+        max_chunks = self._get_max_chunks(model, credentials)
+
+        inputs = []
+        indices = []
+        used_tokens = 0
+
+        for i, text in enumerate(texts):
+            # Here token count is only an approximation based on the GPT2 tokenizer
+            # TODO: Optimize for better token estimation and chunking
+            num_tokens = self._get_num_tokens_by_gpt2(text)
+
+            if num_tokens >= context_size:
+                cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
+                # if num tokens is larger than context length, only use the start
+                inputs.append(text[0:cutoff])
+            else:
+                inputs.append(text)
+            indices += [i]
+
+        batched_embeddings = []
+        _iter = range(0, len(inputs), max_chunks)
+
+        for i in _iter:
+            embeddings_batch, embedding_used_tokens = self._embedding_invoke(
+                model=model,
+                client=client,
+                texts=inputs[i : i + max_chunks],
+                extra_model_kwargs=extra_model_kwargs,
+            )
+            used_tokens += embedding_used_tokens
+            batched_embeddings += embeddings_batch
+
+        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
+        return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+
+    def validate_credentials(self, model: str, credentials: Mapping) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            # transform credentials to kwargs for model instance
+            credentials_kwargs = self._to_credential_kwargs(credentials)
+            client = OpenAI(**credentials_kwargs)
+
+            # call embedding model
+            self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={})
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _embedding_invoke(
+        self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict
+    ) -> tuple[list[list[float]], int]:
+        """
+        Invoke embedding model
+        :param model: model name
+        :param client: model client
+        :param texts: texts to embed
+        :param extra_model_kwargs: extra model kwargs
+        :return: embeddings and used tokens
+        """
+        response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs)
+        return [data.embedding for data in response.data], response.usage.total_tokens
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        input_price_info = self.get_price(
+            model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT
+        )
+
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at,
+        )
+
+        return usage

+ 54 - 0
api/tests/integration_tests/model_runtime/fireworks/test_text_embedding.py

@@ -0,0 +1,54 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.fireworks.text_embedding.text_embedding import FireworksTextEmbeddingModel
+from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
+def test_validate_credentials(setup_openai_mock):
+    model = FireworksTextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": "invalid_key"}
+        )
+
+    model.validate_credentials(
+        model="nomic-ai/nomic-embed-text-v1.5", credentials={"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY")}
+    )
+
+
+@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
+def test_invoke_model(setup_openai_mock):
+    model = FireworksTextEmbeddingModel()
+
+    result = model.invoke(
+        model="nomic-ai/nomic-embed-text-v1.5",
+        credentials={
+            "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
+        },
+        texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
+        user="foo",
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 4
+    assert result.usage.total_tokens == 2
+
+
+def test_get_num_tokens():
+    model = FireworksTextEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model="nomic-ai/nomic-embed-text-v1.5",
+        credentials={
+            "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
+        },
+        texts=["hello", "world"],
+    )
+
+    assert num_tokens == 2