Bläddra i källkod

add document lock for multi-thread (#9873)

Jyong 5 månader sedan
förälder
incheckning
af68084895
1 ändrade filer med 155 tillägg och 153 borttagningar
  1. 155 153
      api/services/dataset_service.py

+ 155 - 153
api/services/dataset_service.py

@@ -760,166 +760,168 @@ class DocumentService:
                     )
                 db.session.add(dataset_process_rule)
                 db.session.commit()
-            position = DocumentService.get_documents_position(dataset.id)
-            document_ids = []
-            duplicate_document_ids = []
-            if document_data["data_source"]["type"] == "upload_file":
-                upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
-                for file_id in upload_file_list:
-                    file = (
-                        db.session.query(UploadFile)
-                        .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
-                        .first()
-                    )
+            lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
+            with redis_client.lock(lock_name, timeout=600):
+                position = DocumentService.get_documents_position(dataset.id)
+                document_ids = []
+                duplicate_document_ids = []
+                if document_data["data_source"]["type"] == "upload_file":
+                    upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"]
+                    for file_id in upload_file_list:
+                        file = (
+                            db.session.query(UploadFile)
+                            .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
+                            .first()
+                        )
 
-                    # raise error if file not found
-                    if not file:
-                        raise FileNotExistsError()
+                        # raise error if file not found
+                        if not file:
+                            raise FileNotExistsError()
 
