|
@@ -4,7 +4,8 @@ from typing import Optional
|
|
|
|
|
|
from requests import post
|
|
|
|
|
|
-from core.model_runtime.entities.model_entities import PriceType
|
|
|
+from core.model_runtime.entities.common_entities import I18nObject
|
|
|
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
|
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
@@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|
|
"""
|
|
|
Model class for Jina text embedding model.
|
|
|
"""
|
|
|
- api_base: str = 'https://api.jina.ai/v1/embeddings'
|
|
|
- models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
|
|
|
+ api_base: str = 'https://api.jina.ai/v1'
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict,
|
|
|
texts: list[str], user: Optional[str] = None) \
|
|
@@ -39,11 +39,14 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|
|
:return: embeddings result
|
|
|
"""
|
|
|
api_key = credentials['api_key']
|
|
|
- if model not in self.models:
|
|
|
- raise InvokeBadRequestError('Invalid model name')
|
|
|
if not api_key:
|
|
|
raise CredentialsValidateFailedError('api_key is required')
|
|
|
- url = self.api_base
|
|
|
+
|
|
|
+ base_url = credentials.get('base_url', self.api_base)
|
|
|
+ if base_url.endswith('/'):
|
|
|
+ base_url = base_url[:-1]
|
|
|
+
|
|
|
+ url = base_url + '/embeddings'
|
|
|
headers = {
|
|
|
'Authorization': 'Bearer ' + api_key,
|
|
|
'Content-Type': 'application/json'
|
|
@@ -70,7 +73,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|
|
elif response.status_code == 500:
|
|
|
raise InvokeServerUnavailableError(msg)
|
|
|
else:
|
|
|
- raise InvokeError(msg)
|
|
|
+ raise InvokeBadRequestError(msg)
|
|
|
except JSONDecodeError as e:
|
|
|
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
|
|
|
|
@@ -118,8 +121,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|
|
"""
|
|
|
try:
|
|
|
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
|
|
- except InvokeAuthorizationError:
|
|
|
- raise CredentialsValidateFailedError('Invalid api key')
|
|
|
+ except Exception as e:
|
|
|
+ raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
|
|
|
|
|
@property
|
|
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
|
@@ -137,7 +140,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|
|
InvokeAuthorizationError
|
|
|
],
|
|
|
InvokeBadRequestError: [
|
|
|
- KeyError
|
|
|
+ KeyError,
|
|
|
+ InvokeBadRequestError
|
|
|
]
|
|
|
}
|
|
|
|
|
@@ -170,3 +174,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
|
|
)
|
|
|
|
|
|
return usage
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
+ """
|
|
|
+ generate custom model entities from credentials
|
|
|
+ """
|
|
|
+ entity = AIModelEntity(
|
|
|
+ model=model,
|
|
|
+ label=I18nObject(en_US=model),
|
|
|
+ model_type=ModelType.TEXT_EMBEDDING,
|
|
|
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
+ model_properties={
|
|
|
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return entity
|