Przeglądaj źródła

Fix/qdrant data issue (#1203)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 rok temu
rodzic
commit
724e053732

+ 80 - 79
api/commands.py

@@ -3,12 +3,13 @@ import json
 import math
 import random
 import string
+import threading
 import time
 import uuid
 
 import click
 from tqdm import tqdm
-from flask import current_app
+from flask import current_app, Flask
 from langchain.embeddings import OpenAIEmbeddings
 from werkzeug.exceptions import NotFound
 
@@ -456,92 +457,92 @@ def update_qdrant_indexes():
 @click.command('normalization-collections', help='restore all collections in one')
 def normalization_collections():
     click.echo(click.style('Start normalization collections.', fg='green'))
-    normalization_count = 0
-
+    normalization_count = []
     page = 1
     while True:
         try:
             datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
-                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
         except NotFound:
             break
-
+        datasets_result = datasets.items
         page += 1
-        for dataset in datasets:
-            if not dataset.collection_binding_id:
-                try:
-                    click.echo('restore dataset index: {}'.format(dataset.id))
-                    try:
-                        embedding_model = ModelFactory.get_embedding_model(
-                            tenant_id=dataset.tenant_id,
-                            model_provider_name=dataset.embedding_model_provider,
-                            model_name=dataset.embedding_model
-                        )
-                    except Exception:
-                        provider = Provider(
-                            id='provider_id',
-                            tenant_id=dataset.tenant_id,
-                            provider_name='openai',
-                            provider_type=ProviderType.CUSTOM.value,
-                            encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
-                            is_valid=True,
-                        )
-                        model_provider = OpenAIProvider(provider=provider)
-                        embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
-                                                          model_provider=model_provider)
-                    embeddings = CacheEmbedding(embedding_model)
-                    dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
-                        filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
-                               DatasetCollectionBinding.model_name == embedding_model.name). \
-                        order_by(DatasetCollectionBinding.created_at). \
-                        first()
-
-                    if not dataset_collection_binding:
-                        dataset_collection_binding = DatasetCollectionBinding(
-                            provider_name=embedding_model.model_provider.provider_name,
-                            model_name=embedding_model.name,
-                            collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
-                        )
-                        db.session.add(dataset_collection_binding)
-                        db.session.commit()
+        for i in range(0, len(datasets_result), 5):
+            threads = []
+            sub_datasets = datasets_result[i:i + 5]
+            for dataset in sub_datasets:
+                document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
+                    'flask_app': current_app._get_current_object(),
+                    'dataset': dataset,
+                    'normalization_count': normalization_count
+                })
+                threads.append(document_format_thread)
+                document_format_thread.start()
+            for thread in threads:
+                thread.join()
+
+    click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
+
+
+def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
+    with flask_app.app_context():
+        try:
+            click.echo('restore dataset index: {}'.format(dataset.id))
+            try:
+                embedding_model = ModelFactory.get_embedding_model(
+                    tenant_id=dataset.tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except Exception:
+                provider = Provider(
+                    id='provider_id',
+                    tenant_id=dataset.tenant_id,
+                    provider_name='openai',
+                    provider_type=ProviderType.CUSTOM.value,
+                    encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
+                    is_valid=True,
+                )
+                model_provider = OpenAIProvider(provider=provider)
+                embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
+                                                  model_provider=model_provider)
+            embeddings = CacheEmbedding(embedding_model)
+            dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
+                       DatasetCollectionBinding.model_name == embedding_model.name). \
+                order_by(DatasetCollectionBinding.created_at). \
+                first()
+
+            if not dataset_collection_binding:
+                dataset_collection_binding = DatasetCollectionBinding(
+                    provider_name=embedding_model.model_provider.provider_name,
+                    model_name=embedding_model.name,
+                    collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
+                )
+                db.session.add(dataset_collection_binding)
+                db.session.commit()
 
-                    from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
-
-                    index = QdrantVectorIndex(
-                        dataset=dataset,
-                        config=QdrantConfig(
-                            endpoint=current_app.config.get('QDRANT_URL'),
-                            api_key=current_app.config.get('QDRANT_API_KEY'),
-                            root_path=current_app.root_path
-                        ),
-                        embeddings=embeddings
-                    )
-                    if index:
-                        index.restore_dataset_in_one(dataset, dataset_collection_binding)
-                    else:
-                        click.echo('passed.')
-
-                    original_index = QdrantVectorIndex(
-                        dataset=dataset,
-                        config=QdrantConfig(
-                            endpoint=current_app.config.get('QDRANT_URL'),
-                            api_key=current_app.config.get('QDRANT_API_KEY'),
-                            root_path=current_app.root_path
-                        ),
-                        embeddings=embeddings
-                    )
-                    if original_index:
-                        original_index.delete_original_collection(dataset, dataset_collection_binding)
-                        normalization_count += 1
-                    else:
-                        click.echo('passed.')
-                except Exception as e:
-                    click.echo(
-                        click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
-                                    fg='red'))
-                    continue
-
-    click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
+            from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
+
+            index = QdrantVectorIndex(
+                dataset=dataset,
+                config=QdrantConfig(
+                    endpoint=current_app.config.get('QDRANT_URL'),
+                    api_key=current_app.config.get('QDRANT_API_KEY'),
+                    root_path=current_app.root_path
+                ),
+                embeddings=embeddings
+            )
+            if index:
+                # index.delete_by_group_id(dataset.id)
+                index.restore_dataset_in_one(dataset, dataset_collection_binding)
+            else:
+                click.echo('passed.')
+            normalization_count.append(1)
+        except Exception as e:
+            click.echo(
+                click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
+                            fg='red'))
 
 
 @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')

