Ver Fonte

feat: update model_provider jina to support custom url and model (#4110)

Co-authored-by: Gimling <huangjl@ruyi.ai>
Co-authored-by: takatost <takatost@gmail.com>
VoidIsVoid há 11 meses atrás
pai
commit
543a00e597

+ 38 - 0
api/core/model_runtime/model_providers/jina/jina.yaml

@@ -19,6 +19,7 @@ supported_model_types:
   - rerank
 configurate_methods:
   - predefined-model
+  - customizable-model
 provider_credential_schema:
   credential_form_schemas:
     - variable: api_key
@@ -29,3 +30,40 @@ provider_credential_schema:
       placeholder:
         zh_Hans: 在此输入您的 API Key
         en_US: Enter your API Key
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: base_url
+      label:
+        zh_Hans: 服务器 URL
+        en_US: Base URL
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: Base URL, e.g. https://api.jina.ai/v1
+        en_US: Base URL, e.g. https://api.jina.ai/v1
+      default: 'https://api.jina.ai/v1'
+    - variable: context_size
+      label:
+        zh_Hans: 上下文大小
+        en_US: Context size
+      placeholder:
+        zh_Hans: 输入上下文大小
+        en_US: Enter context size
+      required: false
+      type: text-input
+      default: '8192'

+ 23 - 1
api/core/model_runtime/model_providers/jina/rerank/rerank.py

@@ -2,6 +2,8 @@ from typing import Optional
 
 import httpx
 
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
 from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
@@ -38,9 +40,13 @@ class JinaRerankModel(RerankModel):
         if len(docs) == 0:
             return RerankResult(model=model, docs=[])
 
+        base_url = credentials.get('base_url', 'https://api.jina.ai/v1')
+        if base_url.endswith('/'):
+            base_url = base_url[:-1]
+
         try:
             response = httpx.post(
-                "https://api.jina.ai/v1/rerank",
+                base_url + '/rerank',
                 json={
                     "model": model,
                     "query": query,
@@ -103,3 +109,19 @@ class JinaRerankModel(RerankModel):
             InvokeAuthorizationError: [httpx.HTTPStatusError],  
             InvokeBadRequestError: [httpx.RequestError]
         }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+            generate custom model entities from credentials
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.RERANK,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
+            }
+        )
+
+        return entity

+ 30 - 10
api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py

@@ -4,7 +4,8 @@ from typing import Optional
 
 from requests import post
 
-from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
 from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
@@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
     """
     Model class for Jina text embedding model.
     """
-    api_base: str = 'https://api.jina.ai/v1/embeddings'
-    models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
+    api_base: str = 'https://api.jina.ai/v1'
 
     def _invoke(self, model: str, credentials: dict,
                 texts: list[str], user: Optional[str] = None) \
@@ -39,11 +39,14 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
         :return: embeddings result
         """
         api_key = credentials['api_key']
-        if model not in self.models:
-            raise InvokeBadRequestError('Invalid model name')
         if not api_key:
             raise CredentialsValidateFailedError('api_key is required')
-        url = self.api_base
+
+        base_url = credentials.get('base_url', self.api_base)
+        if base_url.endswith('/'):
+            base_url = base_url[:-1]
+
+        url = base_url + '/embeddings'
         headers = {
             'Authorization': 'Bearer ' + api_key,
             'Content-Type': 'application/json'
@@ -70,7 +73,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
                 elif response.status_code == 500:
                     raise InvokeServerUnavailableError(msg)
                 else:
-                    raise InvokeError(msg)
+                    raise InvokeBadRequestError(msg)
             except JSONDecodeError as e:
                 raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
 
@@ -118,8 +121,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
         """
         try:
             self._invoke(model=model, credentials=credentials, texts=['ping'])
-        except InvokeAuthorizationError:
-            raise CredentialsValidateFailedError('Invalid api key')
+        except Exception as e:
+            raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
 
     @property
     def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
@@ -137,7 +140,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
                 InvokeAuthorizationError
             ],
             InvokeBadRequestError: [
-                KeyError
+                KeyError,
+                InvokeBadRequestError
             ]
         }
     
@@ -170,3 +174,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
         )
 
         return usage
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+            generate custom model entities from credentials
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.TEXT_EMBEDDING,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
+            }
+        )
+
+        return entity