Prechádzať zdrojové kódy

Feat/improve vector database logic (#1193)

Co-authored-by: jyong <jyong@dify.ai>
Jyong 1 rok pred
rodič
commit
269a465fc4

+ 124 - 17
api/commands.py

@@ -4,6 +4,7 @@ import math
 import random
 import string
 import time
+import uuid
 
 import click
 from tqdm import tqdm
@@ -23,7 +24,7 @@ from libs.helper import email as email_validate
 from extensions.ext_database import db
 from libs.rsa import generate_key_pair
 from models.account import InvitationCode, Tenant, TenantAccountJoin
-from models.dataset import Dataset, DatasetQuery, Document
+from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
 from models.model import Account, AppModelConfig, App
 import secrets
 import base64
@@ -239,7 +240,13 @@ def clean_unused_dataset_indexes():
                         kw_index = IndexBuilder.get_index(dataset, 'economy')
                         # delete from vector index
                         if vector_index:
-                            vector_index.delete()
+                            if dataset.collection_binding_id:
+                                vector_index.delete_by_group_id(dataset.id)
+                            else:
+                                if dataset.collection_binding_id:
+                                    vector_index.delete_by_group_id(dataset.id)
+                                else:
+                                    vector_index.delete()
                         kw_index.delete()
                         # update document
                         update_params = {
@@ -346,7 +353,8 @@ def create_qdrant_indexes():
                                     is_valid=True,
                                 )
                                 model_provider = OpenAIProvider(provider=provider)
-                                embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
+                                embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
+                                                                  model_provider=model_provider)
                         embeddings = CacheEmbedding(embedding_model)
 
                         from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@@ -364,7 +372,8 @@ def create_qdrant_indexes():
                             index.create_qdrant_dataset(dataset)
                             index_struct = {
                                 "type": 'qdrant',
-                                "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
+                                "vector_store": {
+                                    "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
                             }
                             dataset.index_struct = json.dumps(index_struct)
                             db.session.commit()
@@ -373,7 +382,8 @@ def create_qdrant_indexes():
                             click.echo('passed.')
                     except Exception as e:
                         click.echo(
-                            click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                            click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                        fg='red'))
                         continue
 
     click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@@ -414,7 +424,8 @@ def update_qdrant_indexes():
                                 is_valid=True,
                             )
                             model_provider = OpenAIProvider(provider=provider)
-                            embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
+                            embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
+                                                              model_provider=model_provider)
                         embeddings = CacheEmbedding(embedding_model)
 
                         from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@@ -435,11 +446,104 @@ def update_qdrant_indexes():
                             click.echo('passed.')
                     except Exception as e:
                         click.echo(
-                            click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
+                            click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                        fg='red'))
                         continue
 
     click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
 
