|
@@ -0,0 +1,250 @@
|
|
|
+import json
|
|
|
+import time
|
|
|
+from decimal import Decimal
|
|
|
+from typing import Optional
|
|
|
+from urllib.parse import urljoin
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import requests
|
|
|
+
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ FetchFrom,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ PriceConfig,
|
|
|
+ PriceType,
|
|
|
+)
|
|
|
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
|
|
+from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
|
|
+
|
|
|
+
|
|
|
+class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
|
|
+ """
|
|
|
+ Model class for an OpenAI API-compatible text embedding model.
|
|
|
+ """
|
|
|
+
|
|
|
+ def _invoke(self, model: str, credentials: dict,
|
|
|
+ texts: list[str], user: Optional[str] = None) \
|
|
|
+ -> TextEmbeddingResult:
|
|
|
+ """
|
|
|
+ Invoke text embedding model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param texts: texts to embed
|
|
|
+ :param user: unique user id
|
|
|
+ :return: embeddings result
|
|
|
+ """
|
|
|
+
|
|
|
+ # Prepare headers and payload for the request
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json'
|
|
|
+ }
|
|
|
+
|
|
|
+ api_key = credentials.get('api_key')
|
|
|
+ if api_key:
|
|
|
+ headers["Authorization"] = f"Bearer {api_key}"
|
|
|
+
|
|
|
+ if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
|
|
+ endpoint_url='https://cloud.perfxlab.cn/v1/'
|
|
|
+ else:
|
|
|
+ endpoint_url = credentials.get('endpoint_url')
|
|
|
+ if not endpoint_url.endswith('/'):
|
|
|
+ endpoint_url += '/'
|
|
|
+
|
|
|
+ endpoint_url = urljoin(endpoint_url, 'embeddings')
|
|
|
+
|
|
|
+ extra_model_kwargs = {}
|
|
|
+ if user:
|
|
|
+ extra_model_kwargs['user'] = user
|
|
|
+
|
|
|
+ extra_model_kwargs['encoding_format'] = 'float'
|
|
|
+
|
|
|
+ # get model properties
|
|
|
+ context_size = self._get_context_size(model, credentials)
|
|
|
+ max_chunks = self._get_max_chunks(model, credentials)
|
|
|
+
|
|
|
+ inputs = []
|
|
|
+ indices = []
|
|
|
+ used_tokens = 0
|
|
|
+
|
|
|
+ for i, text in enumerate(texts):
|
|
|
+
|
|
|
+ # Here token count is only an approximation based on the GPT2 tokenizer
|
|
|
+ # TODO: Optimize for better token estimation and chunking
|
|
|
+ num_tokens = self._get_num_tokens_by_gpt2(text)
|
|
|
+
|
|
|
+ if num_tokens >= context_size:
|
|
|
+ cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
|
|
+ # if num tokens is larger than context length, only use the start
|
|
|
+ inputs.append(text[0: cutoff])
|
|
|
+ else:
|
|
|
+ inputs.append(text)
|
|
|
+ indices += [i]
|
|
|
+
|
|
|
+ batched_embeddings = []
|
|
|
+ _iter = range(0, len(inputs), max_chunks)
|
|
|
+
|
|
|
+ for i in _iter:
|
|
|
+ # Prepare the payload for the request
|
|
|
+ payload = {
|
|
|
+ 'input': inputs[i: i + max_chunks],
|
|
|
+ 'model': model,
|
|
|
+ **extra_model_kwargs
|
|
|
+ }
|
|
|
+
|
|
|
+ # Make the request to the OpenAI API
|
|
|
+ response = requests.post(
|
|
|
+ endpoint_url,
|
|
|
+ headers=headers,
|
|
|
+ data=json.dumps(payload),
|
|
|
+ timeout=(10, 300)
|
|
|
+ )
|
|
|
+
|
|
|
+ response.raise_for_status() # Raise an exception for HTTP errors
|
|
|
+ response_data = response.json()
|
|
|
+
|
|
|
+ # Extract embeddings and used tokens from the response
|
|
|
+ embeddings_batch = [data['embedding'] for data in response_data['data']]
|
|
|
+ embedding_used_tokens = response_data['usage']['total_tokens']
|
|
|
+
|
|
|
+ used_tokens += embedding_used_tokens
|
|
|
+ batched_embeddings += embeddings_batch
|
|
|
+
|
|
|
+ # calc usage
|
|
|
+ usage = self._calc_response_usage(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ tokens=used_tokens
|
|
|
+ )
|
|
|
+
|
|
|
+ return TextEmbeddingResult(
|
|
|
+ embeddings=batched_embeddings,
|
|
|
+ usage=usage,
|
|
|
+ model=model
|
|
|
+ )
|
|
|
+
|
|
|
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
|
|
+ """
|
|
|
+ Approximate number of tokens for given messages using GPT2 tokenizer
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param texts: texts to embed
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ headers = {
|
|
|
+ 'Content-Type': 'application/json'
|
|
|
+ }
|
|
|
+
|
|
|
+ api_key = credentials.get('api_key')
|
|
|
+
|
|
|
+ if api_key:
|
|
|
+ headers["Authorization"] = f"Bearer {api_key}"
|
|
|
+
|
|
|
+ if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
|
|
+ endpoint_url='https://cloud.perfxlab.cn/v1/'
|
|
|
+ else:
|
|
|
+ endpoint_url = credentials.get('endpoint_url')
|
|
|
+ if not endpoint_url.endswith('/'):
|
|
|
+ endpoint_url += '/'
|
|
|
+
|
|
|
+ endpoint_url = urljoin(endpoint_url, 'embeddings')
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ 'input': 'ping',
|
|
|
+ 'model': model
|
|
|
+ }
|
|
|
+
|
|
|
+ response = requests.post(
|
|
|
+ url=endpoint_url,
|
|
|
+ headers=headers,
|
|
|
+ data=json.dumps(payload),
|
|
|
+ timeout=(10, 300)
|
|
|
+ )
|
|
|
+
|
|
|
+ if response.status_code != 200:
|
|
|
+ raise CredentialsValidateFailedError(
|
|
|
+ f'Credentials validation failed with status code {response.status_code}')
|
|
|
+
|
|
|
+ try:
|
|
|
+ json_result = response.json()
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
|
|
+
|
|
|
+ if 'model' not in json_result:
|
|
|
+ raise CredentialsValidateFailedError(
|
|
|
+ 'Credentials validation failed: invalid response')
|
|
|
+ except CredentialsValidateFailedError:
|
|
|
+ raise
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(str(ex))
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
+ """
|
|
|
+ generate custom model entities from credentials
|
|
|
+ """
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(en_US=model),
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
|
|
+ ModelPropertyKey.MAX_CHUNKS: 1,
|
|
|
+ },
|
|
|
+ parameter_rules=[],
|
|
|
+ pricing=PriceConfig(
|
|
|
+ input=Decimal(credentials.get('input_price', 0)),
|
|
|
+ unit=Decimal(credentials.get('unit', 0)),
|
|
|
+ currency=credentials.get('currency', "USD")
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ return entity
|
|
|
+
|
|
|
+
|
|
|
+ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
|
+ """
|
|
|
+ Calculate response usage
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param tokens: input tokens
|
|
|
+ :return: usage
|
|
|
+ """
|
|
|
+ # get input price info
|
|
|
+ input_price_info = self.get_price(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ price_type=PriceType.INPUT,
|
|
|
+ tokens=tokens
|
|
|
+ )
|
|
|
+
|
|
|
+ # transform usage
|
|
|
+ usage = EmbeddingUsage(
|
|
|
+ tokens=tokens,
|
|
|
+ total_tokens=tokens,
|
|
|
+ unit_price=input_price_info.unit_price,
|
|
|
+ price_unit=input_price_info.unit,
|
|
|
+ total_price=input_price_info.total_amount,
|
|
|
+ currency=input_price_info.currency,
|
|
|
+ latency=time.perf_counter() - self.started_at
|
|
|
+ )
|
|
|
+
|
|
|
+ return usage
|