Ver código fonte

test: add tests covering all methods of vector store (#3849)

Bowen Liang 1 ano atrás
pai
commit
45dd1683fd

+ 6 - 0
api/core/rag/datasource/vdb/vector_base.py

@@ -27,6 +27,12 @@ class BaseVector(ABC):
     def delete_by_ids(self, ids: list[str]) -> None:
         raise NotImplementedError
 
+    def delete_by_document_id(self, document_id: str):
+        raise NotImplementedError
+
+    def get_ids_by_metadata_field(self, key: str, value: str):
+        raise NotImplementedError
+
     @abstractmethod
     def delete_by_metadata_field(self, key: str, value: str) -> None:
         raise NotImplementedError

+ 8 - 1
api/tests/integration_tests/vdb/milvus/test_milvus.py

@@ -24,6 +24,13 @@ class TestMilvusVector(AbstractTestVector):
         hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
         assert len(hits_by_full_text) == 0
 
+    def delete_document_by_id(self):
+        self.vector.delete_by_document_id(self.dataset_id)
+
+    def get_ids_by_metadata_field(self):
+        ids = self.vector.get_ids_by_metadata_field('document_id', self.dataset_id)
+        assert len(ids) >= 1
+
 
 def test_milvus_vector(setup_mock_redis):
-    TestMilvusVector().run_all_test()
+    TestMilvusVector().run_all_tests()

+ 1 - 1
api/tests/integration_tests/vdb/qdrant/test_qdrant.py

@@ -20,4 +20,4 @@ class TestQdrantVector(AbstractTestVector):
 
 
 def test_qdrant_vector(setup_mock_redis):
-    TestQdrantVector().run_all_test()
+    TestQdrantVector().run_all_tests()

+ 32 - 1
api/tests/integration_tests/vdb/test_vector_store.py

@@ -71,8 +71,39 @@ class AbstractTestVector:
     def delete_vector(self):
         self.vector.delete()
 
-    def run_all_test(self):
+    def delete_by_ids(self):
+        self.vector.delete_by_ids([self.dataset_id])
+
+    def add_texts(self):
+        self.vector.add_texts(
+            documents=[
+                get_sample_document(str(uuid.uuid4())),
+                get_sample_document(str(uuid.uuid4())),
+            ],
+            embeddings=[
+                get_sample_embedding(),
+                get_sample_embedding(),
+            ],
+        )
+
+    def text_exists(self):
+        self.vector.text_exists(self.dataset_id)
+
+    def delete_document_by_id(self):
+        with pytest.raises(NotImplementedError):
+            self.vector.delete_by_document_id(self.dataset_id)
+
+    def get_ids_by_metadata_field(self):
+        with pytest.raises(NotImplementedError):
+            self.vector.get_ids_by_metadata_field('key', 'value')
+
+    def run_all_tests(self):
         self.create_vector()
         self.search_by_vector()
         self.search_by_full_text()
+        self.text_exists()
+        self.get_ids_by_metadata_field()
+        self.add_texts()
+        self.delete_document_by_id()
+        self.delete_by_ids()
         self.delete_vector()

+ 1 - 1
api/tests/integration_tests/vdb/weaviate/test_weaviate.py

@@ -21,4 +21,4 @@ class TestWeaviateVector(AbstractTestVector):
 
 
 def test_weaviate_vector(setup_mock_redis):
-    TestWeaviateVector().run_all_test()
+    TestWeaviateVector().run_all_tests()