Jelajahi Sumber

refactor: update the default values of top-k parameter in vdb to be consistent (#9367)

zhuhao 6 bulan lalu
induk
melakukan
86594851cb

+ 1 - 1
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -112,7 +112,7 @@ class ElasticSearchVector(BaseVector):
         self._client.indices.delete(index=self._collection_name)
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        top_k = kwargs.get("top_k", 10)
+        top_k = kwargs.get("top_k", 4)
         num_candidates = math.ceil(top_k * 1.5)
         knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
 

+ 1 - 1
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -121,7 +121,7 @@ class MyScaleVector(BaseVector):
         return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
 
     def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
-        top_k = kwargs.get("top_k", 5)
+        top_k = kwargs.get("top_k", 4)
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         where_str = (
             f"WHERE dist < {1 - score_threshold}"

+ 1 - 9
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -168,14 +168,6 @@ class OracleVector(BaseVector):
                 docs.append(Document(page_content=record[1], metadata=record[0]))
         return docs
 
-    # def get_ids_by_metadata_field(self, key: str, value: str):
-    #    with self._get_cursor() as cur:
-    #        cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
-    #        idss = []
-    #        for record in cur:
-    #            idss.append(record[0])
-    #    return idss
-
     def delete_by_ids(self, ids: list[str]) -> None:
         with self._get_cursor() as cur:
             cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
@@ -192,7 +184,7 @@ class OracleVector(BaseVector):
         :param top_k: The number of nearest neighbors to return, default is 5.
         :return: List of Documents that are nearest to the query vector.
         """
-        top_k = kwargs.get("top_k", 5)
+        top_k = kwargs.get("top_k", 4)
         with self._get_cursor() as cur:
             cur.execute(
                 f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"

+ 1 - 13
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -186,7 +186,7 @@ class PGVectoRS(BaseVector):
                         query_vector,
                     ).label("distance"),
                 )
-                .limit(kwargs.get("top_k", 2))
+                .limit(kwargs.get("top_k", 4))
                 .order_by("distance")
             )
             res = session.execute(stmt)
@@ -205,18 +205,6 @@ class PGVectoRS(BaseVector):
         return docs
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
-        # with Session(self._client) as session:
-        #     select_statement = sql_text(
-        #         f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
-        #     )
-        #     results = session.execute(select_statement).fetchall()
-        # if results:
-        #     docs = []
-        #     for result in results:
-        #         doc = Document(page_content=result[0],
-        #                        metadata=result[1])
-        #         docs.append(doc)
-        #     return docs
         return []
 
 

+ 1 - 1
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -143,7 +143,7 @@ class PGVector(BaseVector):
         :param top_k: The number of nearest neighbors to return, default is 5.
         :return: List of Documents that are nearest to the query vector.
         """
-        top_k = kwargs.get("top_k", 5)
+        top_k = kwargs.get("top_k", 4)
 
         with self._get_cursor() as cur:
             cur.execute(

+ 1 - 1
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -224,7 +224,7 @@ class RelytVector(BaseVector):
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         results = self.similarity_search_with_score_by_vector(
-            k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
+            k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
         )
 
         # Organize results.

+ 1 - 1
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -184,7 +184,7 @@ class TiDBVector(BaseVector):
             self._delete_by_ids(ids)
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
-        top_k = kwargs.get("top_k", 5)
+        top_k = kwargs.get("top_k", 4)
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         filter = kwargs.get("filter")
         distance = 1 - score_threshold

+ 1 - 1
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py

@@ -173,7 +173,7 @@ class VikingDBVector(BaseVector):
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
-            query_vector, limit=kwargs.get("top_k", 50)
+            query_vector, limit=kwargs.get("top_k", 4)
         )
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         return self._get_search_res(results, score_threshold)

+ 1 - 1
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -235,7 +235,7 @@ class WeaviateVector(BaseVector):
             query_obj = query_obj.with_where(kwargs.get("where_filter"))
         query_obj = query_obj.with_additional(["vector"])
         properties = ["text"]
-        result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
+        result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
         if "errors" in result:
             raise ValueError(f"Error during query: {result['errors']}")
         docs = []