Sfoglia il codice sorgente

feat: add spark v3.0 llm support (#1434)

takatost 1 anno fa
parent
commit
076f3289d2

+ 31 - 9
api/core/model_providers/providers/spark_provider.py

@@ -28,14 +28,19 @@ class SparkProvider(BaseModelProvider):
         if model_type == ModelType.TEXT_GENERATION:
             return [
                 {
-                    'id': 'spark',
-                    'name': 'Spark V1.5',
+                    'id': 'spark-v3',
+                    'name': 'Spark V3.0',
                     'mode': ModelMode.CHAT.value,
                 },
                 {
                     'id': 'spark-v2',
                     'name': 'Spark V2.0',
                     'mode': ModelMode.CHAT.value,
+                },
+                {
+                    'id': 'spark',
+                    'name': 'Spark V1.5',
+                    'mode': ModelMode.CHAT.value,
                 }
             ]
         else:
@@ -96,7 +101,7 @@ class SparkProvider(BaseModelProvider):
 
         try:
             chat_llm = ChatSpark(
-                model_name='spark-v2',
+                model_name='spark-v3',
                 max_tokens=10,
                 temperature=0.01,
                 **credential_kwargs
@@ -110,10 +115,10 @@ class SparkProvider(BaseModelProvider):
 
             chat_llm(messages)
         except SparkError as ex:
-            # try spark v1.5 if v2.1 failed
+            # try spark v2.1 if v3.1 failed
             try:
                 chat_llm = ChatSpark(
-                    model_name='spark',
+                    model_name='spark-v2',
                     max_tokens=10,
                     temperature=0.01,
                     **credential_kwargs
@@ -127,10 +132,27 @@ class SparkProvider(BaseModelProvider):
 
                 chat_llm(messages)
             except SparkError as ex:
-                raise CredentialsValidateFailedError(str(ex))
-            except Exception as ex:
-                logging.exception('Spark config validation failed')
-                raise ex
+                # try spark v1.5 if v2.1 failed
+                try:
+                    chat_llm = ChatSpark(
+                        model_name='spark',
+                        max_tokens=10,
+                        temperature=0.01,
+                        **credential_kwargs
+                    )
+
+                    messages = [
+                        HumanMessage(
+                            content="ping"
+                        )
+                    ]
+
+                    chat_llm(messages)
+                except SparkError as ex:
+                    raise CredentialsValidateFailedError(str(ex))
+                except Exception as ex:
+                    logging.exception('Spark config validation failed')
+                    raise ex
         except Exception as ex:
             logging.exception('Spark config validation failed')
             raise ex

+ 6 - 0
api/core/model_providers/rules/spark.json

@@ -22,6 +22,12 @@
             "completion": "0.36",
             "unit": "0.0001",
             "currency": "RMB"
+        },
+        "spark-v3": {
+            "prompt": "0.36",
+            "completion": "0.36",
+            "unit": "0.0001",
+            "currency": "RMB"
         }
     }
 }

+ 18 - 2
api/core/third_party/spark/spark_llm.py

@@ -19,9 +19,25 @@ class SparkLLMClient:
     def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
 
         domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
-        api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
 
-        self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
+        model_api_configs = {
+            'spark': {
+                'version': 'v1.1',
+                'chat_domain': 'general'
+            },
+            'spark-v2': {
+                'version': 'v2.1',
+                'chat_domain': 'generalv2'
+            },
+            'spark-v3': {
+                'version': 'v3.1',
+                'chat_domain': 'generalv3'
+            }
+        }
+
+        api_version = model_api_configs[model_name]['version']
+
+        self.chat_domain = model_api_configs[model_name]['chat_domain']
         self.api_base = f"wss://{domain}/{api_version}/chat"
         self.app_id = app_id
         self.ws_url = self.create_url(