Jelajahi Sumber

fix dataset retrival in dataset mode (#3334)

Jyong 1 tahun lalu
induk
melakukan
6164604462

+ 1 - 0
api/core/rag/extractor/csv_extractor.py

@@ -34,6 +34,7 @@ class CSVExtractor(BaseExtractor):
 
     def extract(self) -> list[Document]:
         """Load data into document objects."""
+        docs = []
         try:
             with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
                 docs = self._read_from_file(csvfile)

+ 92 - 0
api/core/rag/retrieval/dataset_retrieval.py

@@ -2,6 +2,7 @@ import threading
 from typing import Optional, cast
 
 from flask import Flask, current_app
+from langchain.tools import BaseTool
 
 from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
@@ -17,6 +18,8 @@ from core.rag.models.document import Document
 from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
 from core.rerank.rerank import RerankRunner
+from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
+from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
 from extensions.ext_database import db
 from models.dataset import Dataset, DatasetQuery, DocumentSegment
 from models.dataset import Document as DatasetDocument
@@ -373,3 +376,92 @@ class DatasetRetrieval:
                                                           )
 
                     all_documents.extend(documents)
+
+    def to_dataset_retriever_tool(self, tenant_id: str,
+                                  dataset_ids: list[str],
+                                  retrieve_config: DatasetRetrieveConfigEntity,
+                                  return_resource: bool,
+                                  invoke_from: InvokeFrom,
+                                  hit_callback: DatasetIndexToolCallbackHandler) \
+            -> Optional[list[BaseTool]]:
+        """
+        A dataset tool is a tool that can be used to retrieve information from a dataset
+        :param tenant_id: tenant id
+        :param dataset_ids: dataset ids
+        :param retrieve_config: retrieve config
+        :param return_resource: return resource
+        :param invoke_from: invoke from
+        :param hit_callback: hit callback
+        """
+        tools = []
+        available_datasets = []
+        for dataset_id in dataset_ids:
+            # get dataset from dataset id
+            dataset = db.session.query(Dataset).filter(
+                Dataset.tenant_id == tenant_id,
+                Dataset.id == dataset_id
+            ).first()
+
+            # pass if dataset is not available
+            if not dataset:
+                continue
+
+            # pass if dataset is not available
+            if (dataset and dataset.available_document_count == 0
+                    and dataset.available_document_count == 0):
+                continue
+
+            available_datasets.append(dataset)
+
+        if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
+            # get retrieval model config
+            default_retrieval_model = {
+                'search_method': 'semantic_search',
+                'reranking_enable': False,
+                'reranking_model': {
+                    'reranking_provider_name': '',
+                    'reranking_model_name': ''
+                },
+                'top_k': 2,
+                'score_threshold_enabled': False
+            }
+
+            for dataset in available_datasets:
+                retrieval_model_config = dataset.retrieval_model \
+                    if dataset.retrieval_model else default_retrieval_model
+
+                # get top k
+                top_k = retrieval_model_config['top_k']
+
+                # get score threshold
+                score_threshold = None
+                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
+                if score_threshold_enabled:
+                    score_threshold = retrieval_model_config.get("score_threshold")
+
+                tool = DatasetRetrieverTool.from_dataset(
+                    dataset=dataset,
+                    top_k=top_k,
+                    score_threshold=score_threshold,
+                    hit_callbacks=[hit_callback],
+                    return_resource=return_resource,
+                    retriever_from=invoke_from.to_source()
+                )
+
+                tools.append(tool)
+        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
+            tool = DatasetMultiRetrieverTool.from_dataset(
+                dataset_ids=[dataset.id for dataset in available_datasets],
+                tenant_id=tenant_id,
+                top_k=retrieve_config.top_k or 2,
+                score_threshold=retrieve_config.score_threshold,
+                hit_callbacks=[hit_callback],
+                return_resource=return_resource,
+                retriever_from=invoke_from.to_source(),
+                reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
+                reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
+            )
+
+            tools.append(tool)
+
+        return tools