Quellcode durchsuchen

test: improve vector store tests (#3855)

Bowen Liang vor 1 Jahr
Ursprung
Commit
045827043d

+ 15 - 22
.github/workflows/api-tests.yml

@@ -37,27 +37,6 @@ jobs:
       - name: Checkout code
         uses: actions/checkout@v4
 
-      - name: Set up Weaviate
-        uses: hoverkraft-tech/compose-action@v2.0.0
-        with:
-          compose-file: docker/docker-compose.middleware.yaml
-          services: weaviate
-
-      - name: Set up Qdrant
-        uses: hoverkraft-tech/compose-action@v2.0.0
-        with:
-          compose-file: docker/docker-compose.qdrant.yaml
-          services: qdrant
-
-      - name: Set up Milvus
-        uses: hoverkraft-tech/compose-action@v2.0.0
-        with:
-          compose-file: docker/docker-compose.milvus.yaml
-          services: |
-            etcd
-            minio
-            milvus-standalone
-
       - name: Set up Python ${{ matrix.python-version }}
         uses: actions/setup-python@v5
         with:
@@ -82,5 +61,19 @@ jobs:
       - name: Run Workflow
         run: dev/pytest/pytest_workflow.sh
 
-      - name: Run Vector Stores
+      - name: Set up Vector Stores (Weaviate, Qdrant and Milvus)
+        uses: hoverkraft-tech/compose-action@v2.0.0
+        with:
+          compose-file: |
+            docker/docker-compose.middleware.yaml
+            docker/docker-compose.qdrant.yaml
+            docker/docker-compose.milvus.yaml
+          services: |
+            weaviate
+            qdrant
+            etcd
+            minio
+            milvus-standalone
+
+      - name: Test Vector Stores
         run: dev/pytest/pytest_vdb.sh

+ 2 - 2
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -124,7 +124,7 @@ class MilvusVector(BaseVector):
             if ids:
                 self._client.delete(collection_name=self._collection_name, pks=ids)
 
-    def delete_by_ids(self, doc_ids: list[str]) -> None:
+    def delete_by_ids(self, ids: list[str]) -> None:
         alias = uuid4().hex
         if self._client_config.secure:
             uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
@@ -136,7 +136,7 @@ class MilvusVector(BaseVector):
         if utility.has_collection(self._collection_name, using=alias):
 
             result = self._client.query(collection_name=self._collection_name,
-                                        filter=f'metadata["doc_id"] in {doc_ids}',
+                                        filter=f'metadata["doc_id"] in {ids}',
                                         output_fields=["id"])
             if result:
                 ids = [item["id"] for item in result]

+ 2 - 2
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -199,10 +199,10 @@ class RelytVector(BaseVector):
         if ids:
             self.delete_by_uuids(ids)
 
-    def delete_by_ids(self, doc_ids: list[str]) -> None:
+    def delete_by_ids(self, ids: list[str]) -> None:
 
         with Session(self.client) as session:
-            ids_str = ','.join(f"'{doc_id}'" for doc_id in doc_ids)
+            ids_str = ','.join(f"'{doc_id}'" for doc_id in ids)
             select_statement = sql_text(
                 f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
             )

+ 6 - 6
api/tests/integration_tests/vdb/milvus/test_milvus.py

@@ -1,7 +1,7 @@
 from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
 from tests.integration_tests.vdb.test_vector_store import (
     AbstractTestVector,
-    get_sample_text,
+    get_example_text,
     setup_mock_redis,
 )
 
@@ -21,15 +21,15 @@ class TestMilvusVector(AbstractTestVector):
 
     def search_by_full_text(self):
         # milvus dos not support full text searching yet in < 2.3.x
-        hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
+        hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
         assert len(hits_by_full_text) == 0
 
-    def delete_document_by_id(self):
-        self.vector.delete_by_document_id(self.dataset_id)
+    def delete_by_document_id(self):
+        self.vector.delete_by_document_id(document_id=self.example_doc_id)
 
     def get_ids_by_metadata_field(self):
-        ids = self.vector.get_ids_by_metadata_field('document_id', self.dataset_id)
-        assert len(ids) >= 1
+        ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
+        assert len(ids) == 1
 
 
 def test_milvus_vector(setup_mock_redis):

+ 34 - 42
api/tests/integration_tests/vdb/test_vector_store.py

@@ -1,3 +1,4 @@
+import random
 import uuid
 from unittest.mock import MagicMock
 
@@ -8,26 +9,18 @@ from extensions import ext_redis
 from models.dataset import Dataset
 
 
-def get_sample_text() -> str:
+def get_example_text() -> str:
     return 'test_text'
 
 
-def get_sample_embedding() -> list[float]:
-    return [1.1, 2.2, 3.3]
-
-
-def get_sample_query_vector() -> list[float]:
-    return get_sample_embedding()
-
-
-def get_sample_document(sample_dataset_id: str) -> Document:
+def get_example_document(doc_id: str) -> Document:
     doc = Document(
-        page_content=get_sample_text(),
+        page_content=get_example_text(),
         metadata={
-            "doc_id": sample_dataset_id,
-            "doc_hash": sample_dataset_id,
-            "document_id": sample_dataset_id,
-            "dataset_id": sample_dataset_id,
+            "doc_id": doc_id,
+            "doc_hash": doc_id,
+            "document_id": doc_id,
+            "dataset_id": doc_id,
         }
     )
     return doc
@@ -53,49 +46,48 @@ class AbstractTestVector:
         self.vector = None
         self.dataset_id = str(uuid.uuid4())
         self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
+        self.example_doc_id = str(uuid.uuid4())
+        self.example_embedding = [1.001 * i for i in range(128)]
 
     def create_vector(self) -> None:
         self.vector.create(
-            texts=[get_sample_document(self.dataset_id)],
-            embeddings=[get_sample_embedding()],
+            texts=[get_example_document(doc_id=self.example_doc_id)],
+            embeddings=[self.example_embedding],
         )
 
     def search_by_vector(self):
-        hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector())
-        assert len(hits_by_vector) >= 1
+        hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
+        assert len(hits_by_vector) == 1
+        assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
 
     def search_by_full_text(self):
