ソースを参照

chore: reuse existing test functions with upstash vdb (#9679)

ice yao 6 ヶ月 前
コミット
ceb2c4f3ef

+ 2 - 2
api/controllers/console/datasets/datasets.py

@@ -619,6 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
                 | VectorType.PGVECTO_RS
                 | VectorType.BAIDU
                 | VectorType.VIKINGDB
+                | VectorType.UPSTASH
             ):
                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
             case (
@@ -630,7 +631,6 @@ class DatasetRetrievalSettingApi(Resource):
                 | VectorType.ORACLE
                 | VectorType.ELASTICSEARCH
                 | VectorType.PGVECTOR
-                | VectorType.UPSTASH
             ):
                 return {
                     "retrieval_method": [
@@ -658,6 +658,7 @@ class DatasetRetrievalSettingMockApi(Resource):
                 | VectorType.PGVECTO_RS
                 | VectorType.BAIDU
                 | VectorType.VIKINGDB
+                | VectorType.UPSTASH
             ):
                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
             case (
@@ -669,7 +670,6 @@ class DatasetRetrievalSettingMockApi(Resource):
                 | VectorType.ORACLE
                 | VectorType.ELASTICSEARCH
                 | VectorType.PGVECTOR
-                | VectorType.UPSTASH
             ):
                 return {
                     "retrieval_method": [

+ 1 - 1
api/core/rag/datasource/vdb/upstash/upstash_vector.py

@@ -117,7 +117,7 @@ class UpstashVectorFactory(AbstractVectorFactory):
             collection_name = class_prefix.lower()
         else:
             dataset_id = dataset.id
-            collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+            collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
             dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name))
 
         return UpstashVector(

+ 4 - 39
api/tests/integration_tests/vdb/upstash/test_upstash_vector.py

@@ -1,27 +1,7 @@
-import time
-import uuid
-
 from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig
 from core.rag.models.document import Document
 from tests.integration_tests.vdb.__mock.upstashvectordb import setup_upstashvector_mock
-from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
-
-
-def get_example_text() -> str:
-    return "test_text"
-
-
-def get_example_document(doc_id: str) -> Document:
-    doc = Document(
-        page_content=get_example_text(),
-        metadata={
-            "doc_id": doc_id,
-            "doc_hash": doc_id,
-            "document_id": doc_id,
-            "dataset_id": doc_id,
-        },
-    )
-    return doc
+from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
 
 
 class UpstashVectorTest(AbstractVectorTest):
@@ -34,29 +14,14 @@ class UpstashVectorTest(AbstractVectorTest):
                 token="your-access-token",
             ),
         )
-        self.example_embedding = [1.001 * i for i in range(self.vector._get_index_dimension())]
-
-    def add_texts(self) -> list[str]:
-        batch_size = 1
-        documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
-        embeddings = [self.example_embedding] * batch_size
-        self.vector.add_texts(documents=documents, embeddings=embeddings)
-        return [doc.metadata["doc_id"] for doc in documents]
 
     def get_ids_by_metadata_field(self):
-        print("doc_id", self.example_doc_id)
         ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
         assert len(ids) != 0
 
-    def run_all_tests(self):
-        self.create_vector()
-        time.sleep(1)
-        self.search_by_vector()
-        self.text_exists()
-        self.get_ids_by_metadata_field()
-        added_doc_ids = self.add_texts()
-        self.delete_by_ids(added_doc_ids + [self.example_doc_id])
-        self.delete_vector()
+    def search_by_full_text(self):
+        hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
+        assert len(hits_by_full_text) == 0
 
 
 def test_upstash_vector(setup_upstashvector_mock):