|
@@ -2,7 +2,6 @@ import threading
|
|
|
from typing import Optional
|
|
|
|
|
|
from flask import Flask, current_app
|
|
|
-from flask_login import current_user
|
|
|
|
|
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
|
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
|
@@ -27,6 +26,11 @@ class RetrievalService:
|
|
|
@classmethod
|
|
|
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
|
|
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
|
|
|
+ dataset = db.session.query(Dataset).filter(
|
|
|
+ Dataset.id == dataset_id
|
|
|
+ ).first()
|
|
|
+ if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
|
|
+ return []
|
|
|
all_documents = []
|
|
|
threads = []
|
|
|
# retrieval_model source with keyword
|
|
@@ -73,7 +77,7 @@ class RetrievalService:
|
|
|
thread.join()
|
|
|
|
|
|
if retrival_method == 'hybrid_search':
|
|
|
- data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
|
|
|
+ data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
|
|
all_documents = data_post_processor.invoke(
|
|
|
query=query,
|
|
|
documents=all_documents,
|