+
+@click.command('normalization-collections', help='restore all collections in one')
+def normalization_collections():
+    click.echo(click.style('Start normalization collections.', fg='green'))
+    normalization_count = 0
+
+    page = 1
+    while True:
+        try:
+            datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
+                .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
+        except NotFound:
+            break
+
+        page += 1
+        for dataset in datasets:
+            if not dataset.collection_binding_id:
+                try:
+                    click.echo('restore dataset index: {}'.format(dataset.id))
+                    try:
+                        embedding_model = ModelFactory.get_embedding_model(
+                            tenant_id=dataset.tenant_id,
+                            model_provider_name=dataset.embedding_model_provider,
+                            model_name=dataset.embedding_model
+                        )
+                    except Exception:
+                        provider = Provider(
+                            id='provider_id',
+                            tenant_id=dataset.tenant_id,
+                            provider_name='openai',
+                            provider_type=ProviderType.CUSTOM.value,
+                            encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
+                            is_valid=True,
+                        )
+                        model_provider = OpenAIProvider(provider=provider)
+                        embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
+                                                          model_provider=model_provider)
+                    embeddings = CacheEmbedding(embedding_model)
+                    dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                        filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
+                               DatasetCollectionBinding.model_name == embedding_model.name). \
+                        order_by(DatasetCollectionBinding.created_at). \
+                        first()
+
+                    if not dataset_collection_binding:
+                        dataset_collection_binding = DatasetCollectionBinding(
+                            provider_name=embedding_model.model_provider.provider_name,
+                            model_name=embedding_model.name,
+                            collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
+                        )
+                        db.session.add(dataset_collection_binding)
+                        db.session.commit()
+
+                    from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
+
+                    index = QdrantVectorIndex(
+                        dataset=dataset,
+                        config=QdrantConfig(
+                            endpoint=current_app.config.get('QDRANT_URL'),
+                            api_key=current_app.config.get('QDRANT_API_KEY'),
+                            root_path=current_app.root_path
+                        ),
+                        embeddings=embeddings
+                    )
+                    if index:
+                        index.restore_dataset_in_one(dataset, dataset_collection_binding)
+                    else:
+                        click.echo('passed.')
+
+                    original_index = QdrantVectorIndex(
+                        dataset=dataset,
+                        config=QdrantConfig(
+                            endpoint=current_app.config.get('QDRANT_URL'),
+                            api_key=current_app.config.get('QDRANT_API_KEY'),
+                            root_path=current_app.root_path
+                        ),
+                        embeddings=embeddings
+                    )
+                    if original_index:
+                        original_index.delete_original_collection(dataset, dataset_collection_binding)
+                        normalization_count += 1
+                    else:
+                        click.echo('passed.')
+                except Exception as e:
+                    click.echo(
+                        click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                    fg='red'))
+                    continue
+
+    click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
+
+
 @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
 @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
 def update_app_model_configs(batch_size):
@@ -473,7 +577,7 @@ def update_app_model_configs(batch_size):
         .join(App, App.app_model_config_id == AppModelConfig.id) \
         .filter(App.mode == 'completion') \
         .count()
-    
+
     if total_records == 0:
         click.secho("No data to migrate.", fg='green')
         return
@@ -485,14 +589,14 @@ def update_app_model_configs(batch_size):
             offset = i * batch_size
             limit = min(batch_size, total_records - offset)
 
-            click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
-            
+            click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
+
             data_batch = db.session.query(AppModelConfig) \
                 .join(App, App.app_model_config_id == AppModelConfig.id) \
                 .filter(App.mode == 'completion') \
                 .order_by(App.created_at) \
                 .offset(offset).limit(limit).all()
-            
+
             if not data_batch:
                 click.secho("No more data to migrate.", fg='green')
                 break
@@ -512,7 +616,7 @@ def update_app_model_configs(batch_size):
                     app_data = db.session.query(App) \
                         .filter(App.id == data.app_id) \
                         .one()
-                    
+
                     account_data = db.session.query(Account) \
                         .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
                         .filter(TenantAccountJoin.role == 'owner') \
@@ -534,13 +638,15 @@ def update_app_model_configs(batch_size):
                 db.session.commit()
 
             except Exception as e:
-                click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
+                click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
+                            fg='red')
                 continue
-            
-            click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
-            
+
+            click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
+
             pbar.update(len(data_batch))
 
+
 def register_commands(app):
     app.cli.add_command(reset_password)
     app.cli.add_command(reset_email)
@@ -551,4 +657,5 @@ def register_commands(app):
     app.cli.add_command(clean_unused_dataset_indexes)
     app.cli.add_command(create_qdrant_indexes)
     app.cli.add_command(update_qdrant_indexes)
-    app.cli.add_command(update_app_model_configs)
+    app.cli.add_command(update_app_model_configs)
+    app.cli.add_command(normalization_collections)

+ 8 - 0
api/core/index/base.py

@@ -16,6 +16,10 @@ class BaseIndex(ABC):
     def create(self, texts: list[Document], **kwargs) -> BaseIndex:
         raise NotImplementedError
 
+    @abstractmethod
+    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+        raise NotImplementedError
+
     @abstractmethod
     def add_texts(self, texts: list[Document], **kwargs):
         raise NotImplementedError
@@ -28,6 +32,10 @@ class BaseIndex(ABC):
     def delete_by_ids(self, ids: list[str]) -> None:
         raise NotImplementedError
 
