Bläddra i källkod

feat: replicate supports default version. (#3884)

Garfield Dai 1 år sedan
förälder
incheckning
cefe156811

+ 30 - 12
api/core/model_runtime/model_providers/replicate/llm/llm.py

@@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
                 tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
                 user: Optional[str] = None) -> Union[LLMResult, Generator]:
 
-        version = credentials['model_version']
+        model_version = ''
+        if 'model_version' in credentials:
+            model_version = credentials['model_version']
 
         client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
         model_info = client.models.get(model)
-        model_info_version = model_info.versions.get(version)
+
+        if model_version:
+            model_info_version = model_info.versions.get(model_version)
+        else:
+            model_info_version = model_info.latest_version
 
         inputs = {**model_parameters}
 
@@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
         if 'replicate_api_token' not in credentials:
             raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
 
-        if 'model_version' not in credentials:
-            raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
+        model_version = ''
+        if 'model_version' in credentials:
+            model_version = credentials['model_version']
 
         if model.count("/") != 1:
             raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
                                                  'format: {user_name}/{model_name}')
 
-        version = credentials['model_version']
-
         try:
             client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
             model_info = client.models.get(model)
-            model_info_version = model_info.versions.get(version)
 
-            self._check_text_generation_model(model_info_version, model, version)
+            if model_version:
+                model_info_version = model_info.versions.get(model_version)
+            else:
+                model_info_version = model_info.latest_version
+
+            self._check_text_generation_model(model_info_version, model, model_version, model_info.description)
         except ReplicateError as e:
             raise CredentialsValidateFailedError(
-                f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
+                f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}")
         except Exception as e:
             raise CredentialsValidateFailedError(str(e))
 
     @staticmethod
-    def _check_text_generation_model(model_info_version, model_name, version):
+    def _check_text_generation_model(model_info_version, model_name, version, description):
+        if 'language model' in description.lower():
+            return
+
         if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
                 or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
                 or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']:
@@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
 
     @classmethod
     def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
-        version = credentials['model_version']
+        model_version = ''
+        if 'model_version' in credentials:
+            model_version = credentials['model_version']
 
         client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
         model_info = client.models.get(model)
-        model_info_version = model_info.versions.get(version)
+
+        if model_version:
+            model_info_version = model_info.versions.get(model_version)
+        else:
+            model_info_version = model_info.latest_version
 
         parameter_rules = []
 

+ 3 - 3
api/core/model_runtime/model_providers/replicate/replicate.yaml

@@ -35,7 +35,7 @@ model_credential_schema:
       label:
         en_US: Model Version
       type: text-input
-      required: true
+      required: false
       placeholder:
-        zh_Hans: 在此输入您的模型版本
-        en_US: Enter your model version
+        zh_Hans: 在此输入您的模型版本,默认为最新版本
+        en_US: Enter your model version, default to the latest version

+ 18 - 7
api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py

@@ -17,9 +17,16 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
                 user: Optional[str] = None) -> TextEmbeddingResult:
 
         client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
-        replicate_model_version = f'{model}:{credentials["model_version"]}'
 
-        text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
+        if 'model_version' in credentials:
+            model_version = credentials['model_version']
+        else:
+            model_info = client.models.get(model)
+            model_version = model_info.latest_version.id
+
+        replicate_model_version = f'{model}:{model_version}'
+
+        text_input_key = self._get_text_input_key(model, model_version, client)
 
         embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
                                                                  texts)
@@ -43,14 +50,18 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
         if 'replicate_api_token' not in credentials:
             raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
 
-        if 'model_version' not in credentials:
-            raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
-
         try:
             client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
-            replicate_model_version = f'{model}:{credentials["model_version"]}'
 
-            text_input_key = self._get_text_input_key(model, credentials['model_version'], client)
+            if 'model_version' in credentials:
+                model_version = credentials['model_version']
+            else:
+                model_info = client.models.get(model)
+                model_version = model_info.latest_version.id
+
+            replicate_model_version = f'{model}:{model_version}'
+
+            text_input_key = self._get_text_input_key(model, model_version, client)
 
             self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key,
                                                         ['Hello worlds!'])