|
@@ -33,11 +33,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True,
|
|
|
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
|
|
|
- version = credentials['model_version']
|
|
|
+ model_version = ''
|
|
|
+ if 'model_version' in credentials:
|
|
|
+ model_version = credentials['model_version']
|
|
|
|
|
|
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
|
|
model_info = client.models.get(model)
|
|
|
- model_info_version = model_info.versions.get(version)
|
|
|
+
|
|
|
+ if model_version:
|
|
|
+ model_info_version = model_info.versions.get(model_version)
|
|
|
+ else:
|
|
|
+ model_info_version = model_info.latest_version
|
|
|
|
|
|
inputs = {**model_parameters}
|
|
|
|
|
@@ -65,29 +71,35 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
|
|
if 'replicate_api_token' not in credentials:
|
|
|
raise CredentialsValidateFailedError('Replicate Access Token must be provided.')
|
|
|
|
|
|
- if 'model_version' not in credentials:
|
|
|
- raise CredentialsValidateFailedError('Replicate Model Version must be provided.')
|
|
|
+ model_version = ''
|
|
|
+ if 'model_version' in credentials:
|
|
|
+ model_version = credentials['model_version']
|
|
|
|
|
|
if model.count("/") != 1:
|
|
|
raise CredentialsValidateFailedError('Replicate Model Name must be provided, '
|
|
|
'format: {user_name}/{model_name}')
|
|
|
|
|
|
- version = credentials['model_version']
|
|
|
-
|
|
|
try:
|
|
|
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
|
|
model_info = client.models.get(model)
|
|
|
- model_info_version = model_info.versions.get(version)
|
|
|
|
|
|
- self._check_text_generation_model(model_info_version, model, version)
|
|
|
+ if model_version:
|
|
|
+ model_info_version = model_info.versions.get(model_version)
|
|
|
+ else:
|
|
|
+ model_info_version = model_info.latest_version
|
|
|
+
|
|
|
+ self._check_text_generation_model(model_info_version, model, model_version, model_info.description)
|
|
|
except ReplicateError as e:
|
|
|
raise CredentialsValidateFailedError(
|
|
|
- f"Model {model}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
|
|
+ f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}")
|
|
|
except Exception as e:
|
|
|
raise CredentialsValidateFailedError(str(e))
|
|
|
|
|
|
@staticmethod
|
|
|
- def _check_text_generation_model(model_info_version, model_name, version):
|
|
|
+ def _check_text_generation_model(model_info_version, model_name, version, description):
|
|
|
+ if 'language model' in description.lower():
|
|
|
+ return
|
|
|
+
|
|
|
if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
|
|
|
or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \
|
|
|
or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']:
|
|
@@ -113,11 +125,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
|
|
|
|
|
@classmethod
|
|
|
def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]:
|
|
|
- version = credentials['model_version']
|
|
|
+ model_version = ''
|
|
|
+ if 'model_version' in credentials:
|
|
|
+ model_version = credentials['model_version']
|
|
|
|
|
|
client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30)
|
|
|
model_info = client.models.get(model)
|
|
|
- model_info_version = model_info.versions.get(version)
|
|
|
+
|
|
|
+ if model_version:
|
|
|
+ model_info_version = model_info.versions.get(model_version)
|
|
|
+ else:
|
|
|
+ model_info_version = model_info.latest_version
|
|
|
|
|
|
parameter_rules = []
|
|
|
|