+    @abstractmethod
+    def delete_by_group_id(self, group_id: str) -> None:
+        raise NotImplementedError
+
     @abstractmethod
     def delete_by_document_id(self, document_id: str):
         raise NotImplementedError

+ 32 - 0
api/core/index/keyword_table_index/keyword_table_index.py

@@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
 
         return self
 
+    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+        keyword_table_handler = JiebaKeywordTableHandler()
+        keyword_table = {}
+        for text in texts:
+            keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
+            self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
+            keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
+
+        dataset_keyword_table = DatasetKeywordTable(
+            dataset_id=self.dataset.id,
+            keyword_table=json.dumps({
+                '__type__': 'keyword_table',
+                '__data__': {
+                    "index_id": self.dataset.id,
+                    "summary": None,
+                    "table": {}
+                }
+            }, cls=SetEncoder)
+        )
+        db.session.add(dataset_keyword_table)
+        db.session.commit()
+
+        self._save_dataset_keyword_table(keyword_table)
+
+        return self
+
     def add_texts(self, texts: list[Document], **kwargs):
         keyword_table_handler = JiebaKeywordTableHandler()
 
@@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
             db.session.delete(dataset_keyword_table)
             db.session.commit()
 
+    def delete_by_group_id(self, group_id: str) -> None:
+        dataset_keyword_table = self.dataset.dataset_keyword_table
+        if dataset_keyword_table:
+            db.session.delete(dataset_keyword_table)
+            db.session.commit()
+
     def _save_dataset_keyword_table(self, keyword_table):
         keyword_table_dict = {
             '__type__': 'keyword_table',

+ 57 - 1
api/core/index/vector_index/base.py

@@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
 
 from core.index.base import BaseIndex
 from extensions.ext_database import db
-from models.dataset import Dataset, DocumentSegment
+from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
 from models.dataset import Document as DatasetDocument
 
 
@@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex):
         for node_id in ids:
             vector_store.del_text(node_id)
 
+    def delete_by_group_id(self, group_id: str) -> None:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        vector_store.delete()
+
     def delete(self) -> None:
         vector_store = self._get_vector_store()
         vector_store = cast(self._get_vector_store_class(), vector_store)
@@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex):
                 raise e
 
         logging.info(f"Dataset {dataset.id} recreate successfully.")
+
+    def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
+        logging.info(f"restore dataset in_one,_dataset {dataset.id}")
+
+        dataset_documents = db.session.query(DatasetDocument).filter(
+            DatasetDocument.dataset_id == dataset.id,
+            DatasetDocument.indexing_status == 'completed',
+            DatasetDocument.enabled == True,
+            DatasetDocument.archived == False,
+        ).all()
+
+        documents = []
+        for dataset_document in dataset_documents:
+            segments = db.session.query(DocumentSegment).filter(
+                DocumentSegment.document_id == dataset_document.id,
+                DocumentSegment.status == 'completed',
+                DocumentSegment.enabled == True
+            ).all()
+
+            for segment in segments:
+                document = Document(
+                    page_content=segment.content,
+                    metadata={
+                        "doc_id": segment.index_node_id,
+                        "doc_hash": segment.index_node_hash,
+                        "document_id": segment.document_id,
+                        "dataset_id": segment.dataset_id,
+                    }
+                )
+
+                documents.append(document)
+
+        if documents:
+            try:
+                self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
+            except Exception as e:
+                raise e
+
+        logging.info(f"Dataset {dataset.id} recreate successfully.")
+
+    def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
+        logging.info(f"delete original collection: {dataset.id}")
+
+        self.delete()
+
+        dataset.collection_binding_id = dataset_collection_binding.id
+        db.session.add(dataset)
+        db.session.commit()
+
+        logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 13 - 0
api/core/index/vector_index/milvus_vector_index.py

@@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
 
         return self
 
+    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = WeaviateVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            client=self._client,
+            index_name=collection_name,
+            uuids=uuids,
+            by_text=False
+        )
+
+        return self
+
     def _get_vector_store(self) -> VectorStore:
         """Only for created index."""
         if self._vector_store:

+ 38 - 2
api/core/index/vector_index/qdrant.py

