Преглед на файлове

fix: escape double quotation marks in the vector DB search query (#6506)

Sangmin Ahn преди 9 месеца
родител
ревизия
093b8ca475
променени са 2 файла, в които са добавени 12 реда и са изтрити 4 реда
  1. 7 3
      api/core/rag/datasource/retrieval_service.py
  2. 5 1
      api/services/hit_testing_service.py

+ 7 - 3
api/core/rag/datasource/retrieval_service.py

@@ -110,7 +110,7 @@ class RetrievalService:
                 )
 
                 documents = keyword.search(
-                    query,
+                    cls.escape_query_for_search(query),
                     top_k=top_k
                 )
                 all_documents.extend(documents)
@@ -132,7 +132,7 @@ class RetrievalService:
                 )
 
                 documents = vector.search_by_vector(
-                    query,
+                    cls.escape_query_for_search(query),
                     search_type='similarity_score_threshold',
                     top_k=top_k,
                     score_threshold=score_threshold,
@@ -170,7 +170,7 @@ class RetrievalService:
                 )
 
                 documents = vector_processor.search_by_full_text(
-                    query,
+                    cls.escape_query_for_search(query),
                     top_k=top_k
                 )
                 if documents:
@@ -186,3 +186,7 @@ class RetrievalService:
                         all_documents.extend(documents)
             except Exception as e:
                 exceptions.append(str(e))
+
+    @staticmethod
+    def escape_query_for_search(query: str) -> str:
+        return query.replace('"', '\\"')

+ 5 - 1
api/services/hit_testing_service.py

@@ -40,7 +40,7 @@ class HitTestingService:
 
         all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
                                                   dataset_id=dataset.id,
-                                                  query=query,
+                                                  query=cls.escape_query_for_search(query),
                                                   top_k=retrieval_model['top_k'],
                                                   score_threshold=retrieval_model['score_threshold']
                                                   if retrieval_model['score_threshold_enabled'] else None,
@@ -104,3 +104,7 @@ class HitTestingService:
 
         if not query or len(query) > 250:
             raise ValueError('Query is required and cannot exceed 250 characters')
+
+    @staticmethod
+    def escape_query_for_search(query: str) -> str:
+        return query.replace('"', '\\"')