+ 5 - 3
api/core/index/vector_index/base.py

@@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex):
     def delete_by_group_id(self, group_id: str) -> None:
         vector_store = self._get_vector_store()
         vector_store = cast(self._get_vector_store_class(), vector_store)
-
-        vector_store.delete()
+        if self.dataset.collection_binding_id:
+            vector_store.delete_by_group_id(group_id)
+        else:
+            vector_store.delete()
 
     def delete(self) -> None:
         vector_store = self._get_vector_store()
@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
 
         if documents:
             try:
-                self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
+                self.add_texts(documents)
             except Exception as e:
                 raise e
 

+ 67 - 64
api/core/index/vector_index/qdrant.py

@@ -1390,70 +1390,12 @@ class Qdrant(VectorStore):
             path=path,
             **kwargs,
         )
-        try:
-            # Skip any validation in case of forced collection recreate.
-            if force_recreate:
-                raise ValueError
-
-            # Get the vector configuration of the existing collection and vector, if it
-            # was specified. If the old configuration does not match the current one,
-            # an exception is being thrown.
-            collection_info = client.get_collection(collection_name=collection_name)
-            current_vector_config = collection_info.config.params.vectors
-            if isinstance(current_vector_config, dict) and vector_name is not None:
-                if vector_name not in current_vector_config:
-                    raise QdrantException(
-                        f"Existing Qdrant collection {collection_name} does not "
-                        f"contain vector named {vector_name}. Did you mean one of the "
-                        f"existing vectors: {', '.join(current_vector_config.keys())}? "
-                        f"If you want to recreate the collection, set `force_recreate` "
-                        f"parameter to `True`."
-                    )
-                current_vector_config = current_vector_config.get(
-                    vector_name
-                )  # type: ignore[assignment]
-            elif isinstance(current_vector_config, dict) and vector_name is None:
-                raise QdrantException(
-                    f"Existing Qdrant collection {collection_name} uses named vectors. "
-                    f"If you want to reuse it, please set `vector_name` to any of the "
-                    f"existing named vectors: "
-                    f"{', '.join(current_vector_config.keys())}."  # noqa
-                    f"If you want to recreate the collection, set `force_recreate` "
-                    f"parameter to `True`."
-                )
-            elif (
-                not isinstance(current_vector_config, dict) and vector_name is not None
-            ):
-                raise QdrantException(
-                    f"Existing Qdrant collection {collection_name} doesn't use named "
-                    f"vectors. If you want to reuse it, please set `vector_name` to "
-                    f"`None`. If you want to recreate the collection, set "
-                    f"`force_recreate` parameter to `True`."
-                )
-
-            # Check if the vector configuration has the same dimensionality.
-            if current_vector_config.size != vector_size:  # type: ignore[union-attr]
-                raise QdrantException(
-                    f"Existing Qdrant collection is configured for vectors with "
-                    f"{current_vector_config.size} "  # type: ignore[union-attr]
-                    f"dimensions. Selected embeddings are {vector_size}-dimensional. "
-                    f"If you want to recreate the collection, set `force_recreate` "
-                    f"parameter to `True`."
-                )
-
-            current_distance_func = (
-                current_vector_config.distance.name.upper()  # type: ignore[union-attr]
-            )
-            if current_distance_func != distance_func:
-                raise QdrantException(
-                    f"Existing Qdrant collection is configured for "
-                    f"{current_vector_config.distance} "  # type: ignore[union-attr]
-                    f"similarity. Please set `distance_func` parameter to "
-                    f"`{distance_func}` if you want to reuse it. If you want to "
-                    f"recreate the collection, set `force_recreate` parameter to "
-                    f"`True`."
-                )
-        except (UnexpectedResponse, RpcError, ValueError):
+        all_collection_name = []
+        collections_response = client.get_collections()
+        collection_list = collections_response.collections
+        for collection in collection_list:
+            all_collection_name.append(collection.name)
+        if collection_name not in all_collection_name:
             vectors_config = rest.VectorParams(
                 size=vector_size,
                 distance=rest.Distance[distance_func],
@@ -1481,6 +1423,67 @@ class Qdrant(VectorStore):
                 timeout=timeout,  # type: ignore[arg-type]
             )
             is_new_collection = True
