|
@@ -1,3 +1,4 @@
|
|
|
+import random
|
|
|
import uuid
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
@@ -8,26 +9,18 @@ from extensions import ext_redis
|
|
|
from models.dataset import Dataset
|
|
|
|
|
|
|
|
|
-def get_sample_text() -> str:
|
|
|
+def get_example_text() -> str:
|
|
|
return 'test_text'
|
|
|
|
|
|
|
|
|
-def get_sample_embedding() -> list[float]:
|
|
|
- return [1.1, 2.2, 3.3]
|
|
|
-
|
|
|
-
|
|
|
-def get_sample_query_vector() -> list[float]:
|
|
|
- return get_sample_embedding()
|
|
|
-
|
|
|
-
|
|
|
-def get_sample_document(sample_dataset_id: str) -> Document:
|
|
|
+def get_example_document(doc_id: str) -> Document:
|
|
|
doc = Document(
|
|
|
- page_content=get_sample_text(),
|
|
|
+ page_content=get_example_text(),
|
|
|
metadata={
|
|
|
- "doc_id": sample_dataset_id,
|
|
|
- "doc_hash": sample_dataset_id,
|
|
|
- "document_id": sample_dataset_id,
|
|
|
- "dataset_id": sample_dataset_id,
|
|
|
+ "doc_id": doc_id,
|
|
|
+ "doc_hash": doc_id,
|
|
|
+ "document_id": doc_id,
|
|
|
+ "dataset_id": doc_id,
|
|
|
}
|
|
|
)
|
|
|
return doc
|
|
@@ -53,49 +46,48 @@ class AbstractTestVector:
|
|
|
self.vector = None
|
|
|
self.dataset_id = str(uuid.uuid4())
|
|
|
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
|
|
|
+ self.example_doc_id = str(uuid.uuid4())
|
|
|
+ self.example_embedding = [1.001 * i for i in range(128)]
|
|
|
|
|
|
def create_vector(self) -> None:
|
|
|
self.vector.create(
|
|
|
- texts=[get_sample_document(self.dataset_id)],
|
|
|
- embeddings=[get_sample_embedding()],
|
|
|
+ texts=[get_example_document(doc_id=self.example_doc_id)],
|
|
|
+ embeddings=[self.example_embedding],
|
|
|
)
|
|
|
|
|
|
def search_by_vector(self):
|
|
|
- hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector())
|
|
|
- assert len(hits_by_vector) >= 1
|
|
|
+ hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
|
|
+ assert len(hits_by_vector) == 1
|
|
|
+ assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
|
|
|
|
|
|
def search_by_full_text(self):
|
|
|
- hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
|
|
|
- assert len(hits_by_full_text) >= 1
|
|
|
+ hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
|
|
+ assert len(hits_by_full_text) == 1
|
|
|
+ assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
|
|
|
|
|
|
def delete_vector(self):
|
|
|
self.vector.delete()
|
|
|
|
|
|
- def delete_by_ids(self):
|
|
|
- self.vector.delete_by_ids([self.dataset_id])
|
|
|
-
|
|
|
- def add_texts(self):
|
|
|
- self.vector.add_texts(
|
|
|
- documents=[
|
|
|
- get_sample_document(str(uuid.uuid4())),
|
|
|
- get_sample_document(str(uuid.uuid4())),
|
|
|
- ],
|
|
|
- embeddings=[
|
|
|
- get_sample_embedding(),
|
|
|
- get_sample_embedding(),
|
|
|
- ],
|
|
|
- )
|
|
|
+ def delete_by_ids(self, ids: list[str]):
|
|
|
+ self.vector.delete_by_ids(ids=ids)
|
|
|
+
|
|
|
+ def add_texts(self) -> list[str]:
|
|
|
+ batch_size = 100
|
|
|
+ documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
|
|
|
+ embeddings = [self.example_embedding] * batch_size
|
|
|
+ self.vector.add_texts(documents=documents, embeddings=embeddings)
|
|
|
+ return [doc.metadata['doc_id'] for doc in documents]
|
|
|
|
|
|
def text_exists(self):
|
|
|
- self.vector.text_exists(self.dataset_id)
|
|
|
+ assert self.vector.text_exists(self.example_doc_id)
|
|
|
|
|
|
- def delete_document_by_id(self):
|
|
|
+ def delete_by_document_id(self):
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
- self.vector.delete_by_document_id(self.dataset_id)
|
|
|
+ self.vector.delete_by_document_id(document_id=self.example_doc_id)
|
|
|
|
|
|
def get_ids_by_metadata_field(self):
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
- self.vector.get_ids_by_metadata_field('key', 'value')
|
|
|
+ self.vector.get_ids_by_metadata_field(key='key', value='value')
|
|
|
|
|
|
def run_all_tests(self):
|
|
|
self.create_vector()
|
|
@@ -103,7 +95,7 @@ class AbstractTestVector:
|
|
|
self.search_by_full_text()
|
|
|
self.text_exists()
|
|
|
self.get_ids_by_metadata_field()
|
|
|
- self.add_texts()
|
|
|
- self.delete_document_by_id()
|
|
|
- self.delete_by_ids()
|
|
|
+ self.delete_by_document_id()
|
|
|
+ added_doc_ids = self.add_texts()
|
|
|
+ self.delete_by_ids(added_doc_ids)
|
|
|
self.delete_vector()
|