Procházet zdrojové kódy

make sure validation flow works for all model providers in bedrock (#3250)

Chenhe Gu před 1 rokem
rodič
revize
eb76d7a226

+ 1 - 1
api/core/model_runtime/model_providers/bedrock/bedrock.yaml

@@ -74,7 +74,7 @@ provider_credential_schema:
       label:
         en_US: Available Model Name
         zh_Hans: 可用模型名称
-      type: secret-input
+      type: text-input
       placeholder:
         en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
         zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)

+ 15 - 15
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -402,25 +402,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param credentials: model credentials
         :return:
         """
-
-        if "anthropic.claude-3" in model:
-            try:
-                self._invoke_claude(model=model,
-                                        credentials=credentials,
-                                        prompt_messages=[{"role": "user", "content": "ping"}],
-                                        model_parameters={},
-                                        stop=None,
-                                        stream=False)
-
-            except Exception as ex:
-                raise CredentialsValidateFailedError(str(ex))
-
+        required_params = {}
+        if "anthropic" in model:
+            required_params = {
+                "max_tokens": 32,
+            }
+        elif "ai21" in model:
+            # ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again.
+            required_params = {
+                "temperature": 0.7,
+                "topP": 0.9,
+                "maxTokens": 32,
+            }
+            
         try:
             ping_message = UserPromptMessage(content="ping")
-            self._generate(model=model,
+            self._invoke(model=model,
                            credentials=credentials,
                            prompt_messages=[ping_message],
-                           model_parameters={},
+                           model_parameters=required_params,
                            stream=False)
         
         except ClientError as ex: