|
@@ -1,7 +1,16 @@
|
|
|
import time
|
|
|
+from decimal import Decimal
|
|
|
from typing import Optional
|
|
|
|
|
|
-from core.model_runtime.entities.model_entities import PriceType
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ FetchFrom,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ PriceConfig,
|
|
|
+ PriceType,
|
|
|
+)
|
|
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
@@ -21,6 +30,7 @@ from core.model_runtime.model_providers.volcengine_maas.errors import (
|
|
|
RateLimitErrors,
|
|
|
ServerUnavailableErrors,
|
|
|
)
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs
|
|
|
from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
|
|
|
|
|
|
|
@@ -45,7 +55,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|
|
resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
|
|
|
|
|
|
usage = self._calc_response_usage(
|
|
|
- model=model, credentials=credentials, tokens=resp['total_tokens'])
|
|
|
+ model=model, credentials=credentials, tokens=resp['usage']['total_tokens'])
|
|
|
|
|
|
result = TextEmbeddingResult(
|
|
|
model=model,
|
|
@@ -101,6 +111,34 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
|
|
InvokeBadRequestError: BadRequestErrors.values(),
|
|
|
}
|
|
|
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
+ """
|
|
|
+ generate custom model entities from credentials
|
|
|
+ """
|
|
|
+ model_properties = ModelConfigs.get(
|
|
|
+ credentials['base_model_name'], {}).get('model_properties', {}).copy()
|
|
|
+ if credentials.get('context_size'):
|
|
|
+ model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
|
|
|
+ credentials.get('context_size', 4096))
|
|
|
+ if credentials.get('max_chunks'):
|
|
|
+ model_properties[ModelPropertyKey.MAX_CHUNKS] = int(
|
|
|
+ credentials.get('max_chunks', 4096))
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(en_US=model),
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_properties=model_properties,
|
|
|
+ parameter_rules=[],
|
|
|
+ pricing=PriceConfig(
|
|
|
+ input=Decimal(credentials.get('input_price', 0)),
|
|
|
+ unit=Decimal(credentials.get('unit', 0)),
|
|
|
+ currency=credentials.get('currency', "USD")
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ return entity
|
|
|
+
|
|
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
|
|
"""
|
|
|
Calculate response usage
|