|
@@ -9,12 +9,17 @@ from models.dataset import Dataset, Document
|
|
|
|
|
|
|
|
|
class VectorIndex:
|
|
|
- def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
|
|
|
+ def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
|
|
+ attributes: list = None):
|
|
|
+ if attributes is None:
|
|
|
+ attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
|
|
self._dataset = dataset
|
|
|
self._embeddings = embeddings
|
|
|
- self._vector_index = self._init_vector_index(dataset, config, embeddings)
|
|
|
+ self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes)
|
|
|
+ self._attributes = attributes
|
|
|
|
|
|
- def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
|
|
|
+ def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings,
|
|
|
+ attributes: list) -> BaseVectorIndex:
|
|
|
vector_type = config.get('VECTOR_STORE')
|
|
|
|
|
|
if self._dataset.index_struct_dict:
|
|
@@ -33,7 +38,8 @@ class VectorIndex:
|
|
|
api_key=config.get('WEAVIATE_API_KEY'),
|
|
|
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
|
|
|
),
|
|
|
- embeddings=embeddings
|
|
|
+ embeddings=embeddings,
|
|
|
+ attributes=attributes
|
|
|
)
|
|
|
elif vector_type == "qdrant":
|
|
|
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|