Ver código fonte

Feat/dify billing (#1679)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: takatost <takatost@users.noreply.github.com>
Garfield Dai 1 ano atrás
pai
commit
053102f433

+ 7 - 1
api/.env.example

@@ -124,5 +124,11 @@ HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
 HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
 HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
 
+# Stripe configuration
 STRIPE_API_KEY=
-STRIPE_WEBHOOK_SECRET=
+STRIPE_WEBHOOK_SECRET=
+
+# Billing configuration
+BILLING_API_URL=http://127.0.0.1:8000/v1
+BILLING_API_SECRET_KEY=
+STRIPE_WEBHOOK_BILLING_SECRET=

+ 2 - 0
api/controllers/console/__init__.py

@@ -28,3 +28,5 @@ from .universal_chat import chat, conversation, message, parameter, audio
 
 # Import webhook controllers
 from .webhook import stripe
+
+from .billing import billing

+ 0 - 0
api/controllers/console/billing/__init__.py


+ 85 - 0
api/controllers/console/billing/billing.py

@@ -0,0 +1,85 @@
+import stripe
+import os
+
+from flask_restful import Resource, reqparse
+from flask_login import current_user
+from flask import current_app, request
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from controllers.console.wraps import only_edition_cloud
+from libs.login import login_required
+from services.billing_service import BillingService
+
+
+class BillingInfo(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+
+        edition = current_app.config['EDITION']
+        if edition != 'CLOUD':
+            return {"enabled": False}
+
+        return BillingService.get_info(current_user.current_tenant_id)
+
+
+class Subscription(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @only_edition_cloud
+    def get(self):
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
+        parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
+        args = parser.parse_args()
+
+        return BillingService.get_subscription(args['plan'], args['interval'], current_user.email, current_user.name, current_user.current_tenant_id)
+
+
+class Invoices(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @only_edition_cloud
+    def get(self):
+
+        return BillingService.get_invoices(current_user.email)
+
+
+class StripeBillingWebhook(Resource):
+
+    @setup_required
+    @only_edition_cloud
+    def post(self):
+        payload = request.data
+        sig_header = request.headers.get('STRIPE_SIGNATURE')
+        webhook_secret = os.environ.get('STRIPE_WEBHOOK_BILLING_SECRET', 'STRIPE_WEBHOOK_BILLING_SECRET')
+
+        try:
+            event = stripe.Webhook.construct_event(
+                payload, sig_header, webhook_secret
+            )
+        except ValueError as e:
+            # Invalid payload
+            return 'Invalid payload', 400
+        except stripe.error.SignatureVerificationError as e:
+            # Invalid signature
+            return 'Invalid signature', 400
+
+        BillingService.process_event(event)
+
+        return 'success', 200
+
+
+api.add_resource(BillingInfo, '/billing/info')
+api.add_resource(Subscription, '/billing/subscription')
+api.add_resource(Invoices, '/billing/invoices')
+api.add_resource(StripeBillingWebhook, '/billing/webhook/stripe')

+ 1 - 0
api/controllers/console/datasets/datasets.py

@@ -493,3 +493,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
 api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
 api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
 api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
+

+ 55 - 0
api/services/billing_service.py

@@ -0,0 +1,55 @@
+import os
+import requests
+
+from services.dataset_service import DatasetService
+
+
+class BillingService:
+    base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
+    secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
+
+    @classmethod
+    def get_info(cls, tenant_id: str):
+        params = {'tenant_id': tenant_id}
+
+        billing_info = cls._send_request('GET', '/info', params=params)
+
+        vector_size = DatasetService.get_tenant_datasets_usage(tenant_id) / 1024
+        billing_info['vector_space']['size'] = int(vector_size)
+
+        return billing_info
+
+    @classmethod
+    def get_subscription(cls, plan: str, interval: str, prefilled_email: str = '', user_name: str = '', tenant_id: str = ''):
+        params = {
+            'plan': plan,
+            'interval': interval,
+            'prefilled_email': prefilled_email,
+            'user_name': user_name,
+            'tenant_id': tenant_id
+        }
+        return cls._send_request('GET', '/subscription', params=params)
+
+    @classmethod
+    def get_invoices(cls, prefilled_email: str = ''):
+        params = {'prefilled_email': prefilled_email}
+        return cls._send_request('GET', '/invoices', params=params)
+
+    @classmethod
+    def _send_request(cls, method, endpoint, json=None, params=None):
+        headers = {
+            "Content-Type": "application/json",
+            "Billing-Api-Secret-Key": cls.secret_key
+        }
+
+        url = f"{cls.base_url}{endpoint}"
+        response = requests.request(method, url, json=json, params=params, headers=headers)
+
+        return response.json()
+
+    @classmethod
+    def process_event(cls, event: dict):
+        json = {
+            "content": event,
+        }
+        return cls._send_request('POST', '/webhook/stripe', json=json)

+ 32 - 1
api/services/dataset_service.py

@@ -227,6 +227,36 @@ class DatasetService:
         return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
             .order_by(db.desc(AppDatasetJoin.created_at)).all()
 
+    @staticmethod
+    def get_tenant_datasets_usage(tenant_id):
+        # get the high_quality datasets
+        dataset_ids = db.session.query(Dataset.id).filter(Dataset.indexing_technique == 'high_quality',
+                                                          Dataset.tenant_id == tenant_id).all()
+        if not dataset_ids:
+            return 0
+        dataset_ids = [result[0] for result in dataset_ids]
+        document_ids = db.session.query(Document.id).filter(Document.dataset_id.in_(dataset_ids),
+                                                            Document.tenant_id == tenant_id,
+                                                            Document.completed_at.isnot(None),
+                                                            Document.enabled == True,
+                                                            Document.archived == False
+                                                            ).all()
+        if not document_ids:
+            return 0
+        document_ids = [result[0] for result in document_ids]
+        document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids),
+                                                                     DocumentSegment.tenant_id == tenant_id,
+                                                                     DocumentSegment.completed_at.isnot(None),
+                                                                     DocumentSegment.enabled == True,
+                                                                     ).all()
+        if not document_segments:
+            return 0
+
+        total_words_size = sum(document_segment.word_count * 3 for document_segment in document_segments)
+        total_vector_size = 1536 * 4 * len(document_segments)
+
+        return total_words_size + total_vector_size
+
 
 class DocumentService:
     DEFAULT_RULES = {
@@ -488,7 +518,8 @@ class DocumentService:
                         'score_threshold_enabled': False
                     }
 
-                    dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
+                    dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
+                        'retrieval_model') else default_retrieval_model
 
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))