浏览代码

test: refactor vdb tests by visitor design pattern (#3838)

Bowen Liang 1 年之前
父节点
当前提交
86e7330fa2

+ 19 - 28
api/tests/integration_tests/vdb/milvus/test_milvus.py

@@ -1,38 +1,29 @@
-import uuid
-
 from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
-from models.dataset import Dataset
 from tests.integration_tests.vdb.test_vector_store import (
-    get_sample_document,
-    get_sample_embedding,
-    get_sample_query_vector,
+    AbstractTestVector,
+    get_sample_text,
     setup_mock_redis,
 )
 
 
-def test_milvus_vector(setup_mock_redis) -> None:
-    dataset_id = str(uuid.uuid4())
-    vector = MilvusVector(
-        collection_name=Dataset.gen_collection_name_by_id(dataset_id),
-        config=MilvusConfig(
-            host='localhost',
-            port=19530,
-            user='root',
-            password='Milvus',
+class TestMilvusVector(AbstractTestVector):
+    def __init__(self):
+        super().__init__()
+        self.vector = MilvusVector(
+            collection_name=self.collection_name,
+            config=MilvusConfig(
+                host='localhost',
+                port=19530,
+                user='root',
+                password='Milvus',
+            )
         )
-    )
-
-    # create vector
-    vector.create(
-        texts=[get_sample_document(dataset_id)],
-        embeddings=[get_sample_embedding()],
-    )
 
-    # search by vector
-    hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
-    assert len(hits_by_vector) >= 1
+    def search_by_full_text(self):
+        # milvus dos not support full text searching yet in < 2.3.x
+        hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
+        assert len(hits_by_full_text) == 0
 
-    # milvus dos not support full text searching yet in < 2.3.x
 
-    # delete vector
-    vector.delete()
+def test_milvus_vector(setup_mock_redis):
+    TestMilvusVector().run_all_test()

+ 14 - 31
api/tests/integration_tests/vdb/qdrant/test_qdrant.py

@@ -1,40 +1,23 @@
-import uuid
-
 from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
-from models.dataset import Dataset
 from tests.integration_tests.vdb.test_vector_store import (
-    get_sample_document,
-    get_sample_embedding,
-    get_sample_query_vector,
-    get_sample_text,
+    AbstractTestVector,
     setup_mock_redis,
 )
 
 
-def test_qdrant_vector(setup_mock_redis)-> None:
-    dataset_id = str(uuid.uuid4())
-    vector = QdrantVector(
-        collection_name=Dataset.gen_collection_name_by_id(dataset_id),
-        group_id=dataset_id,
-        config=QdrantConfig(
-            endpoint='http://localhost:6333',
-            api_key='difyai123456',
+class TestQdrantVector(AbstractTestVector):
+    def __init__(self):
+        super().__init__()
+        self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
+        self.vector = QdrantVector(
+            collection_name=self.collection_name,
+            group_id=self.dataset_id,
+            config=QdrantConfig(
+                endpoint='http://localhost:6333',
+                api_key='difyai123456',
+            )
         )
-    )
-
-    # create vector
-    vector.create(
-        texts=[get_sample_document(dataset_id)],
-        embeddings=[get_sample_embedding()],
-    )
-
-    # search by vector
-    hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
-    assert len(hits_by_vector) >= 1
 
-    # search by full text
-    hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
-    assert len(hits_by_full_text) >= 1
 
-    # delete vector
-    vector.delete()
+def test_qdrant_vector(setup_mock_redis):
+    TestQdrantVector().run_all_test()

+ 32 - 0
api/tests/integration_tests/vdb/test_vector_store.py

@@ -1,9 +1,11 @@
+import uuid
 from unittest.mock import MagicMock
 
 import pytest
 
 from core.rag.models.document import Document
 from extensions import ext_redis
+from models.dataset import Dataset
 
 
 def get_sample_text() -> str:
@@ -44,3 +46,33 @@ def setup_mock_redis() -> None:
     mock_redis_lock.__enter__ = MagicMock()
     mock_redis_lock.__exit__ = MagicMock()
     ext_redis.redis_client.lock = mock_redis_lock
+
+
+class AbstractTestVector:
+    def __init__(self):
+        self.vector = None
+        self.dataset_id = str(uuid.uuid4())
+        self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id)
+
+    def create_vector(self) -> None:
+        self.vector.create(
+            texts=[get_sample_document(self.dataset_id)],
+            embeddings=[get_sample_embedding()],
+        )
+
+    def search_by_vector(self):
+        hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector())
+        assert len(hits_by_vector) >= 1
+
+    def search_by_full_text(self):
+        hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text())
+        assert len(hits_by_full_text) >= 1
+
+    def delete_vector(self):
+        self.vector.delete()
+
+    def run_all_test(self):
+        self.create_vector()
+        self.search_by_vector()
+        self.search_by_full_text()
+        self.delete_vector()

+ 15 - 32
api/tests/integration_tests/vdb/weaviate/test_weaviate.py

@@ -1,41 +1,24 @@
-import uuid
-
 from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
 from models.dataset import Dataset
 from tests.integration_tests.vdb.test_vector_store import (
-    get_sample_document,
-    get_sample_embedding,
-    get_sample_query_vector,
-    get_sample_text,
+    AbstractTestVector,
     setup_mock_redis,
 )
 
 
-def test_weaviate_vector(setup_mock_redis) -> None:
-    attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
-    dataset_id = str(uuid.uuid4())
-    vector = WeaviateVector(
-        collection_name=Dataset.gen_collection_name_by_id(dataset_id),
-        config=WeaviateConfig(
-            endpoint='http://localhost:8080',
-            api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
-        ),
-        attributes=attributes
-    )
-
-    # create vector
-    vector.create(
-        texts=[get_sample_document(dataset_id)],
-        embeddings=[get_sample_embedding()],
-    )
-
-    # search by vector
-    hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
-    assert len(hits_by_vector) >= 1
+class TestWeaviateVector(AbstractTestVector):
+    def __init__(self):
+        super().__init__()
+        self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
+        self.vector = WeaviateVector(
+            collection_name=self.collection_name,
+            config=WeaviateConfig(
+                endpoint='http://localhost:8080',
+                api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
+            ),
+            attributes=self.attributes
+        )
 
-    # search by full text
-    hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
-    assert len(hits_by_full_text) >= 1
 
-    # delete vector
-    vector.delete()
+def test_weaviate_vector(setup_mock_redis):
+    TestWeaviateVector().run_all_test()