Browse Source

fix: score_threshold handling in vector search methods (#8356)

-LAN- 7 tháng trước cách đây
mục cha
commit
08c486452f

+ 2 - 2
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py

@@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
 
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = kwargs.get("score_threshold") or 0.0
         request = gpdb_20160503_models.QueryCollectionDataRequest(
             dbinstance_id=self.config.instance_id,
             region_id=self.config.region_id,
@@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
 
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
         request = gpdb_20160503_models.QueryCollectionDataRequest(
             dbinstance_id=self.config.instance_id,
             region_id=self.config.region_id,

+ 1 - 1
api/core/rag/datasource/vdb/chroma/chroma_vector.py

@@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         collection = self._client.get_or_create_collection(self._collection_name)
         results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
 
         ids: list[str] = results["ids"][0]
         documents: list[str] = results["documents"][0]

+ 1 - 1
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

@@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
 
         docs = []
         for doc, score in docs_and_scores:
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             if score > score_threshold:
                 doc.metadata["score"] = score
             docs.append(doc)

+ 1 - 1
api/core/rag/datasource/vdb/milvus/milvus_vector.py

@@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
         for result in results[0]:
             metadata = result["entity"].get(Field.METADATA_KEY.value)
             metadata["score"] = result["distance"]
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             if result["distance"] > score_threshold:
                 doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
                 docs.append(doc)

+ 1 - 1
api/core/rag/datasource/vdb/myscale/myscale_vector.py

@@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
 
     def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 5)
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
         where_str = (
             f"WHERE dist < {1 - score_threshold}"
             if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0

+ 1 - 1
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py

@@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
                 metadata = {}
 
             metadata["score"] = hit["_score"]
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             if hit["_score"] > score_threshold:
                 doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
                 docs.append(doc)

+ 2 - 2
api/core/rag/datasource/vdb/oracle/oraclevector.py

@@ -200,7 +200,7 @@ class OracleVector(BaseVector):
                 [numpy.array(query_vector)],
             )
             docs = []
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             for record in cur:
                 metadata, text, distance = record
                 score = 1 - distance
@@ -212,7 +212,7 @@ class OracleVector(BaseVector):
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 5)
         # just not implement fetch by score_threshold now, may be later
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
         if len(query) > 0:
             # Check which language the query is in
             zh_pattern = re.compile("[\u4e00-\u9fa5]+")

+ 1 - 1
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
             metadata = record.meta
             score = 1 - dis
             metadata["score"] = score
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             if score > score_threshold:
                 doc = Document(page_content=record.text, metadata=metadata)
                 docs.append(doc)

+ 1 - 1
api/core/rag/datasource/vdb/pgvector/pgvector.py

@@ -144,7 +144,7 @@ class PGVector(BaseVector):
                 (json.dumps(query_vector),),
             )
             docs = []
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             for record in cur:
                 metadata, text, distance = record
                 score = 1 - distance

+ 2 - 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -333,13 +333,13 @@ class QdrantVector(BaseVector):
             limit=kwargs.get("top_k", 4),
             with_payload=True,
             with_vectors=True,
-            score_threshold=kwargs.get("score_threshold", 0.0),
+            score_threshold=float(kwargs.get("score_threshold") or 0.0),
         )
         docs = []
         for result in results:
             metadata = result.payload.get(Field.METADATA_KEY.value) or {}
             # duplicate check score threshold
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             if result.score > score_threshold:
                 metadata["score"] = result.score
                 doc = Document(

+ 1 - 1
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -230,7 +230,7 @@ class RelytVector(BaseVector):
         # Organize results.
         docs = []
         for document, score in results:
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             if 1 - score > score_threshold:
                 docs.append(document)
         return docs

+ 1 - 1
api/core/rag/datasource/vdb/tencent/tencent_vector.py

@@ -153,7 +153,7 @@ class TencentVector(BaseVector):
             limit=kwargs.get("top_k", 4),
             timeout=self._client_config.timeout,
         )
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
         return self._get_search_res(res, score_threshold)
 
     def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

+ 1 - 1
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py

@@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
 
     def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
         top_k = kwargs.get("top_k", 5)
-        score_threshold = kwargs.get("score_threshold", 0.0)
+        score_threshold = float(kwargs.get("score_threshold") or 0.0)
         filter = kwargs.get("filter")
         distance = 1 - score_threshold
 

+ 1 - 1
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py

@@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
 
         docs = []
         for doc, score in docs_and_scores:
-            score_threshold = kwargs.get("score_threshold", 0.0)
+            score_threshold = float(kwargs.get("score_threshold") or 0.0)
             # check score threshold
             if score > score_threshold:
                 doc.metadata["score"] = score