+        if force_recreate:
+            raise ValueError
+
+        # Get the vector configuration of the existing collection and vector, if it
+        # was specified. If the old configuration does not match the current one,
+        # an exception is being thrown.
+        collection_info = client.get_collection(collection_name=collection_name)
+        current_vector_config = collection_info.config.params.vectors
+        if isinstance(current_vector_config, dict) and vector_name is not None:
+            if vector_name not in current_vector_config:
+                raise QdrantException(
+                    f"Existing Qdrant collection {collection_name} does not "
+                    f"contain vector named {vector_name}. Did you mean one of the "
+                    f"existing vectors: {', '.join(current_vector_config.keys())}? "
+                    f"If you want to recreate the collection, set `force_recreate` "
+                    f"parameter to `True`."
+                )
+            current_vector_config = current_vector_config.get(
+                vector_name
+            )  # type: ignore[assignment]
+        elif isinstance(current_vector_config, dict) and vector_name is None:
+            raise QdrantException(
+                f"Existing Qdrant collection {collection_name} uses named vectors. "
+                f"If you want to reuse it, please set `vector_name` to any of the "
+                f"existing named vectors: "
+                f"{', '.join(current_vector_config.keys())}."  # noqa
+                f"If you want to recreate the collection, set `force_recreate` "
+                f"parameter to `True`."
+            )
+        elif (
+                not isinstance(current_vector_config, dict) and vector_name is not None
+        ):
+            raise QdrantException(
+                f"Existing Qdrant collection {collection_name} doesn't use named "
+                f"vectors. If you want to reuse it, please set `vector_name` to "
+                f"`None`. If you want to recreate the collection, set "
+                f"`force_recreate` parameter to `True`."
+            )
+
+        # Check if the vector configuration has the same dimensionality.
+        if current_vector_config.size != vector_size:  # type: ignore[union-attr]
+            raise QdrantException(
+                f"Existing Qdrant collection is configured for vectors with "
+                f"{current_vector_config.size} "  # type: ignore[union-attr]
+                f"dimensions. Selected embeddings are {vector_size}-dimensional. "
+                f"If you want to recreate the collection, set `force_recreate` "
+                f"parameter to `True`."
+            )
+
+        current_distance_func = (
+            current_vector_config.distance.name.upper()  # type: ignore[union-attr]
+        )
+        if current_distance_func != distance_func:
+            raise QdrantException(
+                f"Existing Qdrant collection is configured for "
+                f"{current_vector_config.distance} "  # type: ignore[union-attr]
+                f"similarity. Please set `distance_func` parameter to "
+                f"`{distance_func}` if you want to reuse it. If you want to "
+                f"recreate the collection, set `force_recreate` parameter to "
+                f"`True`."
+            )
         qdrant = cls(
             client=client,
             collection_name=collection_name,

+ 13 - 0
api/core/index/vector_index/qdrant_vector_index.py

@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
             ],
         ))
 
+    def delete(self) -> None:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+        vector_store.del_texts(models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self.dataset.id),
+                ),
+            ],
+        ))
 
     def _is_origin(self):
         if self.dataset.index_struct_dict:

+ 2 - 1
api/events/event_handlers/clean_when_dataset_deleted.py

@@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
 @dataset_was_deleted.connect
 def handle(sender, **kwargs):
     dataset = sender
-    clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct)
+    clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
+                             dataset.index_struct, dataset.collection_binding_id)

+ 6 - 4
api/tasks/clean_dataset_task.py

@@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase
 
 
 @shared_task(queue='dataset')
-def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str):
+def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
+                       index_struct: str, collection_binding_id: str):
     """
     Clean dataset when dataset deleted.
     :param dataset_id: dataset id
     :param tenant_id: tenant id
     :param indexing_technique: indexing technique
     :param index_struct: index struct dict
+    :param collection_binding_id: collection binding id
 
     Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
     """
@@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
             id=dataset_id,
             tenant_id=tenant_id,
             indexing_technique=indexing_technique,
-            index_struct=index_struct
+            index_struct=index_struct,
+            collection_binding_id=collection_binding_id
         )
-
         documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
         segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
 
@@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
         if dataset.indexing_technique == 'high_quality':
             vector_index = IndexBuilder.get_default_high_quality_index(dataset)
             try:
-                vector_index.delete()
+                vector_index.delete_by_group_id(dataset.id)
             except Exception:
                 logging.exception("Delete doc index failed when dataset deleted.")
 

+ 2 - 2
api/tasks/deal_dataset_vector_index_task.py

@@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
             raise Exception('Dataset not found')
 
         if action == "remove":
-            index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
-            index.delete()
+            index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
+            index.delete_by_group_id(dataset.id)
         elif action == "add":
             dataset_documents = db.session.query(DatasetDocument).filter(
                 DatasetDocument.dataset_id == dataset_id,