浏览代码

Feat/add free provider apply (#829)

takatost 1 年之前
父节点
当前提交
cc52cdc2a9

+ 16 - 0
api/controllers/console/workspace/model_providers.py

@@ -270,6 +270,20 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
         }
 
 
+class ModelProviderFreeQuotaSubmitApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self, provider_name: str):
+        provider_service = ProviderService()
+        result = provider_service.free_quota_submit(
+            tenant_id=current_user.current_tenant_id,
+            provider_name=provider_name
+        )
+
+        return result
+
+
 api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
 api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
 api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
@@ -283,3 +297,5 @@ api.add_resource(ModelProviderModelParameterRuleApi,
                  '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
 api.add_resource(ModelProviderPaymentCheckoutUrlApi,
                  '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
+api.add_resource(ModelProviderFreeQuotaSubmitApi,
+                 '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')

+ 0 - 1
api/core/model_providers/providers/spark_provider.py

@@ -3,7 +3,6 @@ import logging
 from json import JSONDecodeError
 from typing import Type
 
-from flask import current_app
 from langchain.schema import HumanMessage
 
 from core.helper import encrypter

+ 2 - 0
api/core/third_party/langchain/llms/spark.py

@@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel):
     app_id: Optional[str] = None
     api_key: Optional[str] = None
     api_secret: Optional[str] = None
+    api_domain: Optional[str] = None
 
     @root_validator()
     def validate_environment(cls, values: Dict) -> Dict:
@@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel):
             app_id=values["app_id"],
             api_key=values["api_key"],
             api_secret=values["api_secret"],
+            api_domain=values.get('api_domain')
         )
         return values
 

+ 2 - 2
api/core/third_party/spark/spark_llm.py

@@ -16,9 +16,9 @@ import websocket
 
 
 class SparkLLMClient:
-    def __init__(self, app_id: str, api_key: str, api_secret: str):
+    def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
 
-        self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
+        self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
         self.app_id = app_id
         self.ws_url = self.create_url(
             urlparse(self.api_base).netloc,

+ 34 - 0
api/services/provider_service.py

@@ -1,8 +1,12 @@
 import datetime
 import json
+import logging
+import os
 from collections import defaultdict
 from typing import Optional
 
+import requests
+
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from core.model_providers.model_provider_factory import ModelProviderFactory
@@ -509,3 +513,33 @@ class ProviderService:
         # get model parameter rules
         return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
 
+    def free_quota_submit(self, tenant_id: str, provider_name: str):
+        api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
+        api_url = os.environ.get("FREE_QUOTA_APPLY_URL")
+
+        headers = {
+            'Content-Type': 'application/json',
+            'Authorization': f"Bearer {api_key}"
+        }
+        response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
+        if not response.ok:
+            logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
+            raise ValueError(f"Error: {response.status_code} ")
+
+        if response.json()["code"] != 'success':
+            raise ValueError(
+                f"error: {response.json()['message']}"
+            )
+
+        rst = response.json()
+
+        if rst['type'] == 'redirect':
+            return {
+                'type': rst['type'],
+                'redirect_url': rst['redirect_url']
+            }
+        else:
+            return {
+                'type': rst['type'],
+                'result': 'success'
+            }