@@ -28,6 +28,7 @@ from langchain.docstore.document import Document
 from langchain.embeddings.base import Embeddings
 from langchain.vectorstores import VectorStore
 from langchain.vectorstores.utils import maximal_marginal_relevance
+from qdrant_client.http.models import PayloadSchemaType
 
 if TYPE_CHECKING:
     from qdrant_client import grpc  # noqa
@@ -84,6 +85,7 @@ class Qdrant(VectorStore):
 
     CONTENT_KEY = "page_content"
     METADATA_KEY = "metadata"
+    GROUP_KEY = "group_id"
     VECTOR_NAME = None
 
     def __init__(
@@ -93,9 +95,12 @@ class Qdrant(VectorStore):
         embeddings: Optional[Embeddings] = None,
         content_payload_key: str = CONTENT_KEY,
         metadata_payload_key: str = METADATA_KEY,
+        group_payload_key: str = GROUP_KEY,
+        group_id: str = None,
         distance_strategy: str = "COSINE",
         vector_name: Optional[str] = VECTOR_NAME,
         embedding_function: Optional[Callable] = None,  # deprecated
+        is_new_collection: bool = False
     ):
         """Initialize with necessary components."""
         try:
@@ -129,7 +134,10 @@ class Qdrant(VectorStore):
         self.collection_name = collection_name
         self.content_payload_key = content_payload_key or self.CONTENT_KEY
         self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
+        self.group_payload_key = group_payload_key or self.GROUP_KEY
         self.vector_name = vector_name or self.VECTOR_NAME
+        self.group_id = group_id
+        self.is_new_collection= is_new_collection
 
         if embedding_function is not None:
             warnings.warn(
@@ -170,6 +178,8 @@ class Qdrant(VectorStore):
             batch_size:
                 How many vectors upload per-request.
                 Default: 64
+            group_id:
+                collection group
 
         Returns:
             List of ids from adding the texts into the vectorstore.
@@ -182,7 +192,11 @@ class Qdrant(VectorStore):
                 collection_name=self.collection_name, points=points, **kwargs
             )
             added_ids.extend(batch_ids)
-
+        # if is new collection, create payload index on group_id
+        if self.is_new_collection:
+            self.client.create_payload_index(self.collection_name, self.group_payload_key,
+                                             field_schema=PayloadSchemaType.KEYWORD,
+                                             field_type=PayloadSchemaType.KEYWORD)
         return added_ids
 
     @sync_call_fallback
@@ -970,6 +984,8 @@ class Qdrant(VectorStore):
         distance_func: str = "Cosine",
         content_payload_key: str = CONTENT_KEY,
         metadata_payload_key: str = METADATA_KEY,
+        group_payload_key: str = GROUP_KEY,
+        group_id: str = None,
         vector_name: Optional[str] = VECTOR_NAME,
         batch_size: int = 64,
         shard_number: Optional[int] = None,
@@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
             metadata_payload_key:
                 A payload key used to store the metadata of the document.
                 Default: "metadata"
+            group_payload_key:
+                A payload key used to store the content of the document.
+                Default: "group_id"
+            group_id:
+                collection group id
             vector_name:
                 Name of the vector to be used internally in Qdrant.
                 Default: None
@@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
             distance_func,
             content_payload_key,
             metadata_payload_key,
+            group_payload_key,
+            group_id,
             vector_name,
             shard_number,
             replication_factor,
@@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
         distance_func: str = "Cosine",
         content_payload_key: str = CONTENT_KEY,
         metadata_payload_key: str = METADATA_KEY,
+        group_payload_key: str = GROUP_KEY,
+        group_id: str = None,
         vector_name: Optional[str] = VECTOR_NAME,
         shard_number: Optional[int] = None,
         replication_factor: Optional[int] = None,
@@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
         vector_size = len(partial_embeddings[0])
         collection_name = collection_name or uuid.uuid4().hex
         distance_func = distance_func.upper()
+        is_new_collection = False
         client = qdrant_client.QdrantClient(
             location=location,
             url=url,
@@ -1454,6 +1480,7 @@ class Qdrant(VectorStore):
                 init_from=init_from,
                 timeout=timeout,  # type: ignore[arg-type]
             )
