|
@@ -2,6 +2,7 @@ import threading
|
|
|
from typing import Optional, cast
|
|
|
|
|
|
from flask import Flask, current_app
|
|
|
+from langchain.tools import BaseTool
|
|
|
|
|
|
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
|
@@ -17,6 +18,8 @@ from core.rag.models.document import Document
|
|
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
|
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
|
|
from core.rerank.rerank import RerankRunner
|
|
|
+from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
|
|
+from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
|
|
from extensions.ext_database import db
|
|
|
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
|
|
from models.dataset import Document as DatasetDocument
|
|
@@ -373,3 +376,92 @@ class DatasetRetrieval:
|
|
|
)
|
|
|
|
|
|
all_documents.extend(documents)
|
|
|
+
|
|
|
+ def to_dataset_retriever_tool(self, tenant_id: str,
|
|
|
+ dataset_ids: list[str],
|
|
|
+ retrieve_config: DatasetRetrieveConfigEntity,
|
|
|
+ return_resource: bool,
|
|
|
+ invoke_from: InvokeFrom,
|
|
|
+ hit_callback: DatasetIndexToolCallbackHandler) \
|
|
|
+ -> Optional[list[BaseTool]]:
|
|
|
+ """
|
|
|
+ A dataset tool is a tool that can be used to retrieve information from a dataset
|
|
|
+ :param tenant_id: tenant id
|
|
|
+ :param dataset_ids: dataset ids
|
|
|
+ :param retrieve_config: retrieve config
|
|
|
+ :param return_resource: return resource
|
|
|
+ :param invoke_from: invoke from
|
|
|
+ :param hit_callback: hit callback
|
|
|
+ """
|
|
|
+ tools = []
|
|
|
+ available_datasets = []
|
|
|
+ for dataset_id in dataset_ids:
|
|
|
+ # get dataset from dataset id
|
|
|
+ dataset = db.session.query(Dataset).filter(
|
|
|
+ Dataset.tenant_id == tenant_id,
|
|
|
+ Dataset.id == dataset_id
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ # pass if dataset is not available
|
|
|
+ if not dataset:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # pass if dataset is not available
|
|
|
+ if (dataset and dataset.available_document_count == 0
|
|
|
+ and dataset.available_document_count == 0):
|
|
|
+ continue
|
|
|
+
|
|
|
+ available_datasets.append(dataset)
|
|
|
+
|
|
|
+ if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
|
|
+ # get retrieval model config
|
|
|
+ default_retrieval_model = {
|
|
|
+ 'search_method': 'semantic_search',
|
|
|
+ 'reranking_enable': False,
|
|
|
+ 'reranking_model': {
|
|
|
+ 'reranking_provider_name': '',
|
|
|
+ 'reranking_model_name': ''
|
|
|
+ },
|
|
|
+ 'top_k': 2,
|
|
|
+ 'score_threshold_enabled': False
|
|
|
+ }
|
|
|
+
|
|
|
+ for dataset in available_datasets:
|
|
|
+ retrieval_model_config = dataset.retrieval_model \
|
|
|
+ if dataset.retrieval_model else default_retrieval_model
|
|
|
+
|
|
|
+ # get top k
|
|
|
+ top_k = retrieval_model_config['top_k']
|
|
|
+
|
|
|
+ # get score threshold
|
|
|
+ score_threshold = None
|
|
|
+ score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
|
|
+ if score_threshold_enabled:
|
|
|
+ score_threshold = retrieval_model_config.get("score_threshold")
|
|
|
+
|
|
|
+ tool = DatasetRetrieverTool.from_dataset(
|
|
|
+ dataset=dataset,
|
|
|
+ top_k=top_k,
|
|
|
+ score_threshold=score_threshold,
|
|
|
+ hit_callbacks=[hit_callback],
|
|
|
+ return_resource=return_resource,
|
|
|
+ retriever_from=invoke_from.to_source()
|
|
|
+ )
|
|
|
+
|
|
|
+ tools.append(tool)
|
|
|
+ elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
|
|
+ tool = DatasetMultiRetrieverTool.from_dataset(
|
|
|
+ dataset_ids=[dataset.id for dataset in available_datasets],
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ top_k=retrieve_config.top_k or 2,
|
|
|
+ score_threshold=retrieve_config.score_threshold,
|
|
|
+ hit_callbacks=[hit_callback],
|
|
|
+ return_resource=return_resource,
|
|
|
+ retriever_from=invoke_from.to_source(),
|
|
|
+ reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
|
|
|
+ reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
|
|
|
+ )
|
|
|
+
|
|
|
+ tools.append(tool)
|
|
|
+
|
|
|
+ return tools
|