Prechádzať zdrojové kódy

fix(batch_create_segment_to_index_task): count max_position in memory. (#12929)

-LAN- 3 mesiacov pred
rodič
commit
f91f5c7401

+ 2 - 1
api/models/dataset.py

@@ -13,6 +13,7 @@ from typing import Any, cast
 
 from sqlalchemy import func
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.orm import Mapped
 
 from configs import dify_config
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -515,7 +516,7 @@ class DocumentSegment(db.Model):  # type: ignore[name-defined]
     tenant_id = db.Column(StringUUID, nullable=False)
     dataset_id = db.Column(StringUUID, nullable=False)
     document_id = db.Column(StringUUID, nullable=False)
-    position = db.Column(db.Integer, nullable=False)
+    position: Mapped[int]
     content = db.Column(db.Text, nullable=False)
     answer = db.Column(db.Text, nullable=True)
     word_count = db.Column(db.Integer, nullable=False)

+ 76 - 61
api/tasks/batch_create_segment_to_index_task.py

@@ -5,7 +5,8 @@ import uuid
 
 import click
 from celery import shared_task  # type: ignore
-from sqlalchemy import func
+from sqlalchemy import func, select
+from sqlalchemy.orm import Session
 
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
@@ -18,7 +19,12 @@ from services.vector_service import VectorService
 
 @shared_task(queue="dataset")
 def batch_create_segment_to_index_task(
-    job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str
+    job_id: str,
+    content: list,
+    dataset_id: str,
+    document_id: str,
+    tenant_id: str,
+    user_id: str,
 ):
     """
     Async batch create segment to index
@@ -37,71 +43,80 @@ def batch_create_segment_to_index_task(
     indexing_cache_key = "segment_batch_import_{}".format(job_id)
 
     try:
-        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
-        if not dataset:
-            raise ValueError("Dataset not exist.")
+        with Session(db.engine) as session:
+            dataset = session.get(Dataset, dataset_id)
+            if not dataset:
+                raise ValueError("Dataset not exist.")
 
-        dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
-        if not dataset_document:
-            raise ValueError("Document not exist.")
+            dataset_document = session.get(Document, document_id)
+            if not dataset_document:
+                raise ValueError("Document not exist.")
 
-        if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
-            raise ValueError("Document is not available.")
-        document_segments = []
-        embedding_model = None
-        if dataset.indexing_technique == "high_quality":
-            model_manager = ModelManager()
-            embedding_model = model_manager.get_model_instance(
-                tenant_id=dataset.tenant_id,
-                provider=dataset.embedding_model_provider,
-                model_type=ModelType.TEXT_EMBEDDING,
-                model=dataset.embedding_model,
+            if (
+                not dataset_document.enabled
+                or dataset_document.archived
+                or dataset_document.indexing_status != "completed"
+            ):
+                raise ValueError("Document is not available.")
+            document_segments = []
+            embedding_model = None
+            if dataset.indexing_technique == "high_quality":
+                model_manager = ModelManager()
+                embedding_model = model_manager.get_model_instance(
+                    tenant_id=dataset.tenant_id,
+                    provider=dataset.embedding_model_provider,
+                    model_type=ModelType.TEXT_EMBEDDING,
+                    model=dataset.embedding_model,
+                )
+            word_count_change = 0
+            segments_to_insert: list[str] = []
+            max_position_stmt = select(func.max(DocumentSegment.position)).where(
+                DocumentSegment.document_id == dataset_document.id
             )
-        word_count_change = 0
-        segments_to_insert: list[str] = []  # Explicitly type hint the list as List[str]
-        for segment in content:
-            content_str = segment["content"]
-            doc_id = str(uuid.uuid4())
-            segment_hash = helper.generate_text_hash(content_str)
-            # calc embedding use tokens
-            tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
-            max_position = (
-                db.session.query(func.max(DocumentSegment.position))
-                .filter(DocumentSegment.document_id == dataset_document.id)
-                .scalar()
-            )
-            segment_document = DocumentSegment(
-                tenant_id=tenant_id,
-                dataset_id=dataset_id,
-                document_id=document_id,
-                index_node_id=doc_id,
-                index_node_hash=segment_hash,
-                position=max_position + 1 if max_position else 1,
-                content=content_str,
-                word_count=len(content_str),
-                tokens=tokens,
-                created_by=user_id,
-                indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
-                status="completed",
-                completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
-            )
-            if dataset_document.doc_form == "qa_model":
-                segment_document.answer = segment["answer"]
-                segment_document.word_count += len(segment["answer"])
-            word_count_change += segment_document.word_count
-            db.session.add(segment_document)
-            document_segments.append(segment_document)
-            segments_to_insert.append(str(segment))  # Cast to string if needed
-        # update document word count
-        dataset_document.word_count += word_count_change
-        db.session.add(dataset_document)
-        # add index to db
-        VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
-        db.session.commit()
+            max_position = session.scalar(max_position_stmt) or 1
+            for segment in content:
+                content_str = segment["content"]
+                doc_id = str(uuid.uuid4())
+                segment_hash = helper.generate_text_hash(content_str)
+                # calc embedding use tokens
+                tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
+                segment_document = DocumentSegment(
+                    tenant_id=tenant_id,
+                    dataset_id=dataset_id,
+                    document_id=document_id,
+                    index_node_id=doc_id,
+                    index_node_hash=segment_hash,
+                    position=max_position,
+                    content=content_str,
+                    word_count=len(content_str),
+                    tokens=tokens,
+                    created_by=user_id,
+                    indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+                    status="completed",
+                    completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+                )
+                max_position += 1
+                if dataset_document.doc_form == "qa_model":
+                    segment_document.answer = segment["answer"]
+                    segment_document.word_count += len(segment["answer"])
+                word_count_change += segment_document.word_count
+                session.add(segment_document)
+                document_segments.append(segment_document)
+                segments_to_insert.append(str(segment))  # Cast to string if needed
+            # update document word count
+            dataset_document.word_count += word_count_change
+            session.add(dataset_document)
+            # add index to db
+            VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
+            session.commit()
+
         redis_client.setex(indexing_cache_key, 600, "completed")
         end_at = time.perf_counter()
         logging.info(
-            click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green")
+            click.style(
+                "Segment batch created job: {} latency: {}".format(job_id, end_at - start_at),
+                fg="green",
+            )
         )
     except Exception as e:
         logging.exception("Segments batch created index failed")