+            is_new_collection = True
         qdrant = cls(
             client=client,
             collection_name=collection_name,
@@ -1462,6 +1489,9 @@ class Qdrant(VectorStore):
             metadata_payload_key=metadata_payload_key,
             distance_strategy=distance_func,
             vector_name=vector_name,
+            group_id=group_id,
+            group_payload_key=group_payload_key,
+            is_new_collection=is_new_collection
         )
         return qdrant
 
@@ -1516,6 +1546,8 @@ class Qdrant(VectorStore):
         metadatas: Optional[List[dict]],
         content_payload_key: str,
         metadata_payload_key: str,
+        group_id: str,
+        group_payload_key: str
     ) -> List[dict]:
         payloads = []
         for i, text in enumerate(texts):
@@ -1529,6 +1561,7 @@ class Qdrant(VectorStore):
                 {
                     content_payload_key: text,
                     metadata_payload_key: metadata,
+                    group_payload_key: group_id
                 }
             )
 
@@ -1578,7 +1611,7 @@ class Qdrant(VectorStore):
         else:
             out.append(
                 rest.FieldCondition(
-                    key=f"{self.metadata_payload_key}.{key}",
+                    key=key,
                     match=rest.MatchValue(value=value),
                 )
             )
@@ -1654,6 +1687,7 @@ class Qdrant(VectorStore):
         metadatas: Optional[List[dict]] = None,
         ids: Optional[Sequence[str]] = None,
         batch_size: int = 64,
+        group_id: Optional[str] = None,
     ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
         from qdrant_client.http import models as rest
 
@@ -1684,6 +1718,8 @@ class Qdrant(VectorStore):
                         batch_metadatas,
                         self.content_payload_key,
                         self.metadata_payload_key,
+                        self.group_id,
+                        self.group_payload_key
                     ),
                 )
             ]

+ 58 - 20
api/core/index/vector_index/qdrant_vector_index.py

@@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
 from langchain.schema import Document, BaseRetriever
 from langchain.vectorstores import VectorStore
 from pydantic import BaseModel
+from qdrant_client.http.models import HnswConfigDiff
 
 from core.index.base import BaseIndex
 from core.index.vector_index.base import BaseVectorIndex
 from core.vector_store.qdrant_vector_store import QdrantVectorStore
-from models.dataset import Dataset
+from extensions.ext_database import db
+from models.dataset import Dataset, DatasetCollectionBinding
 
 
 class QdrantConfig(BaseModel):
     endpoint: str
     api_key: Optional[str]
     root_path: Optional[str]
-    
+
     def to_qdrant_params(self):
         if self.endpoint and self.endpoint.startswith('path:'):
             path = self.endpoint.replace('path:', '')
@@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
         return 'qdrant'
 
     def get_index_name(self, dataset: Dataset) -> str:
-        if self.dataset.index_struct_dict:
-            class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
-            if not class_prefix.endswith('_Node'):
-                # original class_prefix
-                class_prefix += '_Node'
-
-            return class_prefix
+        if dataset.collection_binding_id:
+            dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+                filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
+                one_or_none()
+            if dataset_collection_binding:
+                return dataset_collection_binding.collection_name
+            else:
+                raise ValueError('Dataset Collection Bindings is not exist!')
+        else:
+            if self.dataset.index_struct_dict:
+                class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
+                return class_prefix
 
-        dataset_id = dataset.id
-        return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
+            dataset_id = dataset.id
+            return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
 
     def to_index_struct(self) -> dict:
         return {
@@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
             collection_name=self.get_index_name(self.dataset),
             ids=uuids,
             content_payload_key='page_content',
+            group_id=self.dataset.id,
+            group_payload_key='group_id',
+            hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
+                                       max_indexing_threads=0, on_disk=False),
+            **self._client_config.to_qdrant_params()
+        )
+
+        return self
+
+    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = QdrantVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            collection_name=collection_name,
+            ids=uuids,
+            content_payload_key='page_content',
+            group_id=self.dataset.id,
+            group_payload_key='group_id',
+            hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
+                                       max_indexing_threads=0, on_disk=False),
             **self._client_config.to_qdrant_params()
         )
 
