Browse Source

document limit (#999)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 năm trước cách đây
mục cha
commit
5397799aac

+ 37 - 4
api/controllers/console/datasets/datasets_document.py

@@ -3,7 +3,7 @@ import random
 from datetime import datetime
 from typing import List
 
-from flask import request
+from flask import request, current_app
 from flask_login import current_user
 from core.login.login import login_required
 from flask_restful import Resource, fields, marshal, marshal_with, reqparse
@@ -275,7 +275,8 @@ class DatasetDocumentListApi(Resource):
         parser.add_argument('duplicate', type=bool, nullable=False, location='json')
         parser.add_argument('original_document_id', type=str, required=False, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
 
         if not dataset.indexing_technique and not args['indexing_technique']:
@@ -335,7 +336,8 @@ class DatasetInitApi(Resource):
         parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
 
         try:
@@ -483,7 +485,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
             indexing_runner = IndexingRunner()
             try:
                 response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
-                                                                  data_process_rule_dict,  None, dataset_id)
+                                                                  data_process_rule_dict, None, dataset_id)
             except LLMBadRequestError:
                 raise ProviderNotInitializeError(
                     f"No Embedding Model available. Please configure a valid provider "
@@ -855,6 +857,14 @@ class DocumentStatusApi(DocumentResource):
             if not document.archived:
                 raise InvalidActionError('Document is not archived.')
 
+            # check document limit
+            if current_app.config['EDITION'] == 'CLOUD':
+                documents_count = DocumentService.get_tenant_documents_count()
+                total_count = documents_count + 1
+                tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
+                if total_count > tenant_document_count:
+                    raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
+
             document.archived = False
             document.archived_at = None
             document.archived_by = None
@@ -872,6 +882,10 @@ class DocumentStatusApi(DocumentResource):
 
 
 class DocumentPauseApi(DocumentResource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
     def patch(self, dataset_id, document_id):
         """pause document."""
         dataset_id = str(dataset_id)
@@ -901,6 +915,9 @@ class DocumentPauseApi(DocumentResource):
 
 
 class DocumentRecoverApi(DocumentResource):
+    @setup_required
+    @login_required
+    @account_initialization_required
     def patch(self, dataset_id, document_id):
         """recover document."""
         dataset_id = str(dataset_id)
@@ -926,6 +943,21 @@ class DocumentRecoverApi(DocumentResource):
         return {'result': 'success'}, 204
 
 
+class DocumentLimitApi(DocumentResource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        """get document limit"""
+        documents_count = DocumentService.get_tenant_documents_count()
+        tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
+
+        return {
+            'documents_count': documents_count,
+            'documents_limit': tenant_document_count
+                }, 200
+
+
 api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
 api.add_resource(DatasetDocumentListApi,
                  '/datasets/<uuid:dataset_id>/documents')
@@ -951,3 +983,4 @@ api.add_resource(DocumentStatusApi,
                  '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
 api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
 api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
+api.add_resource(DocumentLimitApi, '/datasets/limit')

+ 20 - 3
api/services/dataset_service.py

@@ -394,11 +394,20 @@ class DocumentService:
     def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
                                       account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
                                       created_from: str = 'web'):
+
         # check document limit
         if current_app.config['EDITION'] == 'CLOUD':
+            count = 0
+            if document_data["data_source"]["type"] == "upload_file":
+                upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
+                count = len(upload_file_list)
+            elif document_data["data_source"]["type"] == "notion_import":
+                notion_page_list = document_data["data_source"]['info_list']['notion_info_list']['pages']
+                count = len(notion_page_list)
             documents_count = DocumentService.get_tenant_documents_count()
+            total_count = documents_count + count
             tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-            if documents_count > tenant_document_count:
+            if total_count > tenant_document_count:
                 raise ValueError(f"over document limit {tenant_document_count}.")
         # if dataset is empty, update dataset data_source_type
         if not dataset.data_source_type:
@@ -649,12 +658,20 @@ class DocumentService:
 
     @staticmethod
     def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
+        count = 0
+        if document_data["data_source"]["type"] == "upload_file":
+            upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
+            count = len(upload_file_list)
+        elif document_data["data_source"]["type"] == "notion_import":
+            notion_page_list = document_data["data_source"]['info_list']['notion_info_list']['pages']
+            count = len(notion_page_list)
         # check document limit
         if current_app.config['EDITION'] == 'CLOUD':
             documents_count = DocumentService.get_tenant_documents_count()
+            total_count = documents_count + count
             tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
-            if documents_count > tenant_document_count:
-                raise ValueError(f"over document limit {tenant_document_count}.")
+            if total_count > tenant_document_count:
+                raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
         embedding_model = ModelFactory.get_embedding_model(
             tenant_id=tenant_id
         )