Przeglądaj źródła

fix vector db sql injection (#16096)

Jyong 1 miesiąc temu
rodzic
commit
33ba7e659b

+ 4 - 0
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py

@@ -194,6 +194,8 @@ class AnalyticdbVectorBySql:
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 4)
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         with self._get_cursor() as cur:
             query_vector_str = json.dumps(query_vector)
@@ -220,6 +222,8 @@ class AnalyticdbVectorBySql:
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 4)
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
         with self._get_cursor() as cur:
             cur.execute(
                 f"""SELECT id, vector, page_content, metadata_, 

+ 2 - 0
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -125,6 +125,8 @@ class MyScaleVector(BaseVector):
 
     def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 4)
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
         score_threshold = float(kwargs.get("score_threshold") or 0.0)
         where_str = (
             f"WHERE dist < {1 - score_threshold}"

+ 4 - 2
api/core/rag/datasource/vdb/opengauss/opengauss.py

@@ -155,7 +155,8 @@ class OpenGauss(BaseVector):
         :return: List of Documents that are nearest to the query vector.
         """
         top_k = kwargs.get("top_k", 4)
-
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
         with self._get_cursor() as cur:
             cur.execute(
                 f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
@@ -174,7 +175,8 @@ class OpenGauss(BaseVector):
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 5)
-
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
         with self._get_cursor() as cur:
             cur.execute(
                 f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score

+ 4 - 1
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -171,6 +171,8 @@ class PGVector(BaseVector):
         :return: List of Documents that are nearest to the query vector.
         """
         top_k = kwargs.get("top_k", 4)
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
 
         with self._get_cursor() as cur:
             cur.execute(
@@ -190,7 +192,8 @@ class PGVector(BaseVector):
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 5)
-
+        if not isinstance(top_k, int) or top_k <= 0:
+            raise ValueError("top_k must be a positive integer")
         with self._get_cursor() as cur:
             if self.pg_bigm:
                 cur.execute("SET pg_bigm.similarity_limit TO 0.000001")