@@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
         if self._vector_store:
             return self._vector_store
         attributes = ['doc_id', 'dataset_id', 'document_id']
-        if self._is_origin():
-            attributes = ['doc_id']
         client = qdrant_client.QdrantClient(
             **self._client_config.to_qdrant_params()
         )
@@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
             client=client,
             collection_name=self.get_index_name(self.dataset),
             embeddings=self._embeddings,
-            content_payload_key='page_content'
+            content_payload_key='page_content',
+            group_id=self.dataset.id,
+            group_payload_key='group_id'
         )
 
     def _get_vector_store_class(self) -> type:
         return QdrantVectorStore
 
     def delete_by_document_id(self, document_id: str):
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
 
         vector_store = self._get_vector_store()
         vector_store = cast(self._get_vector_store_class(), vector_store)
@@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
         ))
 
     def delete_by_ids(self, ids: list[str]) -> None:
-        if self._is_origin():
-            self.recreate_dataset(self.dataset)
-            return
 
         vector_store = self._get_vector_store()
         vector_store = cast(self._get_vector_store_class(), vector_store)
@@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex):
                 ],
             ))
 
+    def delete_by_group_id(self, group_id: str) -> None:
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+        vector_store.del_texts(models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=group_id),
+                ),
+            ],
+        ))
+
+
     def _is_origin(self):
         if self.dataset.index_struct_dict:
             class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

+ 14 - 0
api/core/index/vector_index/weaviate_vector_index.py

@@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
 
         return self
 
+    def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
+        uuids = self._get_uuids(texts)
+        self._vector_store = WeaviateVectorStore.from_documents(
+            texts,
+            self._embeddings,
+            client=self._client,
+            index_name=self.get_index_name(self.dataset),
+            uuids=uuids,
+            by_text=False
+        )
+
+        return self
+
+
     def _get_vector_store(self) -> VectorStore:
         """Only for created index."""
         if self._vector_store:

+ 4 - 2
api/core/tool/dataset_retriever_tool.py

@@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
     return_resource: str
     retriever_from: str
 
-
     @classmethod
     def from_dataset(cls, dataset: Dataset, **kwargs):
         description = dataset.description
@@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
                     query,
                     search_type='similarity_score_threshold',
                     search_kwargs={
-                        'k': self.k
+                        'k': self.k,
+                        'filter': {
+                            'group_id': [dataset.id]
+                        }
                     }
                 )
             else:

+ 5 - 0
api/core/vector_store/qdrant_vector_store.py

@@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
 
         self.client.delete_collection(collection_name=self.collection_name)
 
