Browse Source

add error msg for hit test (#4704)

Jyong 11 months ago
parent
commit
1b2d862973
1 changed files with 88 additions and 71 deletions
  1. 88 71
      api/core/rag/datasource/retrieval_service.py

+ 88 - 71
api/core/rag/datasource/retrieval_service.py

@@ -33,6 +33,7 @@ class RetrievalService:
             return []
         all_documents = []
         threads = []
+        exceptions = []
         # retrieval_model source with keyword
         if retrival_method == 'keyword_search':
             keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
@@ -40,7 +41,8 @@ class RetrievalService:
                 'dataset_id': dataset_id,
                 'query': query,
                 'top_k': top_k,
-                'all_documents': all_documents
+                'all_documents': all_documents,
+                'exceptions': exceptions,
             })
             threads.append(keyword_thread)
             keyword_thread.start()
@@ -54,7 +56,8 @@ class RetrievalService:
                 'score_threshold': score_threshold,
                 'reranking_model': reranking_model,
                 'all_documents': all_documents,
-                'retrival_method': retrival_method
+                'retrival_method': retrival_method,
+                'exceptions': exceptions,
             })
             threads.append(embedding_thread)
             embedding_thread.start()
@@ -69,7 +72,8 @@ class RetrievalService:
                 'score_threshold': score_threshold,
                 'top_k': top_k,
                 'reranking_model': reranking_model,
-                'all_documents': all_documents
+                'all_documents': all_documents,
+                'exceptions': exceptions,
             })
             threads.append(full_text_index_thread)
             full_text_index_thread.start()
@@ -77,6 +81,10 @@ class RetrievalService:
         for thread in threads:
             thread.join()
 
+        if exceptions:
+            exception_message = ';\n'.join(exceptions)
+            raise Exception(exception_message)
+
         if retrival_method == 'hybrid_search':
             data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
             all_documents = data_post_processor.invoke(
@@ -89,82 +97,91 @@ class RetrievalService:
 
     @classmethod
     def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str,
-                       top_k: int, all_documents: list):
+                       top_k: int, all_documents: list, exceptions: list):
         with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            keyword = Keyword(
-                dataset=dataset
-            )
-
-            documents = keyword.search(
-                query,
-                top_k=top_k
-            )
-            all_documents.extend(documents)
+            try:
+                dataset = db.session.query(Dataset).filter(
+                    Dataset.id == dataset_id
+                ).first()
+
+                keyword = Keyword(
+                    dataset=dataset
+                )
+
+                documents = keyword.search(
+                    query,
+                    top_k=top_k
+                )
+                all_documents.extend(documents)
+            except Exception as e:
+                exceptions.append(str(e))
 
     @classmethod
     def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
                          top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
-                         all_documents: list, retrival_method: str):
+                         all_documents: list, retrival_method: str, exceptions: list):
         with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            vector = Vector(
-                dataset=dataset
-            )
-
-            documents = vector.search_by_vector(
-                query,
-                search_type='similarity_score_threshold',
-                top_k=top_k,
-                score_threshold=score_threshold,
-                filter={
-                    'group_id': [dataset.id]
-                }
-            )
-
-            if documents:
-                if reranking_model and retrival_method == 'semantic_search':
-                    data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
-                    all_documents.extend(data_post_processor.invoke(
-                        query=query,
-                        documents=documents,
-                        score_threshold=score_threshold,
-                        top_n=len(documents)
-                    ))
-                else:
-                    all_documents.extend(documents)
+            try:
+                dataset = db.session.query(Dataset).filter(
+                    Dataset.id == dataset_id
+                ).first()
+
+                vector = Vector(
+                    dataset=dataset
+                )
+
+                documents = vector.search_by_vector(
+                    query,
+                    search_type='similarity_score_threshold',
+                    top_k=top_k,
+                    score_threshold=score_threshold,
+                    filter={
+                        'group_id': [dataset.id]
+                    }
+                )
+
+                if documents:
+                    if reranking_model and retrival_method == 'semantic_search':
+                        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                        all_documents.extend(data_post_processor.invoke(
+                            query=query,
+                            documents=documents,
+                            score_threshold=score_threshold,
+                            top_n=len(documents)
+                        ))
+                    else:
+                        all_documents.extend(documents)
+            except Exception as e:
+                exceptions.append(str(e))
 
     @classmethod
     def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
                                top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
-                               all_documents: list, retrival_method: str):
+                               all_documents: list, retrival_method: str, exceptions: list):
         with flask_app.app_context():
-            dataset = db.session.query(Dataset).filter(
-                Dataset.id == dataset_id
-            ).first()
-
-            vector_processor = Vector(
-                dataset=dataset,
-            )
-
-            documents = vector_processor.search_by_full_text(
-                query,
-                top_k=top_k
-            )
-            if documents:
-                if reranking_model and retrival_method == 'full_text_search':
-                    data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
-                    all_documents.extend(data_post_processor.invoke(
-                        query=query,
-                        documents=documents,
-                        score_threshold=score_threshold,
-                        top_n=len(documents)
-                    ))
-                else:
-                    all_documents.extend(documents)
+            try:
+                dataset = db.session.query(Dataset).filter(
+                    Dataset.id == dataset_id
+                ).first()
+
+                vector_processor = Vector(
+                    dataset=dataset,
+                )
+
+                documents = vector_processor.search_by_full_text(
+                    query,
+                    top_k=top_k
+                )
+                if documents:
+                    if reranking_model and retrival_method == 'full_text_search':
+                        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                        all_documents.extend(data_post_processor.invoke(
+                            query=query,
+                            documents=documents,
+                            score_threshold=score_threshold,
+                            top_n=len(documents)
+                        ))
+                    else:
+                        all_documents.extend(documents)
+            except Exception as e:
+                exceptions.append(str(e))