-        hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
-        assert len(hits_by_full_text) >= 1
+        hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 1
+        assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
 
     def delete_vector(self):
         self.vector.delete()
 
-    def delete_by_ids(self):
-        self.vector.delete_by_ids([self.dataset_id])
-
-    def add_texts(self):
-        self.vector.add_texts(
-            documents=[
-                get_sample_document(str(uuid.uuid4())),
-                get_sample_document(str(uuid.uuid4())),
-            ],
-            embeddings=[
-                get_sample_embedding(),
-                get_sample_embedding(),
-            ],
-        )
+    def delete_by_ids(self, ids: list[str]):
+        self.vector.delete_by_ids(ids=ids)
+
+    def add_texts(self) -> list[str]:
+        batch_size = 100
+        documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
+        embeddings = [self.example_embedding] * batch_size
+        self.vector.add_texts(documents=documents, embeddings=embeddings)
+        return [doc.metadata['doc_id'] for doc in documents]
 
     def text_exists(self):
-        self.vector.text_exists(self.dataset_id)
+        assert self.vector.text_exists(self.example_doc_id)
 
-    def delete_document_by_id(self):
+    def delete_by_document_id(self):
         with pytest.raises(NotImplementedError):
-            self.vector.delete_by_document_id(self.dataset_id)
+            self.vector.delete_by_document_id(document_id=self.example_doc_id)
 
     def get_ids_by_metadata_field(self):
         with pytest.raises(NotImplementedError):
-            self.vector.get_ids_by_metadata_field('key', 'value')
+            self.vector.get_ids_by_metadata_field(key='key', value='value')
 
     def run_all_tests(self):
         self.create_vector()
@@ -103,7 +95,7 @@ class AbstractTestVector:
         self.search_by_full_text()
         self.text_exists()
         self.get_ids_by_metadata_field()
-        self.add_texts()
-        self.delete_document_by_id()
-        self.delete_by_ids()
+        self.delete_by_document_id()
+        added_doc_ids = self.add_texts()
+        self.delete_by_ids(added_doc_ids)
         self.delete_vector()

+ 0 - 1
api/tests/integration_tests/vdb/weaviate/test_weaviate.py

@@ -1,5 +1,4 @@
 from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
-from models.dataset import Dataset
 from tests.integration_tests.vdb.test_vector_store import (
     AbstractTestVector,
     setup_mock_redis,