+    def delete_group(self):
+        self._reload_if_needed()
+
+        self.client.delete_collection(collection_name=self.collection_name)
+
     @classmethod
     def _document_from_scored_point(
             cls,

+ 47 - 0
api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py

@@ -0,0 +1,47 @@
+"""add_dataset_collection_binding
+
+Revision ID: 6e2cfb077b04
+Revises: 77e83833755c
+Create Date: 2023-09-13 22:16:48.027810
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '6e2cfb077b04'
+down_revision = '77e83833755c'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('dataset_collection_bindings',
+    sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('provider_name', sa.String(length=40), nullable=False),
+    sa.Column('model_name', sa.String(length=40), nullable=False),
+    sa.Column('collection_name', sa.String(length=64), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey')
+    )
+    with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
+        batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
+
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.drop_column('collection_binding_id')
+
+    with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
+        batch_op.drop_index('provider_model_name_idx')
+
+    op.drop_table('dataset_collection_bindings')
+    # ### end Alembic commands ###

+ 18 - 0
api/models/dataset.py

@@ -38,6 +38,8 @@ class Dataset(db.Model):
                            server_default=db.text('CURRENT_TIMESTAMP(0)'))
     embedding_model = db.Column(db.String(255), nullable=True)
     embedding_model_provider = db.Column(db.String(255), nullable=True)
+    collection_binding_id = db.Column(UUID, nullable=True)
+
 
     @property
     def dataset_keyword_table(self):
@@ -445,3 +447,19 @@ class Embedding(db.Model):
 
     def get_embedding(self) -> list[float]:
         return pickle.loads(self.embedding)
+
+
+class DatasetCollectionBinding(db.Model):
+    __tablename__ = 'dataset_collection_bindings'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
+        db.Index('provider_model_name_idx', 'provider_name', 'model_name')
+
+    )
+
+    id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
+    provider_name = db.Column(db.String(40), nullable=False)
+    model_name = db.Column(db.String(40), nullable=False)
+    collection_name = db.Column(db.String(64), nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+

+ 41 - 3
api/services/dataset_service.py

@@ -20,7 +20,8 @@ from events.document_event import document_was_deleted
 from extensions.ext_database import db
 from libs import helper
 from models.account import Account
-from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
+from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
+    DatasetCollectionBinding
 from models.model import UploadFile
 from models.source import DataSourceBinding
 from services.errors.account import NoPermissionError
@@ -147,6 +148,7 @@ class DatasetService:
                 action = 'remove'
                 filtered_data['embedding_model'] = None
                 filtered_data['embedding_model_provider'] = None
+                filtered_data['collection_binding_id'] = None
             elif data['indexing_technique'] == 'high_quality':
                 action = 'add'
                 # get embedding model setting
@@ -156,6 +158,11 @@ class DatasetService:
                     )
                     filtered_data['embedding_model'] = embedding_model.name
                     filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
+                    dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+                        embedding_model.model_provider.provider_name,
+                        embedding_model.name
+                    )
+                    filtered_data['collection_binding_id'] = dataset_collection_binding.id
                 except LLMBadRequestError:
                     raise ValueError(
                         f"No Embedding Model available. Please configure a valid provider "
@@ -464,7 +471,11 @@ class DocumentService:
                 )
                 dataset.embedding_model = embedding_model.name
                 dataset.embedding_model_provider = embedding_model.model_provider.provider_name
-
+                dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+                    embedding_model.model_provider.provider_name,
+                    embedding_model.name
+                )
+                dataset.collection_binding_id = dataset_collection_binding.id
 
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -720,10 +731,16 @@ class DocumentService:
             if total_count > tenant_document_count:
                 raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
         embedding_model = None
+        dataset_collection_binding_id = None
         if document_data['indexing_technique'] == 'high_quality':
             embedding_model = ModelFactory.get_embedding_model(
                 tenant_id=tenant_id
             )
+            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+                embedding_model.model_provider.provider_name,
+                embedding_model.name
+            )
+            dataset_collection_binding_id = dataset_collection_binding.id
         # save dataset
         dataset = Dataset(
             tenant_id=tenant_id,
@@ -732,7 +749,8 @@ class DocumentService:
             indexing_technique=document_data["indexing_technique"],
             created_by=account.id,
             embedding_model=embedding_model.name if embedding_model else None,
-            embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
+            embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
+            collection_binding_id=dataset_collection_binding_id
         )
 
         db.session.add(dataset)
@@ -1069,3 +1087,23 @@ class SegmentService:
             delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
         db.session.delete(segment)
         db.session.commit()
+
+
+class DatasetCollectionBindingService:
+    @classmethod
+    def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
+        dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
+            filter(DatasetCollectionBinding.provider_name == provider_name,
+                   DatasetCollectionBinding.model_name == model_name). \
+            order_by(DatasetCollectionBinding.created_at). \
+            first()
+
+        if not dataset_collection_binding:
+            dataset_collection_binding = DatasetCollectionBinding(
+                provider_name=provider_name,
+                model_name=model_name,
+                collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
+            )
+            db.session.add(dataset_collection_binding)
+            db.session.flush()
+        return dataset_collection_binding

+ 4 - 1
api/services/hit_testing_service.py

@@ -47,7 +47,10 @@ class HitTestingService:
             query,
             search_type='similarity_score_threshold',
             search_kwargs={
-                'k': 10
+                'k': 10,
+                'filter': {
+                    'group_id': [dataset.id]
+                }
             }
         )
         end = time.perf_counter()