-                    file_name = file.name
-                    data_source_info = {
-                        "upload_file_id": file_id,
-                    }
-                    # check duplicate
-                    if document_data.get("duplicate", False):
-                        document = Document.query.filter_by(
-                            dataset_id=dataset.id,
-                            tenant_id=current_user.current_tenant_id,
-                            data_source_type="upload_file",
-                            enabled=True,
-                            name=file_name,
-                        ).first()
-                        if document:
-                            document.dataset_process_rule_id = dataset_process_rule.id
-                            document.updated_at = datetime.datetime.utcnow()
-                            document.created_from = created_from
-                            document.doc_form = document_data["doc_form"]
-                            document.doc_language = document_data["doc_language"]
-                            document.data_source_info = json.dumps(data_source_info)
-                            document.batch = batch
-                            document.indexing_status = "waiting"
-                            db.session.add(document)
-                            documents.append(document)
-                            duplicate_document_ids.append(document.id)
-                            continue
-                    document = DocumentService.build_document(
-                        dataset,
-                        dataset_process_rule.id,
-                        document_data["data_source"]["type"],
-                        document_data["doc_form"],
-                        document_data["doc_language"],
-                        data_source_info,
-                        created_from,
-                        position,
-                        account,
-                        file_name,
-                        batch,
-                    )
-                    db.session.add(document)
-                    db.session.flush()
-                    document_ids.append(document.id)
-                    documents.append(document)
-                    position += 1
-            elif document_data["data_source"]["type"] == "notion_import":
-                notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
-                exist_page_ids = []
-                exist_document = {}
-                documents = Document.query.filter_by(
-                    dataset_id=dataset.id,
-                    tenant_id=current_user.current_tenant_id,
-                    data_source_type="notion_import",
-                    enabled=True,
-                ).all()
-                if documents:
-                    for document in documents:
-                        data_source_info = json.loads(document.data_source_info)
-                        exist_page_ids.append(data_source_info["notion_page_id"])
-                        exist_document[data_source_info["notion_page_id"]] = document.id
-                for notion_info in notion_info_list:
-                    workspace_id = notion_info["workspace_id"]
-                    data_source_binding = DataSourceOauthBinding.query.filter(
-                        db.and_(
-                            DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-                            DataSourceOauthBinding.provider == "notion",
-                            DataSourceOauthBinding.disabled == False,
-                            DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
+                        file_name = file.name
+                        data_source_info = {
+                            "upload_file_id": file_id,
+                        }
+                        # check duplicate
+                        if document_data.get("duplicate", False):
+                            document = Document.query.filter_by(
+                                dataset_id=dataset.id,
+                                tenant_id=current_user.current_tenant_id,
+                                data_source_type="upload_file",
+                                enabled=True,
+                                name=file_name,
+                            ).first()
+                            if document:
+                                document.dataset_process_rule_id = dataset_process_rule.id
+                                document.updated_at = datetime.datetime.utcnow()
+                                document.created_from = created_from
+                                document.doc_form = document_data["doc_form"]
+                                document.doc_language = document_data["doc_language"]
+                                document.data_source_info = json.dumps(data_source_info)
+                                document.batch = batch
+                                document.indexing_status = "waiting"
+                                db.session.add(document)
+                                documents.append(document)
+                                duplicate_document_ids.append(document.id)
+                                continue
+                        document = DocumentService.build_document(
+                            dataset,
+                            dataset_process_rule.id,
+                            document_data["data_source"]["type"],
+                            document_data["doc_form"],
+                            document_data["doc_language"],
+                            data_source_info,
+                            created_from,
+                            position,
+                            account,
+                            file_name,
+                            batch,
                         )
-                    ).first()
-                    if not data_source_binding:
-                        raise ValueError("Data source binding not found.")
-                    for page in notion_info["pages"]:
-                        if page["page_id"] not in exist_page_ids:
-                            data_source_info = {
-                                "notion_workspace_id": workspace_id,
-                                "notion_page_id": page["page_id"],
-                                "notion_page_icon": page["page_icon"],
-                                "type": page["type"],
-                            }
-                            document = DocumentService.build_document(
-                                dataset,
-                                dataset_process_rule.id,
-                                document_data["data_source"]["type"],
-                                document_data["doc_form"],
-                                document_data["doc_language"],
-                                data_source_info,
-                                created_from,
-                                position,
-                                account,
-                                page["page_name"],
-                                batch,
+                        db.session.add(document)
+                        db.session.flush()
+                        document_ids.append(document.id)
+                        documents.append(document)
+                        position += 1
+                elif document_data["data_source"]["type"] == "notion_import":
+                    notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"]
+                    exist_page_ids = []
+                    exist_document = {}
+                    documents = Document.query.filter_by(
+                        dataset_id=dataset.id,
+                        tenant_id=current_user.current_tenant_id,
+                        data_source_type="notion_import",
+                        enabled=True,
+                    ).all()
+                    if documents:
+                        for document in documents:
+                            data_source_info = json.loads(document.data_source_info)
+                            exist_page_ids.append(data_source_info["notion_page_id"])
+                            exist_document[data_source_info["notion_page_id"]] = document.id
+                    for notion_info in notion_info_list:
+                        workspace_id = notion_info["workspace_id"]
+                        data_source_binding = DataSourceOauthBinding.query.filter(
+                            db.and_(
+                                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                                DataSourceOauthBinding.provider == "notion",
+                                DataSourceOauthBinding.disabled == False,
+                                DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
                             )
-                            db.session.add(document)
-                            db.session.flush()
-                            document_ids.append(document.id)
-                            documents.append(document)
-                            position += 1
+                        ).first()
+                        if not data_source_binding:
+                            raise ValueError("Data source binding not found.")
+                        for page in notion_info["pages"]:
+                            if page["page_id"] not in exist_page_ids:
+                                data_source_info = {
+                                    "notion_workspace_id": workspace_id,
+                                    "notion_page_id": page["page_id"],
+                                    "notion_page_icon": page["page_icon"],
+                                    "type": page["type"],
+                                }
+                                document = DocumentService.build_document(
+                                    dataset,
+                                    dataset_process_rule.id,
+                                    document_data["data_source"]["type"],
+                                    document_data["doc_form"],
+                                    document_data["doc_language"],
+                                    data_source_info,
+                                    created_from,
+                                    position,
+                                    account,
+                                    page["page_name"],
+                                    batch,
+                                )
+                                db.session.add(document)
+                                db.session.flush()
+                                document_ids.append(document.id)
+                                documents.append(document)
+                                position += 1
+                            else:
+                                exist_document.pop(page["page_id"])
+                    # delete not selected documents
+                    if len(exist_document) > 0:
+                        clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
+                elif document_data["data_source"]["type"] == "website_crawl":
+                    website_info = document_data["data_source"]["info_list"]["website_info_list"]
+                    urls = website_info["urls"]
+                    for url in urls:
+                        data_source_info = {
+                            "url": url,
+                            "provider": website_info["provider"],
+                            "job_id": website_info["job_id"],
+                            "only_main_content": website_info.get("only_main_content", False),
+                            "mode": "crawl",
+                        }
+                        if len(url) > 255:
+                            document_name = url[:200] + "..."
                         else:
-                            exist_document.pop(page["page_id"])
-                # delete not selected documents
-                if len(exist_document) > 0:
-                    clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
-            elif document_data["data_source"]["type"] == "website_crawl":
-                website_info = document_data["data_source"]["info_list"]["website_info_list"]
-                urls = website_info["urls"]
-                for url in urls:
-                    data_source_info = {
-                        "url": url,
-                        "provider": website_info["provider"],
-                        "job_id": website_info["job_id"],
-                        "only_main_content": website_info.get("only_main_content", False),
-                        "mode": "crawl",
-                    }
-                    if len(url) > 255:
-                        document_name = url[:200] + "..."
-                    else:
-                        document_name = url
-                    document = DocumentService.build_document(
-                        dataset,
-                        dataset_process_rule.id,
-                        document_data["data_source"]["type"],
-                        document_data["doc_form"],
-                        document_data["doc_language"],
-                        data_source_info,
-                        created_from,
-                        position,
-                        account,
-                        document_name,
-                        batch,
-                    )
-                    db.session.add(document)
-                    db.session.flush()
-                    document_ids.append(document.id)
-                    documents.append(document)
-                    position += 1
-            db.session.commit()
+                            document_name = url
+                        document = DocumentService.build_document(
+                            dataset,
+                            dataset_process_rule.id,
+                            document_data["data_source"]["type"],
+                            document_data["doc_form"],
+                            document_data["doc_language"],
+                            data_source_info,
+                            created_from,
+                            position,
+                            account,
+                            document_name,
+                            batch,
+                        )
+                        db.session.add(document)
+                        db.session.flush()
+                        document_ids.append(document.id)
+                        documents.append(document)
+                        position += 1
+                db.session.commit()
 
-            # trigger async task
-            if document_ids:
-                document_indexing_task.delay(dataset.id, document_ids)
-            if duplicate_document_ids:
-                duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
+                # trigger async task
+                if document_ids:
+                    document_indexing_task.delay(dataset.id, document_ids)
+                if duplicate_document_ids:
+                    duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
 
-        return documents, batch
+            return documents, batch
 
     @staticmethod
     def check_documents_upload_quota(count: int, features: FeatureModel):