Selaa lähdekoodia

refactor: reduce duplciate code by inheritance (#13073)

Yingchun Lai 2 kuukautta sitten
vanhempi
commit
d44882c1b5

+ 9 - 187
api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py

@@ -1,29 +1,13 @@
-import json
-import time
-from decimal import Decimal
 from typing import Optional
-from urllib.parse import urljoin
-
-import numpy as np
-import requests
 
 from core.entities.embedding_type import EmbeddingInputType
-from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.model_entities import (
-    AIModelEntity,
-    FetchFrom,
-    ModelPropertyKey,
-    ModelType,
-    PriceConfig,
-    PriceType,
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
+    OAICompatEmbeddingModel,
 )
-from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
-from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
 
 
-class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
+class PerfXCloudEmbeddingModel(OAICompatEmbeddingModel):
     """
     Model class for an OpenAI API-compatible text embedding model.
     """
@@ -47,86 +31,10 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
         :return: embeddings result
         """
 
-        # Prepare headers and payload for the request
-        headers = {"Content-Type": "application/json"}
-
-        api_key = credentials.get("api_key")
-        if api_key:
-            headers["Authorization"] = f"Bearer {api_key}"
-        endpoint_url: Optional[str]
         if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
-            endpoint_url = "https://cloud.perfxlab.cn/v1/"
-        else:
-            endpoint_url = credentials.get("endpoint_url")
-            assert endpoint_url is not None, "endpoint_url is required in credentials"
-            if not endpoint_url.endswith("/"):
-                endpoint_url += "/"
-
-        assert isinstance(endpoint_url, str)
-        endpoint_url = urljoin(endpoint_url, "embeddings")
-
-        extra_model_kwargs = {}
-        if user:
-            extra_model_kwargs["user"] = user
-
-        extra_model_kwargs["encoding_format"] = "float"
-
-        # get model properties
-        context_size = self._get_context_size(model, credentials)
-        max_chunks = self._get_max_chunks(model, credentials)
-
-        inputs = []
-        indices = []
-        used_tokens = 0
-
-        for i, text in enumerate(texts):
-            # Here token count is only an approximation based on the GPT2 tokenizer
-            # TODO: Optimize for better token estimation and chunking
-            num_tokens = self._get_num_tokens_by_gpt2(text)
-
-            if num_tokens >= context_size:
-                cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
-                # if num tokens is larger than context length, only use the start
-                inputs.append(text[0:cutoff])
-            else:
-                inputs.append(text)
-            indices += [i]
-
-        batched_embeddings = []
-        _iter = range(0, len(inputs), max_chunks)
-
-        for i in _iter:
-            # Prepare the payload for the request
-            payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs}
-
-            # Make the request to the OpenAI API
-            response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
+            credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/"
 
-            response.raise_for_status()  # Raise an exception for HTTP errors
-            response_data = response.json()
-
-            # Extract embeddings and used tokens from the response
-            embeddings_batch = [data["embedding"] for data in response_data["data"]]
-            embedding_used_tokens = response_data["usage"]["total_tokens"]
-
-            used_tokens += embedding_used_tokens
-            batched_embeddings += embeddings_batch
-
-        # calc usage
-        usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
-
-        return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model)
-
-    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
-        """
-        Approximate number of tokens for given messages using GPT2 tokenizer
-
-        :param model: model name
-        :param credentials: model credentials
-        :param texts: texts to embed
-        :return:
-        """
-        return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
+        return OAICompatEmbeddingModel._invoke(self, model, credentials, texts, user, input_type)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -136,93 +44,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
         :param credentials: model credentials
         :return:
         """
-        try:
-            headers = {"Content-Type": "application/json"}
-
-            api_key = credentials.get("api_key")
-
-            if api_key:
-                headers["Authorization"] = f"Bearer {api_key}"
-
-            endpoint_url: Optional[str]
-            if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
-                endpoint_url = "https://cloud.perfxlab.cn/v1/"
-            else:
-                endpoint_url = credentials.get("endpoint_url")
-                assert endpoint_url is not None, "endpoint_url is required in credentials"
-                if not endpoint_url.endswith("/"):
-                    endpoint_url += "/"
-
-            assert isinstance(endpoint_url, str)
-            endpoint_url = urljoin(endpoint_url, "embeddings")
-
-            payload = {"input": "ping", "model": model}
-
-            response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300))
-
-            if response.status_code != 200:
-                raise CredentialsValidateFailedError(
-                    f"Credentials validation failed with status code {response.status_code}"
-                )
-
-            try:
-                json_result = response.json()
-            except json.JSONDecodeError as e:
-                raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
-
-            if "model" not in json_result:
-                raise CredentialsValidateFailedError("Credentials validation failed: invalid response")
-        except CredentialsValidateFailedError:
-            raise
-        except Exception as ex:
-            raise CredentialsValidateFailedError(str(ex))
-
-    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
-        """
-        generate custom model entities from credentials
-        """
-        entity = AIModelEntity(
-            model=model,
-            label=I18nObject(en_US=model),
-            model_type=ModelType.TEXT_EMBEDDING,
-            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
-            model_properties={
-                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
-                ModelPropertyKey.MAX_CHUNKS: 1,
-            },
-            parameter_rules=[],
-            pricing=PriceConfig(
-                input=Decimal(credentials.get("input_price", 0)),
-                unit=Decimal(credentials.get("unit", 0)),
-                currency=credentials.get("currency", "USD"),
-            ),
-        )
-
-        return entity
-
-    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
-        """
-        Calculate response usage
-
-        :param model: model name
-        :param credentials: model credentials
-        :param tokens: input tokens
-        :return: usage
-        """
-        # get input price info
-        input_price_info = self.get_price(
-            model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
-        )
-
-        # transform usage
-        usage = EmbeddingUsage(
-            tokens=tokens,
-            total_tokens=tokens,
-            unit_price=input_price_info.unit_price,
-            price_unit=input_price_info.unit,
-            total_price=input_price_info.total_amount,
-            currency=input_price_info.currency,
-            latency=time.perf_counter() - self.started_at,
-        )
+        if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
+            credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/"
 
-        return usage
+        OAICompatEmbeddingModel.validate_credentials(self, model, credentials)