Pārlūkot izejas kodu

feat: support spark v2 validate (#1086)

takatost 1 gadu atpakaļ
vecāks
revīzija
a7cdb745c1
1 mainītis faili ar 28 papildinājumiem un 7 dzēšanām
  1. 28 7
      api/core/model_providers/providers/spark_provider.py

+ 28 - 7
api/core/model_providers/providers/spark_provider.py

@@ -83,14 +83,15 @@ class SparkProvider(BaseModelProvider):
         if 'api_secret' not in credentials:
             raise CredentialsValidateFailedError('Spark api_secret must be provided.')
 
-        try:
-            credential_kwargs = {
-                'app_id': credentials['app_id'],
-                'api_key': credentials['api_key'],
-                'api_secret': credentials['api_secret'],
-            }
+        credential_kwargs = {
+            'app_id': credentials['app_id'],
+            'api_key': credentials['api_key'],
+            'api_secret': credentials['api_secret'],
+        }
 
+        try:
             chat_llm = ChatSpark(
+                model_name='spark-v2',
                 max_tokens=10,
                 temperature=0.01,
                 **credential_kwargs
@@ -104,7 +105,27 @@ class SparkProvider(BaseModelProvider):
 
             chat_llm(messages)
         except SparkError as ex:
-            raise CredentialsValidateFailedError(str(ex))
+            # try spark v1.5 if v2.1 failed
+            try:
+                chat_llm = ChatSpark(
+                    model_name='spark',
+                    max_tokens=10,
+                    temperature=0.01,
+                    **credential_kwargs
+                )
+
+                messages = [
+                    HumanMessage(
+                        content="ping"
+                    )
+                ]
+
+                chat_llm(messages)
+            except SparkError as ex:
+                raise CredentialsValidateFailedError(str(ex))
+            except Exception as ex:
+                logging.exception('Spark config validation failed')
+                raise ex
         except Exception as ex:
             logging.exception('Spark config validation failed')
             raise ex